import torch
from torch import nn as nn
from . import unet_param
from ..utils import last_layer_dynamics as lyd
[docs]class SimpleSMLMNet(unet_param.UNet2d):
def __init__(self, ch_in, ch_out, depth=3, initial_features=64, inter_features=64, p_dropout=0.,
activation=nn.ReLU(), use_last_nl=True, norm=None, norm_groups=None, norm_head=None,
norm_head_groups=None, pool_mode='StrideConv', upsample_mode='bilinear', skip_gn_level=None):
super().__init__(in_channels=ch_in,
out_channels=inter_features,
depth=depth,
initial_features=initial_features,
pad_convs=True,
norm=norm,
norm_groups=norm_groups,
p_dropout=p_dropout,
pool_mode=pool_mode,
activation=activation,
skip_gn_level=skip_gn_level)
assert ch_out in (5, 6)
self.ch_out = ch_out
self.mt_heads = nn.ModuleList([
MLTHeads(inter_features, norm=norm_head, norm_groups=norm_head_groups, padding="same", activation=activation)
for _ in range(self.ch_out)
])
self._use_last_nl = use_last_nl
self.p_nl = torch.sigmoid # only in inference, during training
self.phot_nl = torch.sigmoid
self.xyz_nl = torch.tanh
self.bg_nl = torch.sigmoid
[docs] @staticmethod
def parse(param):
activation = eval(param.HyperParameter.arch_param.activation)
return SimpleSMLMNet(
ch_in=param.HyperParameter.channels_in,
ch_out=param.HyperParameter.channels_out,
depth=param.HyperParameter.arch_param.depth,
initial_features=param.HyperParameter.arch_param.initial_features,
inter_features=param.HyperParameter.arch_param.inter_features,
p_dropout=param.HyperParameter.arch_param.p_dropout,
pool_mode=param.HyperParameter.arch_param.pool_mode,
upsample_mode=param.HyperParameter.arch_param.upsample_mode,
activation=activation,
use_last_nl=param.HyperParameter.arch_param.use_last_nl,
norm=param.HyperParameter.arch_param.norm,
norm_groups=param.HyperParameter.arch_param.norm_groups,
norm_head=param.HyperParameter.arch_param.norm_head,
norm_head_groups=param.HyperParameter.arch_param.norm_head_groups,
skip_gn_level=param.HyperParameter.arch_param.skip_gn_level
)
[docs] @staticmethod
def check_target(y_tar):
assert y_tar.dim() == 4, "Wrong dim."
assert y_tar.size(1) == 6, "Wrong num. of channels"
assert ((y_tar[:, 0] >= 0.) * (y_tar[:, 0] <= 1.)).all(), "Probability outside of the range."
assert ((y_tar[:, 1] >= 0.) * (y_tar[:, 1] <= 1.)).all(), "Photons outside of the range."
assert ((y_tar[:, 2:5] >= -1.) * (y_tar[:, 2:5] <= 1.)).all(), "XYZ outside of the range."
assert ((y_tar[:, 1] >= 0.) * (y_tar[:, 1] <= 1.)).all(), "BG outside of the range."
[docs] def rescale_last_layer_grad(self, loss, optimizer):
"""
:param loss: non-reduced loss of size N x C x H x W
:param optimizer:
:return: weight, channelwise loss, channelwise weighted loss
"""
return lyd.weight_by_gradient(self.mt_heads, loss, optimizer)
[docs] def apply_pnl(self, o):
"""
Apply nonlinearity (sigmoid) to p channel. This is combined during training in the loss function.
Only use when not training
:param o:
:return:
"""
o[:, [0]] = self.p_nl(o[:, [0]])
return o
[docs] def apply_nonlin(self, o):
"""
Apply non linearity in all the other channels
:param o:
:return:
"""
# Apply for phot, xyz
p = o[:, [0]] # leave unused
phot = o[:, [1]]
xyz = o[:, 2:5]
phot = self.phot_nl(phot)
xyz = self.xyz_nl(xyz)
if self.ch_out == 5:
o = torch.cat((p, phot, xyz), 1)
return o
elif self.ch_out == 6:
bg = o[:, [5]]
bg = self.bg_nl(bg)
o = torch.cat((p, phot, xyz, bg), 1)
return o
[docs] def forward(self, x, force_no_p_nl=False):
o = super().forward(x)
o_head = []
for i in range(self.ch_out):
o_head.append(self.mt_heads[i].forward(o))
o = torch.cat(o_head, 1)
"""Apply the final non-linearities"""
if not self.training and not force_no_p_nl:
o[:, [0]] = self.p_nl(o[:, [0]])
if self._use_last_nl:
o = self.apply_nonlin(o)
return o
[docs]class DoubleMUnet(nn.Module):
def __init__(self, ch_in, ch_out, ext_features=0, depth_shared=3, depth_union=3, initial_features=64,
inter_features=64,
activation=nn.ReLU(), use_last_nl=True, norm=None, norm_groups=None, norm_head=None,
norm_head_groups=None, pool_mode='Conv2d', upsample_mode='bilinear', skip_gn_level=None):
super().__init__()
self.unet_shared = unet_param.UNet2d(1 + ext_features, inter_features, depth=depth_shared, pad_convs=True,
initial_features=initial_features,
activation=activation, norm=norm, norm_groups=norm_groups,
pool_mode=pool_mode, upsample_mode=upsample_mode,
skip_gn_level=skip_gn_level)
self.unet_union = unet_param.UNet2d(ch_in * inter_features, inter_features, depth=depth_union, pad_convs=True,
initial_features=initial_features,
activation=activation, norm=norm, norm_groups=norm_groups,
pool_mode=pool_mode, upsample_mode=upsample_mode,
skip_gn_level=skip_gn_level)
assert ch_in in (1, 3)
# assert ch_out in (5, 6)
self.ch_in = ch_in
self.ch_out = ch_out
self.mt_heads = nn.ModuleList(
[MLTHeads(inter_features, out_channels=1, last_kernel=1,
norm=norm_head, norm_groups=norm_head_groups,
padding="same", activation=activation) for _ in range(self.ch_out)])
self._use_last_nl = use_last_nl
self.p_nl = torch.sigmoid # only in inference, during training
self.phot_nl = torch.sigmoid
self.xyz_nl = torch.tanh
self.bg_nl = torch.sigmoid
[docs] @classmethod
def parse(cls, param, **kwargs):
activation = eval(param.HyperParameter.arch_param.activation)
return cls(
ch_in=param.HyperParameter.channels_in,
ch_out=param.HyperParameter.channels_out,
ext_features=0,
depth_shared=param.HyperParameter.arch_param.depth_shared,
depth_union=param.HyperParameter.arch_param.depth_union,
initial_features=param.HyperParameter.arch_param.initial_features,
inter_features=param.HyperParameter.arch_param.inter_features,
activation=activation,
use_last_nl=param.HyperParameter.arch_param.use_last_nl,
norm=param.HyperParameter.arch_param.norm,
norm_groups=param.HyperParameter.arch_param.norm_groups,
norm_head=param.HyperParameter.arch_param.norm_head,
norm_head_groups=param.HyperParameter.arch_param.norm_head_groups,
pool_mode=param.HyperParameter.arch_param.pool_mode,
upsample_mode=param.HyperParameter.arch_param.upsample_mode,
skip_gn_level=param.HyperParameter.arch_param.skip_gn_level,
**kwargs
)
[docs] def rescale_last_layer_grad(self, loss, optimizer):
"""
Rescales the weight as by the last layer's gradient
Args:
loss:
optimizer:
Returns:
weight, channelwise loss, channelwise weighted loss
"""
return lyd.weight_by_gradient(self.mt_heads, loss, optimizer)
[docs] def apply_detection_nonlin(self, x: torch.Tensor) -> torch.Tensor:
"""
Apply detection non-linearity. Useful for non-training situations. When BCEWithLogits loss is used, do not use this
during training (because it's already included in the loss).
Args:
o: model output
"""
x[:, [0]] = self.p_nl(x[:, [0]])
return x
[docs] def apply_nonlin(self, o: torch.Tensor) -> torch.Tensor:
"""
Apply non-linearity to all but the detection channel.
Args:
o:
"""
# Apply for phot, xyz
p = o[:, [0]] # leave unused
phot = o[:, [1]]
xyz = o[:, 2:5]
phot = self.phot_nl(phot)
xyz = self.xyz_nl(xyz)
if self.ch_out == 5:
o = torch.cat((p, phot, xyz), 1)
return o
elif self.ch_out == 6:
bg = o[:, [5]]
bg = self.bg_nl(bg)
o = torch.cat((p, phot, xyz, bg), 1)
return o
[docs] def forward(self, x, force_no_p_nl=False):
"""
Args:
x:
force_no_p_nl:
Returns:
"""
o = self._forward_core(x)
o_head = []
for i in range(self.ch_out):
o_head.append(self.mt_heads[i].forward(o))
o = torch.cat(o_head, 1)
"""Apply the final non-linearities"""
if not self.training and not force_no_p_nl:
o[:, [0]] = self.p_nl(o[:, [0]])
if self._use_last_nl:
o = self.apply_nonlin(o)
return o
def _forward_core(self, x) -> torch.Tensor:
if self.ch_in == 3:
x0 = x[:, [0]]
x1 = x[:, [1]]
x2 = x[:, [2]]
o0 = self.unet_shared.forward(x0)
o1 = self.unet_shared.forward(x1)
o2 = self.unet_shared.forward(x2)
o = torch.cat((o0, o1, o2), 1)
elif self.ch_in == 1:
o = self.unet_shared.forward(x)
o = self.unet_union.forward(o)
return o
[docs]class MLTHeads(nn.Module):
def __init__(self, in_channels, out_channels, last_kernel, norm, norm_groups, padding, activation):
super().__init__()
self.norm = norm
self.norm_groups = norm_groups
if self.norm is not None:
groups_1 = min(in_channels, self.norm_groups)
groups_2 = min(1, self.norm_groups)
else:
groups_1 = None
groups_2 = None
self.core = self._make_core(in_channels, groups_1, groups_2, activation, padding, self.norm)
self.out_conv = nn.Conv2d(in_channels, out_channels, kernel_size=last_kernel, padding="valid")
[docs] def forward(self, x):
o = self.core.forward(x)
o = self.out_conv.forward(o)
return o
@staticmethod
def _make_core(in_channels, groups_1, groups_2, activation, padding, norm):
if norm == 'GroupNorm':
return nn.Sequential(nn.GroupNorm(groups_1, in_channels),
nn.Conv2d(in_channels, in_channels,
kernel_size=3, padding=padding),
activation,
# nn.GroupNorm(groups_2, in_channels)
)
elif norm is None:
return nn.Sequential(nn.Conv2d(in_channels, in_channels,
kernel_size=3, padding=padding),
activation)
else:
raise NotImplementedError