Maxima's Lab

[Python, Tensorflow] Custom Callback 만들기 (Classification) 본문

Python/Tensorflow

[Python, Tensorflow] Custom Callback 만들기 (Classification)

Minima 2024. 1. 30. 21:57
728x90
SMALL

안녕하세요, 오늘은 Tensorflow 2에서 Classification 모델 학습 시 Callback을 Customize 하는 방법에 대해서 알아보겠습니다.

 

MNIST 데이터 셋을 활용하여, 모델 학습하는 예시 입니다.

 

import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Flatten, Dense
from tensorflow.keras.callbacks import Callback
import numpy as np

class SimpleCallBack(Callback):
    def __init__(self, patience=0):
        super(SimpleCallBack, self).__init__()
        self.best_weights = None

    def on_train_begin(self, logs=None):
        self.best_train_accuracy = 0.0
        self.best_validation_accuracy = 0.0

    def on_epoch_begin(self, epoch, logs=None):
        pass

    def on_epoch_end(self, epoch, logs=None):
        
        epoch_accuracy = logs.get("accuracy")
        epoch_validation_accuracy = logs.get("val_accuracy")

        print(f"[EPOCH : {epoch+1:3d}] >>> Train Accuracy : {epoch_accuracy:.3f}, Validation Accuracy : {epoch_validation_accuracy:.3f}")

        if np.less(self.best_validation_accuracy, epoch_validation_accuracy):
          self.best_weights = self.model.get_weights()
          self.model.save("./mnist_best")

          self.best_validation_accuracy = epoch_validation_accuracy
          print("MNIST Best Model Saved...")


    def on_train_end(self, logs=None):
        print("Train Finished...")

 

위의 코드를 구성하는 함수에 대해서 알아보게습니다.

 

  • on_train_begin : 모델 학습이 시작될 때 실행되는 함수
  • on_epoch_begin : 모델 학습 중 특정 Epoch이 시작될 때 실행되는 함수
  • on_epoch_end : 모델 학습 중 특정 Epoch이 종료될 때 실행되는 함수
  • on_train_end : 모델 학습이 종료될 때 실행되는 함수

 

위의 코드에서 유심히 봐야할 코드는 다음과 같습니다.

 

    def on_epoch_end(self, epoch, logs=None):
        
        epoch_accuracy = logs.get("accuracy")
        epoch_validation_accuracy = logs.get("val_accuracy")

        print(f"[EPOCH : {epoch+1:3d}] >>> Train Accuracy : {epoch_accuracy:.3f}, Validation Accuracy : {epoch_validation_accuracy:.3f}")

        if np.less(self.best_validation_accuracy, epoch_validation_accuracy):
          self.best_weights = self.model.get_weights()
          self.model.save("./mnist_best")

          self.best_validation_accuracy = epoch_validation_accuracy
          print("MNIST Best Model Saved...")

 

위의 코드는 특정 Epoch에 대하여 학습이 완료되었을 때 logs에서 train/validation accuracy를 불러와서 validation accuracy가 개선이 되었을 때 Model Weights을 저장하는 코드입니다.

 


위의 코드들을 활용한 전체 코드에 대해서 보여드리겠습니다.

 

import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Flatten, Dense
from tensorflow.keras.callbacks import Callback
import numpy as np

class SimpleCallBack(Callback):
    def __init__(self, patience=0):
        super(SimpleCallBack, self).__init__()
        self.best_weights = None

    def on_train_begin(self, logs=None):
        self.best_train_accuracy = 0.0
        self.best_validation_accuracy = 0.0

    def on_epoch_begin(self, epoch, logs=None):
        pass

    def on_epoch_end(self, epoch, logs=None):
        
        epoch_accuracy = logs.get("accuracy")
        epoch_validation_accuracy = logs.get("val_accuracy")

        print(f"[EPOCH : {epoch+1:3d}] >>> Train Accuracy : {epoch_accuracy:.3f}, Validation Accuracy : {epoch_validation_accuracy:.3f}")

        if np.less(self.best_validation_accuracy, epoch_validation_accuracy):
          self.best_weights = self.model.get_weights()
          self.model.save("./mnist_best")

          self.best_validation_accuracy = epoch_validation_accuracy
          print("MNIST Best Model Saved...")


    def on_train_end(self, logs=None):
        print("Train Finished...")


# MNIST 데이터 셋 로드
(x_train, y_train), (x_val, y_val) = mnist.load_data()
x_train, x_val = x_train / 255.0, x_val / 255.0


model = Sequential([
    Flatten(input_shape=(28, 28)), 
    Dense(128, activation='relu'), 
    Dense(64, activation='relu'),  
    Dense(10, activation='sigmoid') 
])

# 모델 컴파일
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

simple_callback = SimpleCallBack()

# 모델 학습
model.fit(x_train, y_train, batch_size=16, epochs=100, verbose=0, validation_data = (x_val, y_val), callbacks=[simple_callback])

 

모델 학습 시 출력 결과

 

이상으로, Tensorflow 2에서 Callback을 Customize 하는 방법에 대해서 알아보았습니다.

감사드립니다.

728x90
LIST
Comments