Source code for decode.generic.emitter

import copy
import warnings
from pathlib import Path
from typing import Union, Optional, Iterable

import numpy as np
import torch

import decode.generic.utils
from . import slicing as gutil, test_utils as tutil


[docs]class EmitterSet: """ Class, storing a set of emitters and its attributes. Probably the most commonly used class of this framework. Attributes: xyz: Coordinates of size N x [2,3]. phot: Photon count of size N frame_ix: size N. Index on which the emitter appears. id: size N. Identity the emitter. prob: size N. Probability estimate of the emitter. bg: size N. Background estimate of emitter. xyz_cr: size N x 3. Cramer-Rao estimate of the emitters position. phot_cr: size N. Cramer-Rao estimate of the emitters photon count. bg_cr: size N. Cramer-Rao estimate of the emitters background value. sanity_check: performs a sanity check if true. xy_unit: Unit of the x and y coordinate. px_size: Pixel size for unit conversion. If not specified, derived attributes (xyz_px and xyz_nm) can not be accessed """ _eq_precision = 1E-8 _power_auto_conversion_attrs = {'xyz_cr': 2, 'xyz_sig': 1} _xy_units = ('px', 'nm') def __init__(self, xyz: torch.Tensor, phot: torch.Tensor, frame_ix: torch.LongTensor, id: torch.LongTensor = None, prob: torch.Tensor = None, bg: torch.Tensor = None, xyz_cr: torch.Tensor = None, phot_cr: torch.Tensor = None, bg_cr: torch.Tensor = None, xyz_sig: torch.Tensor = None, phot_sig: torch.Tensor = None, bg_sig: torch.Tensor = None, sanity_check: bool = True, xy_unit: str = None, px_size: Union[tuple, torch.Tensor] = None): """ Initialises EmitterSet of :math:`N` emitters. Args: xyz: Coordinates of size :math:`(N,3)` phot: Photon count of size :math:`N` frame_ix: Index on which the emitter appears. Must be integer type. Size :math:`N` id: Identity the emitter. Must be tensor integer type and the same type as frame_ix. Size :math:`N` prob: Probability estimate of the emitter. Size :math:`N` bg: Background estimate of emitter. Size :math:`N` xyz_cr: Cramer-Rao estimate of the emitters position. Size :math:`(N,3)` phot_cr: Cramer-Rao estimate of the emitters photon count. Size :math:`N` bg_cr: Cramer-Rao estimate of the emitters background value. Size :math:`N` xyz_sig: Error estimate of the emitters position. Size :math:`(N,3)` phot_sig: Error estimate of the photon count. Size :math:`N` bg_sig: Error estimate of the background value. Size :math:`N` sanity_check: performs a sanity check. xy_unit: Unit of the x and y coordinate. px_size: Pixel size for unit conversion. If not specified, derived attributes (xyz_px and xyz_nm) may not be accessed because one can not convert units without pixel size. """ self.xyz = None self.phot = None self.frame_ix = None self.id = None self.prob = None self.bg = None # Cramer-Rao values self.xyz_cr = None self.phot_cr = None self.bg_cr = None # Error estimates self.xyz_sig = None self.phot_sig = None self.bg_sig = None self._set_typed(xyz=xyz, phot=phot, frame_ix=frame_ix, id=id, prob=prob, bg=bg, xyz_cr=xyz_cr, phot_cr=phot_cr, bg_cr=bg_cr, xyz_sig=xyz_sig, phot_sig=phot_sig, bg_sig=bg_sig) self._sorted = False # get at least one_dim tensors at_least_one_dim(self.xyz, self.phot, self.frame_ix, self.id, self.prob, self.bg, self.xyz_cr, self.phot_cr, self.bg_cr) self.xy_unit = xy_unit self.px_size = px_size if self.px_size is not None: if not isinstance(self.px_size, torch.Tensor): self.px_size = torch.Tensor(self.px_size) if sanity_check: self._sanity_check() @property def xyz_px(self) -> torch.Tensor: """ Returns xyz in pixel coordinates and performs respective transformations if needed. """ return self._pxnm_conversion(self.xyz, in_unit=self.xy_unit, tar_unit='px', power=1.) @xyz_px.setter def xyz_px(self, xyz): self.xyz = xyz self.xy_unit = 'px' @property def xyz_nm(self) -> torch.Tensor: """ Returns xyz in nanometres and performs respective transformations if needed. """ return self._pxnm_conversion(self.xyz, in_unit=self.xy_unit, tar_unit='nm', power=1.) @xyz_nm.setter def xyz_nm(self, xyz): # xyz in nanometres self.xyz = xyz self.xy_unit = 'nm' @property def xyz_scr(self) -> torch.Tensor: """ Square-Root cramer rao of xyz. """ return self.xyz_cr.sqrt() @property def xyz_scr_px(self) -> torch.Tensor: """ Square-Root cramer rao of xyz in px units. """ return self.xyz_cr_px.sqrt() @property def xyz_scr_nm(self) -> torch.Tensor: return self.xyz_cr_nm.sqrt() @property def xyz_sig_tot_nm(self) -> torch.Tensor: return (self.xyz_sig_nm ** 2).sum(1).sqrt() @property def xyz_sig_weighted_tot_nm(self) -> torch.Tensor: return self._calc_sigma_weighted_total(self.xyz_sig_nm, self.dim() == 3) @property def phot_scr(self) -> torch.Tensor: # sqrt cramer-rao of photon count return self.phot_cr.sqrt() @property def bg_scr(self) -> torch.Tensor: # sqrt cramer-rao of bg count return self.bg_cr.sqrt() def __getattr__(self, item): """Auto unit convert a couple of attributes by trailing unit specification""" attr_base = item.rstrip('_nm').rstrip('_px') if attr_base in self._power_auto_conversion_attrs.keys(): tar_unit = item[-2:] if tar_unit not in ('nm', 'px'): raise NotImplementedError return self._pxnm_conversion( getattr(self, attr_base), in_unit=self.xy_unit, tar_unit=tar_unit, power=self._power_auto_conversion_attrs[attr_base]) raise AttributeError @property def meta(self) -> dict: """Return metadata of EmitterSet""" return { 'xy_unit': self.xy_unit, 'px_size': self.px_size } @property def data(self) -> dict: """Return intrinsic data (without metadata)""" return { 'xyz': self.xyz, 'phot': self.phot, 'frame_ix': self.frame_ix, 'id': self.id, 'prob': self.prob, 'bg': self.bg, 'xyz_cr': self.xyz_cr, 'phot_cr': self.phot_cr, 'bg_cr': self.bg_cr, 'xyz_sig': self.xyz_sig, 'phot_sig': self.phot_sig, 'bg_sig': self.bg_sig, }
[docs] def dim(self) -> int: """ Returns dimensionality of coordinates. If z is 0 everywhere, it returns 2, else 3. Note: Does not do PCA or other sophisticated things. """ if (self.xyz[:, 2] == 0).all(): return 2 else: return 3
[docs] def to_dict(self) -> dict: """ Returns dictionary representation of this EmitterSet so that the keys and variables correspond to what an EmitterSet would be initialised. Example: >>> em_dict = em.to_dict() # any emitterset instance >>> em_clone = EmitterSet(**em_dict) # returns a clone of the emitterset """ em_dict = {} em_dict.update(self.meta) em_dict.update(self.data) return em_dict
[docs] def save(self, file: Union[str, Path]): """ Pickle save's the dictionary of this instance. No legacy guarantees given. Should only be used for short-term storage. Args: file: path where to save """ from decode.utils import emitter_io if not isinstance(file, Path): file = Path(file) if file.suffix == '.pt': emitter_io.save_torch(file, self.data, self.meta) elif file.suffix in ('.h5', '.hdf5'): emitter_io.save_h5(file, self.data, self.meta) elif file.suffix == '.csv': emitter_io.save_csv(file, self.data, self.meta) else: raise ValueError
[docs] @staticmethod def load(file: Union[str, Path]): """ Loads the set of emitters which was saved by the 'save' method. Args: file: path to the emitterset Returns: EmitterSet """ from decode.utils import emitter_io file = Path(file) if not isinstance(file, Path) else file if file.suffix == '.pt': em_dict, meta, _ = emitter_io.load_torch(file) elif file.suffix in ('.h5', '.hdf5'): em_dict, meta, _ = emitter_io.load_h5(file) elif file.suffix == '.csv': warnings.warn("For .csv files, implicit usage of .load() is discouraged. " "Please use 'decode.utils.emitter_io.load_csv' explicitly.") em_dict, meta, _ = emitter_io.load_csv(file) else: raise ValueError em_dict.update(meta) return EmitterSet(**em_dict)
def _set_typed(self, xyz, phot, frame_ix, id, prob, bg, xyz_cr, phot_cr, bg_cr, xyz_sig, phot_sig, bg_sig): """ Sets the attributes in the correct type and with default argument if None """ if xyz.dtype not in (torch.float, torch.double, torch.half): raise ValueError("XYZ coordinates must be float type.") else: f_type = xyz.dtype if frame_ix.dtype not in (torch.int16, torch.int32, torch.int64): raise ValueError(f"Frame index must be integer type and not {frame_ix.dtype}.") if id is not None and (id.dtype not in (torch.int16, torch.int32, torch.int64)): raise ValueError(f"ID must be None or integer type not {id.dtype}.") i_type = torch.int64 # make xyz always 3 dim xyz = xyz if xyz.shape[1] == 3 else torch.cat((xyz, torch.zeros_like(xyz[:, [0]])), 1) num_input = int(xyz.shape[0]) if xyz.shape[0] != 0 else 0 """Set values""" if num_input != 0: self.xyz = xyz self.phot = phot.type(f_type) self.frame_ix = frame_ix.type(i_type) # Optionals self.id = id if id is not None else -torch.ones_like(frame_ix) self.prob = prob.type(f_type) if prob is not None else torch.ones_like(frame_ix).type(f_type) self.bg = bg.type(f_type) if bg is not None else float('nan') * torch.ones_like(frame_ix).type(f_type) self.xyz_cr = xyz_cr.type(f_type) if xyz_cr is not None else float('nan') * torch.ones_like(self.xyz) self.phot_cr = phot_cr.type(f_type) if phot_cr is not None else float('nan') * torch.ones_like(self.phot) self.bg_cr = bg_cr.type(f_type) if bg_cr is not None else float('nan') * torch.ones_like(self.bg) self.xyz_sig = xyz_sig.type(f_type) if xyz_sig is not None else float('nan') * torch.ones_like(self.xyz) self.phot_sig = phot_sig.type(f_type) if phot_sig is not None else float('nan') * torch.ones_like(self.phot) self.bg_sig = bg_sig.type(f_type) if bg_sig is not None else float('nan') * torch.ones_like(self.bg) else: self.xyz = torch.zeros((0, 3)).type(f_type) self.phot = torch.zeros((0,)).type(f_type) self.frame_ix = torch.zeros((0,)).type(i_type) # Optionals self.id = -torch.ones((0,)).type(i_type) self.prob = torch.ones((0,)).type(f_type) self.bg = float('nan') * torch.ones_like(self.prob) self.xyz_cr = float('nan') * torch.ones((0, 3)).type(f_type) self.phot_cr = float('nan') * torch.ones_like(self.prob) self.bg_cr = float('nan') * torch.ones_like(self.bg) self.xyz_sig = float('nan') * torch.ones((0, 3)).type(f_type) self.phot_sig = float('nan') * torch.ones_like(self.prob) self.bg_sig = float('nan') * torch.ones_like(self.bg) def _inplace_replace(self, em): """ Inplace replacement of this self instance. Does not work for inherited methods ... Args: em: other EmitterSet instance that should replace self """ self.__init__(**em.to_dict(), sanity_check=False) def _sanity_check(self, check_uniqueness=False): """ Performs several integrity tests on the EmitterSet. Args: check_uniqueness: (bool) check the uniqueness of the ID Returns: (bool) sane or not sane """ if not same_shape_tensor(0, self.xyz, self.phot, self.frame_ix, self.id, self.bg, self.xyz_cr, self.phot_cr, self.bg_cr): raise ValueError("Coordinates, photons, frame ix, id and prob are not of equal shape in 0th dimension.") if not same_dim_tensor(torch.ones(1), self.phot, self.prob, self.frame_ix, self.id): raise ValueError("Expected photons, probability frame index and id to be 1D.") # Motivate the user to specify an xyz unit. if len(self) > 0: if self.xy_unit is None: warnings.warn("No xyz unit specified. No guarantees given ...") else: if self.xy_unit not in self._xy_units: raise ValueError("XY unit not supported.") # check uniqueness of identity (ID) if check_uniqueness: if torch.unique(self.id).numel() != self.id.numel(): raise ValueError("IDs are not unique.") return True def __len__(self): """ Implements length of EmitterSet. Length of EmitterSet is number of rows of xyz container. Returns: (int) length of EmitterSet """ return int(self.xyz.shape[0]) if self.xyz.shape[0] != 0 else 0 def __str__(self): """ Friendly representation of EmitterSet Returns: (string) representation of this class """ print_str = f"EmitterSet" \ f"\n::num emitters: {len(self)}" if len(self) >= 1: print_str += f"\n::xy unit: {self.xy_unit}" print_str += f"\n::px size: {self.px_size}" print_str += f"\n::frame range: {self.frame_ix.min().item()} - {self.frame_ix.max().item()}" \ f"\n::spanned volume: {self.xyz.min(0)[0].numpy()} - {self.xyz.max(0)[0].numpy()}" return print_str def __add__(self, other): return self.cat((self, other), None, None) def __iadd__(self, other): self._inplace_replace(self + other) return self def __eq__(self, other) -> bool: """ Implements equalness check. Returns true if all attributes are the same and in the same order. If it fails, you may want to sort by the ID first and then check again. Args: other: (emitterset) Returns: true if as stated above. """ def check_em_dict_equality(em_a: dict, em_b: dict) -> bool: for k in em_a.keys(): if not tutil.tens_almeq(em_a[k], em_b[k], nan=True): return False return True if not self.eq_attr(other): return False if not check_em_dict_equality(self.data, other.data): return False return True
[docs] def eq_attr(self, other) -> bool: """ Tests whether the meta attributes (xy_unit and px size) are the same Args: other: the EmitterSet to compare to """ if self.px_size is None: if other.px_size is not None: return False elif not (self.px_size == other.px_size).all(): return False if not self.xy_unit == other.xy_unit: return False return True
def __iter__(self): """ Implements iterator bookkeeping. Returns: (self) """ self.n = 0 return self def __next__(self): """ Implements next element in iterator method Returns: (EmitterSet) next element of iterator """ if self.n <= len(self) - 1: self.n += 1 return self._get_subset(self.n - 1) else: raise StopIteration def __getitem__(self, item): """ Implements array indexing for this class. Args: item: (int), or indexing Returns: EmitterSet """ if isinstance(item, int) and item >= len(self): raise IndexError(f"Index {item} out of bounds of EmitterSet of size {len(self)}") return self._get_subset(item) def __setitem__(self, key, value): raise NotImplementedError
[docs] def clone(self): """ Returns a deep copy of this EmitterSet. Returns: EmitterSet """ return copy.deepcopy(self)
def _calc_sigma_weighted_total(self, xyz_sigma_nm, use_3d): x_sig_var = torch.var(xyz_sigma_nm[:, 0]) y_sig_var = torch.var(xyz_sigma_nm[:, 1]) tot_var = xyz_sigma_nm[:, 0] ** 2 + (torch.sqrt(x_sig_var / y_sig_var) * xyz_sigma_nm[:, 1]) ** 2 if use_3d: z_sig_var = torch.var(xyz_sigma_nm[:, 2]) tot_var += (torch.sqrt(x_sig_var / z_sig_var) * xyz_sigma_nm[:, 2]) ** 2 return torch.sqrt(tot_var)
[docs] @staticmethod def cat(emittersets: Iterable, remap_frame_ix: Union[None, torch.Tensor] = None, step_frame_ix: int = None): """ Concatenate multiple emittersets into one emitterset which is returned. Optionally modify the frame indices by the arguments. Args: emittersets: iterable of emittersets to be concatenated remap_frame_ix: new index of the 0th frame of each iterable step_frame_ix: step size between 0th frame of each iterable Returns: concatenated emitters """ meta = [] data = [] for em in emittersets: meta.append(em.meta) data.append(em.data) n_chunks = len(data) if remap_frame_ix is not None and step_frame_ix is not None: raise ValueError("You cannot specify remap frame ix and step frame ix at the same time.") elif remap_frame_ix is not None: shift = remap_frame_ix.clone() elif step_frame_ix is not None: shift = torch.arange(0, n_chunks) * step_frame_ix else: shift = torch.zeros(n_chunks).int() # apply shift for d, s in zip(data, shift): d['frame_ix'] = d['frame_ix'] + s # list of dicts to dict of lists data = {k: torch.cat([x[k] for x in data], 0) for k in data[0]} # meta = {k: [x[k] for x in meta] for k in meta[0]} # px_size and xy unit is taken from the first element that is not None xy_unit = None px_size = None for m in meta: if m['xy_unit'] is not None: xy_unit = m['xy_unit'] break for m in meta: if m['px_size'] is not None: px_size = m['px_size'] break return EmitterSet(xy_unit=xy_unit, px_size=px_size, **data)
[docs] def sort_by_frame_(self): """ Inplace sort this emitterset by its frame index. """ em = self.sort_by_frame() self._inplace_replace(em)
[docs] def sort_by_frame(self): """ Sort a deepcopy of this emitterset and return it. Returns: Sorted copy of this emitterset """ _, ix = self.frame_ix.sort() em = self[ix] em._sorted = True return em
def _get_subset(self, ix): """ Returns subset of emitterset. Implementation of __getitem__ and __next__ methods. Args: ix: (int, list) integer index or list of indices Returns: (EmitterSet) """ if isinstance(ix, int): ix = [ix] # PyTorch single element support if not isinstance(ix, torch.BoolTensor) and isinstance(ix, torch.Tensor) and ix.numel() == 1: ix = [int(ix)] # Todo: Check for numpy boolean array if isinstance(ix, (np.ndarray, np.generic)) and ix.size == 1: # numpy support ix = [int(ix)] return EmitterSet(xyz=self.xyz[ix], phot=self.phot[ix], frame_ix=self.frame_ix[ix], id=self.id[ix], prob=self.prob[ix], bg=self.bg[ix], xyz_sig=self.xyz_sig[ix], phot_sig=self.phot_sig[ix], bg_sig=self.bg_sig[ix], xyz_cr=self.xyz_cr[ix], phot_cr=self.phot_cr[ix], bg_cr=self.bg_cr[ix], sanity_check=False, xy_unit=self.xy_unit, px_size=self.px_size)
[docs] def get_subset_frame(self, frame_start, frame_end, frame_ix_shift=None): """ Returns emitters that are in the frame range as specified. Args: frame_start: (int) lower frame index limit frame_end: (int) upper frame index limit (including) frame_ix_shift: Returns: """ ix = (self.frame_ix >= frame_start) * (self.frame_ix <= frame_end) em = self[ix] if not frame_ix_shift: return em elif len(em) != 0: # only shift if there is actually something em.frame_ix += frame_ix_shift return em
@property def single_frame(self) -> bool: """ Check if all emitters are on the same frame. Returns: bool """ return True if torch.unique(self.frame_ix).shape[0] == 1 else False
[docs] def chunks(self, chunks: int): """ Splits the EmitterSet into (almost) equal chunks Args: chunks (int): number of splits Returns: list: of emittersets """ n = len(self) l = self k = chunks # https://stackoverflow.com/questions/2130016/splitting-a-list-into-n-parts-of-approximately-equal-length/37414115#37414115 return [l[i * (n // k) + min(i, n % k):(i + 1) * (n // k) + min(i + 1, n % k)] for i in range(k)]
[docs] def filter_by_sigma(self, fraction: float, dim: Optional[int] = None, return_low=True): """ Filter by sigma values. Returns EmitterSet. Args: fraction: relative fraction of emitters remaining after filtering. Ranges from 0. to 1. dim: 2 or 3 for taking into account z. If None, it will be autodetermined. return_low: if True return the fraction of emitter with the lowest sigma values. if False return the (1-fraction) with the highest sigma values. """ if dim is None: is_3d = False if self.dim() == 2 else True else: is_3d = False if dim == 2 else True if fraction == 1.: return self xyz_sig = self.xyz_sig x_sig_var = torch.var(xyz_sig[:, 0]) y_sig_var = torch.var(xyz_sig[:, 1]) z_sig_var = torch.var(xyz_sig[:, 2]) tot_var = xyz_sig[:, 0] ** 2 + (torch.sqrt(x_sig_var / y_sig_var) * xyz_sig[:, 1]) ** 2 if is_3d: tot_var += (np.sqrt(x_sig_var / z_sig_var) * xyz_sig[:, 2]) ** 2 max_s = np.percentile(tot_var.cpu().numpy(), fraction * 100.) if return_low: filt_sig = torch.where(tot_var < max_s) else: filt_sig = torch.where(tot_var > max_s) return self[filt_sig]
[docs] def hist_detection(self) -> dict: """ Compute hist for detection associated attributes. """ return { 'prob': np.histogram(self.prob), 'sigma_x': np.histogram(self.xyz_sig[:, 0]), 'sigma_y': np.histogram(self.xyz_sig[:, 1]), 'sigma_z': np.histogram(self.xyz_sig[:, 2]), }
[docs] def split_in_frames(self, ix_low: int = 0, ix_up: int = None) -> list: """ Splits a set of emitters in a list of emittersets based on their respective frame index. Args: ix_low: (int, 0) lower bound ix_up: (int, None) upper bound Returns: list """ """The first frame is assumed to be 0. If it's negative go to the lowest negative.""" ix_low = ix_low if ix_low is not None else self.frame_ix.min().item() ix_up = ix_up if ix_up is not None else self.frame_ix.max().item() return gutil.split_sliceable(x=self, x_ix=self.frame_ix, ix_low=ix_low, ix_high=ix_up)
def _pxnm_conversion(self, xyz, in_unit, tar_unit, power: float = 1.): if in_unit is None: raise ValueError("Conversion not possible if unit not specified.") if in_unit == tar_unit: return xyz elif in_unit == 'nm' and tar_unit == 'px': """px check needs to happen here, because in _convert_coordinates, factor is an optional argument.""" if self.px_size is None: raise ValueError("Conversion not possible if px size is not specified.") return self._convert_coordinates(factor=1 / self.px_size ** power, xyz=xyz) elif in_unit == 'px' and tar_unit == 'nm': if self.px_size is None: raise ValueError("Conversion not possible if px size is not specified.") return self._convert_coordinates(factor=self.px_size ** power, xyz=xyz) else: raise ValueError("Unsupported conversion.") def _convert_coordinates(self, factor=None, shift=None, axis=None, xyz=None): """ Convert coordinates. Order: factor -> shift -> axis Args: factor: (torch.Tensor, None) shift: (torch.Tensor, None) axis: (list) xyz (torch.Tensor, None): use different coordinates (not self.xyz) Returns: xyz (torch.Tensor) modified coordinates. """ if xyz is None: xyz = self.xyz.clone() if factor is not None: if factor.size(0) == 2: factor = torch.cat((factor, torch.tensor([1.])), 0) xyz = xyz * factor.float().unsqueeze(0) if shift is not None: xyz += shift.float().unsqueeze(0) if axis is not None: xyz = xyz[:, axis] return xyz
[docs] def populate_crlb(self, psf, **kwargs): """ Populate the CRLB values by the PSF function. Args: psf (PSF): Point Spread function with CRLB implementation **kwargs: additional arguments to be parsed to the CRLB method Returns: """ crlb, _ = psf.crlb(self.xyz, self.phot, self.bg, **kwargs) self.xyz_cr = crlb[:, :3] self.phot_cr = crlb[:, 3] self.bg_cr = crlb[:, 4]
[docs]class RandomEmitterSet(EmitterSet): """ A helper calss when we only want to provide a number of emitters. """ def __init__(self, num_emitters: int, extent: float = 32, xy_unit: str = 'px', px_size: tuple = None): xyz = torch.rand((num_emitters, 3)) * extent super().__init__(xyz, torch.ones_like(xyz[:, 0]), torch.zeros_like(xyz[:, 0]).long(), xy_unit=xy_unit, px_size=px_size) def _inplace_replace(self, em): super().__init__(**em.to_dict(), sanity_check=False)
[docs]class CoordinateOnlyEmitter(EmitterSet): """ A helper class when we only want to provide xyz, but not photons and frame_ix. Useful for testing. Photons will be tensor of 1, frame_ix tensor of 0. """ def __init__(self, xyz: torch.Tensor, xy_unit: str = None, px_size=None): """ :param xyz: (torch.tensor) N x 2, N x 3 """ super().__init__(xyz, torch.ones_like(xyz[:, 0]), torch.zeros_like(xyz[:, 0]).int(), xy_unit=xy_unit, px_size=px_size) def _inplace_replace(self, em): super().__init__(**em.to_dict(), sanity_check=False)
[docs]class EmptyEmitterSet(CoordinateOnlyEmitter): """An empty emitter set.""" def __init__(self, xy_unit=None, px_size=None): super().__init__(torch.zeros((0, 3)), xy_unit=xy_unit, px_size=px_size) def _inplace_replace(self, em): super().__init__(**em.to_dict())
[docs]class LooseEmitterSet: """ Related to the standard EmitterSet. However, here we do not specify a frame_ix but rather a (non-integer) initial point in time where the emitter starts to blink and an on-time. Attributes: xyz (torch.Tensor): coordinates. Dimension: N x 3 intensity (torch.Tensor): intensity, i.e. photon flux per time unit. Dimension N id (torch.Tensor, int): identity of the emitter. Dimension: N t0 (torch.Tensor, float): initial blink event. Dimension: N ontime (torch.Tensor): duration in frame-time units how long the emitter blinks. Dimension N xy_unit (string): unit of the coordinates """ def __init__(self, xyz: torch.Tensor, intensity: torch.Tensor, ontime: torch.Tensor, t0: torch.Tensor, xy_unit: str, px_size, id: torch.Tensor = None, sanity_check=True): """ Args: xyz (torch.Tensor): coordinates. Dimension: N x 3 intensity (torch.Tensor): intensity, i.e. photon flux per time unit. Dimension N t0 (torch.Tensor, float): initial blink event. Dimension: N ontime (torch.Tensor): duration in frame-time units how long the emitter blinks. Dimension N id (torch.Tensor, int, optional): identity of the emitter. Dimension: N xy_unit (string): unit of the coordinates """ """If no ID specified, give them one.""" if id is None: id = torch.arange(xyz.shape[0]) self.xyz = xyz self.xy_unit = xy_unit self.px_size = px_size self._phot = None self.intensity = intensity self.id = id self.t0 = t0 self.ontime = ontime if sanity_check: self.sanity_check()
[docs] def sanity_check(self): """Check IDs""" if self.id.unique().numel() != self.id.numel(): raise ValueError("IDs are not unique.") """Check xyz""" if self.xyz.dim() != 2 or self.xyz.size(1) != 3: raise ValueError("Wrong xyz dimension.") """Check intensity""" if (self.intensity < 0).any(): raise ValueError("Negative intensity values encountered.") """Check timings""" if (self.ontime < 0).any(): raise ValueError("Negative ontime encountered.")
@property def te(self): # end time return self.t0 + self.ontime def _distribute_framewise(self): """ Distributes the emitters framewise and prepares them for EmitterSet format. Returns: xyz_ (torch.Tensor): coordinates phot_ (torch.Tensor): photon count frame_ (torch.Tensor): frame indices (the actual distribution) id_ (torch.Tensor): identities """ frame_start = torch.floor(self.t0).long() frame_last = torch.floor(self.te).long() frame_count = (frame_last - frame_start).long() frame_count_full = frame_count - 2 ontime_first = torch.min(self.te - self.t0, frame_start + 1 - self.t0) ontime_last = torch.min(self.te - self.t0, self.te - frame_last) """Repeat by full-frame duration""" # kick out everything that has no full frame_duration ix_full = frame_count_full >= 0 xyz_ = self.xyz[ix_full, :] flux_ = self.intensity[ix_full] id_ = self.id[ix_full] frame_start_full = frame_start[ix_full] frame_dur_full_clean = frame_count_full[ix_full] xyz_ = xyz_.repeat_interleave(frame_dur_full_clean + 1, dim=0) phot_ = flux_.repeat_interleave(frame_dur_full_clean + 1, dim=0) # because intensity * 1 = phot id_ = id_.repeat_interleave(frame_dur_full_clean + 1, dim=0) # because 0 is first occurence frame_ix_ = frame_start_full.repeat_interleave(frame_dur_full_clean + 1, dim=0) \ + decode.generic.utils.cum_count_per_group(id_) + 1 """First frame""" # first xyz_ = torch.cat((xyz_, self.xyz), 0) phot_ = torch.cat((phot_, self.intensity * ontime_first), 0) id_ = torch.cat((id_, self.id), 0) frame_ix_ = torch.cat((frame_ix_, frame_start), 0) # last (only if frame_last != frame_first ix_with_last = frame_last >= frame_start + 1 xyz_ = torch.cat((xyz_, self.xyz[ix_with_last])) phot_ = torch.cat((phot_, self.intensity[ix_with_last] * ontime_last[ix_with_last]), 0) id_ = torch.cat((id_, self.id[ix_with_last]), 0) frame_ix_ = torch.cat((frame_ix_, frame_last[ix_with_last])) return xyz_, phot_, frame_ix_, id_
[docs] def return_emitterset(self): """ Returns EmitterSet with distributed emitters. The ID is preserved such that localisations coming from the same fluorophore will have the same ID. Returns: EmitterSet """ xyz_, phot_, frame_ix_, id_ = self._distribute_framewise() return EmitterSet(xyz_, phot_, frame_ix_.long(), id_.long(), xy_unit=self.xy_unit, px_size=self.px_size)
[docs]def at_least_one_dim(*args) -> None: """Make tensors at least one dimensional (inplace)""" for arg in args: if arg.dim() == 0: arg.unsqueeze_(0)
[docs]def same_shape_tensor(dim, *args) -> bool: """Test if tensors are of same size in a certain dimension.""" for i in range(args.__len__() - 1): if args[i].size(dim) == args[i + 1].size(dim): continue else: return False return True
[docs]def same_dim_tensor(*args) -> bool: """Test if tensors are of same dimensionality""" for i in range(args.__len__() - 1): if args[i].dim() == args[i + 1].dim(): continue else: return False return True