Model Save and Load

yeoni·2023년 6월 27일
0

Tensorflow

목록 보기
13/15

save 함수

  • 모델구조 & weights
  • epoch마다 저정한다면 그때 생성된 모델 -> checkpoints
  • h5, pb, ckpt 등 모델 확장자
model.save("checkpoints/sample/model.h5")
model_loaded = tf.keras.models.load_model("checkpoints/sample/model.h5")

save_weights 함수

  • weights만 저장 하므로, 저장공간이 절약됨.
model.save_weights("checkpoints/sample/model.h5")

# 모델을 만들고 weights를 넣기 -> 학습된 모델로 변환
new_model = build_resnet((32, 32, 3))
new_model.load_weights("checkpoints/sample/model.h5")

Callbacks 함수 사용하기

  • save_best_only=True 이전 현재 성능을 비교해서 좋은 것만 남김
  • save_best_only=False 실제로 서비스로 갔을 때 만족도가 다른 경우가 존재해서 현업에 많이 사용
save_path = 'checkpoints/{epoch:02d}-{val_loss:.2f}.h5'
checkpoint = tf.keras.callbacks.ModelCheckpoint(save_path,
                                                monitor='val_accuracy',
                                                save_best_only=True)
                                                
model.fit(x=train_x,
          y=train_y,
          epochs=1,
          validation_data=(test_x, test_y),
          callbacks=[checkpoint])

pb 형식으로 저장 하기

  • 모델을 protoBuffer 형식으로 저장
  • 모델 & weights
save_path = 'checkpoints/{epoch:02d}-{val_loss:.2f}'
checkpoint = tf.keras.callbacks.ModelCheckpoint(save_path,
                                                monitor='val_accuracy',
                                                save_best_only=True)
                                                
model.fit(x=train_x,
          y=train_y,
          epochs=1,
          validation_data=(test_x, test_y),
          callbacks=[checkpoint])                         

model = tf.saved_model.load("checkpoints/01-2.32") 

Reference
1) 제로베이스 데이터스쿨 강의자료

profile
데이터 사이언스 / just do it

0개의 댓글