Source code for decode.neuralfitter.utils.log_train_val_progress

import warnings

import matplotlib.pyplot as plt
import torch
import seaborn as sns

import decode.generic.emitter
from decode.evaluation.evaluation import WeightedErrors
from decode.evaluation import predict_dist
from decode.plot import frame_coord

from decode.evaluation import evaluation


[docs]def log_frames(x, y_out, y_tar, weight, em_out, em_tar, tp, tp_match, logger, step, colorbar=True): r_ix = torch.randint(0, len(x), (1, )).long().item() assert x.dim() == 4 # rm batch dimension, i.e. select one sample x = x[r_ix] y_out = y_out[r_ix] y_tar = y_tar[r_ix] if y_tar is not None else None weight = weight[r_ix] if weight is not None else None assert isinstance(em_tar, decode.generic.emitter.EmitterSet) em_tar = em_tar.get_subset_frame(r_ix, r_ix) em_out = em_out.get_subset_frame(r_ix, r_ix) em_tp = tp.get_subset_frame(r_ix, r_ix) em_tp_match = tp_match.get_subset_frame(r_ix, r_ix) # loop over all input channels for i, xc in enumerate(x): f_input = plt.figure() frame_coord.PlotFrameCoord(xc, pos_tar=em_tar.xyz_px, plot_colorbar_frame=colorbar).plot() logger.add_figure('input/raw_input_ch_' + str(i), f_input, step) # loop over all output channels for i, yc in enumerate(y_out): f_out = plt.figure() frame_coord.PlotFrameCoord(yc, plot_colorbar_frame=colorbar).plot() logger.add_figure('output/raw_output_ch_' + str(i), f_out, step) # record tar / output emitters tar_ch = (x.size(0) - 1) // 2 f_em_out = plt.figure(figsize=(10, 8)) frame_coord.PlotFrameCoord(x[tar_ch], pos_tar=em_tar.xyz_px, pos_out=em_out.xyz_px).plot() logger.add_figure('em_out/em_out_tar', f_em_out, step) f_em_out3d = plt.figure(figsize=(10, 8)) frame_coord.PlotCoordinates3D(pos_tar=em_tar.xyz_px, pos_out=em_out.xyz_px).plot() logger.add_figure('em_out/em_out_tar_3d', f_em_out3d, step) f_match = plt.figure(figsize=(10, 8)) frame_coord.PlotFrameCoord(x[tar_ch], pos_tar=em_tp_match.xyz_px, pos_out=em_tp.xyz_px, match_lines=True, labels=('TP match', 'TP')).plot() logger.add_figure('em_out/em_match', f_match, step) f_match_3d = plt.figure(figsize=(10, 8)) frame_coord.PlotCoordinates3D(pos_tar=em_tp_match.xyz_px, pos_out=em_tp.xyz_px, match_lines=True, labels=('TP match', 'TP')).plot() logger.add_figure('em_out/em_match_3d', f_match_3d, step) # loop over all target channels if y_tar is not None: for i, yct in enumerate(y_tar): f_tar = plt.figure() frame_coord.PlotFrameCoord(yct, plot_colorbar_frame=colorbar).plot() logger.add_figure('target/target_ch_' + str(i), f_tar, step) # loop over all weight channels if weight is not None: for i, w in enumerate(weight): f_w = plt.figure() frame_coord.PlotFrameCoord(w, plot_colorbar_frame=colorbar).plot() logger.add_figure('weight/weight_ch_' + str(i), f_w, step)
# plot dist of probability channel # ToDo: Histplots seem to cause trouble with memory. Deactivated for now. If reactivate: change back to distplot # f_prob_dist, ax_prob_dist = plt.subplots() # sns.histplot(y_out[0].reshape(-1).numpy(), kde=False, ax=ax_prob_dist) # plt.xlabel('prob') # logger.add_figure('output_dist/prob', f_prob_dist) # # f_prob_dist_log, ax_prob_dist_log = plt.subplots() # sns.histplot(y_out[0].reshape(-1).numpy(), kde=False, ax=ax_prob_dist_log) # plt.yscale('log') # plt.xlabel('prob') # logger.add_figure('output_dist/prob_log', f_prob_dist_log)
[docs]def log_kpi(loss_scalar: float, loss_cmp: dict, eval_set: dict, logger, step): logger.add_scalar('learning/test_ep', loss_scalar, step) assert loss_cmp.dim() >= 2 for i in range(loss_cmp.size(1)): # channel-wise mean logger.add_scalar('loss_cmp/test_ep_loss_ch_' + str(i), loss_cmp[:, i].mean(), step) logger.add_scalar_dict('eval/', eval_set, step)
[docs]def log_dists(tp, tp_match, pred, px_border, px_size, logger, step): """Log z vs z_gt""" f_x, ax_x = plt.subplots() f_y, ax_y = plt.subplots() f_z, ax_z = plt.subplots() f_phot, ax_phot = plt.subplots() predict_dist.emitter_deviations(tp, tp_match, px_border=px_border, px_size=px_size, axes=[ax_x, ax_y, ax_z, ax_phot]) logger.add_figure('dist/x_offset', f_x, step) logger.add_figure('dist/y_offset', f_y, step) logger.add_figure('residuals/z_gt_pred', f_z, step) logger.add_figure('residuals/phot_gt_pred', f_phot, step) """Log prob dist""" f_prob, ax_prob = plt.subplots() with warnings.catch_warnings(): warnings.simplefilter("ignore") sns.distplot(pred.prob, bins=50, norm_hist=True, ax=ax_prob, kde=False) logger.add_figure('dist/prob', f_prob, step)
[docs]def log_train(*, loss_p_batch: (list, tuple), loss_mean: float, logger, step: int): logger.add_scalar('learning/train_ep', loss_mean, step) for i, loss_batch in enumerate(loss_p_batch): step_batch = step * len(loss_p_batch) + i if i % 10 != 0: continue logger.add_scalar('learning/train_batch', loss_batch, step_batch)
[docs]def post_process_log_test(*, loss_cmp, loss_scalar, x, y_out, y_tar, weight, em_tar, px_border, px_size, post_processor, matcher, logger, step): """Post-Process""" em_out = post_processor.forward(y_out) """Match and Evaluate""" tp, fp, fn, tp_match = matcher.forward(em_out, em_tar) with warnings.catch_warnings(): warnings.simplefilter("ignore") result = evaluation.SMLMEvaluation(weighted_eval=WeightedErrors(mode='crlb', reduction='gaussian')).forward(tp, fp, fn, tp_match) """Log""" # raw frames log_frames(x=x, y_out=y_out, y_tar=y_tar, weight=weight, em_out=em_out, em_tar=em_tar, tp=tp, tp_match=tp_match, logger=logger, step=step) # KPIs log_kpi(loss_scalar=loss_scalar, loss_cmp=loss_cmp, eval_set=result._asdict(), logger=logger, step=step) # distributions log_dists(tp=tp, tp_match=tp_match, pred=em_out, px_border=px_border, px_size=px_size, logger=logger, step=step) return