Source code for decode.neuralfitter.utils.processing

from operator import itemgetter
from typing import Callable


[docs]class TransformSequence: """ Simple class which calls forward method of all it's components sequentially. """ def __init__(self, components, input_slice=None): """ Args: components: components with forward method input_slice: list of lists which indicate what is the output to the i-th component; e.g. [[0, 1], [0]] means that the first component get's the 0th and 1st element which are input to this instances forward method, the 1st component will get the 0th output of the 0th component. Input slice is ignored when the potential input is not a tuple anyways """ self.com = components self._input_slice = input_slice """Sanity""" if self._input_slice is not None: assert len(self._input_slice) == len(self), "Input slices must be the same number as components"
[docs] @classmethod def parse(cls, components, param: dict, **kwargs): """ If all components implemented a parse method, you can do it globally only once for the whole sequence. Args: components: component reference (unintialised) with forward method param (dict): parameters which are forwarded to the constructor of the components kwargs: arbitrary keyword arguments subject to this class constructor returns: TransformSequence or subclass of it """ return cls([cpt.parse(param) for cpt in components], **kwargs)
def __len__(self): """ Returns the number of components """ return self.com.__len__()
[docs] def forward(self, *x): """ Forwards the input data sequentially through all components Args: *x: arbitrary input data Returns: Any: Output of the last component """ for i, com in enumerate(self.com): if isinstance(x, tuple): if self._input_slice is not None: com_in = itemgetter(*self._input_slice[i])(x) # get specific outputs as input for next com if len(self._input_slice[i]) >= 2: x = com.forward(*com_in) else: x = com.forward(com_in) else: x = com.forward(*x) else: x = com.forward(x) return x
[docs]class ParallelTransformSequence(TransformSequence): """ Simple processing class that forwards data through all of it's components parallelly (not in a hardware sense) and returns a list of the output or combines them if a merging function is specified. A merging function needs to accept a list as an argument. """ def __init__(self, components, input_slice, merger=None): super().__init__(components=components, input_slice=input_slice) self.merger = merger
[docs] def forward(self, *x): out_cache = [None] * len(self) for i, com in enumerate(self.com): if self._input_slice is not None: com_in = itemgetter(*self._input_slice[i])(x) if len(self._input_slice[i]) >= 2: # unpack out_cache[i] = com.forward(*com_in) else: out_cache[i] = com.forward(com_in) else: out_cache[i] = com.forward(*x) if self.merger is not None: return self.merger(out_cache) else: return out_cache
[docs]def wrap_callable(func: Callable): """ Wrapps a callable in a class to provide a forward method. This is mainly a helper to wrap arbitrary functions to fit into the transform sequence as above. Args: func: """ return _TrafoWrapper(func=func)
class _TrafoWrapper: """ Wrapps a callable. Useful because this way they can be element of a Transform Sequence. Only to be used in conjunction with wrap_callable function above. """ def __init__(self, func: Callable): self._wrapped_callable = func def forward(self, *args, **kwargs): return self._wrapped_callable(*args, **kwargs)