본문 바로가기
IT 프로그래밍 관련/딥러닝

Transfer Learning을 위한 코드와 설명

by 지나는행인 2021. 3. 4.
728x90

Transfer Learning은 기존에 잘 훈련되어 있는 모델을 base부분만 가져와 head부분은 따로 만들어서

이용 하는 것이다.

잘 만들어진 모델을 사용하여 새로운 모델을 만들시 학습을 빠르게 하며, 예측을 더 높일 수 있다.

모델을 가져올시 내가 지금 하려는 문제해결 방법과 유사한 문제해결 모델인지 확인해야 한다.

 

 

keras에서 모델을 불러와 base_model로 .   ex)MobileNetV2

base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE, include_top=False, weights="imagenet")

IMG_SHAPE( 128, 128, 3)으로 처리 되어있다.

include_top = False  :  가져오는 모델 상단에 F.C레이어를 연결 할건지 결정하는 파라미터.

                              우리가 만든걸 사용할거기 때문에 False

weights  이미지넷 그대로의 가중치를 쓴다.

 

 

#base_model

base_model.trainable = False

base_model.trainable = False  가져온 basemodel은 훈련하지 않는다.

 

 

#head_model

headModel = baseModel.output

headModel = AveragePooling2D((4,4))(headModel)
headModel = Flatten()(headModel)
headModel = Dense(128, activation='relu')(headModel)
headModel = BatchNormalization()(headModel)
headModel = Dense(64, activation='relu')(headModel)
headModel = Dropout(0.5)(headModel)
headModel = BatchNormalization()(headModel)
headModel = Dense(7, activation='softmax')(headModel)


model = Model(inputs = baseModel.input, outputs = headModel)

basemodel의 아웃풋을 받아 풀링, 플래튼한후, ANN모델링을 한다.

후에 그 둘을 묶는다.

 

이후에는 이 전에 해왔던 것과 같이 compile하고 훈련시키면 끝~

댓글