본문 바로가기

Programming

[pytorch] 학습 결과 및 모델 저장하기

728x90
반응형

model.save()

  • 학습 결과 저장 함수
  • 모델 형태(architecture)와 파라미터를 저장할 수 있음
  • 모델 학습 중간 과정을 저장함으로써 최선의 결과를 가지는 모델을 선택할 수 있음
  • 만들어진 모델을 외부 연구자와 공유해 학습 재연성을 향상할 수 있음
  • state_dict : 모델의 파라미터를 의미
# 모델의 파라미터를 저장
torch.save(model.state_dict(),
           os.path.join(MODEL_PATH, "model.pt"))

# 동일한 형태의 모델을 만들어 파라미터만 로드하여 사용
new_model = TheModelClass() 
new_model.load_state_dict(torch.load(os.path.join(MODEL_PATH, "model.pt")))


# 모델 architecture와 함께 저장
torch.save(model, os.path.join(MODEL_PATH, "model_pickle.pt"))
model = torch.load(os.path.join(MODEL_PATH, "model_pickle.pt"))

 


 

Checkpoints

  • 학습의 중간 결과를 저장 → 최선의 결과를 선택
  • early stopping 기법 사용 시 이전 학습의 결과물을 저장할 수 있음
  • loss와 metric 값을 지속적으로 확인하고 저장할 수 있음
  • 일반적으로 epoch, loss, metric을 함께 저장하여 확인
  torch.save({
      "epoch" : e,
      "model_state_dict": model.state_dict(),
      "optimizer_state_dict": optimizer.state_dict(),
      "loss": epoch_loss
  },
  # 파일명에서 성능을 확인할 수 있도록 정의해주면 편함
  f"saved/checkpoint_model{e}_{epoch_loss/len(dataloader)}_{epoch_acc/len(dataloader)}.pt"
  )
728x90
반응형