decode.neuralfitter package#

Subpackages#

Submodules#

decode.neuralfitter.coord_transform module#

class decode.neuralfitter.coord_transform.Offset2Coordinate(xextent, yextent, img_shape)[source]#

Bases: object

Parameters:
  • xextent (tuple) – extent in x

  • yextent (tuple) – extent in y

  • img_shape (tuple) – image shape

forward(x)[source]#

Forward frames through post-processor.

Parameters:

x (torch.Tensor) – features to be converted. Expecting x/y coordinates in channel index 2, 3. expected shape \((N, C, H, W)\)

Return type:

Tensor

classmethod parse(param)[source]#

decode.neuralfitter.dataset module#

class decode.neuralfitter.dataset.InferenceDataset(*, frames, frame_proc, frame_window)[source]#

Bases: SMLMStaticDataset

Parameters:
  • frames (torch.Tensor) – frames

  • frame_proc – frame processing function

  • frame_window (int) – frame window

class decode.neuralfitter.dataset.SMLMAPrioriDataset(*, simulator, em_proc, frame_proc, bg_frame_proc, tar_gen, weight_gen, frame_window, pad, return_em=False)[source]#

Bases: SMLMLiveDataset

Parameters:
  • frames (torch.Tensor) – frames. N x H x W

  • em (list of EmitterSets) – ground-truth emitter-sets

  • frame_proc – frame processing function

  • em_proc – emitter processing / filter function

  • tar_gen – target generator function

  • weight_gen – weight generator function

  • frame_window (int) – width of frame window

  • return_em (bool) – return EmitterSet in getitem method.

property emitter: EmitterSet#

Return emitter with same indexing frames are returned; i.e. when pad same is used, the emitters frame index is not changed. When pad is None, the respective frame index is corrected for the frame window.

sample(verbose=False)[source]#

Sample new dataset and process them instantaneously.

Parameters:

verbose (bool) –

class decode.neuralfitter.dataset.SMLMDataset(*, em_proc, frame_proc, bg_frame_proc, tar_gen, weight_gen, frame_window, pad=None, return_em)[source]#

Bases: Dataset

Init new dataset.

Parameters:
  • em_proc – Emitter processing

  • frame_proc – Frame processing

  • bg_frame_proc – Background frame processing

  • tar_gen – Target generator

  • weight_gen – Weight generator

  • frame_window (int) – number of frames per sample / size of frame window

  • pad (Optional[str]) – pad mode, applicable for first few, last few frames (relevant when frame window is used)

  • return_em (bool) – return target emitter

return_em#

Sanity

sanity_check()[source]#

Checks the sanity of the dataset, if fails, errors are raised.

class decode.neuralfitter.dataset.SMLMLiveDataset(*, simulator, em_proc, frame_proc, bg_frame_proc, tar_gen, weight_gen, frame_window, pad, return_em=False)[source]#

Bases: SMLMStaticDataset

Parameters:
  • frames (torch.Tensor) – frames. N x H x W

  • em (list of EmitterSets) – ground-truth emitter-sets

  • frame_proc – frame processing function

  • em_proc – emitter processing / filter function

  • tar_gen – target generator function

  • weight_gen – weight generator function

  • frame_window (int) – width of frame window

  • return_em (bool) – return EmitterSet in getitem method.

sample(verbose=False)[source]#

Sample new acquisition, i.e. a whole dataset.

Parameters:

verbose (bool) – print performance / verification information

sanity_check()[source]#

Checks the sanity of the dataset, if fails, errors are raised.

class decode.neuralfitter.dataset.SMLMLiveSampleDataset(*, simulator, ds_len, em_proc, frame_proc, bg_frame_proc, tar_gen, weight_gen, frame_window, return_em=False)[source]#

Bases: SMLMDataset

Init new dataset.

Parameters:
  • em_proc – Emitter processing

  • frame_proc – Frame processing

  • bg_frame_proc – Background frame processing

  • tar_gen – Target generator

  • weight_gen – Weight generator

  • frame_window – number of frames per sample / size of frame window

  • pad – pad mode, applicable for first few, last few frames (relevant when frame window is used)

  • return_em – return target emitter

