import warnings
from abc import ABC, abstractmethod
from collections import namedtuple
import numpy as np
import torch
from decode.generic import emitter as emitter
[docs]class EmitterMatcher(ABC):
"""
Abstract emitter matcher class.
"""
_return_match = namedtuple('MatchResult', ['tp', 'fp', 'fn', 'tp_match']) # return-type as namedtuple
def __init__(self):
super().__init__()
[docs] @abstractmethod
def forward(self, output: emitter.EmitterSet, target: emitter.EmitterSet) -> _return_match:
"""
All implementations shall implement this forward method which takes output and reference set of emitters
and outputs true positives, false positives, false negatives and matching ground truth (matched to the true positives).
Args:
output: output set of emitters
target: reference set of emitters
Returns:
(emitter.EmitterSet, emitter.EmitterSet, emitter.EmitterSet, emitter.EmitterSet)
- **tp**: true positives
- **fp**: false positives
- **fn**: false negatives
- **tp_match**: ground truths that have been matched to the true positives
"""
raise NotImplementedError
[docs]class GreedyHungarianMatching(EmitterMatcher):
"""
Matching emitters in a greedy 'hungarian' fashion, by using best first search.
"""
def __init__(self, *, match_dims: int, dist_ax: float = None, dist_lat: float = None, dist_vol: float = None):
"""
Args:
match_dims: match in 2D or 3D
dist_lat: lateral tolerance radius
dist_ax: axial tolerance threshold
dist_vol: volumetric tolerance radius
"""
super().__init__()
self.match_dims = match_dims
self.dist_ax = dist_ax
self.dist_lat = dist_lat
self.dist_vol = dist_vol
"""Sanity checks"""
if self.match_dims not in (2, 3):
raise ValueError("Not supported match dimensionality.")
if self.dist_lat is not None and self.dist_ax is not None and self.dist_vol is not None:
warnings.warn("You specified a lateral, axial and volumetric threshold. "
"While this is allowed; are you sure?")
if self.dist_lat is None and self.dist_ax is None and self.dist_vol is None:
warnings.warn("You specified neither a lateral, axial nor volumetric threshold. Are you sure about this?")
[docs] @classmethod
def parse(cls, param):
return cls(match_dims=param.Evaluation.match_dims,
dist_lat=param.Evaluation.dist_lat,
dist_ax=param.Evaluation.dist_ax,
dist_vol=param.Evaluation.dist_vol)
[docs] def filter(self, xyz_out, xyz_tar) -> torch.Tensor:
"""
Filter kernel to rule out unwanted matches. Batch implemented, i.e. input can be 2 or 3 dimensional, where the
latter dimensions are the dimensions of interest.
Args:
xyz_out: output coordinates, shape :math: `(B x) N x 3`
xyz_tar: target coordinates, shape :math: `(B x) M x 3`
Returns:
filter_mask (torch.Tensor): boolean of size (B x) N x M
"""
if xyz_out.dim() == 3:
assert xyz_out.size(0) == xyz_tar.size(0)
sque_ret = False # no squeeze before return
else:
xyz_out = xyz_out.unsqueeze(0)
xyz_tar = xyz_tar.unsqueeze(0)
sque_ret = True # squeeze before return
filter_mask = torch.ones((xyz_out.size(0), xyz_out.size(1), xyz_tar.size(1))).bool() # dim: B x N x M
if self.dist_lat is not None:
dist_mat = torch.cdist(xyz_out[:, :, :2], xyz_tar[:, :, :2], p=2)
filter_mask[dist_mat > self.dist_lat] = 0
if self.dist_ax is not None:
dist_mat = torch.cdist(xyz_out[:, :, [2]], xyz_tar[:, :, [2]], p=2)
filter_mask[dist_mat > self.dist_ax] = 0
if self.dist_vol is not None:
dist_mat = torch.cdist(xyz_out, xyz_tar, p=2)
filter_mask[dist_mat > self.dist_vol] = 0
if sque_ret:
filter_mask = filter_mask.squeeze(0)
return filter_mask
@staticmethod
def _rule_out_kernel(dists):
"""
Kernel which goes through the distance matrix, picks shortest distance and assign match.
Actual 'greedy' kernel
Args:
dists: distance matrix
Returns:
"""
assert dists.dim() == 2
if dists.numel() == 0:
return torch.zeros((0,)).long(), torch.zeros((0,)).long()
dists_ = dists.clone()
match_list = []
while not (dists_ == float('inf')).all():
ix = np.unravel_index(dists_.argmin(), dists_.shape)
dists_[ix[0]] = float('inf')
dists_[:, ix[1]] = float('inf')
match_list.append(ix)
if match_list.__len__() >= 1:
assignment = torch.tensor(match_list).long()
else:
assignment = torch.zeros((0, 2)).long()
return assignment[:, 0], assignment[:, 1]
def _match_kernel(self, xyz_out, xyz_tar, filter_mask):
"""
Args:
xyz_out: N x 3 - no batch implementation currently
xyz_tar: M x 3 - no batch implementation currently
filter_mask: N x M - not batched
Returns:
tp_ix_: (boolean) index for xyz_out
tp_match_ix_: (boolean) index for matching xyz_tar
"""
assert filter_mask.dim() == 2
assert filter_mask.size() == torch.Size([xyz_out.size(0), xyz_tar.size(0)])
if self.match_dims == 2:
dist_mat = torch.cdist(xyz_out[None, :, :2], xyz_tar[None, :, :2], p=2).squeeze(0)
elif self.match_dims == 3:
dist_mat = torch.cdist(xyz_out[None, :, :], xyz_tar[None, :, :], p=2).squeeze(0)
else:
raise ValueError
dist_mat[~filter_mask] = float('inf') # rule out matches by filter
tp_ix, tp_match_ix = self._rule_out_kernel(dist_mat)
tp_ix_bool = torch.zeros(xyz_out.size(0)).bool()
tp_ix_bool[tp_ix] = 1
tp_match_ix_bool = torch.zeros(xyz_tar.size(0)).bool()
tp_match_ix_bool[tp_match_ix] = 1
return tp_ix, tp_match_ix, tp_ix_bool, tp_match_ix_bool
[docs] def forward(self, output: emitter.EmitterSet, target: emitter.EmitterSet):
"""Setup split in frames. Determine the frame range automatically so as to cover everything."""
if len(output) >= 1 and len(target) >= 1:
frame_low = output.frame_ix.min() if output.frame_ix.min() < target.frame_ix.min() else target.frame_ix.min()
frame_high = output.frame_ix.max() if output.frame_ix.max() > target.frame_ix.max() else target.frame_ix.max()
elif len(output) >= 1:
frame_low = output.frame_ix.min()
frame_high = output.frame_ix.max()
elif len(target) >= 1:
frame_low = target.frame_ix.min()
frame_high = target.frame_ix.max()
else:
return (emitter.EmptyEmitterSet(xy_unit=target.xyz, px_size=target.px_size),) * 4
out_pframe = output.split_in_frames(frame_low.item(), frame_high.item())
tar_pframe = target.split_in_frames(frame_low.item(), frame_high.item())
tpl, fpl, fnl, tpml = [], [], [], [] # true positive list, false positive list, false neg. ...
"""Match the emitters framewise"""
for out_f, tar_f in zip(out_pframe, tar_pframe):
filter_mask = self.filter(out_f.xyz_nm, tar_f.xyz_nm) # batch implemented
tp_ix, tp_match_ix, tp_ix_bool, tp_match_ix_bool = self._match_kernel(out_f.xyz_nm, tar_f.xyz_nm,
filter_mask) # non batch impl.
tpl.append(out_f[tp_ix])
tpml.append(tar_f[tp_match_ix])
fpl.append(out_f[~tp_ix_bool])
fnl.append(tar_f[~tp_match_ix_bool])
"""Concat them back"""
tp = emitter.EmitterSet.cat(tpl)
fp = emitter.EmitterSet.cat(fpl)
fn = emitter.EmitterSet.cat(fnl)
tp_match = emitter.EmitterSet.cat(tpml)
"""Let tp and tp_match share the same id's. IDs of ground truth are copied to true positives."""
if (tp_match.id == -1).all().item():
tp_match.id = torch.arange(len(tp_match)).type(tp_match.id.dtype)
tp.id = tp_match.id.type(tp.id.dtype)
return self._return_match(tp=tp, fp=fp, fn=fn, tp_match=tp_match)