import argparse
import copy
import datetime
import os
import shutil
import socket
import sys
from pathlib import Path
import torch
import decode.evaluation
import decode.neuralfitter
import decode.neuralfitter.coord_transform
import decode.neuralfitter.utils
import decode.simulation
import decode.utils
from decode.neuralfitter.train.random_simulation import setup_random_simulation
from decode.neuralfitter.utils import log_train_val_progress
from decode.utils.checkpoint import CheckPoint
[docs]def parse_args():
parser = argparse.ArgumentParser(description='Training Args')
parser.add_argument('-i', '--device', default=None,
help='Specify the device string (cpu, cuda, cuda:0) and overwrite param.',
type=str, required=False)
parser.add_argument('-p', '--param_file',
help='Specify your parameter file (.yml or .json).',
required=True)
parser.add_argument('-d', '--debug', default=False, action='store_true',
help='Debug the specified parameter file. Will reduce ds size for example.')
parser.add_argument('-w', '--num_worker_override',
help='Override the number of workers for the dataloaders.',
type=int)
parser.add_argument('-n', '--no_log', default=False, action='store_true',
help='Set no log if you do not want to log the current run.')
parser.add_argument('-l', '--log_folder', default='runs',
help='Specify the (parent) folder you want to log to. If rel-path, relative to DECODE root.')
parser.add_argument('-c', '--log_comment', default=None,
help='Add a log_comment to the run.')
args = parser.parse_args()
return args
[docs]def live_engine_setup(param_file: str, device_overwrite: str = None, debug: bool = False,
no_log: bool = False,
num_worker_override: int = None,
log_folder: str = 'runs', log_comment: str = None):
"""
Sets up the engine to train DECODE. Includes sample simulation and the actual training.
Args:
param_file: parameter file path
device_overwrite: overwrite cuda index specified by param file
debug: activate debug mode (i.e. less samples) for fast testing
no_log: disable logging
num_worker_override: overwrite number of workers for dataloader
log_folder: folder for logging (where tensorboard puts its stuff)
log_comment: comment to the experiment
"""
"""Load Parameters and back them up to the network output directory"""
param_file = Path(param_file)
param = decode.utils.param_io.ParamHandling().load_params(param_file)
# auto-set some parameters (will be stored in the backup copy)
param = decode.utils.param_io.autoset_scaling(param)
# add meta information
param.Meta.version = decode.utils.bookkeeping.decode_state()
"""Experiment ID"""
if not debug:
if param.InOut.checkpoint_init is None:
experiment_id = datetime.datetime.now().strftime(
"%Y-%m-%d_%H-%M-%S") + '_' + socket.gethostname()
from_ckpt = False
if log_comment:
experiment_id = experiment_id + '_' + log_comment
else:
from_ckpt = True
experiment_id = Path(param.InOut.checkpoint_init).parent.name
else:
experiment_id = 'debug'
from_ckpt = False
"""Set up unique folder for experiment"""
if not from_ckpt:
experiment_path = Path(param.InOut.experiment_out) / Path(experiment_id)
else:
experiment_path = Path(param.InOut.checkpoint_init).parent
if not experiment_path.parent.exists():
experiment_path.parent.mkdir()
if not from_ckpt:
if debug:
experiment_path.mkdir(exist_ok=True)
else:
experiment_path.mkdir(exist_ok=False)
model_out = experiment_path / Path('model.pt')
ckpt_path = experiment_path / Path('ckpt.pt')
# Backup the parameter file under the network output path with the experiments ID
param_backup_in = experiment_path / Path('param_run_in').with_suffix(param_file.suffix)
shutil.copy(param_file, param_backup_in)
param_backup = experiment_path / Path('param_run').with_suffix(param_file.suffix)
decode.utils.param_io.ParamHandling().write_params(param_backup, param)
if debug:
decode.utils.param_io.ParamHandling.convert_param_debug(param)
if num_worker_override is not None:
param.Hardware.num_worker_train = num_worker_override
"""Hardware / Server stuff."""
if device_overwrite is not None:
device = device_overwrite
param.Hardware.device_simulation = device_overwrite # lazy assumption
else:
device = param.Hardware.device
if torch.cuda.is_available():
_, device_ix = decode.utils.hardware._specific_device_by_str(device)
if device_ix is not None:
# do this instead of set env variable, because torch is inevitably already imported
torch.cuda.set_device(device)
elif not torch.cuda.is_available():
device = 'cpu'
if param.Hardware.torch_multiprocessing_sharing_strategy is not None:
torch.multiprocessing.set_sharing_strategy(
param.Hardware.torch_multiprocessing_sharing_strategy)
if sys.platform in ('linux', 'darwin'):
os.nice(param.Hardware.unix_niceness)
elif param.Hardware.unix_niceness is not None:
print(f"Cannot set niceness on platform {sys.platform}. You probably do not need to worry.")
torch.set_num_threads(param.Hardware.torch_threads)
"""Setup Log System"""
if no_log:
logger = decode.neuralfitter.utils.logger.NoLog()
else:
log_folder = log_folder + '/' + experiment_id
logger = decode.neuralfitter.utils.logger.MultiLogger(
[decode.neuralfitter.utils.logger.SummaryWriter(log_dir=log_folder,
filter_keys=["dx_red_mu", "dx_red_sig",
"dy_red_mu",
"dy_red_sig", "dz_red_mu",
"dz_red_sig",
"dphot_red_mu",
"dphot_red_sig"]),
decode.neuralfitter.utils.logger.DictLogger()])
sim_train, sim_test = setup_random_simulation(param)
ds_train, ds_test, model, model_ls, optimizer, criterion, lr_scheduler, grad_mod, post_processor, matcher, ckpt = \
setup_trainer(sim_train, sim_test, logger, model_out, ckpt_path, device, param)
dl_train, dl_test = setup_dataloader(param, ds_train, ds_test)
if from_ckpt:
ckpt = decode.utils.checkpoint.CheckPoint.load(param.InOut.checkpoint_init)
model.load_state_dict(ckpt.model_state)
optimizer.load_state_dict(ckpt.optimizer_state)
lr_scheduler.load_state_dict(ckpt.lr_sched_state)
first_epoch = ckpt.step + 1
model = model.train()
print(f'Resuming training from checkpoint ' + experiment_id)
else:
first_epoch = 0
converges = False
n = 0
n_max = param.HyperParameter.auto_restart_param.num_restarts
while not converges and n < n_max:
n += 1
conv_check = decode.neuralfitter.utils.progress.GMMHeuristicCheck(
ref_epoch=1,
emitter_avg=sim_train.em_sampler.em_avg,
threshold=param.HyperParameter.auto_restart_param.restart_treshold,
)
for i in range(first_epoch, param.HyperParameter.epochs):
logger.add_scalar('learning/learning_rate', optimizer.param_groups[0]['lr'], i)
if i >= 1:
_ = decode.neuralfitter.train_val_impl.train(
model=model,
optimizer=optimizer,
loss=criterion,
dataloader=dl_train,
grad_rescale=param.HyperParameter.moeller_gradient_rescale,
grad_mod=grad_mod,
epoch=i,
device=torch.device(device),
logger=logger
)
val_loss, test_out = decode.neuralfitter.train_val_impl.test(
model=model,
loss=criterion,
dataloader=dl_test,
epoch=i,
device=torch.device(device))
if not conv_check(test_out.loss[:, 0].mean(), i):
print(f"The model will be reinitialized and retrained due to a pathological loss."
f"The max. allowed loss per emitter is {conv_check.threshold:.1f} vs."
f" {(test_out.loss[:, 0].mean() / conv_check.emitter_avg):.1f} (observed).")
ds_train, ds_test, model, model_ls, optimizer, criterion, lr_scheduler, grad_mod, post_processor, matcher, ckpt = \
setup_trainer(sim_train, sim_test, logger, model_out, ckpt_path, device, param)
dl_train, dl_test = setup_dataloader(param, ds_train, ds_test)
converges = False
break
else:
converges = True
"""Post-Process and Evaluate"""
log_train_val_progress.post_process_log_test(loss_cmp=test_out.loss,
loss_scalar=val_loss,
x=test_out.x, y_out=test_out.y_out,
y_tar=test_out.y_tar,
weight=test_out.weight,
em_tar=ds_test.emitter,
px_border=-0.5, px_size=1.,
post_processor=post_processor,
matcher=matcher, logger=logger,
step=i)
if i >= 1:
if isinstance(lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
lr_scheduler.step(val_loss)
else:
lr_scheduler.step()
model_ls.save(model, None)
if no_log:
ckpt.dump(model.state_dict(), optimizer.state_dict(), lr_scheduler.state_dict(),
step=i)
else:
ckpt.dump(model.state_dict(), optimizer.state_dict(), lr_scheduler.state_dict(),
log=logger.logger[1].log_dict, step=i)
"""Draw new samples Samples"""
if param.Simulation.mode in 'acquisition':
ds_train.sample(True)
elif param.Simulation.mode != 'samples':
raise ValueError
if converges:
print("Training finished after reaching maximum number of epochs.")
else:
raise ValueError(f"Training aborted after {n_max} restarts. "
"You can try to reduce the learning rate by a factor of 2."
"\nIt is also possible that the simulated data is to challenging. "
"Check if your background and intensity values are correct "
"and possibly lower the average number of emitters.")
[docs]def setup_trainer(simulator_train, simulator_test, logger, model_out, ckpt_path, device, param):
"""Set model, optimiser, loss and schedulers"""
models_available = {
'SigmaMUNet': decode.neuralfitter.models.SigmaMUNet,
'DoubleMUnet': decode.neuralfitter.models.model_param.DoubleMUnet,
'SimpleSMLMNet': decode.neuralfitter.models.model_param.SimpleSMLMNet,
}
model = models_available[param.HyperParameter.architecture]
model = model.parse(param)
model_ls = decode.utils.model_io.LoadSaveModel(model,
output_file=model_out)
model = model_ls.load_init()
model = model.to(torch.device(device))
# Small collection of optimisers
optimizer_available = {
'Adam': torch.optim.Adam,
'AdamW': torch.optim.AdamW
}
optimizer = optimizer_available[param.HyperParameter.optimizer]
optimizer = optimizer(model.parameters(), **param.HyperParameter.opt_param)
"""Loss function."""
criterion = decode.neuralfitter.loss.GaussianMMLoss(
xextent=param.Simulation.psf_extent[0],
yextent=param.Simulation.psf_extent[1],
img_shape=param.Simulation.img_size,
device=device,
chweight_stat=param.HyperParameter.chweight_stat)
"""Learning Rate and Simulation Scheduling"""
lr_scheduler_available = {
'ReduceLROnPlateau': torch.optim.lr_scheduler.ReduceLROnPlateau,
'StepLR': torch.optim.lr_scheduler.StepLR
}
lr_scheduler = lr_scheduler_available[param.HyperParameter.learning_rate_scheduler]
lr_scheduler = lr_scheduler(optimizer, **param.HyperParameter.learning_rate_scheduler_param)
"""Checkpointing"""
checkpoint = CheckPoint(path=ckpt_path)
"""Setup gradient modification"""
grad_mod = param.HyperParameter.grad_mod
"""Log the model"""
try:
dummy = torch.rand((2, param.HyperParameter.channels_in,
*param.Simulation.img_size), requires_grad=False).to(
torch.device(device))
logger.add_graph(model, dummy)
except:
print("Did not log graph.")
# raise RuntimeError("Your dummy input is wrong. Please update it.")
"""Transform input data, compute weight mask and target data"""
frame_proc = decode.neuralfitter.scale_transform.AmplitudeRescale.parse(param)
bg_frame_proc = None
if param.HyperParameter.emitter_label_photon_min is not None:
em_filter = decode.neuralfitter.em_filter.PhotonFilter(
param.HyperParameter.emitter_label_photon_min)
else:
em_filter = decode.neuralfitter.em_filter.NoEmitterFilter()
tar_frame_ix_train = (0, 0)
tar_frame_ix_test = (0, param.TestSet.test_size)
"""Setup Target generator consisting possibly multiple steps in a transformation sequence."""
tar_gen = decode.neuralfitter.utils.processing.TransformSequence(
[
decode.neuralfitter.target_generator.ParameterListTarget(
n_max=param.HyperParameter.max_number_targets,
xextent=param.Simulation.psf_extent[0],
yextent=param.Simulation.psf_extent[1],
ix_low=tar_frame_ix_train[0],
ix_high=tar_frame_ix_train[1],
squeeze_batch_dim=True),
decode.neuralfitter.target_generator.DisableAttributes.parse(param),
decode.neuralfitter.scale_transform.ParameterListRescale(
phot_max=param.Scaling.phot_max,
z_max=param.Scaling.z_max,
bg_max=param.Scaling.bg_max)
])
# setup target for test set in similar fashion, however test-set is static.
tar_gen_test = copy.deepcopy(tar_gen)
tar_gen_test.com[0].ix_low = tar_frame_ix_test[0]
tar_gen_test.com[0].ix_high = tar_frame_ix_test[1]
tar_gen_test.com[0].squeeze_batch_dim = False
tar_gen_test.com[0].sanity_check()
if param.Simulation.mode == 'acquisition':
train_ds = decode.neuralfitter.dataset.SMLMLiveDataset(
simulator=simulator_train,
em_proc=em_filter,
frame_proc=frame_proc,
bg_frame_proc=bg_frame_proc,
tar_gen=tar_gen, weight_gen=None,
frame_window=param.HyperParameter.channels_in,
pad=None, return_em=False)
train_ds.sample(True)
elif param.Simulation.mode == 'samples':
train_ds = decode.neuralfitter.dataset.SMLMLiveSampleDataset(
simulator=simulator_train,
em_proc=em_filter,
frame_proc=frame_proc,
bg_frame_proc=bg_frame_proc,
tar_gen=tar_gen,
weight_gen=None,
frame_window=param.HyperParameter.channels_in,
return_em=False,
ds_len=param.HyperParameter.pseudo_ds_size)
test_ds = decode.neuralfitter.dataset.SMLMAPrioriDataset(
simulator=simulator_test,
em_proc=em_filter,
frame_proc=frame_proc,
bg_frame_proc=bg_frame_proc,
tar_gen=tar_gen_test, weight_gen=None,
frame_window=param.HyperParameter.channels_in,
pad=None, return_em=False)
test_ds.sample(True)
"""Set up post processor"""
if param.PostProcessing is None:
post_processor = decode.neuralfitter.post_processing.NoPostProcessing(xy_unit='px',
px_size=param.Camera.px_size)
elif param.PostProcessing == 'LookUp':
post_processor = decode.neuralfitter.utils.processing.TransformSequence([
decode.neuralfitter.scale_transform.InverseParamListRescale(
phot_max=param.Scaling.phot_max,
z_max=param.Scaling.z_max,
bg_max=param.Scaling.bg_max),
decode.neuralfitter.coord_transform.Offset2Coordinate.parse(param),
decode.neuralfitter.post_processing.LookUpPostProcessing(
raw_th=param.PostProcessingParam.raw_th,
pphotxyzbg_mapping=[0, 1, 2, 3, 4, -1],
xy_unit='px',
px_size=param.Camera.px_size)
])
elif param.PostProcessing in ('SpatialIntegration', 'NMS'): # NMS as legacy support
post_processor = decode.neuralfitter.utils.processing.TransformSequence([
decode.neuralfitter.scale_transform.InverseParamListRescale(
phot_max=param.Scaling.phot_max,
z_max=param.Scaling.z_max,
bg_max=param.Scaling.bg_max),
decode.neuralfitter.coord_transform.Offset2Coordinate.parse(param),
decode.neuralfitter.post_processing.SpatialIntegration(
raw_th=param.PostProcessingParam.raw_th,
xy_unit='px',
px_size=param.Camera.px_size)
])
else:
raise NotImplementedError
"""Evaluation Specification"""
matcher = decode.evaluation.match_emittersets.GreedyHungarianMatching.parse(param)
return train_ds, test_ds, model, model_ls, optimizer, criterion, lr_scheduler, grad_mod, post_processor, matcher, checkpoint
[docs]def setup_dataloader(param, train_ds, test_ds=None):
"""Set's up dataloader"""
train_dl = torch.utils.data.DataLoader(
dataset=train_ds,
batch_size=param.HyperParameter.batch_size,
drop_last=True,
shuffle=True,
num_workers=param.Hardware.num_worker_train,
pin_memory=True,
collate_fn=decode.neuralfitter.utils.dataloader_customs.smlm_collate)
if test_ds is not None:
test_dl = torch.utils.data.DataLoader(
dataset=test_ds,
batch_size=param.HyperParameter.batch_size,
drop_last=False,
shuffle=False,
num_workers=param.Hardware.num_worker_train,
pin_memory=False,
collate_fn=decode.neuralfitter.utils.dataloader_customs.smlm_collate)
else:
test_dl = None
return train_dl, test_dl
if __name__ == '__main__':
args = parse_args()
live_engine_setup(args.param_file, args.device, args.debug, args.no_log,
args.num_worker_override, args.log_folder,
args.log_comment)