class decode.neuralfitter.dataset.SMLMStaticDataset(*, frames, emitter, frame_proc=None, bg_frame_proc=None, em_proc=None, tar_gen=None, bg_frames=None, weight_gen=None, frame_window=3, pad=None, return_em=True)[source]#

Bases: SMLMDataset

Parameters:
  • frames (torch.Tensor) – frames. N x H x W

  • em (list of EmitterSets) – ground-truth emitter-sets

  • frame_proc – frame processing function

  • em_proc – emitter processing / filter function

  • tar_gen – target generator function

  • weight_gen – weight generator function

  • frame_window (int) – width of frame window

  • return_em (bool) – return EmitterSet in getitem method.

decode.neuralfitter.de_bias module#

class decode.neuralfitter.de_bias.UniformizeOffset(n_bins)[source]#

Bases: object

Parameters:
  • n_bins (int) – The bias scales with the uncertainty of the localization. Therefore all detections are binned according to their predicted uncertainty.

  • bins. (Detections within different bins are then rescaled seperately. This specifies the number of) –

cdf_get(cdf, val)[source]#
forward(x)[source]#

Rescales x and y offsets (inplace) so that they are distributed uniformly within [-0.5, 0.5] to correct for biased outputs. Forward frames through post-processor. :rtype: Tensor

Parameters:

x (torch.Tensor) – features to be converted. Expecting x/y coordinates in channel index 2, 3 and x/y sigma coordinates in channel index 6, 7 expected shape \((N, C, H, W)\)

histedges_equal_n(x)[source]#
uniformize(x)[source]#

decode.neuralfitter.em_filter module#

Here we provide some filtering on EmitterSets.

class decode.neuralfitter.em_filter.EmitterFilter[source]#

Bases: ABC

abstract forward(em)[source]#

Forwards a set of emitters through the filter implementation

Parameters:

em (EmitterSet) – emitters

Return type:

EmitterSet

class decode.neuralfitter.em_filter.NoEmitterFilter[source]#

Bases: EmitterFilter

The no filter

forward(em)[source]#

Forwards a set of emitters through the filter implementation

Parameters:

em – emitters

class decode.neuralfitter.em_filter.PhotonFilter(th)[source]#

Bases: EmitterFilter

Parameters:

th – (int, float) photon threshold

forward(em)[source]#
Parameters:

em – (EmitterSet)

Returns:

(EmitterSet) filtered set of emitters

Return type:

em

class decode.neuralfitter.em_filter.TarEmitterFilter(tar_ix=0)[source]#

Bases: EmitterFilter

Parameters:

tar_ix – (int) index of the target frame

forward(em)[source]#
Parameters:

em – (EmitterSet)

Returns:

(EmitterSet) filtered set of emitters

Return type:

em

decode.neuralfitter.frame_processing module#

class decode.neuralfitter.frame_processing.AutoCenterCrop(px_fold)[source]#

Bases: FrameProcessing

Automatic cropping in centre. Specify pixel_fold which the target frame size must satistfy and the frame will be center-cropped to this size.

Parameters:

px_fold (int) – integer in which multiple the frame must dimensioned (H, W dimension)

forward(frame)[source]#

Process frames

Parameters:

frame (Tensor) – size [*, H, W]

Return type:

Tensor

class decode.neuralfitter.frame_processing.AutoPad(px_fold, mode='constant')[source]#

Bases: AutoCenterCrop

Pad frame to a size that is divisible by px_fold. Useful to prepare an experimental frame for forwarding through network.

Parameters:
  • px_fold (int) – number of pixels the resulting frame size should be divisible by

  • mode (str) – torch mode for padding. refer to docs of torch.nn.functional.pad

forward(frame)[source]#

Process frames

Parameters:

frame (Tensor) – size [*, H, W]

Return type:

Tensor

class decode.neuralfitter.frame_processing.FrameProcessing[source]#

Bases: ABC

abstract forward(frame)[source]#

Forward frame through processing implementation.

Parameters:

frame (Tensor) –

Return type:

Tensor

class decode.neuralfitter.frame_processing.Mirror2D(dims)[source]#

Bases: FrameProcessing

