Source code for decode.neuralfitter.dataset

import time

import torch
from torch.utils.data import Dataset

from decode.generic import emitter


[docs]class SMLMDataset(Dataset): """ SMLM base dataset. """ _pad_modes = (None, 'same') def __init__(self, *, em_proc, frame_proc, bg_frame_proc, tar_gen, weight_gen, frame_window: int, pad: str = None, return_em: bool): """ Init new dataset. Args: em_proc: Emitter processing frame_proc: Frame processing bg_frame_proc: Background frame processing tar_gen: Target generator weight_gen: Weight generator frame_window: number of frames per sample / size of frame window pad: pad mode, applicable for first few, last few frames (relevant when frame window is used) return_em: return target emitter """ super().__init__() self._frames = None self._emitter = None self.em_proc = em_proc self.frame_proc = frame_proc self.bg_frame_proc = bg_frame_proc self.tar_gen = tar_gen self.weight_gen = weight_gen self.frame_window = frame_window self.pad = pad self.return_em = return_em """Sanity""" self.sanity_check() def __len__(self): if self.pad is None: # loosing samples at the border return self._frames.size(0) - self.frame_window + 1 elif self.pad == 'same': return self._frames.size(0)
[docs] def sanity_check(self): """ Checks the sanity of the dataset, if fails, errors are raised. """ if self.pad not in self._pad_modes: raise ValueError(f"Pad mode {self.pad} not available. Available pad modes are {self._pad_modes}.") if self.frame_window is not None and self.frame_window % 2 != 1: raise ValueError(f"Unsupported frame window. Frame window must be odd integered, not {self.frame_window}.")
def _get_frames(self, frames, index): hw = (self.frame_window - 1) // 2 # half window without centre frame_ix = torch.arange(index - hw, index + hw + 1).clamp(0, len(frames) - 1) return frames[frame_ix] def _pad_index(self, index): if self.pad is None: assert index >= 0, "Negative indexing not supported." return index + (self.frame_window - 1) // 2 elif self.pad == 'same': return index def _process_sample(self, frames, tar_emitter, bg_frame): """Process""" if self.frame_proc is not None: frames = self.frame_proc.forward(frames) if self.bg_frame_proc is not None: bg_frame = self.bg_frame_proc.forward(bg_frame) if self.em_proc is not None: tar_emitter = self.em_proc.forward(tar_emitter) if self.tar_gen is not None: target = self.tar_gen.forward(tar_emitter, bg_frame) else: target = None if self.weight_gen is not None: weight = self.weight_gen.forward(tar_emitter, target) else: weight = None return frames, target, weight, tar_emitter def _return_sample(self, frame, target, weight, emitter): if self.return_em: return frame, target, weight, emitter else: return frame, target, weight
[docs]class SMLMStaticDataset(SMLMDataset): """ A simple and static SMLMDataset. Attributes: frame_window (int): width of frame window tar_gen: target generator function frame_proc: frame processing function em_proc: emitter processing / filter function weight_gen: weight generator function return_em (bool): return EmitterSet in getitem method. """ def __init__(self, *, frames, emitter: (None, list, tuple), frame_proc=None, bg_frame_proc=None, em_proc=None, tar_gen=None, bg_frames=None, weight_gen=None, frame_window=3, pad: (str, None) = None, return_em=True): """ Args: frames (torch.Tensor): frames. N x H x W em (list of EmitterSets): ground-truth emitter-sets frame_proc: frame processing function em_proc: emitter processing / filter function tar_gen: target generator function weight_gen: weight generator function frame_window (int): width of frame window return_em (bool): return EmitterSet in getitem method. """ super().__init__(em_proc=em_proc, frame_proc=frame_proc, bg_frame_proc=bg_frame_proc, tar_gen=tar_gen, weight_gen=weight_gen, frame_window=frame_window, pad=pad, return_em=return_em) self._frames = frames self._emitter = emitter self._bg_frames = bg_frames if self._frames is not None and self._frames.dim() != 3: raise ValueError("Frames must be 3 dimensional, i.e. N x H x W.") if self._emitter is not None and not isinstance(self._emitter, (list, tuple)): raise TypeError("Please split emitters in list of emitters by their frame index first.") def __getitem__(self, ix): """ Get a training sample. Args: ix (int): index Returns: frames (torch.Tensor): processed frames. C x H x W tar (torch.Tensor): target em_tar (optional): Ground truth emitters """ """Pad index, get frames and emitters.""" ix = self._pad_index(ix) tar_emitter = self._emitter[ix] if self._emitter is not None else None frames = self._get_frames(self._frames, ix) bg_frame = self._bg_frames[ix] if self._bg_frames is not None else None frames, target, weight, tar_emitter = self._process_sample(frames, tar_emitter, bg_frame) return self._return_sample(frames, target, weight, tar_emitter)
[docs]class InferenceDataset(SMLMStaticDataset): """ A SMLM dataset without ground truth data. This is dummy wrapper to keep the visual appearance of a separate dataset. """ def __init__(self, *, frames, frame_proc, frame_window): """ Args: frames (torch.Tensor): frames frame_proc: frame processing function frame_window (int): frame window """ super().__init__(frames=frames, emitter=None, frame_proc=frame_proc, bg_frame_proc=None, em_proc=None, tar_gen=None, pad='same', frame_window=frame_window, return_em=False) def _return_sample(self, frame, target, weight, emitter): return frame
[docs]class SMLMLiveDataset(SMLMStaticDataset): """ A SMLM dataset where new datasets is sampleable via the sample() method of the simulation instance. The final processing on frame, emitters and target is done online. """ def __init__(self, *, simulator, em_proc, frame_proc, bg_frame_proc, tar_gen, weight_gen, frame_window, pad, return_em=False): super().__init__(emitter=None, frames=None, em_proc=em_proc, frame_proc=frame_proc, bg_frame_proc=bg_frame_proc, tar_gen=tar_gen, weight_gen=weight_gen, frame_window=frame_window, pad=pad, return_em=return_em) self.simulator = simulator self._bg_frames = None
[docs] def sanity_check(self): super().sanity_check() if self._emitter is not None and not isinstance(self._emitter, (list, tuple)): raise TypeError("EmitterSet shall be stored in list format, where each list item is one target emitter.")
[docs] def sample(self, verbose: bool = False): """ Sample new acquisition, i.e. a whole dataset. Args: verbose: print performance / verification information """ def set_frame_ix(em): # helper function em.frame_ix = torch.zeros_like(em.frame_ix) return em """Sample new dataset.""" t0 = time.time() emitter, frames, bg_frames = self.simulator.sample() if verbose: print(f"Sampled dataset in {time.time() - t0:.2f}s. {len(emitter)} emitters on {frames.size(0)} frames.") """Split Emitters into list of emitters (per frame) and set frame_ix to 0.""" emitter = emitter.split_in_frames(0, frames.size(0) - 1) emitter = [set_frame_ix(em) for em in emitter] self._emitter = emitter self._frames = frames.cpu() self._bg_frames = bg_frames.cpu()
[docs]class SMLMAPrioriDataset(SMLMLiveDataset): """ A SMLM Dataset where new data is sampled and processed in an 'a priori' manner, i.e. once per epoche. This is useful when processing is fast. Since everything is ready a few number of workers for the dataloader will suffice. """ def __init__(self, *, simulator, em_proc, frame_proc, bg_frame_proc, tar_gen, weight_gen, frame_window, pad, return_em=False): super().__init__(simulator=simulator, em_proc=em_proc, frame_proc=frame_proc, bg_frame_proc=bg_frame_proc, tar_gen=tar_gen, weight_gen=weight_gen, frame_window=frame_window, pad=pad, return_em=return_em) self._em_split = None # emitter splitted in frames self._target = None self._weight = None @property def emitter(self) -> emitter.EmitterSet: """ Return emitter with same indexing frames are returned; i.e. when pad same is used, the emitters frame index is not changed. When pad is None, the respective frame index is corrected for the frame window. """ if self.pad == 'same': return self._emitter elif self.pad is None: hw = (self.frame_window - 1) // 2 # half window without centre # ToDo: Change here when pythonize emitter / frame indexing em = self._emitter.get_subset_frame(hw, len(self)) em.frame_ix -= hw return em else: raise ValueError
[docs] def sample(self, verbose: bool = False): """ Sample new dataset and process them instantaneously. Args: verbose: """ t0 = time.time() emitter, frames, bg_frames = self.simulator.sample() if verbose: print(f"Sampled dataset in {time.time() - t0:.2f}s. {len(emitter)} emitters on {frames.size(0)} frames.") frames, target, weight, tar_emitter = self._process_sample(frames, emitter, bg_frames) self._frames = frames.cpu() self._emitter = tar_emitter self._em_split = tar_emitter.split_in_frames(0, frames.size(0) - 1) self._target, self._weight = target, weight
def __getitem__(self, ix): """ Args: ix: Returns: """ """Pad index, get frames and emitters.""" ix = self._pad_index(ix) return self._return_sample(self._get_frames(self._frames, ix), [tar[ix] for tar in self._target], # target is tuple self._weight[ix] if self._weight is not None else None, self._em_split[ix])
[docs]class SMLMLiveSampleDataset(SMLMDataset): """ A SMLM dataset where a new sample is drawn per (training) sample. """ def __init__(self, *, simulator, ds_len, em_proc, frame_proc, bg_frame_proc, tar_gen, weight_gen, frame_window, return_em=False): super().__init__(em_proc=em_proc, frame_proc=frame_proc, bg_frame_proc=bg_frame_proc, tar_gen=tar_gen, weight_gen=weight_gen, frame_window=frame_window, pad=None, return_em=return_em) self.simulator = simulator self.ds_len = ds_len def __len__(self): return self.ds_len def __getitem__(self, ix): """Sample""" emitter, frames, bg_frames = self.simulator.sample() assert frames.size(0) % 2 == 1 frames = self._get_frames(frames, (frames.size(0) - 1) // 2) tar_emitter = emitter.get_subset_frame(0, 0) # target emitters are the zero ones bg_frames = bg_frames[(self.frame_window - 1) // 2] # ToDo: Beautify this frames, target, weight, tar_emitter = self._process_sample(frames, tar_emitter, bg_frames) return self._return_sample(frames, target, weight, tar_emitter)