import warnings
from abc import ABC, abstractmethod # abstract class
from typing import Union, Callable
import scipy
import torch
from deprecated import deprecated
from sklearn.cluster import AgglomerativeClustering
import decode.simulation.background
from decode.evaluation import match_emittersets
from decode.generic.emitter import EmitterSet, EmptyEmitterSet
from decode.neuralfitter.utils.probability import binom_pdiverse
[docs]class PostProcessing(ABC):
_return_types = ('batch-set', 'frame-set')
def __init__(self, xy_unit, px_size, return_format: str):
"""
Args:
return_format (str): return format of forward function. Must be 'batch-set', 'frame-set'. If 'batch-set'
one instance of EmitterSet will be returned per forward call, if 'frame-set' a tuple of EmitterSet one
per frame will be returned
sanity_check (bool): perform sanity check
"""
super().__init__()
self.xy_unit = xy_unit
self.px_size = px_size
self.return_format = return_format
[docs] def sanity_check(self):
"""
Sanity checks
"""
if self.return_format not in self._return_types:
raise ValueError("Not supported return type.")
[docs] def skip_if(self, x):
"""
Skip post-processing when a certain condition is met and implementation would fail, i.e. to many
bright pixels in the detection channel. Default implementation returns False always.
Args:
x: network output
Returns:
bool: returns true when post-processing should be skipped
"""
return False
@deprecated(reason="Not of interest for the post-processing.", version="0.1.dev")
def _return_as_type(self, em, ix_low, ix_high):
"""
Returns in the type specified in constructor
Args:
em (EmitterSet): emitters
ix_low (int): lower frame_ix
ix_high (int): upper frame_ix
Returns:
EmitterSet or list: Returns as EmitterSet or as list of EmitterSets
"""
if self.return_format == 'batch-set':
return em
elif self.return_format == 'frame-set':
return em.split_in_frames(ix_low=ix_low, ix_up=ix_high)
else:
raise ValueError
[docs] @abstractmethod
def forward(self, x: torch.Tensor) -> (EmitterSet, list):
"""
Forward anything through the post-processing and return an EmitterSet
Args:
x:
Returns:
EmitterSet or list: Returns as EmitterSet or as list of EmitterSets
"""
raise NotImplementedError
[docs]class NoPostProcessing(PostProcessing):
"""
The 'No' Post-Processing post-processing. Will always return an empty EmitterSet.
"""
def __init__(self, xy_unit=None, px_size=None, return_format='batch-set'):
super().__init__(xy_unit=xy_unit, px_size=px_size, return_format=return_format)
[docs] def forward(self, x: torch.Tensor):
"""
Args:
x (torch.Tensor): any input tensor where the first dim is the batch-dim.
Returns:
EmptyEmitterSet: An empty EmitterSet
"""
return EmptyEmitterSet(xy_unit=self.xy_unit, px_size=self.px_size)
[docs]class LookUpPostProcessing(PostProcessing):
"""
Simple post-processing in which we threshold the probability output (raw threshold) and then look-up the features
in the respective channels.
"""
def __init__(self, raw_th: float, xy_unit: str, px_size=None,
pphotxyzbg_mapping: Union[list, tuple] = (0, 1, 2, 3, 4, -1),
photxyz_sigma_mapping: Union[list, tuple, None] = (5, 6, 7, 8)):
"""
Args:
raw_th: initial raw threshold
xy_unit: xy unit unit
px_size: pixel size
pphotxyzbg_mapping: channel index mapping of detection (p), photon, x, y, z, bg
"""
super().__init__(xy_unit=xy_unit, px_size=px_size, return_format='batch-set')
self.raw_th = raw_th
self.pphotxyzbg_mapping = pphotxyzbg_mapping
self.photxyz_sigma_mapping = photxyz_sigma_mapping
assert len(self.pphotxyzbg_mapping) == 6, "Wrong length of mapping."
if self.photxyz_sigma_mapping is not None:
assert len(self.photxyz_sigma_mapping) == 4, "Wrong length of sigma mapping."
def _filter(self, detection) -> torch.BoolTensor:
"""
Args:
detection: any tensor that should be thresholded
Returns:
boolean with active px
"""
return detection >= self.raw_th
@staticmethod
def _lookup_features(features: torch.Tensor, active_px: torch.Tensor) -> tuple:
"""
Args:
features: size :math:`(N, C, H, W)`
active_px: size :math:`(N, H, W)`
Returns:
torch.Tensor: batch-ix, size :math: `M`
torch.Tensor: extracted features size :math:`(C, M)`
"""
assert features.dim() == 4
assert active_px.dim() == features.dim() - 1
batch_ix = active_px.nonzero(as_tuple=False)[:, 0]
features_active = features.permute(1, 0, 2, 3)[:, active_px]
return batch_ix, features_active
[docs] def forward(self, x: torch.Tensor) -> EmitterSet:
"""
Forward model output tensor through post-processing and return EmitterSet. Will include sigma values in
EmitterSet if mapping was provided initially.
Args:
x: model output
Returns:
EmitterSet
"""
"""Reorder features channel-wise."""
x_mapped = x[:, self.pphotxyzbg_mapping]
"""Filter"""
active_px = self._filter(x_mapped[:, 0]) # 0th ch. is detection channel
prob = x_mapped[:, 0][active_px]
"""Look-Up in channels"""
frame_ix, features = self._lookup_features(x_mapped[:, 1:], active_px)
"""Return EmitterSet"""
xyz = features[1:4].transpose(0, 1)
"""If sigma mapping is present, get those values as well."""
if self.photxyz_sigma_mapping is not None:
sigma = x[:, self.photxyz_sigma_mapping]
_, features_sigma = self._lookup_features(sigma, active_px)
xyz_sigma = features_sigma[1:4].transpose(0, 1).cpu()
phot_sigma = features_sigma[0].cpu()
else:
xyz_sigma = None
phot_sigma = None
return EmitterSet(xyz=xyz.cpu(), frame_ix=frame_ix.cpu(), phot=features[0, :].cpu(),
xyz_sig=xyz_sigma, phot_sig=phot_sigma, bg_sig=None,
bg=features[4, :].cpu() if features.size(0) == 5 else None,
prob=prob.cpu(), xy_unit=self.xy_unit, px_size=self.px_size)
[docs]class SpatialIntegration(LookUpPostProcessing):
"""
Spatial Integration post processing.
"""
_p_aggregations = ('sum', 'norm_sum') # , 'max', 'pbinom_cdf', 'pbinom_pdf')
_split_th = 0.6
def __init__(self, raw_th: float, xy_unit: str, px_size=None,
pphotxyzbg_mapping: Union[list, tuple] = (0, 1, 2, 3, 4, -1),
photxyz_sigma_mapping: Union[list, tuple, None] = (5, 6, 7, 8),
p_aggregation: Union[str, Callable] = 'norm_sum'):
"""
Args:
raw_th: probability threshold from where detections are considered
xy_unit: unit of the xy coordinates
px_size: pixel size
pphotxyzbg_mapping: channel index mapping
photxyz_sigma_mapping: channel index mapping of sigma channels
p_aggregation: aggreation method to aggregate probabilities. can be 'sum', 'max', 'norm_sum'
"""
super().__init__(raw_th=raw_th, xy_unit=xy_unit, px_size=px_size,
pphotxyzbg_mapping=pphotxyzbg_mapping,
photxyz_sigma_mapping=photxyz_sigma_mapping)
self.p_aggregation = self.set_p_aggregation(p_aggregation)
[docs] def forward(self, x: torch.Tensor) -> EmitterSet:
x[:, 0] = self._nms(x[:, 0], self.p_aggregation, self.raw_th, self._split_th)
return super().forward(x)
@staticmethod
def _nms(p: torch.Tensor, p_aggregation, raw_th, split_th) -> torch.Tensor:
"""
Non-Maximum Suppresion
Args:
p:
"""
with torch.no_grad():
p_copy = p.clone()
"""Probability values > 0.3 are regarded as possible locations"""
p_clip = torch.where(p > raw_th, p, torch.zeros_like(p))[:, None]
"""localize maximum values within a 3x3 patch"""
pool = torch.nn.functional.max_pool2d(p_clip, 3, 1, padding=1)
max_mask1 = torch.eq(p[:, None], pool).float()
"""Add probability values from the 4 adjacent pixels"""
diag = 0. # 1/np.sqrt(2)
filt = torch.tensor([[diag, 1., diag], [1, 1, 1], [diag, 1, diag]]).unsqueeze(0).unsqueeze(0).to(p.device)
conv = torch.nn.functional.conv2d(p[:, None], filt, padding=1)
p_ps1 = max_mask1 * conv
"""
In order do be able to identify two fluorophores in adjacent pixels we look for
probablity values > 0.6 that are not part of the first mask
"""
p_copy *= (1 - max_mask1[:, 0])
# p_clip = torch.where(p_copy > split_th, p_copy, torch.zeros_like(p_copy))[:, None]
max_mask2 = torch.where(p_copy > split_th, torch.ones_like(p_copy), torch.zeros_like(p_copy))[:, None]
p_ps2 = max_mask2 * conv
"""This is our final clustered probablity which we then threshold (normally > 0.7)
to get our final discrete locations"""
p_ps = p_aggregation(p_ps1, p_ps2)
assert p_ps.size(1) == 1
return p_ps.squeeze(1)
[docs] @classmethod
def set_p_aggregation(cls, p_aggr: Union[str, Callable]) -> Callable:
"""
Sets the p_aggregation by string or callable. Return s Callable
Args:
p_aggr: probability aggregation
"""
if isinstance(p_aggr, str):
if p_aggr == 'sum':
return torch.add
elif p_aggr == 'max':
return torch.max
elif p_aggr == 'norm_sum':
def norm_sum(*args):
return torch.clamp(torch.add(*args), 0., 1.)
return norm_sum
else:
raise ValueError
else:
return p_aggr
[docs]class ConsistencyPostprocessing(PostProcessing):
"""
PostProcessing implementation that divides the output in hard and easy samples. Easy samples are predictions in
which we have a single one hot pixel in the detection channel, hard samples are pixels in the detection channel
where the adjacent pixels are also active (i.e. above a certain initial threshold).
"""
_p_aggregations = ('sum', 'max', 'pbinom_cdf', 'pbinom_pdf')
_xy_unit = 'nm'
def __init__(self, *, raw_th, em_th, xy_unit: str, img_shape, ax_th=None, vol_th=None, lat_th=None,
p_aggregation='pbinom_cdf', px_size=None, match_dims=2, diag=0, pphotxyzbg_mapping=[0, 1, 2, 3, 4, -1],
num_workers=0, skip_th: (None, float) = None, return_format='batch-set', sanity_check=True):
"""
Args:
pphotxyzbg_mapping:
raw_th:
em_th:
xy_unit:
img_shape:
ax_th:
vol_th:
lat_th:
p_aggregation:
px_size:
match_dims:
diag:
num_workers:
skip_th: relative fraction of the detection output to be on to skip post_processing.
This is useful during training when the network has not yet converged and major parts of the
detection output is white (i.e. non sparse detections).
return_format:
sanity_check:
"""
super().__init__(xy_unit=xy_unit, px_size=px_size, return_format=return_format)
self.raw_th = raw_th
self.em_th = em_th
self.p_aggregation = p_aggregation
self.match_dims = match_dims
self.num_workers = num_workers
self.skip_th = skip_th
self.pphotxyzbg_mapping = pphotxyzbg_mapping
self._filter = match_emittersets.GreedyHungarianMatching(match_dims=match_dims, dist_lat=lat_th,
dist_ax=ax_th, dist_vol=vol_th).filter
self._bg_calculator = decode.simulation.background.BgPerEmitterFromBgFrame(filter_size=13, xextent=(0., 1.),
yextent=(0., 1.),
img_shape=img_shape)
self._neighbor_kernel = torch.tensor([[diag, 1, diag], [1, 1, 1], [diag, 1, diag]]).float().view(1, 1, 3, 3)
self._clusterer = AgglomerativeClustering(n_clusters=None,
distance_threshold=lat_th if self.match_dims == 2 else vol_th,
affinity='precomputed',
linkage='single')
if sanity_check:
self.sanity_check()
[docs] @classmethod
def parse(cls, param, **kwargs):
"""
Return an instance of this post-processing as specified by the parameters
Args:
param:
Returns:
ConsistencyPostProcessing
"""
return cls(raw_th=param.PostProcessingParam.single_val_th, em_th=param.PostProcessingParam.total_th,
xy_unit='px', px_size=param.Camera.px_size,
img_shape=param.TestSet.img_size,
ax_th=param.PostProcessingParam.ax_th, vol_th=param.PostProcessingParam.vol_th,
lat_th=param.PostProcessingParam.lat_th, match_dims=param.PostProcessingParam.match_dims,
return_format='batch-set', **kwargs)
[docs] def sanity_check(self):
"""
Performs some sanity checks. Part of the constructor; useful if you modify attributes later on and want to
double check.
"""
super().sanity_check()
if self.p_aggregation not in self._p_aggregations:
raise ValueError("Unsupported probability aggregation type.")
if len(self.pphotxyzbg_mapping) != 6:
raise ValueError(f"Wrong channel mapping length.")
[docs] def skip_if(self, x):
if x.dim() != 4:
raise ValueError("Unsupported dim.")
if self.skip_th is not None and (x[:, 0] >= self.raw_th).sum() > self.skip_th * x[:, 0].numel():
return True
else:
return False
def _cluster_batch(self, p, features):
"""
Cluster a batch of frames
Args:
p (torch.Tensor): detections
features (torch.Tensor): features
Returns:
"""
clusterer = self._clusterer
p_aggregation = self.p_aggregation
if p.size(1) > 1:
raise ValueError("Not Supported shape for propbabilty.")
p_out = torch.zeros_like(p).view(p.size(0), p.size(1), -1)
feat_out = features.clone().view(features.size(0), features.size(1), -1)
"""Frame wise clustering."""
for i in range(features.size(0)):
ix = p[i, 0] > 0
if (ix == 0.).all().item():
continue
alg_ix = (p[i].view(-1) > 0).nonzero().squeeze(1)
p_frame = p[i, 0, ix].view(-1)
f_frame = features[i, :, ix]
# flatten samples and put them in the first dim
f_frame = f_frame.reshape(f_frame.size(0), -1).permute(1, 0)
filter_mask = self._filter(f_frame[:, 1:4], f_frame[:, 1:4])
if self.match_dims == 2:
dist_mat = torch.pdist(f_frame[:, 1:3])
elif self.match_dims == 3:
dist_mat = torch.pdist(f_frame[:, 1:4])
else:
raise ValueError
dist_mat = torch.from_numpy(scipy.spatial.distance.squareform(dist_mat))
dist_mat[~filter_mask] = 999999999999. # those who shall not match shall be separated, only finite vals ...
if dist_mat.shape[0] == 1:
warnings.warn("I don't know how this can happen but there seems to be a"
" single an isolated difficult case ...", stacklevel=3)
n_clusters = 1
labels = torch.tensor([0])
else:
clusterer.fit(dist_mat)
n_clusters = clusterer.n_clusters_
labels = torch.from_numpy(clusterer.labels_)
for c in range(n_clusters):
in_cluster = labels == c
feat_ix = alg_ix[in_cluster]
if p_aggregation == 'sum':
p_agg = p_frame[in_cluster].sum()
elif p_aggregation == 'max':
p_agg = p_frame[in_cluster].max()
elif p_aggregation == 'pbinom_cdf':
z = binom_pdiverse(p_frame[in_cluster].view(-1))
p_agg = z[1:].sum()
elif p_aggregation == 'pbinom_pdf':
z = binom_pdiverse(p_frame[in_cluster].view(-1))
p_agg = z[1]
else:
raise ValueError
p_out[i, 0, feat_ix[0]] = p_agg # only set first element to some probability
"""Average the features."""
feat_av = (feat_out[i, :, feat_ix] * p_frame[in_cluster]).sum(1) / p_frame[in_cluster].sum()
feat_out[i, :, feat_ix] = feat_av.unsqueeze(1).repeat(1, in_cluster.sum())
return p_out.reshape(p.size()), feat_out.reshape(features.size())
def _forward_raw_impl(self, p, features):
"""
Actual implementation.
Args:
p:
features:
Returns:
"""
with torch.no_grad():
"""First init by an first threshold to get rid of all the nonsense"""
p_clip = torch.zeros_like(p)
is_above_svalue = p > self.raw_th
p_clip[is_above_svalue] = p[is_above_svalue]
# p_clip_rep = p_clip.repeat(1, features.size(1), 1, 1) # repeated to access the features
"""Compute Local Mean Background"""
if features.size(1) == 5:
bg_out = features[:, [4]]
bg_out = self._bg_calculator._mean_filter(bg_out).cpu()
else:
bg_out = None
"""Divide the set in easy (no neighbors) and set of predictions with adjacents"""
binary_mask = torch.zeros_like(p_clip)
binary_mask[p_clip > 0] = 1.
# count neighbors
self._neighbor_kernel = self._neighbor_kernel.type(binary_mask.dtype).to(binary_mask.device)
count = torch.nn.functional.conv2d(binary_mask, self._neighbor_kernel, padding=1) * binary_mask
# divide in easy and difficult set
is_easy = count == 1
is_easy_rep = is_easy.repeat(1, features.size(1), 1, 1)
is_diff = count > 1
is_diff_rep = is_diff.repeat(1, features.size(1), 1, 1)
p_easy = torch.zeros_like(p_clip)
p_diff = p_easy.clone()
feat_easy = torch.zeros_like(features)
feat_diff = feat_easy.clone()
p_easy[is_easy] = p_clip[is_easy]
feat_easy[is_easy_rep] = features[is_easy_rep]
p_diff[is_diff] = p_clip[is_diff]
feat_diff[is_diff_rep] = features[is_diff_rep]
p_out = torch.zeros_like(p_clip).cpu()
feat_out = torch.zeros_like(feat_diff).cpu()
"""Cluster the hard cases if they are consistent given euclidean affinity."""
if self.num_workers == 0:
p_out_diff, feat_out_diff = self._cluster_batch(p_diff.cpu(), feat_diff.cpu())
else:
raise NotImplementedError
"""Add the easy ones."""
p_out[is_easy] = p_easy[is_easy].cpu()
p_out[is_diff] = p_out_diff[is_diff].cpu()
feat_out[is_easy_rep] = feat_easy[is_easy_rep].cpu()
feat_out[is_diff_rep] = feat_out_diff[is_diff_rep].cpu()
"""Write the bg frame"""
if features.size(1) == 5:
feat_out[:, [4]] = bg_out
return p_out, feat_out
def _frame2emitter(self, p: torch.Tensor, features: torch.Tensor):
"""
Convert frame based features to tensor based features (go frame from image world to emitter world)
Args:
p (torch.Tensor): detection channel
features (torch.Tensor): features
Returns:
(torch.Tensor, torch.Tensor, torch.Tensor)
feat_out: output features
p_out: final probabilities
batch_ix: batch index
"""
is_pos = (p >= self.em_th).nonzero() # is above threshold
p_out = p[p >= self.em_th]
# look up features
feat_out = features[is_pos[:, 0], :, is_pos[:, 2], is_pos[:, 3]]
# pick corresponding batch index
batch_ix = torch.ones_like(p) * torch.arange(p.size(0), dtype=features.dtype).view(-1, 1, 1, 1) # bookkeep
batch_ix = batch_ix[is_pos[:, 0], :, is_pos[:, 2], is_pos[:, 3]]
return feat_out, p_out, batch_ix.long()
[docs] def forward(self, features: torch.Tensor):
"""
Forward the feature map through the post processing and return an EmitterSet or a list of EmitterSets.
For the input features we use the following convention:
0 - Detection channel
1 - Photon channel
2 - 'x' channel
3 - 'y' channel
4 - 'z' channel
5 - Background channel
Expecting x and y channels in nano-metres.
Args:
features (torch.Tensor): Features of size :math:`(N, C, H, W)`
Returns:
EmitterSet or list of EmitterSets: Specified by return_format argument, EmitterSet in nano metres.
"""
if self.skip_if(features):
return EmptyEmitterSet(xy_unit=self.xy_unit, px_size=self.px_size)
if features.dim() != 4:
raise ValueError("Wrong dimensionality. Needs to be N x C x H x W.")
features = features[:, self.pphotxyzbg_mapping] # change channel order if needed
p = features[:, [0], :, :]
features = features[:, 1:, :, :] # phot, x, y, z, bg
p_out, feat_out = self._forward_raw_impl(p, features)
feature_list, prob_final, frame_ix = self._frame2emitter(p_out, feat_out)
frame_ix = frame_ix.squeeze()
return EmitterSet(xyz=feature_list[:, 1:4], phot=feature_list[:, 0], frame_ix=frame_ix,
prob=prob_final, bg=feature_list[:, 4],
xy_unit=self.xy_unit, px_size=self.px_size)