Mirror the specified dimensions. Providing dim index in negative format is recommended. Given format N x C x H x W and you want to mirror H and W set dims=(-2, -1).

Parameters:

dims (Tuple) – dimensions

forward(frame)[source]#

Forward frame through processing implementation.

Parameters:

frame (Tensor) –

Return type:

Tensor

decode.neuralfitter.frame_processing.get_frame_extent(size, func)[source]#

Get frame extent after processing pipeline

Parameters:
  • size

  • func

Return type:

Size

Returns:

decode.neuralfitter.losscollection module#

decode.neuralfitter.post_processing module#

class decode.neuralfitter.post_processing.ConsistencyPostprocessing(*, raw_th, em_th, xy_unit, img_shape, ax_th=None, vol_th=None, lat_th=None, p_aggregation='pbinom_cdf', px_size=None, match_dims=2, diag=0, pphotxyzbg_mapping=[0, 1, 2, 3, 4, -1], num_workers=0, skip_th=None, return_format='batch-set', sanity_check=True)[source]#

Bases: PostProcessing

Parameters:
  • pphotxyzbg_mapping

  • raw_th

  • em_th

  • xy_unit

  • img_shape

  • ax_th

  • vol_th

  • lat_th

  • p_aggregation

  • px_size

  • match_dims

  • diag

  • num_workers

  • skip_th – relative fraction of the detection output to be on to skip post_processing. This is useful during training when the network has not yet converged and major parts of the detection output is white (i.e. non sparse detections).

  • return_format

  • sanity_check

forward(features)[source]#

Forward the feature map through the post processing and return an EmitterSet or a list of EmitterSets. For the input features we use the following convention:

0 - Detection channel

1 - Photon channel

2 - ‘x’ channel

3 - ‘y’ channel

4 - ‘z’ channel

5 - Background channel

Expecting x and y channels in nano-metres.

Parameters:

features (torch.Tensor) – Features of size \((N, C, H, W)\)

Returns:

Specified by return_format argument, EmitterSet in nano metres.

Return type:

EmitterSet or list of EmitterSets

classmethod parse(param, **kwargs)[source]#

Return an instance of this post-processing as specified by the parameters

Parameters:

param

Returns:

ConsistencyPostProcessing

sanity_check()[source]#

Performs some sanity checks. Part of the constructor; useful if you modify attributes later on and want to double check.

skip_if(x)[source]#

Skip post-processing when a certain condition is met and implementation would fail, i.e. to many bright pixels in the detection channel. Default implementation returns False always.

Parameters:

x – network output

Returns:

returns true when post-processing should be skipped

Return type:

bool

class decode.neuralfitter.post_processing.LookUpPostProcessing(raw_th, xy_unit, px_size=None, pphotxyzbg_mapping=(0, 1, 2, 3, 4, -1), photxyz_sigma_mapping=(5, 6, 7, 8))[source]#

Bases: PostProcessing

Parameters:
  • raw_th (float) – initial raw threshold

  • xy_unit (str) – xy unit unit

  • px_size – pixel size

  • pphotxyzbg_mapping (Union[list, tuple]) – channel index mapping of detection (p), photon, x, y, z, bg

forward(x)[source]#

Forward model output tensor through post-processing and return EmitterSet. Will include sigma values in EmitterSet if mapping was provided initially.

Parameters:

x (Tensor) – model output

Return type:

EmitterSet

Returns:

EmitterSet

class decode.neuralfitter.post_processing.NoPostProcessing(xy_unit=None, px_size=None, return_format='batch-set')[source]#

Bases: PostProcessing

Parameters:
  • return_format (str) – return format of forward function. Must be ‘batch-set’, ‘frame-set’. If ‘batch-set’

  • call (one instance of EmitterSet will be returned per forward) –

  • one (if 'frame-set' a tuple of EmitterSet) –

  • returned (per frame will be) –

  • sanity_check (bool) – perform sanity check

forward(x)[source]#
Parameters:

x (torch.Tensor) – any input tensor where the first dim is the batch-dim.

Returns:

An empty EmitterSet

Return type:

EmptyEmitterSet

class decode.neuralfitter.post_processing.PostProcessing(xy_unit, px_size, return_format)[source]#

Bases: ABC

