My Data Story

[딥러닝] 인공신경망 - 서브클래싱 API 구현 본문

Deep Learning

[딥러닝] 인공신경망 - 서브클래싱 API 구현

Hwasss 2021. 9. 4. 22:42
728x90

◈  '인공 신경망' 목차 

1. 퍼셉트론

2. 다층 퍼셉트론

3. 케라스 API 소개

4. 시퀀셜 API 구현

5. 함수형 API 구현

6. 서브클래싱 API 구현

    모델 구현에 유연성을 더하여 여러 가지 동적인 구조를 필요로 할 때 사용할 수 있는 서브클래싱 API 구현 방법을 살펴보자.

7. 모델 저장과 복원, 콜백, 텐서보드

8. 신경망 하이퍼파라미터 튜닝하기


지금까지 살펴본 시퀀셜 API 와 함수형 API는 선언적이었다.

즉 사용할 층과 연결 방식을 먼저 정의한 후, 모델에 데이터를 주입하여 훈련이나 추론을 시작할 수 있다.

 

이 방식에는 장점이 많다.

모델을 저장하거나 복사, 공유하기 쉽다. 또 한 모델의 구조를 출력하거나 분석하기 좋다. 

프레임워크가 크기를 짐작하고 타입을 확인하여 데이터가 주입되기 전에 에러를 일찍 발견할 수 있다. 

전체 모델이 층으로 구성된 정적 프로그램이므로 디버깅하기도 쉽다. 

 

하지만 정적이라는 것이 단점이 된다. 

어떤 모델은 반복문을 포함하고 다양한 크기를 다뤄야 하며 조건문을 가지는 등 여러 가지 동적인 구조를 필요로 한다. 

이런 경우 조금 더 명령형 프로그래밍 스타일이 필요하다면 서브클래싱 API 가 정답이다. 

 

 

1. 서브 클래스 API

간단히 Model 클래스를 상속한 다음 생성자 안에서 필요한 층을 만든다.

그 다음 call() 메서드 안에 수행하려는 연산을 기술한다. 

 

서브클래스 API 를 구현한 다음 예제를 살펴보자.

class WideAndDeepModel(keras.Model) :
      def __init__(self, units=30, activation='relu', **kwargs) :
          super().__init__(**kwargs) #표준 매개 변수를 처리한다.
          self.hidden1 = keras.layers.Dense(units, activation=activation)
          self.hidden2 = kersa.layers.Dense(units, activation=activation)
          self.main_output = keras.layers.Dense(1) 
          self.aux_output = keras.layers.Dense(1)
          
      def call(self, inputs) : 
          input_A, inputs_B = inputs
          hidden1 = self.hidden1(input_B)
          hidden2 = self.hiddend2(hidden1)
          concat = keras.layers.concatenate([input_A, hidden2])
          main_output = self.main_output(concat)
          aux_output = self.aux_output(hidden2)
          
          return main_output, aux_output

model = WideAndDeepModel()

 

이 예제는 함수형 API와 매우 비슷하지만 keras.layers.Input 클래스 객체를 만들 필요가 없다. 

대신 call() 메서드와 input 매개 변수를 사용하여, 생성자에 있는 층 구성과 call() 메서드에 있는 정방향 계산을 분리하였다.

다시 말해, call() 메서드 안에서 for문, if 문 등 원하는 계산을 구현할 수 있다. 

새로운 아이디어를 실험하는 연구자들에게 잘 맞는 훌륭한 API 이다!

 

하지만 유연성이 높아진 대신, 모델 구조가 call() 메서드 안에 숨겨져 있기 때문에 케라스가 쉽게 이를 분석할 수 없다. 즉 모델을 저장하거나 복사할 수 없고   summary() 메서드 호출하면 층의 목록만 나열되고 층간의 연결 정보를 얻을 수 없다. 또한 케라스가 타입과 크기를 미리 확인할 수 없어 실수가 발생하기 쉽다. 

 

높은 유연성이 필요하지 않다면 시퀀셜 API 나 함수 API를 사용하는 것이 좋다.