Source code for decode.evaluation.evaluation

# from abc import ABC
import warnings
from collections import namedtuple

import matplotlib.pyplot as plt
import scipy.stats
import seaborn as sns
import torch

from decode.evaluation.metric import precision_recall_jaccard, rmse_mad_dist, efficiency
from ..generic import emitter as emitter
from ..generic.emitter import EmitterSet


[docs]class SegmentationEvaluation: """ Wrapper class that holds all segementation evaluations in one place. """ _seg_eval_return = namedtuple("seg_eval", ["prec", "rec", "jac", "f1"]) def __init__(self): self._tp = None self._fp = None self._fn = None self._prec = None self._rec = None self._jac = None self._f1 = None def __str__(self): if self._tp is None or self._fp is None or self._fn is None: return "Segmentation evaluation unavailable. Run .forward(tp, fp, fn)" actual_em = len(self._tp) + len(self._fn) pred_em = len(self._tp) + len(self._fp) str_repr = "Segmentation evaluation (cached values)\n" str_repr += f"Number of actual emitters: {actual_em} Predicted emitters: {pred_em}\n" str_repr += f"Number of TP: {len(self._tp)} FP: {len(self._fp)} FN: {len(self._fn)}\n" str_repr += f"Jacquard: {self._jac:.3f}\n" str_repr += f"F1Score: {self._f1:.3f}\n" str_repr += f"Precision: {self._prec:.3f}, Recall: {self._rec:.3f}\n" return str_repr
[docs] def forward(self, tp: EmitterSet, fp: EmitterSet, fn: EmitterSet): """ Forward emitters through evaluation. Args: tp: true positives fp: false positives fn: false negatives Returns: prec (float): precision value rec (float): recall value jac (float): jaccard value f1 (float): f1 score value """ prec, rec, jac, f1 = precision_recall_jaccard(len(tp), len(fp), len(fn)) """Store last result to cache""" self._tp, self._fp, self._fn = tp, fp, fn self._prec, self._rec, self._jac, self._f1 = prec, rec, jac, f1 return self._seg_eval_return(prec=prec, rec=rec, jac=jac, f1=f1) # namedtuple
[docs]class DistanceEvaluation: """ A small wrapper calss that holds distance evaluations and accepts sets of emitters as inputs. """ _dist_eval_return = namedtuple("dist_eval", ["rmse_lat", "rmse_ax", "rmse_vol", "mad_lat", "mad_ax", "mad_vol"]) def __init__(self): self._rmse_lat = None self._rmse_ax = None self._rmse_vol = None self._mad_lat = None self._mad_ax = None self._mad_vol = None def __str__(self): if self._rmse_lat is None: return "Distance Evaluation unavailable. Run .forward(tp, tp_match)." str_repr = "Distance Evaluation (cached values)\n" str_repr += f"RMSE: Lat. {self._rmse_lat:.3f} Axial. {self._rmse_ax:.3f} Vol. {self._rmse_vol:.3f}\n" str_repr += f"MAD: Lat. {self._mad_lat:.3f} Axial. {self._mad_ax:.3f} Vol. {self._mad_vol:.3f}\n" return str_repr
[docs] def forward(self, tp: EmitterSet, tp_match: EmitterSet): """ Args: tp: true positives tp_match: matching ground truths Returns: rmse_lat: RMSE lateral rmse_ax: RMSE axial rmse_vol: RMSE volumetric mad_lat: MAD lateral mad_ax: MAD axial mad_vol: MAD volumetric """ rmse_lat, rmse_axial, rmse_vol, mad_lat, mad_axial, mad_vol = rmse_mad_dist(tp.xyz_nm, tp_match.xyz_nm) """Store in cache""" self._rmse_lat, self._rmse_ax, self._rmse_vol = rmse_lat, rmse_axial, rmse_vol self._mad_lat, self._mad_ax, self._mad_vol = mad_lat, mad_axial, mad_vol return self._dist_eval_return(rmse_lat=rmse_lat, rmse_ax=rmse_axial, rmse_vol=rmse_vol, mad_lat=mad_lat, mad_ax=mad_axial, mad_vol=mad_vol) # namedtuple
[docs]class WeightedErrors: """ Weighted deviations. """ _modes_all = ('phot', 'crlb') _reduction_all = ('mstd', 'gaussian') _return = namedtuple("weighted_err", ["dxyz_red", "dphot_red", "dbg_red", "dxyz_w", "dphot_w", "dbg_w"]) def __init__(self, mode, reduction): self.mode = mode self.reduction = reduction """Sanity check""" if self.mode not in self._modes_all: raise ValueError(f"Mode {self.mode} not implemented. Available modes are {self._modes_all}") if self.reduction not in self._reduction_all: raise ValueError(f"Reduction type {self.reduction} not implemented. Available reduction types" f"are {self._reduction_all}.") @staticmethod def _reduce(dxyz: torch.Tensor, dphot: torch.Tensor, dbg: torch.Tensor, reduction): """ Reduce the weighted errors as by the specified method. Args: dxyz (torch.Tensor): weighted err in xyz, N x 3 dphot (torch.Tensor): weighted err in phot, N dbg (torch.Tensor): weighted err in bg, N reduction (string,None): reduction type Returns: (torch.Tensor or tuple of tensors) """ def norm_fit_nan(input_data, warning=True): try: out = scipy.stats.norm.fit(input_data) out = torch.tensor(out) except (ValueError, RuntimeError): warnings.warn("Non-Finite values encountered during fitting.") out = float('nan') * torch.ones(2) return out if reduction == 'mstd': return (dxyz.mean(0), dxyz.std(0)), (dphot.mean(), dphot.std()), (dbg.mean(), dbg.std()) elif reduction == 'gaussian': dxyz_mu_sig = torch.stack([norm_fit_nan(dxyz[:, i]) for i in range(3)], 0) dphot_mu_sig = norm_fit_nan(dphot) dbg_mu_sig = norm_fit_nan(dbg) return (dxyz_mu_sig[:, 0], dxyz_mu_sig[:, 1]), \ (dphot_mu_sig[0], dphot_mu_sig[1]), \ (dbg_mu_sig[0], dbg_mu_sig[1]) else: raise ValueError
[docs] @staticmethod def plot_error(dxyz, dphot, dbg, axes=None): """ Plot the histograms Args: dxyz (torch.Tensor): weighted err in xyz, N x 3 dphot (torch.Tensor): weighted err in phot, N dbg (torch.Tensor): weighted err in bg, N axes (tuple of axes,None): axes to which to plot to, tuple of size 6 or None Returns: axes """ if axes is None: _, axes = plt.subplots(5) # axes = [axes[0, 0], axes[0, 1], axes[0, 2], axes[1, 0], axes[1, 1], axes[1, 2]] else: if len(axes) != 5: raise ValueError("You must parse exactly 6 axes objects or None.") if len(dxyz) == 0: return axes if len(dxyz[:, 0]) != len(dphot) or len(dphot) != len(dbg): raise ValueError("Inconsistent number of elements.") sns.distplot(dxyz[:, 0].numpy(), norm_hist=True, kde=False, fit=scipy.stats.norm, ax=axes[0]) sns.distplot(dxyz[:, 1].numpy(), norm_hist=True, kde=False, fit=scipy.stats.norm, ax=axes[1]) sns.distplot(dxyz[:, 2].numpy(), norm_hist=True, kde=False, fit=scipy.stats.norm, ax=axes[2]) sns.distplot(dphot.numpy(), norm_hist=True, kde=False, fit=scipy.stats.norm, ax=axes[3]) sns.distplot(dbg.numpy(), norm_hist=True, kde=False, fit=scipy.stats.norm, ax=axes[4]) return axes
[docs] def forward(self, tp: emitter.EmitterSet, ref: emitter.EmitterSet, plot: bool = False, axes=None) -> namedtuple: """ Args: tp (EmitterSet): true positives ref (EmitterSet): matching ground truth plot (bool): plot histograms axes (list,tuple): axis to which to plot the histograms Returns: """ if len(tp) != len(ref): raise ValueError(f"Size of true positives ({len(tp)}) does not match size of reference ({len(ref)}).") dxyz = tp.xyz_nm - ref.xyz_nm dphot = tp.phot - ref.phot dbg = tp.bg - ref.bg if self.mode == 'phot': """Definition of the 0st / 1st order approximations for the sqrt cramer rao""" xyz_scr_est = 1 / ref.phot.unsqueeze(1).sqrt() phot_scr_est = ref.phot.sqrt() bg_scr_est = ref.bg.sqrt() dxyz_w = dxyz / xyz_scr_est dphot_w = dphot / phot_scr_est dbg_w = dbg / bg_scr_est elif self.mode == 'crlb': dxyz_w = dxyz / ref.xyz_scr_nm dphot_w = dphot / ref.phot_scr dbg_w = dbg / ref.bg_scr else: raise ValueError if plot: _ = self.plot_error(dxyz_w, dphot_w, dbg_w, axes=axes) dxyz_wred, dphot_wred, dbg_wred = self._reduce(dxyz_w, dphot_w, dbg_w, reduction=self.reduction) return self._return(dxyz_red=dxyz_wred, dphot_red=dphot_wred, dbg_red=dbg_wred, dxyz_w=dxyz_w, dphot_w=dphot_w, dbg_w=dbg_w)
[docs]class SMLMEvaluation: """ Just a wrapper class to combine things into one. """ alpha_lat = 1 # nm alpha_ax = 0.5 # nm _return = namedtuple("eval_set", ["prec", "rec", "jac", "f1", "effcy_lat", "effcy_ax", "effcy_vol", "rmse_lat", "rmse_ax", "rmse_vol", "mad_lat", "mad_ax", "mad_vol", "dx_red_mu", "dx_red_sig", "dy_red_mu", "dy_red_sig", "dz_red_mu", "dz_red_sig", "dphot_red_mu", "dphot_red_sig"]) descriptors = { 'pred': 'Precision', 'rec': 'Recall', 'jac': 'Jaccard Index', 'rmse_lat': 'RMSE lateral', 'rmse_ax': 'RMSE axial', 'rmse_vol': 'RMSE volumetric', 'mad_lat': 'Mean average distance lateral', 'mad_ax': 'Mean average distance axial', 'mad_vol': 'Mean average distance in 3D', 'dx_red_sig': 'CRLB normalised error in x', 'dy_red_sig': 'CRLB normalised error in y', 'dz_red_sig': 'CRLB normalised error in z', 'dx_red_mu': 'CRLB normalised bias in x', 'dy_red_mu': 'CRLB normalised bias in y', 'dz_red_mu': 'CRLB normalised bias in z', } def __init__(self, seg_eval=SegmentationEvaluation(), dist_eval=DistanceEvaluation(), weighted_eval=WeightedErrors(mode='crlb', reduction='gaussian')): self.seg_eval = seg_eval self.dist_eval = dist_eval self.weighted_eval = weighted_eval self.prec = None self.rec = None self.jac = None self.f1 = None self.rmse_vol = None self.rmse_lat = None self.rmse_ax = None self.mad_vol = None self.mad_lat = None self.mad_ax = None @property def effcy_lat(self): return efficiency(self.jac, self.rmse_lat, self.alpha_lat) @property def effcy_ax(self): return efficiency(self.jac, self.rmse_ax, self.alpha_ax) @property def effcy_vol(self): return (self.effcy_lat + self.effcy_ax) / 2 def __str__(self): str = "------------------------ Evaluation Set ------------------------\n" str += "Precision {}\n".format(self.prec.__str__()) str += "Recall {}\n".format(self.rec.__str__()) str += "Jaccard {}\n".format(self.jac.__str__()) str += "F1Score {}\n".format(self.f1.__str__()) str += "RMSE lat. {}\n".format(self.rmse_lat.__str__()) str += "RMSE ax. {}\n".format(self.rmse_axial.__str__()) str += "RMSE vol. {}\n".format(self.rmse_vol.__str__()) str += "MAD lat. {}\n".format(self.mad_lat.__str__()) str += "MAD ax. {}\n".format(self.mad_axial.__str__()) str += "MAD vol. {}\n".format(self.mad_vol.__str__()) str += "Efficiency lat. {}\n".format(self.effcy_lat.__str__()) str += "Efficiency ax. {}\n".format(self.effcy_ax.__str__()) str += "Efficiency vol. {}\n".format(self.effcy_vol.__str__()) str += "-----------------------------------------------------------------" return str
[docs] def forward(self, tp, fp, fn, p_ref) -> _return: """ Evaluate sets of emitters by all available metrics. Args: tp: true positives fp: false positives fn: false negatives p_ref: true positive references (i.e. the ground truth that has been matched to tp) Returns: namedtuple: A namedtuple of floats containing - **prec** (*float*): Precision - **rec** (*float*): Recall - **jac** (*float*): Jaccard - **f1** (*float*): F1-Score - **effcy_lat** (*float*): Efficiency lateral - **effcy_ax** (*float*): Efficiency axial - **effcy_vol** (*float*): Efficiency volumetric - **rmse_lat** (*float*): RMSE lateral - **rmse_ax** (*float*): RMSE axial - **rmse_vol** (*float*): RMSE volumetric - **mad_lat** (*float*): MAD lateral - **mad_ax** (*float*): MAD axial - **mad_vol** (*float*): MAD volumetric """ seg_out = self.seg_eval.forward(tp, fp, fn) dist_out = self.dist_eval.forward(tp, p_ref) weight_out = self.weighted_eval.forward(tp, p_ref, plot=False) self.prec, self.rec, self.jac, self.f1 = seg_out.prec, seg_out.rec, seg_out.jac, seg_out.f1 self.rmse_lat = dist_out.rmse_lat self.rmse_ax = dist_out.rmse_ax self.rmse_vol = dist_out.rmse_vol self.mad_lat = dist_out.mad_lat self.mad_ax = dist_out.mad_ax self.mad_vol = dist_out.mad_vol dx_red = (weight_out.dxyz_red[0][0].item(), weight_out.dxyz_red[1][0].item()) dy_red = (weight_out.dxyz_red[0][1].item(), weight_out.dxyz_red[1][1].item()) dz_red = (weight_out.dxyz_red[0][2].item(), weight_out.dxyz_red[1][2].item()) return self._return(prec=seg_out.prec, rec=seg_out.rec, jac=seg_out.jac, f1=seg_out.f1, effcy_lat=self.effcy_lat, effcy_ax=self.effcy_ax, effcy_vol=self.effcy_vol, rmse_lat=dist_out.rmse_lat, rmse_ax=dist_out.rmse_ax, rmse_vol=dist_out.rmse_vol, mad_lat=dist_out.mad_lat, mad_ax=dist_out.mad_ax, mad_vol=dist_out.mad_vol, dx_red_mu=dx_red[0], dx_red_sig=dx_red[1], dy_red_mu=dy_red[0], dy_red_sig=dy_red[1], dz_red_mu=dz_red[0], dz_red_sig=dz_red[1], dphot_red_mu=weight_out.dphot_red[0].item(), dphot_red_sig=weight_out.dphot_red[1].item())