Parameters:
  • return_format (str) – return format of forward function. Must be ‘batch-set’, ‘frame-set’. If ‘batch-set’

  • call (one instance of EmitterSet will be returned per forward) –

  • one (if 'frame-set' a tuple of EmitterSet) –

  • returned (per frame will be) –

  • sanity_check (bool) – perform sanity check

abstract forward(x)[source]#

Forward anything through the post-processing and return an EmitterSet

Parameters:

x (Tensor) –

Returns:

Returns as EmitterSet or as list of EmitterSets

Return type:

EmitterSet or list

sanity_check()[source]#

Sanity checks

skip_if(x)[source]#

Skip post-processing when a certain condition is met and implementation would fail, i.e. to many bright pixels in the detection channel. Default implementation returns False always.

Parameters:

x – network output

Returns:

returns true when post-processing should be skipped

Return type:

bool

class decode.neuralfitter.post_processing.SpatialIntegration(raw_th, xy_unit, px_size=None, pphotxyzbg_mapping=(0, 1, 2, 3, 4, -1), photxyz_sigma_mapping=(5, 6, 7, 8), p_aggregation='norm_sum')[source]#

Bases: LookUpPostProcessing

Parameters:
  • raw_th (float) – probability threshold from where detections are considered

  • xy_unit (str) – unit of the xy coordinates

  • px_size – pixel size

  • pphotxyzbg_mapping (Union[list, tuple]) – channel index mapping

  • photxyz_sigma_mapping (Union[list, tuple, None]) – channel index mapping of sigma channels

  • p_aggregation (Union[str, Callable]) – aggreation method to aggregate probabilities. can be ‘sum’, ‘max’, ‘norm_sum’

forward(x)[source]#

Forward model output tensor through post-processing and return EmitterSet. Will include sigma values in EmitterSet if mapping was provided initially.

Parameters:

x (Tensor) – model output

Return type:

EmitterSet

Returns:

EmitterSet

classmethod set_p_aggregation(p_aggr)[source]#

Sets the p_aggregation by string or callable. Return s Callable

Parameters:

p_aggr (Union[str, Callable]) – probability aggregation

Return type:

Callable

decode.neuralfitter.sampling module#

decode.neuralfitter.sampling.sample_crop(x_in, sample_size)[source]#

Takes a 2D tensor and returns random crops

Parameters:
  • x_in (Tensor) – input tensor

  • sample_size (Sequence[int]) – size of sample, size specification (N, H, W)

Return type:

Tensor

Returns:

random crops with size sample_size

decode.neuralfitter.scale_transform module#

class decode.neuralfitter.scale_transform.AmplitudeRescale(scale=1.0, offset=0.0)[source]#

Bases: object

Parameters:
  • offset (float) –

  • scale (float) – reference value

forward(x)[source]#

Forward the tensor and rescale it.

Parameters:

x (torch.Tensor) –

Returns:

rescaled tensor

Return type:

x_ (torch.Tensor)

static parse(param)[source]#
class decode.neuralfitter.scale_transform.FourFoldInverseOffsetRescale(*args, **kwargs)[source]#

Bases: InverseOffsetRescale

Assumes scale_x, scale_y, scale_z to be symmetric ranged, scale_phot, ranged between 0-1

Parameters:
  • scale_x (float) – scale factor in x

  • scale_y – scale factor in y

  • scale_z – scale factor in z

  • scale_phot – scale factor for photon values

  • mu_sig_bg – offset and scaling for background

  • buffer – buffer to extend the scales overall

  • power – power factor

forward(x)[source]#

Inverse scale transformation (typically before the network).

Parameters:

x (torch.Tensor) – input tensor N x 5/6 x H x W

Returns:

(inverse) scaled

Return type:

x_ (torch.Tensor)

class decode.neuralfitter.scale_transform.InverseOffsetRescale(*, scale_x, scale_y, scale_z, scale_phot, mu_sig_bg=(None, None), buffer=1.0, power=1.0)[source]#

Bases: OffsetRescale

Assumes scale_x, scale_y, scale_z to be symmetric ranged, scale_phot, ranged between 0-1

