decode.neuralfitter.models package#

Submodules#

decode.neuralfitter.models.model_param module#

class decode.neuralfitter.models.model_param.DoubleMUnet(ch_in, ch_out, ext_features=0, depth_shared=3, depth_union=3, initial_features=64, inter_features=64, activation=ReLU(), use_last_nl=True, norm=None, norm_groups=None, norm_head=None, norm_head_groups=None, pool_mode='Conv2d', upsample_mode='bilinear', skip_gn_level=None)[source]#

Bases: Module

Initializes internal Module state, shared by both nn.Module and ScriptModule.

apply_detection_nonlin(x)[source]#
Apply detection non-linearity. Useful for non-training situations. When BCEWithLogits loss is used, do not use this

during training (because it’s already included in the loss).

Parameters:

o – model output

Return type:

Tensor

apply_nonlin(o)[source]#

Apply non-linearity to all but the detection channel.

Parameters:

o (Tensor) –

Return type:

Tensor

forward(x, force_no_p_nl=False)[source]#
Parameters:
  • x

  • force_no_p_nl

Returns:

classmethod parse(param, **kwargs)[source]#
rescale_last_layer_grad(loss, optimizer)[source]#

Rescales the weight as by the last layer’s gradient

Parameters:
  • loss

  • optimizer

Returns:

weight, channelwise loss, channelwise weighted loss

training: bool#
class decode.neuralfitter.models.model_param.MLTHeads(in_channels, out_channels, last_kernel, norm, norm_groups, padding, activation)[source]#

Bases: Module

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(x)[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool#
class decode.neuralfitter.models.model_param.SimpleSMLMNet(ch_in, ch_out, depth=3, initial_features=64, inter_features=64, p_dropout=0.0, activation=ReLU(), use_last_nl=True, norm=None, norm_groups=None, norm_head=None, norm_head_groups=None, pool_mode='StrideConv', upsample_mode='bilinear', skip_gn_level=None)[source]#

Bases: UNet2d

Initializes internal Module state, shared by both nn.Module and ScriptModule.

apply_nonlin(o)[source]#

Apply non linearity in all the other channels :type o: :param o: :return:

apply_pnl(o)[source]#

Apply nonlinearity (sigmoid) to p channel. This is combined during training in the loss function. Only use when not training :type o: :param o: :return:

static check_target(y_tar)[source]#
forward(x, force_no_p_nl=False)[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

static parse(param)[source]#
rescale_last_layer_grad(loss, optimizer)[source]#
Parameters:
  • loss – non-reduced loss of size N x C x H x W

  • optimizer

Returns:

weight, channelwise loss, channelwise weighted loss

training: bool#

decode.neuralfitter.models.model_speced_impl module#

class decode.neuralfitter.models.model_speced_impl.SigmaMUNet(ch_in, *, depth_shared, depth_union, initial_features, inter_features, norm=None, norm_groups=None, norm_head=None, norm_head_groups=None, pool_mode='StrideConv', upsample_mode='bilinear', skip_gn_level=None, activation=ReLU(), kaiming_normal=True)[source]#

Bases: DoubleMUnet

Initializes internal Module state, shared by both nn.Module and ScriptModule.

apply_detection_nonlin(x)[source]#
Apply detection non-linearity. Useful for non-training situations. When BCEWithLogits loss is used, do not use this

during training (because it’s already included in the loss).

Parameters:

o – model output

Return type:

Tensor

apply_nonlin(o)[source]#

Apply non-linearity to all but the detection channel.

Parameters:

o (Tensor) –

Return type:

Tensor

bg_ch_ix = [10]#
ch_out = 10#
forward(x)[source]#
Parameters:
  • x (Tensor) –

  • force_no_p_nl

Return type:

Tensor

Returns:

mt_heads#

Register sigma as parameter such that it is stored in the models state dict and loaded correctly.

out_channels_heads = (1, 4, 4, 1)#
p_ch_ix = [0]#
classmethod parse(param, **kwargs)[source]#
pxyz_mu_ch_ix = slice(1, 5, None)#
pxyz_sig_ch_ix = slice(5, 9, None)#
sigma_eps_default = 0.001#
sigmoid_ch_ix = [0, 1, 5, 6, 7, 8, 9]#
tanh_ch_ix = [2, 3, 4]#
training: bool#
static weight_init(m)[source]#

Apply Kaiming normal init. Call this recursively by model.apply(model.weight_init)

Parameters:

m – model

decode.neuralfitter.models.unet_param module#

class decode.neuralfitter.models.unet_param.UNet2d(in_channels, out_channels, depth=4, initial_features=64, gain=2, pad_convs=False, norm=None, norm_groups=None, p_dropout=None, final_activation=None, activation=ReLU(), pool_mode='MaxPool', skip_gn_level=None, upsample_mode='bilinear')[source]#

Bases: UNetBase

Initializes internal Module state, shared by both nn.Module and ScriptModule.

training: bool#
class decode.neuralfitter.models.unet_param.UNetBase(in_channels, out_channels, depth=4, initial_features=64, gain=2, pad_convs=False, norm=None, norm_groups=None, p_dropout=None, final_activation=None, activation=ReLU(), pool_mode='MaxPool', skip_gn_level=None, upsample_mode='bilinear')[source]#

Bases: Module

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(input)[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

forward_parts(parts)[source]#
norms = ('BatchNorm', 'GroupNorm')#
pool_modules = ('MaxPool', 'StrideConv')#
training: bool#
class decode.neuralfitter.models.unet_param.Upsample(scale_factor, mode='nearest', in_channels=None, out_channels=None, align_corners=False, ndim=3)[source]#

Bases: Module

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(input)[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool#
decode.neuralfitter.models.unet_param.get_activation(activation)[source]#

Get activation from str or nn.Module

decode.neuralfitter.models.unet_parts module#

class decode.neuralfitter.models.unet_parts.Upsample(scale_factor, mode, align_corners)[source]#

Bases: Module

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(x)[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool#
class decode.neuralfitter.models.unet_parts.double_conv(in_ch, out_ch)[source]#

Bases: Module

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(x)[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool#
class decode.neuralfitter.models.unet_parts.down(in_ch, out_ch)[source]#

Bases: Module

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(x)[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool#
class decode.neuralfitter.models.unet_parts.down_3d(in_ch, out_ch)[source]#

Bases: down

Initializes internal Module state, shared by both nn.Module and ScriptModule.

training: bool#
class decode.neuralfitter.models.unet_parts.inconv(in_ch, out_ch)[source]#

Bases: Module

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(x)[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool#
class decode.neuralfitter.models.unet_parts.inconv_3d(in_ch, out_ch)[source]#

Bases: inconv

Initializes internal Module state, shared by both nn.Module and ScriptModule.

training: bool#
class decode.neuralfitter.models.unet_parts.outconv(in_ch, out_ch)[source]#

Bases: Module

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(x)[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool#
class decode.neuralfitter.models.unet_parts.up(in_ch, out_ch, bilinear=True)[source]#

Bases: Module

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(x1, x2)[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool#

Module contents#