Source code for decode.neuralfitter.models.model_speced_impl

from typing import Union

import torch
from torch import nn

from . import model_param


[docs]class SigmaMUNet(model_param.DoubleMUnet): ch_out = 10 out_channels_heads = (1, 4, 4, 1) # p head, phot,xyz_mu head, phot,xyz_sig head, bg head sigmoid_ch_ix = [0, 1, 5, 6, 7, 8, 9] # channel indices with respective activation function tanh_ch_ix = [2, 3, 4] p_ch_ix = [0] # channel indices of the respective parameters pxyz_mu_ch_ix = slice(1, 5) pxyz_sig_ch_ix = slice(5, 9) bg_ch_ix = [10] sigma_eps_default = 0.001 def __init__(self, ch_in: int, *, depth_shared: int, depth_union: int, initial_features: int, inter_features: int, norm=None, norm_groups=None, norm_head=None, norm_head_groups=None, pool_mode='StrideConv', upsample_mode='bilinear', skip_gn_level: Union[None, bool] = None, activation=nn.ReLU(), kaiming_normal=True): super().__init__(ch_in=ch_in, ch_out=self.ch_out, depth_shared=depth_shared, depth_union=depth_union, initial_features=initial_features, inter_features=inter_features, norm=norm, norm_groups=norm_groups, norm_head=norm_head, norm_head_groups=norm_head_groups, pool_mode=pool_mode, upsample_mode=upsample_mode, skip_gn_level=skip_gn_level, activation=activation, use_last_nl=False) self.mt_heads = torch.nn.ModuleList( [model_param.MLTHeads(in_channels=inter_features, out_channels=ch_out, activation=activation, last_kernel=1, padding="same", norm=norm_head, norm_groups=norm_head_groups) for ch_out in self.out_channels_heads] ) """Register sigma as parameter such that it is stored in the models state dict and loaded correctly.""" self.register_parameter('sigma_eps', torch.nn.Parameter(torch.tensor([self.sigma_eps_default]), requires_grad=False)) if kaiming_normal: self.apply(self.weight_init) # custom torch.nn.init.kaiming_normal_(self.mt_heads[0].core[0].weight, mode='fan_in', nonlinearity='relu') torch.nn.init.kaiming_normal_(self.mt_heads[0].out_conv.weight, mode='fan_in', nonlinearity='linear') torch.nn.init.constant_(self.mt_heads[0].out_conv.bias, -6.)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: x = self._forward_core(x) """Forward through the respective heads""" x_heads = [mt_head.forward(x) for mt_head in self.mt_heads] x = torch.cat(x_heads, dim=1) """Clamp prob before sigmoid""" x[:, [0]] = torch.clamp(x[:, [0]], min=-8., max=8.) """Apply non linearities""" x[:, self.sigmoid_ch_ix] = torch.sigmoid(x[:, self.sigmoid_ch_ix]) x[:, self.tanh_ch_ix] = torch.tanh(x[:, self.tanh_ch_ix]) """Add epsilon to sigmas and rescale""" x[:, self.pxyz_sig_ch_ix] = x[:, self.pxyz_sig_ch_ix] * 3 + self.sigma_eps return x
[docs] def apply_detection_nonlin(self, x: torch.Tensor) -> torch.Tensor: raise NotImplementedError
[docs] def apply_nonlin(self, o: torch.Tensor) -> torch.Tensor: raise NotImplementedError
[docs] @classmethod def parse(cls, param, **kwargs): activation = getattr(torch.nn, param.HyperParameter.arch_param.activation) activation = activation() return cls( ch_in=param.HyperParameter.channels_in, depth_shared=param.HyperParameter.arch_param.depth_shared, depth_union=param.HyperParameter.arch_param.depth_union, initial_features=param.HyperParameter.arch_param.initial_features, inter_features=param.HyperParameter.arch_param.inter_features, activation=activation, norm=param.HyperParameter.arch_param.norm, norm_groups=param.HyperParameter.arch_param.norm_groups, norm_head=param.HyperParameter.arch_param.norm_head, norm_head_groups=param.HyperParameter.arch_param.norm_head_groups, pool_mode=param.HyperParameter.arch_param.pool_mode, upsample_mode=param.HyperParameter.arch_param.upsample_mode, skip_gn_level=param.HyperParameter.arch_param.skip_gn_level, kaiming_normal=param.HyperParameter.arch_param.init_custom )
[docs] @staticmethod def weight_init(m): """ Apply Kaiming normal init. Call this recursively by model.apply(model.weight_init) Args: m: model """ if isinstance(m, torch.nn.Conv2d): torch.nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')