본문 바로가기
IT 프로그래밍 관련/딥러닝

learning_curve 함수 설정

by 지나는행인 2021. 3. 4.
728x90
def learning_curve(history, epoch):
  # 정확도 차트
  epoch_range = np.arange(1, epoch+1 )
  plt.plot(epoch_range, history.history['accuracy'])
  plt.plot(epoch_range, history.history['val_accuracy'])
  plt.title('Model Accuracy')
  plt.xlabel('Epoch')
  plt.ylabel('Accuracy')
  plt.legend( ['Train', 'Val'])
  plt.show()
  # loss 차트
  epoch_range = np.arange(1, epoch+1 )
  plt.plot(epoch_range, history.history['loss'])
  plt.plot(epoch_range, history.history['val_loss'])
  plt.title('Model Loss')
  plt.xlabel('Epoch')
  plt.ylabel('Loss')
  plt.legend( ['Train', 'Val'])
  plt.show()
learning_curve(history, 50)

 

learning_curve는 모델이 훈련을 완료하고, 각 epoch마다 accuracy, loss를

 

차트로 나타낸다.

 

history 변수에는 훈련을 끝마친 모델이 저장되어 있고, 그 모델은 epoch 50으로 학습하였다.

 

    

댓글