Parameters:
  • scale_x (float) – scale factor in x

  • scale_y (float) – scale factor in y

  • scale_z (float) – scale factor in z

  • scale_phot (float) – scale factor for photon values

  • mu_sig_bg – offset and scaling for background

  • buffer – buffer to extend the scales overall

  • power – power factor

forward(x)[source]#

Inverse scale transformation (typically before the network).

Parameters:

x (torch.Tensor) – input tensor N x 5/6 x H x W

Returns:

(inverse) scaled

Return type:

x_ (torch.Tensor)

classmethod parse(param)[source]#
class decode.neuralfitter.scale_transform.InverseParamListRescale(phot_max, z_max, bg_max)[source]#

Bases: ParameterListRescale

Rescale network output trained with GMM Loss.

forward(x)[source]#
Parameters:

x (Tensor) – model output

Return type:

Tensor

Returns:

torch.Tensor (rescaled model output)

class decode.neuralfitter.scale_transform.OffsetRescale(*, scale_x, scale_y, scale_z, scale_phot, mu_sig_bg=(None, None), buffer=1.0, power=1.0)[source]#

Bases: object

Assumes scale_x, scale_y, scale_z to be symmetric ranged, scale_phot, ranged between 0-1

Parameters:
  • scale_x (float) – scale factor in x

  • scale_y (float) – scale factor in y

  • scale_z (float) – scale factor in z

  • scale_phot (float) – scale factor for photon values

  • mu_sig_bg – offset and scaling for background

  • buffer – buffer to extend the scales overall

  • power – power factor

forward(x)[source]#

Scale the input (typically after the network).

Parameters:

x (torch.Tensor) – input tensor N x 5/6 x H x W

Returns:

scaled

Return type:

x_ (torch.Tensor)

static parse(param)[source]#
return_inverse()[source]#

Returns the inverse counterpart of this class (instance).

Returns:

Inverse counterpart.

Return type:

InverseOffSetRescale

class decode.neuralfitter.scale_transform.ParameterListRescale(phot_max, z_max, bg_max)[source]#

Bases: object

forward(x, mask, bg)[source]#
Return type:

Tuple[Tensor, Tensor, Tensor]

classmethod parse(param)[source]#
class decode.neuralfitter.scale_transform.SpatialInterpolation(mode='nearest', size=None, scale_factor=None, impl=None)[source]#

Bases: object

Parameters:
  • mode (string, None) – mode which is used for interpolation. Those are the modes by the torch interpolation

  • function

  • impl (optional) – override function for interpolation

forward(x)[source]#

Forward a tensor through the interpolation process.

Parameters:

x (torch.Tensor) – arbitrary tensor complying with the interpolation function. Must have a batch and channel dimension.

Returns:

interpolated tensor

Return type:

x_inter

decode.neuralfitter.target_generator module#

class decode.neuralfitter.target_generator.DisableAttributes(attr_ix)[source]#

Bases: object

Allows to disable attribute prediction of parameter list target; e.g. when you don’t want to predict z.

Parameters:

attr_ix (Union[None, int, tuple, list]) – index of the attribute you want to disable (phot, x, y, z).

forward(param_tar, mask_tar, bg)[source]#
classmethod parse(param)[source]#
class decode.neuralfitter.target_generator.FourFoldEmbedding(xextent, yextent, img_shape, rim_size, roi_size, ix_low=None, ix_high=None, squeeze_batch_dim=False)[source]#

Bases: TargetGenerator

Parameters:
  • xy_unit – Which unit to use for target generator

  • ix_low – lower bound of frame / batch index

  • ix_high – upper bound of frame / batch index

  • squeeze_batch_dim (bool) – if lower and upper frame_ix are the same, squeeze out the batch dimension before return

forward(em, bg=None, ix_low=None, ix_high=None)[source]#

Forward calculate target as by the emitters and background. Overwrite the default frame ix boundaries.

Parameters:
  • em (EmitterSet) – set of emitters

  • bg (Optional[Tensor]) – background frame

  • ix_low (Optional[int]) – lower frame index

  • ix_high (Optional[int]) – upper frame index

Return type:

Tensor

Returns:

target frames

classmethod parse(param, **kwargs)[source]#
class decode.neuralfitter.target_generator.ParameterListTarget(n_max, xextent, yextent, ix_low=None, ix_high=None, xy_unit='px', squeeze_batch_dim=False)[source]#

