from abc import ABC, abstractmethod # abstract class
from deprecated import deprecated
import numpy as np
import torch
from torch.distributions.exponential import Exponential
import decode.generic.emitter
from . import structure_prior
[docs]class EmitterSampler(ABC):
"""
Abstract emitter sampler. All implementations / childs must implement a sample method.
"""
def __init__(self, structure: structure_prior.StructurePrior, xy_unit: str, px_size: tuple):
super().__init__()
self.structure = structure
self.px_size = px_size
self.xy_unit = xy_unit
def __call__(self) -> decode.generic.emitter.EmitterSet:
return self.sample()
[docs] @abstractmethod
def sample(self) -> decode.generic.emitter.EmitterSet:
raise NotImplementedError
[docs]class EmitterSamplerFrameIndependent(EmitterSampler):
"""
Simple Emitter sampler. Samples emitters from a structure and puts them all on the same frame, i.e. their
blinking model is not modelled.
"""
def __init__(self, *, structure: structure_prior.StructurePrior, photon_range: tuple,
density: float = None, em_avg: float = None, xy_unit: str, px_size: tuple):
"""
Args:
structure: structure to sample from
photon_range: range of photon value to sample from (uniformly)
density: target emitter density (exactly only when em_avg is None)
em_avg: target emitter average (exactly only when density is None)
xy_unit: emitter xy unit
px_size: emitter pixel size
"""
super().__init__(structure=structure, xy_unit=xy_unit, px_size=px_size)
self._density = density
self.photon_range = photon_range
"""
Sanity Checks.
U shall not pa(rse)! (Emitter Average and Density at the same time!
"""
if (density is None and em_avg is None) or (density is not None and em_avg is not None):
raise ValueError("You must XOR parse either density or emitter average. Not both or none.")
self.area = self.structure.area
if em_avg is not None:
self._em_avg = em_avg
else:
self._em_avg = self._density * self.area
@property
def em_avg(self) -> float:
return self._em_avg
[docs] def sample(self) -> decode.generic.emitter.EmitterSet:
"""
Sample an EmitterSet.
Returns:
EmitterSet:
"""
n = np.random.poisson(lam=self._em_avg)
return self.sample_n(n=n)
[docs] def sample_n(self, n: int) -> decode.generic.emitter.EmitterSet:
"""
Sample 'n' emitters, i.e. the number of emitters is given and is not sampled from the Poisson dist.
Args:
n: number of emitters
"""
if n < 0:
raise ValueError("Negative number of samples is not well-defined.")
xyz = self.structure.sample(n)
phot = torch.randint(*self.photon_range, (n,))
return decode.generic.emitter.EmitterSet(xyz=xyz, phot=phot,
frame_ix=torch.zeros_like(phot).long(),
id=torch.arange(n).long(),
xy_unit=self.xy_unit,
px_size=self.px_size)
[docs]class EmitterSamplerBlinking(EmitterSamplerFrameIndependent):
def __init__(self, *, structure: structure_prior.StructurePrior, intensity_mu_sig: tuple, lifetime: float,
frame_range: tuple, xy_unit: str, px_size: tuple, density=None, em_avg=None, intensity_th=None):
"""
Args:
structure:
intensity_mu_sig:
lifetime:
xy_unit:
px_size:
frame_range: specifies the frame range
density:
em_avg:
intensity_th:
"""
super().__init__(structure=structure,
photon_range=None,
xy_unit=xy_unit,
px_size=px_size,
density=density,
em_avg=em_avg)
self.n_sampler = np.random.poisson
self.frame_range = frame_range
self.intensity_mu_sig = intensity_mu_sig
self.intensity_dist = torch.distributions.normal.Normal(self.intensity_mu_sig[0],
self.intensity_mu_sig[1])
self.intensity_th = intensity_th if intensity_th is not None else 1e-8
self.lifetime_avg = lifetime
self.lifetime_dist = Exponential(1 / self.lifetime_avg) # parse the rate not the scale ...
self.t0_dist = torch.distributions.uniform.Uniform(*self._frame_range_plus)
"""
Determine the total number of emitters. Depends on lifetime, frames and emitters.
(lifetime + 1) because of binning effect.
"""
self._emitter_av_total = self._em_avg * self._num_frames_plus / (self.lifetime_avg + 1)
@property
def _frame_range_plus(self):
"""
Frame range including buffer in front and end to account for build up effects.
"""
return self.frame_range[0] - 3 * self.lifetime_avg, self.frame_range[1] + 3 * self.lifetime_avg
@property
def num_frames(self):
return self.frame_range[1] - self.frame_range[0] + 1
@property
def _num_frames_plus(self):
return self._frame_range_plus[1] - self._frame_range_plus[0] + 1
[docs] def sample(self):
"""
Return sampled EmitterSet in the specified frame range.
Returns:
EmitterSet
"""
n = self.n_sampler(self._emitter_av_total)
loose_em = self.sample_loose_emitter(n=n)
em = loose_em.return_emitterset()
em = em.get_subset_frame(*self.frame_range) # because the simulated frame range is larger
return em
[docs] def sample_n(self, *args, **kwargs):
raise NotImplementedError
[docs] def sample_loose_emitter(self, n) -> decode.generic.emitter.LooseEmitterSet:
"""
Generate loose EmitterSet. Loose emitters are emitters that are not yet binned to frames.
Args:
n: number of 'loose' emitters
Returns:
LooseEmitterSet
"""
xyz = self.structure.sample(n)
"""Draw from intensity distribution but clamp the value so as not to fall below 0."""
intensity = torch.clamp(self.intensity_dist.sample((n,)), self.intensity_th)
"""Distribute emitters in time. Increase the range a bit."""
t0 = self.t0_dist.sample((n,))
ontime = self.lifetime_dist.rsample((n,))
return decode.generic.emitter.LooseEmitterSet(xyz, intensity, ontime, t0, id=torch.arange(n).long(),
xy_unit=self.xy_unit, px_size=self.px_size)
[docs] @classmethod
def parse(cls, param, structure, frames: tuple):
return cls(structure=structure,
intensity_mu_sig=param.Simulation.intensity_mu_sig,
lifetime=param.Simulation.lifetime_avg,
xy_unit=param.Simulation.xy_unit,
px_size=param.Camera.px_size,
frame_range=frames,
density=param.Simulation.density,
em_avg=param.Simulation.emitter_av,
intensity_th=param.Simulation.intensity_th)
[docs]@deprecated(reason="Deprecated in favour of EmitterSamplerFrameIndependent.", version="0.1.dev")
class EmitterPopperSingle:
pass
[docs]@deprecated(reason="Deprecated in favour of EmitterSamplerBlinking.", version="0.1.dev")
class EmitterPopperMultiFrame:
pass