import functools
from typing import Tuple
import torch
[docs]class SpatialInterpolation:
"""
Up or downscales by a given method.
Attributes:
dim (int): dimensionality for safety checks
"""
def __init__(self, mode='nearest', size=None, scale_factor=None, impl=None):
"""
Args:
mode (string, None): mode which is used for interpolation. Those are the modes by the torch interpolation
function
impl (optional): override function for interpolation
"""
if impl is not None:
self._inter_impl = impl
else:
self._inter_impl = functools.partial(torch.nn.functional.interpolate,
mode=mode, size=size, scale_factor=scale_factor)
@staticmethod
def _unsq_call_sq(func, x: torch.Tensor, dim: int) -> any:
"""
Unsqueeze input tensor until dimensionality 'dim' is matched and squeeze output before return
Args:
func:
x:
dim:
Returns:
"""
n_unsq = 0
while x.dim() < dim:
x.unsqueeze_(0)
n_unsq += 1
out = func(x)
for _ in range(n_unsq):
out.squeeze_(0)
return out
[docs] def forward(self, x: torch.Tensor):
"""
Forward a tensor through the interpolation process.
Args:
x (torch.Tensor): arbitrary tensor complying with the interpolation function.
Must have a batch and channel dimension.
Returns:
x_inter: interpolated tensor
"""
return self._unsq_call_sq(self._inter_impl, x, 4)
[docs]class AmplitudeRescale:
"""
Simple Processing that rescales the amplitude, i.e. the pixel values.
Attributes:
norm-value (float): Value to which to norm the data.
"""
def __init__(self, scale: float = 1., offset: float = 0.):
"""
Args:
offset:
scale (float): reference value
"""
self.scale = scale if scale is not None else 1.
self.offset = offset if offset is not None else 0.
[docs] @staticmethod
def parse(param):
return AmplitudeRescale(scale=param.Scaling.input_scale,
offset=param.Scaling.input_offset)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward the tensor and rescale it.
Args:
x (torch.Tensor):
Returns:
x_ (torch.Tensor): rescaled tensor
"""
return (x - self.offset) / self.scale
[docs]class OffsetRescale:
"""
The purpose of this class is to rescale the (target) data from the network value world back to the real values.
This class is used if we want to know the actual values and do not want to just use it for the loss.
"""
def __init__(self, *, scale_x: float, scale_y: float, scale_z: float, scale_phot: float, mu_sig_bg=(None, None),
buffer=1., power=1.):
"""
Assumes scale_x, scale_y, scale_z to be symmetric ranged, scale_phot, ranged between 0-1
Args:
scale_x (float): scale factor in x
scale_y: scale factor in y
scale_z: scale factor in z
scale_phot: scale factor for photon values
mu_sig_bg: offset and scaling for background
buffer: buffer to extend the scales overall
power: power factor
"""
self.sc_x = scale_x
self.sc_y = scale_y
self.sc_z = scale_z
self.sc_phot = scale_phot
self.mu_sig_bg = mu_sig_bg
self.buffer = buffer
self.power = power
[docs] @staticmethod
def parse(param):
return OffsetRescale(scale_x=param.Scaling.dx_max,
scale_y=param.Scaling.dy_max,
scale_z=param.Scaling.z_max,
scale_phot=param.Scaling.phot_max,
mu_sig_bg=param.Scaling.mu_sig_bg,
buffer=param.Scaling.linearisation_buffer)
[docs] def return_inverse(self):
"""
Returns the inverse counterpart of this class (instance).
Returns:
InverseOffSetRescale: Inverse counterpart.
"""
return InverseOffsetRescale(scale_x=self.sc_x,
scale_y=self.sc_y,
scale_z=self.sc_z,
scale_phot=self.sc_phot,
mu_sig_bg=self.mu_sig_bg,
buffer=self.buffer,
power=self.power)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Scale the input (typically after the network).
Args:
x (torch.Tensor): input tensor N x 5/6 x H x W
Returns:
x_ (torch.Tensor): scaled
"""
if x.dim() == 3:
x.unsqueeze_(0)
squeeze_before_return = True
else:
squeeze_before_return = False
x_ = x.clone()
x_[:, 1, :, :] *= (self.sc_phot * self.buffer) ** self.power
x_[:, 2, :, :] *= (self.sc_x * self.buffer) ** self.power
x_[:, 3, :, :] *= (self.sc_y * self.buffer) ** self.power
x_[:, 4, :, :] *= (self.sc_z * self.buffer) ** self.power
if x_.size(1) == 6:
x_[:, 5, :, :] *= (self.mu_sig_bg[1] * self.buffer) ** self.power
x_[:, 5, :, :] += self.mu_sig_bg[0]
if squeeze_before_return:
return x_.squeeze(0)
else:
return x_
[docs]class InverseOffsetRescale(OffsetRescale):
"""
The purpose of this class is to scale the target data for the loss to an apropriate range.
"""
def __init__(self, *, scale_x: float, scale_y: float, scale_z: float, scale_phot: float, mu_sig_bg=(None, None),
buffer=1., power=1.):
"""
Assumes scale_x, scale_y, scale_z to be symmetric ranged, scale_phot, ranged between 0-1
Args:
scale_x (float): scale factor in x
scale_y: scale factor in y
scale_z: scale factor in z
scale_phot: scale factor for photon values
mu_sig_bg: offset and scaling for background
buffer: buffer to extend the scales overall
power: power factor
"""
super().__init__(scale_x=scale_x, scale_y=scale_y, scale_z=scale_z, scale_phot=scale_phot,
mu_sig_bg=mu_sig_bg, buffer=buffer, power=power)
[docs] @classmethod
def parse(cls, param):
return cls(scale_x=param.Scaling.dx_max,
scale_y=param.Scaling.dy_max,
scale_z=param.Scaling.z_max,
scale_phot=param.Scaling.phot_max,
mu_sig_bg=param.Scaling.mu_sig_bg,
buffer=param.Scaling.linearisation_buffer)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Inverse scale transformation (typically before the network).
Args:
x (torch.Tensor): input tensor N x 5/6 x H x W
Returns:
x_ (torch.Tensor): (inverse) scaled
"""
if x.dim() == 3:
x.unsqueeze_(0)
squeeze_before_return = True
else:
squeeze_before_return = False
x_ = x.clone()
x_[:, 1, :, :] /= (self.sc_phot * self.buffer) ** self.power
x_[:, 2, :, :] /= (self.sc_x * self.buffer) ** self.power
x_[:, 3, :, :] /= (self.sc_y * self.buffer) ** self.power
x_[:, 4, :, :] /= (self.sc_z * self.buffer) ** self.power
if x_.size(1) == 6:
x_[:, 5, :, :] -= self.mu_sig_bg[0]
x_[:, 5, :, :] /= (self.mu_sig_bg[1] * self.buffer) ** self.power
if squeeze_before_return:
return x_.squeeze(0)
else:
return x_
[docs]class FourFoldInverseOffsetRescale(InverseOffsetRescale):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor:
x_no_bg = torch.cat([super(FourFoldInverseOffsetRescale, self).forward(x[..., i:(i + 5), :, :]) for i in range(0, 20, 5)], dim=1 if x.dim() == 4 else 0)
bg = (x[..., [20], :, :] - self.mu_sig_bg[0]) / (self.mu_sig_bg[1] * self.buffer) ** self.power
return torch.cat([x_no_bg, bg], 1 if x.dim() == 4 else 0)
[docs]class ParameterListRescale:
def __init__(self, phot_max, z_max, bg_max):
self.phot_max = phot_max
self.z_max = z_max
self.bg_max = bg_max
[docs] def forward(self, x: torch.Tensor, mask: torch.Tensor, bg: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if x.dim() not in (2, 3) or x.size(-1) != 4:
raise ValueError(f"Unsupported shape of input {x.size()}")
x = x.clone()
x[..., 0] = x[..., 0] / self.phot_max
x[..., 3] = x[..., 3] / self.z_max
bg = bg / self.bg_max
return x, mask, bg
[docs] @classmethod
def parse(cls, param):
return cls(phot_max=param.Scaling.phot_max,
z_max=param.Scaling.z_max,
bg_max=param.Scaling.bg_max)
[docs]class InverseParamListRescale(ParameterListRescale):
"""
Rescale network output trained with GMM Loss.
"""
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: model output
Returns:
torch.Tensor (rescaled model output)
"""
if x.dim() != 4 or x.size(1) != 10:
raise ValueError(f"Unsupported size of input {x.size()}")
x = x.clone()
x[:, 1] *= self.phot_max
x[:, 5] *= self.phot_max # sigma rescaling
x[:, 4] *= self.z_max
x[:, 8] *= self.z_max
x[:, -1] *= self.bg_max
return x