Bases: TargetGenerator

Target corresponding to the Gausian-Mixture Model Loss. Simply cat all emitter’s attributes up to a

maximum number of emitters as a list.

Parameters:
  • n_max (int) – maximum number of emitters (should be multitude of what you draw on average)

  • xextent (tuple) – extent of the emitters in x

  • yextent (tuple) – extent of the emitters in y

  • ix_low – lower frame index

  • ix_high – upper frame index

  • xy_unit (str) – xy unit

  • squeeze_batch_dim (bool) – squeeze batch dimension before return

forward(em, bg=None, ix_low=None, ix_high=None)[source]#

Forward calculate target as by the emitters and background. Overwrite the default frame ix boundaries.

Parameters:
  • em (EmitterSet) – set of emitters

  • bg (Optional[Tensor]) – background frame

  • ix_low (Optional[int]) – lower frame index

  • ix_high (Optional[int]) – upper frame index

Returns:

target frames

class decode.neuralfitter.target_generator.TargetGenerator(xy_unit='px', ix_low=None, ix_high=None, squeeze_batch_dim=False)[source]#

Bases: ABC

Parameters:
  • xy_unit – Which unit to use for target generator

  • ix_low (Optional[int]) – lower bound of frame / batch index

  • ix_high (Optional[int]) – upper bound of frame / batch index

  • squeeze_batch_dim (bool) – if lower and upper frame_ix are the same, squeeze out the batch dimension before return

abstract forward(em, bg=None, ix_low=None, ix_high=None)[source]#

Forward calculate target as by the emitters and background. Overwrite the default frame ix boundaries.

Parameters:
  • em (EmitterSet) – set of emitters

  • bg (Optional[Tensor]) – background frame

  • ix_low (Optional[int]) – lower frame index

  • ix_high (Optional[int]) – upper frame index

Return type:

Tensor

Returns:

target frames

sanity_check()[source]#
class decode.neuralfitter.target_generator.UnifiedEmbeddingTarget(xextent, yextent, img_shape, roi_size, ix_low=None, ix_high=None, squeeze_batch_dim=False)[source]#

Bases: TargetGenerator

Parameters:
  • xy_unit – Which unit to use for target generator

  • ix_low – lower bound of frame / batch index

  • ix_high – upper bound of frame / batch index

  • squeeze_batch_dim (bool) – if lower and upper frame_ix are the same, squeeze out the batch dimension before return

const_roi_target(batch_ix_roi, x_ix_roi, y_ix_roi, phot, id, batch_size)[source]#
forward(em, bg=None, ix_low=None, ix_high=None)[source]#

Forward calculate target as by the emitters and background. Overwrite the default frame ix boundaries.

Parameters:
  • em (EmitterSet) – set of emitters

  • bg (Optional[Tensor]) – background frame

  • ix_low (Optional[int]) – lower frame index

  • ix_high (Optional[int]) – upper frame index

Return type:

Tensor

Returns:

target frames

forward_(xyz, phot, frame_ix, ix_low, ix_high)[source]#

Get index of central bin for each emitter.

Return type:

Tensor

classmethod parse(param, **kwargs)[source]#
single_px_target(batch_ix, x_ix, y_ix, batch_size)[source]#
property xextent#
xy_target(batch_ix_roi, x_ix_roi, y_ix_roi, xy, id, batch_size)[source]#
property yextent#

decode.neuralfitter.train_val_impl module#

decode.neuralfitter.train_val_impl.ship_device(x, device)[source]#

Ships the input to a pytorch compatible device (e.g. CUDA)

Parameters:
  • x

  • device (Union[str, device]) –

Returns:

x

decode.neuralfitter.train_val_impl.test(model, loss, dataloader, epoch, device)[source]#

Setup

decode.neuralfitter.train_val_impl.train(model, optimizer, loss, dataloader, grad_rescale, grad_mod, epoch, device, logger)[source]#

Some Setup things

Return type:

float

decode.neuralfitter.weight_generator module#

class decode.neuralfitter.weight_generator.FourFoldSimpleWeight(*, xextent, yextent, img_shape, roi_size, rim, weight_mode='const', weight_power=None)[source]#

