Source code for decode.neuralfitter.target_generator

from abc import ABC, abstractmethod
from typing import Union

import torch

from decode.evaluation import predict_dist
from decode.generic import EmitterSet
from decode.generic import process
from decode.generic.process import RemoveOutOfFOV
from decode.simulation.psf_kernel import DeltaPSF


[docs]class TargetGenerator(ABC): def __init__(self, xy_unit='px', ix_low: int = None, ix_high: int = None, squeeze_batch_dim: bool = False): """ Args: xy_unit: Which unit to use for target generator ix_low: lower bound of frame / batch index ix_high: upper bound of frame / batch index squeeze_batch_dim: if lower and upper frame_ix are the same, squeeze out the batch dimension before return """ super().__init__() self.xy_unit = xy_unit self.ix_low = ix_low self.ix_high = ix_high self.squeeze_batch_dim = squeeze_batch_dim self.sanity_check()
[docs] def sanity_check(self): if self.squeeze_batch_dim and self.ix_low != self.ix_high: raise ValueError(f"Automatic batch squeeze can only be used when upper and lower ix fall together.")
def _filter_forward(self, em: EmitterSet, ix_low: (int, None), ix_high: (int, None)): """ Filter emitters and auto-set frame bounds Args: em: ix_low: ix_high: Returns: em (EmitterSet): filtered EmitterSet ix_low (int): lower frame index ix_high (int): upper frame index """ if ix_low is None: ix_low = self.ix_low if ix_high is None: ix_high = self.ix_high """Limit the emitters to the frames of interest and shift the frame index to start at 0.""" em = em.get_subset_frame(ix_low, ix_high, -ix_low) return em, ix_low, ix_high def _postprocess_output(self, x: torch.Tensor) -> torch.Tensor: """ Some simple post-processual steps before return. Args: x: input of size :math:`(N,C,H,W)` """ if self.squeeze_batch_dim: if x.size(0) != 1: raise ValueError("First, batch dimension, not singular.") return x.squeeze(0) return x
[docs] @abstractmethod def forward(self, em: EmitterSet, bg: torch.Tensor = None, ix_low: int = None, ix_high: int = None) -> torch.Tensor: """ Forward calculate target as by the emitters and background. Overwrite the default frame ix boundaries. Args: em: set of emitters bg: background frame ix_low: lower frame index ix_high: upper frame index Returns: target frames """ raise NotImplementedError
[docs]class UnifiedEmbeddingTarget(TargetGenerator): def __init__(self, xextent: tuple, yextent: tuple, img_shape: tuple, roi_size: int, ix_low=None, ix_high=None, squeeze_batch_dim: bool = False): super().__init__(xy_unit='px', ix_low=ix_low, ix_high=ix_high, squeeze_batch_dim=squeeze_batch_dim) self._roi_size = roi_size self.img_shape = img_shape self.mesh_x, self.mesh_y = torch.meshgrid( (torch.arange(-(self._roi_size - 1) // 2, (self._roi_size - 1) // 2 + 1),) * 2) self._delta_psf = DeltaPSF(xextent=xextent, yextent=yextent, img_shape=img_shape) self._em_filter = process.RemoveOutOfFOV(xextent=xextent, yextent=yextent, zextent=None, xy_unit='px') self._bin_ctr_x = self._delta_psf.bin_ctr_x self._bin_ctr_y = self._delta_psf.bin_ctr_y @property def xextent(self): return self._delta_psf.xextent @property def yextent(self): return self._delta_psf.yextent
[docs] @classmethod def parse(cls, param, **kwargs): return cls(xextent=param.Simulation.psf_extent[0], yextent=param.Simulation.psf_extent[1], img_shape=param.Simulation.img_size, roi_size=param.HyperParameter.target_roi_size, **kwargs)
def _get_roi_px(self, batch_ix, x_ix, y_ix): """ For each pixel index (aka bin), get the pixel around the center (i.e. the ROI) Args: batch_ix: x_ix: y_ix: Returns: """ """Pixel pointer relative to the ROI pixels""" xx = self.mesh_x.flatten().to(batch_ix.device) yy = self.mesh_y.flatten().to(batch_ix.device) n_roi = xx.size(0) """ Repeat the indices and add an ID for bookkeeping. The idea here is that for the ix we do 'repeat_interleave' and for the offsets we do repeat, such that they overlap correctly. E.g. 5 5 5 9 9 9 (indices) +1 0 -1 +1 0 -1 (offset) 6 5 4 10 9 8 (final indices) """ batch_ix_roi = batch_ix.repeat_interleave(n_roi) x_ix_roi = x_ix.repeat_interleave(n_roi) y_ix_roi = y_ix.repeat_interleave(n_roi) id = torch.arange(x_ix.size(0)).repeat_interleave(n_roi) """Repeat offsets accordingly and add""" offset_x = xx.repeat(x_ix.size(0)) offset_y = yy.repeat(y_ix.size(0)) x_ix_roi = x_ix_roi + offset_x y_ix_roi = y_ix_roi + offset_y """Limit ROIs by frame dimension""" mask = (x_ix_roi >= 0) * (x_ix_roi < self.img_shape[0]) * \ (y_ix_roi >= 0) * (y_ix_roi < self.img_shape[1]) batch_ix_roi, x_ix_roi, y_ix_roi, offset_x, offset_y, id = batch_ix_roi[mask], x_ix_roi[mask], \ y_ix_roi[mask], \ offset_x[mask], offset_y[mask], \ id[mask] return batch_ix_roi, x_ix_roi, y_ix_roi, offset_x, offset_y, id
[docs] def single_px_target(self, batch_ix, x_ix, y_ix, batch_size): p_tar = torch.zeros((batch_size, *self.img_shape)).to(batch_ix.device) p_tar[batch_ix, x_ix, y_ix] = 1. return p_tar
[docs] def const_roi_target(self, batch_ix_roi, x_ix_roi, y_ix_roi, phot, id, batch_size): phot_tar = torch.zeros((batch_size, *self.img_shape)).to(batch_ix_roi.device) phot_tar[batch_ix_roi, x_ix_roi, y_ix_roi] = phot[id] return phot_tar
[docs] def xy_target(self, batch_ix_roi, x_ix_roi, y_ix_roi, xy, id, batch_size): xy_tar = torch.zeros((batch_size, 2, *self.img_shape)).to(batch_ix_roi.device) xy_tar[batch_ix_roi, 0, x_ix_roi, y_ix_roi] = xy[id, 0] - self._bin_ctr_x[x_ix_roi] xy_tar[batch_ix_roi, 1, x_ix_roi, y_ix_roi] = xy[id, 1] - self._bin_ctr_y[y_ix_roi] return xy_tar
def _filter_forward(self, em: EmitterSet, ix_low: (int, None), ix_high: (int, None)): """ Filter as in abstract class, plus kick out emitters that are outside the frame Args: em: ix_low: ix_high: """ em, ix_low, ix_high = super()._filter_forward(em, ix_low, ix_high) em = self._em_filter.forward(em) # kick outside of frame out return em, ix_low, ix_high
[docs] def forward_(self, xyz: torch.Tensor, phot: torch.Tensor, frame_ix: torch.LongTensor, ix_low: int, ix_high: int) -> torch.Tensor: """Get index of central bin for each emitter.""" x_ix, y_ix = self._delta_psf.search_bin_index(xyz[:, :2]) assert isinstance(frame_ix, torch.LongTensor) """Get the indices of the ROIs around the ctrl pixel""" batch_ix_roi, x_ix_roi, y_ix_roi, offset_x, offset_y, id = self._get_roi_px(frame_ix, x_ix, y_ix) batch_size = ix_high - ix_low + 1 target = torch.zeros((batch_size, 5, *self.img_shape)) target[:, 0] = self.single_px_target(frame_ix, x_ix, y_ix, batch_size) target[:, 1] = self.const_roi_target(batch_ix_roi, x_ix_roi, y_ix_roi, phot, id, batch_size) target[:, 2:4] = self.xy_target(batch_ix_roi, x_ix_roi, y_ix_roi, xyz[:, :2], id, batch_size) target[:, 4] = self.const_roi_target(batch_ix_roi, x_ix_roi, y_ix_roi, xyz[:, 2], id, batch_size) return target
[docs] def forward(self, em: EmitterSet, bg: torch.Tensor = None, ix_low: int = None, ix_high: int = None) -> torch.Tensor: em, ix_low, ix_high = self._filter_forward(em, ix_low, ix_high) # filter em that are out of view target = self.forward_(xyz=em.xyz_px, phot=em.phot, frame_ix=em.frame_ix, ix_low=ix_low, ix_high=ix_high) if bg is not None: target = torch.cat((target, bg.unsqueeze(0).unsqueeze(0)), 1) return self._postprocess_output(target)
[docs]class ParameterListTarget(TargetGenerator): def __init__(self, n_max: int, xextent: tuple, yextent: tuple, ix_low=None, ix_high=None, xy_unit: str = 'px', squeeze_batch_dim: bool = False): """ Target corresponding to the Gausian-Mixture Model Loss. Simply cat all emitter's attributes up to a maximum number of emitters as a list. Args: n_max: maximum number of emitters (should be multitude of what you draw on average) xextent: extent of the emitters in x yextent: extent of the emitters in y ix_low: lower frame index ix_high: upper frame index xy_unit: xy unit squeeze_batch_dim: squeeze batch dimension before return """ super().__init__(xy_unit=xy_unit, ix_low=ix_low, ix_high=ix_high, squeeze_batch_dim=squeeze_batch_dim) self.n_max = n_max self.xextent = xextent self.yextent = yextent self._fov_filter = RemoveOutOfFOV(xextent=xextent, yextent=yextent, xy_unit=xy_unit) def _filter_forward(self, em: EmitterSet, ix_low: (int, None), ix_high: (int, None)): em, ix_low, ix_high = super()._filter_forward(em, ix_low, ix_high) em = self._fov_filter.forward(em) return em, ix_low, ix_high
[docs] def forward(self, em: EmitterSet, bg: torch.Tensor = None, ix_low: int = None, ix_high: int = None): em, ix_low, ix_high = self._filter_forward(em, ix_low, ix_high) n_frames = ix_high - ix_low + 1 """Setup and compute parameter target (i.e. a matrix / tensor in which all params are concatenated).""" param_tar = torch.zeros((n_frames, self.n_max, 4)) mask_tar = torch.zeros((n_frames, self.n_max)).bool() if self.xy_unit == 'px': xyz = em.xyz_px elif self.xy_unit == 'nm': xyz = em.xyz_nm else: raise NotImplementedError """Set number of active elements per frame""" for i in range(n_frames): n_emitter = len(em.get_subset_frame(i, i)) if n_emitter > self.n_max: raise ValueError("Number of actual emitters exceeds number of max. emitters.") mask_tar[i, :n_emitter] = 1 ix = em.frame_ix == i param_tar[i, :n_emitter, 0] = em[ix].phot param_tar[i, :n_emitter, 1:] = xyz[ix] return self._postprocess_output(param_tar), self._postprocess_output(mask_tar), bg
[docs]class DisableAttributes: def __init__(self, attr_ix: Union[None, int, tuple, list]): """ Allows to disable attribute prediction of parameter list target; e.g. when you don't want to predict z. Args: attr_ix: index of the attribute you want to disable (phot, x, y, z). """ self.attr_ix = None # convert to list if attr_ix is None or isinstance(attr_ix, (tuple, list)): self.attr_ix = attr_ix else: self.attr_ix = [attr_ix]
[docs] def forward(self, param_tar, mask_tar, bg): if self.attr_ix is None: return param_tar, mask_tar, bg param_tar[..., self.attr_ix] = 0. return param_tar, mask_tar, bg
[docs] @classmethod def parse(cls, param): return cls(attr_ix=param.HyperParameter.disabled_attributes)
[docs]class FourFoldEmbedding(TargetGenerator): def __init__(self, xextent: tuple, yextent: tuple, img_shape: tuple, rim_size: float, roi_size: int, ix_low=None, ix_high=None, squeeze_batch_dim: bool = False): super().__init__(xy_unit='px', ix_low=ix_low, ix_high=ix_high, squeeze_batch_dim=squeeze_batch_dim) self.xextent_native = xextent self.yextent_native = yextent self.rim = rim_size self.img_shape = img_shape self.roi_size = roi_size self.embd_ctr = UnifiedEmbeddingTarget(xextent=xextent, yextent=yextent, img_shape=img_shape, roi_size=roi_size, ix_low=ix_low, ix_high=ix_high) self.embd_half_x = UnifiedEmbeddingTarget(xextent=(xextent[0] + 0.5, xextent[1] + 0.5), yextent=yextent, img_shape=img_shape, roi_size=roi_size, ix_low=ix_low, ix_high=ix_high) self.embd_half_y = UnifiedEmbeddingTarget(xextent=xextent, yextent=(yextent[0] + 0.5, yextent[1] + 0.5), img_shape=img_shape, roi_size=roi_size, ix_low=ix_low, ix_high=ix_high) self.embd_half_xy = UnifiedEmbeddingTarget(xextent=(xextent[0] + 0.5, xextent[1] + 0.5), yextent=(yextent[0] + 0.5, yextent[1] + 0.5), img_shape=img_shape, roi_size=roi_size, ix_low=ix_low, ix_high=ix_high)
[docs] @classmethod def parse(cls, param, **kwargs): return cls(xextent=param.Simulation.psf_extent[0], yextent=param.Simulation.psf_extent[1], img_shape=param.Simulation.img_size, roi_size=param.HyperParameter.target_roi_size, rim_size=param.HyperParameter.target_train_rim, **kwargs)
@staticmethod def _filter_rim(xy, xy_0, rim, px_size) -> torch.BoolTensor: """ Takes coordinates and checks whether they are close to a pixel border (i.e. within a rim). True if not in rim, false if in rim. Args: xy: xy_0: rim: px_size: Returns: """ """Transform coordinates relative to unit px""" x_rel = (predict_dist.px_pointer_dist(xy[:, 0], xy_0[0], px_size[0]) - xy_0[0]) / px_size[0] y_rel = (predict_dist.px_pointer_dist(xy[:, 1], xy_0[1], px_size[1]) - xy_0[1]) / px_size[1] """Falsify coordinates that are inside the rim""" ix = (x_rel >= rim) * (x_rel < (1 - rim)) * (y_rel >= rim) * (y_rel < (1 - rim)) return ix
[docs] def forward(self, em: EmitterSet, bg: torch.Tensor = None, ix_low: int = None, ix_high: int = None) -> torch.Tensor: em, ix_low, ix_high = self._filter_forward(em, ix_low=ix_low, ix_high=ix_high) """Forward through each and all targets and filter the rim""" ctr = self.embd_ctr.forward(em=em[self._filter_rim(em.xyz_px, (-0.5, -0.5), self.rim, (1., 1.))], bg=None, ix_low=ix_low, ix_high=ix_high) half_x = self.embd_half_x.forward(em=em[self._filter_rim(em.xyz_px, (0., -0.5), self.rim, (1., 1.))], bg=None, ix_low=ix_low, ix_high=ix_high) half_y = self.embd_half_y.forward(em=em[self._filter_rim(em.xyz_px, (-0.5, 0.), self.rim, (1., 1.))], bg=None, ix_low=ix_low, ix_high=ix_high) half_xy = self.embd_half_xy.forward(em=em[self._filter_rim(em.xyz_px, (0., 0.), self.rim, (1., 1.))], bg=None, ix_low=ix_low, ix_high=ix_high) target = torch.cat((ctr, half_x, half_y, half_xy), 1) if bg is not None: target = torch.cat((target, bg.unsqueeze(0).unsqueeze(0)), 1) return self._postprocess_output(target)