Source code for decode.generic.slicing

import numpy as np
import torch


[docs]def split_sliceable(x, x_ix: torch.Tensor, ix_low: int, ix_high: int): """ Split a sliceable / iterable according to an index into list of elements between lower and upper bound. Not present elements will be filled with empty instances of the iterable itself. This function is mainly used to split the EmitterSet in list of EmitterSets according to its frame index. This function can also be called with arguments x and x_ix being the same. In this case you get a list of indices out which can be used for further indexing. Args: x: sliceable / iterable x_ix (torch.Tensor): index according to which to split ix_low (int): lower bound ix_high (int): upper bound Returns: x_list: list of instances sliced as specified by the x_ix """ """Safety checks""" if x_ix.numel() >= 1 and not isinstance(x_ix, (torch.IntTensor, torch.ShortTensor, torch.LongTensor)): raise TypeError("Index must be subtype of integer.") if len(x_ix) != len(x): raise ValueError("Index and sliceable are not of same length (along first index).") """Sort iterable by x_ix""" x_ix, re = torch.sort(x_ix) x = x[re] """ arange( + 2) because + 1 for pythonic and another + 1 because the loop before return below goes from 0 to on range('len' - 1) """ picker = np.arange(ix_low, ix_high + 2) ix_sort = np.searchsorted(x_ix, picker) return [x[ix_sort[i]:ix_sort[i + 1]] for i in range(ix_sort.shape[0] - 1)]
[docs]def ix_split(ix: torch.Tensor, ix_min: int, ix_max: int): """ Splits an index rather than a sliceable (as above). Might be slower than splitting the sliceable because here we can not just sort once and return the element of interest but must rather return the index. Args: ix (torch.Tensor): index to split ix_min (int): lower limit ix_max (int): upper limit (inclusive) Returns: list of logical(!) indices """ assert ix.dtype in (torch.short, torch.int, torch.long) n = ix_max - ix_min + 1 log_ix = [ix == ix_c for ix_c in range(ix_min, ix_max + 1)] return log_ix, n