Bases: WeightGenerator

Parameters:
  • xy_unit – Which unit to use for target generator

  • ix_low – lower bound of frame / batch index

  • ix_high – upper bound of frame / batch index

  • squeeze_batch_dim – if lower and upper frame_ix are the same, squeeze out the batch dimension before return

forward(tar_em, tar_frames, ix_low=None, ix_high=None)[source]#

Calculate weight map based on target frames and target emitters.

Parameters:
  • tar_em (EmitterSet) – target EmitterSet

  • tar_frames (torch.Tensor) – frames of size \(((N,),C,H,W)\)

Returns:

Weight mask of size \(((N,),D,H,W)\) where likely \(C=D\)

Return type:

torch.Tensor

classmethod parse(param)[source]#

Constructs WeightGenerator by parameter variable which will be likely be a namedtuple, dotmap or similiar.

Parameters:

param

Returns:

Instance of WeightGenerator child classes.

Return type:

WeightGenerator

class decode.neuralfitter.weight_generator.SimpleWeight(*, xextent, yextent, img_shape, roi_size, weight_mode='const', weight_power=None, forward_safety=True, ix_low=None, ix_high=None, squeeze_batch_dim=False)[source]#

Bases: WeightGenerator

Parameters:
  • xextent (tuple) – extent in x

  • yextent (tuple) – extent in y

  • img_shape (tuple) – image shape

  • roi_size (int) – roi size of the target

  • weight_mode (str) – constant or phot

  • weight_power (float) – power factor of the weight

  • forward_safety (bool) – check sanity of forward arguments

check_forward_sanity(tar_em, tar_frames, ix_low, ix_high)[source]#

Check sanity of forward arguments, raise error otherwise.

Parameters:
  • tar_em (EmitterSet) – target emitters

  • tar_frames (Tensor) – target frames

  • ix_low (int) – lower frame index

  • ix_high (int) – upper frame index

check_sanity()[source]#
forward(tar_em, tar_frames, ix_low=None, ix_high=None)[source]#

Calculate weight map based on target frames and target emitters.

Parameters:
  • tar_em (EmitterSet) – target EmitterSet

  • tar_frames (torch.Tensor) – frames of size \(((N,),C,H,W)\)

Returns:

Weight mask of size \(((N,),D,H,W)\) where likely \(C=D\)

Return type:

torch.Tensor

classmethod parse(param, **kwargs)[source]#

Constructs WeightGenerator by parameter variable which will be likely be a namedtuple, dotmap or similiar.

Parameters:

param

Returns:

Instance of WeightGenerator child classes.

Return type:

WeightGenerator

class decode.neuralfitter.weight_generator.WeightGenerator(ix_low=None, ix_high=None, squeeze_batch_dim=False)[source]#

Bases: TargetGenerator

Parameters:
  • xy_unit – Which unit to use for target generator

  • ix_low (Optional[int]) – lower bound of frame / batch index

  • ix_high (Optional[int]) – upper bound of frame / batch index

  • squeeze_batch_dim (bool) – if lower and upper frame_ix are the same, squeeze out the batch dimension before return

check_forward_sanity(tar_em, tar_frames, ix_low, ix_high)[source]#

Check sanity of forward arguments, raise error otherwise.

Parameters:
  • tar_em (EmitterSet) – target emitters

  • tar_frames (Tensor) – target frames

  • ix_low (int) – lower frame index

  • ix_high (int) – upper frame index

abstract forward(tar_em, tar_frames, ix_low, ix_high)[source]#

Calculate weight map based on target frames and target emitters.

Parameters:
  • tar_em (EmitterSet) – target EmitterSet

  • tar_frames (torch.Tensor) – frames of size \(((N,),C,H,W)\)

Returns:

Weight mask of size \(((N,),D,H,W)\) where likely \(C=D\)

Return type:

torch.Tensor

classmethod parse(param)[source]#

Constructs WeightGenerator by parameter variable which will be likely be a namedtuple, dotmap or similiar.

Parameters:

param

Returns:

Instance of WeightGenerator child classes.

Return type:

WeightGenerator

Module contents#