Source code for decode.utils.checkpoint
from pathlib import Path
from typing import Union, Optional
import torch
[docs]class CheckPoint:
def __init__(self, path: Union[str, Path]):
"""
Checkpointing intended to resume to an already started training.
Warning:
Checkpointing is not intended for long-term storage of models or other information.
No version compatibility guarantees are given here at all.
Args:
path: filename / path where to dump the checkpoints
"""
self.path = path
self.model_state = None
self.optimizer_state = None
self.lr_sched_state = None
self.step = None
self.log = None
@property
def dict(self):
return {
'step': self.step,
'model_state': self.model_state,
'optimizer_state': self.optimizer_state,
'lr_sched_state': self.lr_sched_state,
'log': self.log
}
[docs] def update(self, model_state: dict, optimizer_state: dict, lr_sched_state: dict, step: int, log=None):
self.model_state = model_state
self.optimizer_state = optimizer_state
self.lr_sched_state = lr_sched_state
self.step = step
self.log = log
[docs] def save(self):
torch.save(self.dict, self.path)
[docs] @classmethod
def load(cls, path: Union[str, Path], path_out: Optional[Union[str, Path]] = None):
ckpt_dict = torch.load(path)
if path_out is None:
path_out = path
ckpt = cls(path=path_out)
ckpt.update(model_state=ckpt_dict['model_state'], optimizer_state=ckpt_dict['optimizer_state'],
lr_sched_state=ckpt_dict['lr_sched_state'], step=ckpt_dict['step'],
log=ckpt_dict['log'] if 'log' in ckpt_dict.keys() else None)
return ckpt
[docs] def dump(self, model_state: dict, optimizer_state: dict, lr_sched_state: dict, step: int, log=None):
"""Updates and saves to file."""
self.update(model_state, optimizer_state, lr_sched_state, step, log)
self.save()