728x90
반응형
model.save()
- 학습 결과 저장 함수
- 모델 형태(architecture)와 파라미터를 저장할 수 있음
- 모델 학습 중간 과정을 저장함으로써 최선의 결과를 가지는 모델을 선택할 수 있음
- 만들어진 모델을 외부 연구자와 공유해 학습 재연성을 향상할 수 있음
- state_dict : 모델의 파라미터를 의미
# 모델의 파라미터를 저장
torch.save(model.state_dict(),
os.path.join(MODEL_PATH, "model.pt"))
# 동일한 형태의 모델을 만들어 파라미터만 로드하여 사용
new_model = TheModelClass()
new_model.load_state_dict(torch.load(os.path.join(MODEL_PATH, "model.pt")))
# 모델 architecture와 함께 저장
torch.save(model, os.path.join(MODEL_PATH, "model_pickle.pt"))
model = torch.load(os.path.join(MODEL_PATH, "model_pickle.pt"))
Checkpoints
- 학습의 중간 결과를 저장 → 최선의 결과를 선택
- early stopping 기법 사용 시 이전 학습의 결과물을 저장할 수 있음
- loss와 metric 값을 지속적으로 확인하고 저장할 수 있음
- 일반적으로 epoch, loss, metric을 함께 저장하여 확인
torch.save({
"epoch" : e,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"loss": epoch_loss
},
# 파일명에서 성능을 확인할 수 있도록 정의해주면 편함
f"saved/checkpoint_model{e}_{epoch_loss/len(dataloader)}_{epoch_acc/len(dataloader)}.pt"
)
728x90
반응형
'Programming' 카테고리의 다른 글
[monitoring tools] Tensorboard (0) | 2022.12.05 |
---|---|
[pytorch] 전이 학습, Transfer Learning (0) | 2022.12.03 |
[pytorch] Dataset & Dataloader (0) | 2022.12.03 |
[Ray] 3-2 Ray Libraries: Ray Train (Key Concept) (0) | 2022.11.02 |
[Ray] 3-1. Ray Libraries : Ray Data (Key Concept) (1) | 2022.11.02 |