Source code for decode.generic.test_utils

import hashlib
import pathlib
from typing import Union

import torch


[docs]def tens_almeq(a: torch.Tensor, b: torch.Tensor, prec: float = 1e-8, nan: bool = False) -> bool: """ Tests if a and b are equal (i.e. all elements are the same) within a given precision. If both tensors have / are nan, the function will return False unless nan=True. Args: a: first tensor for comparison b: second tensor for comparison prec: precision comparison nan: if true, the function will return true if both tensors are all nan Returns: bool """ if a.type() != b.type(): raise TypeError("Both tensors must be of equal type.") if a.type != torch.FloatTensor: a = a.float() b = b.float() if nan: if torch.isnan(a).all() and torch.isnan(b).all(): return True return torch.all(torch.lt(torch.abs(torch.add(a, -b)), prec)).item()
[docs]def open_n_hash(file: Union[str, pathlib.Path]) -> str: """ Check SHA 256 hash of file Args: file: Returns: str """ if not isinstance(file, pathlib.Path): file = pathlib.Path(file) hash_str = hashlib.sha256(file.read_bytes()).hexdigest() return hash_str
[docs]def file_loadable(path: Union[str, pathlib.Path], reader=None, mode=None, exceptions=None) -> bool: """ Check whether file is present and loadable. This function could be used in a while lood and sleep Example: while not file_loadable(path, ...): time.sleep() """ if not isinstance(path, pathlib.Path): path = pathlib.Path(path) if not path.is_file(): return False # try to actually load the file (or the handle) if reader is not None: try: if mode is not None: reader(path, mode=mode) else: reader(path) return True except exceptions: return False
[docs]def same_weights(model1, model2) -> bool: """Tests whether model1 and 2 have the same weights.""" for p1, p2 in zip(model1.parameters(), model2.parameters()): if p1.data.ne(p2.data).sum() > 0: return False return True