Source code for decode.utils.model_io
import hashlib
import math
import pathlib
import time
from typing import Union
import torch
[docs]def hash_model(modelfile):
"""
Calculate hash and show it to the user.
(https://www.pythoncentral.io/hashing-files-with-python/)
"""
blocksize = 65536
hasher = hashlib.sha1()
with open(modelfile, 'rb') as afile:
buf = afile.read(blocksize)
while len(buf) > 0:
hasher.update(buf)
buf = afile.read(blocksize)
return hasher.hexdigest()
[docs]class LoadSaveModel:
def __init__(self, model_instance, output_file: (str, pathlib.Path), input_file=None, name_time_interval=(60 * 60),
better_th=1e-6, max_files=3, state_dict_update=None):
self.warmstart_file = pathlib.Path(input_file) if input_file is not None else None
self.output_file = pathlib.Path(output_file) if output_file is not None else None
self.output_file_suffix = -1 # because will be increased to one in the first round
self.model = model_instance
self.name_time_interval = name_time_interval
self._last_saved = 0 # timestamp when it was saved last
self._new_name_time = 0 # timestamp when the name changed last
self._best_metric_val = math.inf
self.better_th = better_th
self.max_files = max_files if ((max_files is not None) or (max_files != -1)) else float('inf')
self.state_dict_update = state_dict_update
def _create_target_folder(self):
"""
Creates the target folder for the network output .pt file, if it does not exists already
"""
p = pathlib.Path(self.output_file)
try:
pathlib.Path(p.parents[0]).mkdir(parents=False, exist_ok=True)
except FileNotFoundError:
raise FileNotFoundError("I will only create the last folder for model saving. But the path you specified "
"lacks more folders or is completely wrong.")
[docs] def load_init(self, device: Union[str, torch.device, None] = None):
"""
Init and warmstart model (if possible) and ship to specified device
Args:
device:
Returns:
"""
if device is None:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
model = self.model
print('Model instantiated.')
if self.warmstart_file is None:
print('Model initialised as specified in the constructor.')
else:
hashv = hash_model(self.warmstart_file)
print(f'Model SHA-1 hash: {hashv}')
model.hash = hashv
state_dict = torch.load(self.warmstart_file, map_location=device)
if self.state_dict_update is not None:
state_dict.update(self.state_dict_update)
model.load_state_dict(state_dict)
print('Loaded pretrained model: {}'.format(self.warmstart_file))
model.eval()
return model
[docs] def save(self, model, metric_val=None):
"""
Save model (conditioned on a better metric if one is provided)
Args:
model:
metric_val:
"""
# create folder if does not exists
self._create_target_folder()
if metric_val is not None:
"""If relative difference to previous value is less than threshold difference, do not save."""
rel_diff = metric_val / self._best_metric_val
if rel_diff <= 1 - self.better_th:
self._best_metric_val = metric_val
else:
return
"""After a certain period, change the suffix."""
if (time.time() > self._new_name_time + self.name_time_interval) or metric_val is None:
self.output_file_suffix += 1
if self.output_file_suffix > self.max_files - 1:
self.output_file_suffix = 0
self._new_name_time = time.time()
"""Determine file name and save."""
fname = pathlib.Path(str(self.output_file.with_suffix('')) + '_' + str(self.output_file_suffix) + '.pt')
torch.save(model.state_dict(), fname)
print('Saved model to file: {}'.format(fname))
self._last_saved = time.time()