import matplotlib.pyplot as plt
import seaborn as sns
import torch
import warnings
from . import utils
[docs]def deviation_dist(x: torch.Tensor, x_gt: torch.Tensor, residuals=False, kde=True, ax=None, nan_okay=True):
"""Log z vs z_gt"""
if ax is None:
ax = plt.gca()
if len(x) == 0:
ax.set_ylabel('no data')
return ax
if residuals:
x = x - x_gt
if not torch.isnan(x).any():
if kde:
utils.kde_sorted(x_gt, x, True, ax, sub_sample=10000, nan_inf_ignore=True)
else:
ax.plot(x_gt, x, 'x')
else:
if not nan_okay:
raise ValueError(f"Some of the values are NaN.")
if residuals:
ax.plot([x_gt.min(), x_gt.max()], [0, 0], 'green')
ax.set_ylabel('residuals')
else:
ax.plot([x_gt.min(), x_gt.max()], [x_gt.min(), x_gt.max()], 'green')
ax.set_ylabel('prediction')
ax.set_xlabel('ground truth')
return ax
[docs]def px_pointer_dist(pointer, px_border: float, px_size: float):
"""
Args:
pointer:
px_border: lower limit of pixel (most commonly -0.5)
px_size: size of pixel (most commonly 1.)
Returns:
"""
x = (pointer - px_border) % px_size + px_border
return x
[docs]def emitter_deviations(tp, tp_match, px_border: float, px_size: float, axes, residuals=False, kde=True):
"""Plot within px distribution"""
assert len(axes) == 4
"""XY within px"""
with warnings.catch_warnings():
warnings.simplefilter("ignore")
sns.distplot(px_pointer_dist(tp.xyz_px[:, 0], px_border=px_border, px_size=px_size), norm_hist=True, ax=axes[0], bins=50)
sns.distplot(px_pointer_dist(tp.xyz_px[:, 1], px_border=px_border, px_size=px_size), norm_hist=True, ax=axes[1], bins=50)
"""Z and Photons"""
deviation_dist(tp.xyz_nm[:, 2], tp_match.xyz_nm[:, 2], residuals=residuals, kde=kde, ax=axes[2])
deviation_dist(tp.phot, tp_match.phot, residuals=residuals, kde=kde, ax=axes[3])