import time
import warnings
from functools import partial
from typing import Union, Callable
import torch
from tqdm import tqdm
from .. import dataset
from ...generic import emitter
from ...utils import hardware, frames_io
[docs]class Infer:
def __init__(self, model, ch_in: int, frame_proc, post_proc, device: Union[str, torch.device],
batch_size: Union[int, str] = 'auto', num_workers: int = 0, pin_memory: bool = False,
forward_cat: Union[str, Callable] = 'emitter'):
"""
Convenience class for inference.
Args:
model: pytorch model
ch_in: number of input channels
frame_proc: frame pre-processing pipeline
post_proc: post-processing pipeline
device: device where to run inference
batch_size: batch-size or 'auto' if the batch size should be determined automatically (only use in combination with cuda)
num_workers: number of workers
pin_memory: pin memory in dataloader
forward_cat: method which concatenates the output batches. Can be string or Callable.
Use 'em' when the post-processor outputs an EmitterSet, or 'frames' when you don't use post-processing or if
the post-processor outputs frames.
"""
self.model = model
self.ch_in = ch_in
self.batch_size = batch_size
self.device = device
self.num_workers = num_workers
self.pin_memory = pin_memory
self.frame_proc = frame_proc
self.post_proc = post_proc
self.forward_cat = None
self._forward_cat_mode = forward_cat
if str(self.device) == 'cpu' and self.batch_size == 'auto':
warnings.warn("Automatically determining the batch size does not make sense on cpu device. "
"Falling back to reasonable value.")
self.batch_size = 64
[docs] def forward(self, frames: torch.Tensor) -> emitter.EmitterSet:
"""
Forward frames through model, pre- and post-processing and output EmitterSet
Args:
frames:
"""
"""Move Model"""
model = self.model.to(self.device)
model.eval()
"""Form Dataset and Dataloader"""
ds = dataset.InferenceDataset(frames=frames, frame_proc=self.frame_proc, frame_window=self.ch_in)
if self.batch_size == 'auto':
# include safety factor of 20%
bs = int(0.8 * self.get_max_batch_size(model, ds[0].size(), 1, 512))
else:
bs = self.batch_size
# generate concatenate function here because we need batch size for this
self.forward_cat = self._setup_forward_cat(self._forward_cat_mode, bs)
dl = torch.utils.data.DataLoader(dataset=ds, batch_size=bs, shuffle=False, drop_last=False,
num_workers=self.num_workers, pin_memory=self.pin_memory)
out = []
with torch.no_grad():
for sample in tqdm(dl):
x_in = sample.to(self.device)
# compute output
y_out = model(x_in)
"""In post processing we need to make sure that we get a single Emitterset for each batch,
so that we can easily concatenate."""
if self.post_proc is not None:
out.append(self.post_proc.forward(y_out))
else:
out.append(y_out.detach().cpu())
"""Cat to single emitterset / frame tensor depending on the specification of the forward_cat attr."""
out = self.forward_cat(out)
return out
def _setup_forward_cat(self, forward_cat, batch_size: int):
if forward_cat is None:
return lambda x: x
elif isinstance(forward_cat, str):
if forward_cat == 'emitter':
return partial(emitter.EmitterSet.cat, step_frame_ix=batch_size)
elif forward_cat == 'frames':
return partial(torch.cat, dim=0)
elif callable(forward_cat):
return forward_cat
else:
raise TypeError(f"Specified forward cat method was wrong.")
raise ValueError(f"Unsupported forward_cat value.")
[docs] @staticmethod
def get_max_batch_size(model: torch.nn.Module, frame_size: Union[tuple, torch.Size],
limit_low: int, limit_high: int):
"""
Get maximum batch size for inference.
Args:
model: model on correct device
frame_size: size of frames (without batch dimension)
limit_low: lower batch size limit
limit_high: upper batch size limit
"""
def model_forward_no_grad(x: torch.Tensor):
"""
Helper function because we need to account for torch.no_grad()
"""
with torch.no_grad():
o = model.forward(x)
return o
assert next(model.parameters()).is_cuda, \
"Auto determining the max batch size makes only sense when running on CUDA device."
return hardware.get_max_batch_size(model_forward_no_grad, frame_size, next(model.parameters()).device,
limit_low, limit_high)
[docs]class LiveInfer(Infer):
def __init__(self,
model, ch_in: int, *,
stream, time_wait=5, safety_buffer: int = 20,
frame_proc=None, post_proc=None,
device: Union[str, torch.device] = 'cuda:0' if torch.cuda.is_available() else 'cpu',
batch_size: Union[int, str] = 'auto', num_workers: int = 0, pin_memory: bool = False,
forward_cat: Union[str, Callable] = 'emitter'):
"""
Inference from memmory mapped tensor, where the mapped file is possibly live being written to.
Args:
model: pytorch model
ch_in: number of input channels
stream: output stream. Will typically get emitters (along with starting and stopping index)
time_wait: wait if length of mapped tensor has not changed
safety_buffer: buffer distance to end of tensor to avoid conflicts when the file is actively being
written to
frame_proc: frame pre-processing pipeline
post_proc: post-processing pipeline
device: device where to run inference
batch_size: batch-size or 'auto' if the batch size should be determined automatically (only use in combination with cuda)
num_workers: number of workers
pin_memory: pin memory in dataloader
forward_cat: method which concatenates the output batches. Can be string or Callable.
Use 'em' when the post-processor outputs an EmitterSet, or 'frames' when you don't use post-processing or if
the post-processor outputs frames.
"""
super().__init__(
model=model, ch_in=ch_in, frame_proc=frame_proc, post_proc=post_proc,
device=device, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory,
forward_cat=forward_cat)
self._stream = stream
self._time_wait = time_wait
self._buffer_length = safety_buffer
[docs] def forward(self, frames: Union[torch.Tensor, frames_io.TiffTensor]):
n_fitted = 0
n_waited = 0
while n_waited <= 2:
n = len(frames)
if n_fitted == n - self._buffer_length:
n_waited += 1
time.sleep(self._time_wait) # wait
continue
n_2fit = n - self._buffer_length
out = super().forward(frames[n_fitted:n_2fit])
self._stream(out, n_fitted, n_2fit)
n_fitted = n_2fit
n_waited = 0
# fit remaining frames
out = super().forward(frames[n_fitted:n])
self._stream(out, n_fitted, n)
if __name__ == '__main__':
import argparse
import yaml
import decode.neuralfitter.models
import decode.utils
parse = argparse.ArgumentParser(
description="Inference. This uses the default, suggested implementation. "
"For anything else, consult the fitting notebook and make your changes there.")
parse.add_argument('frame_path', help='Path to the tiff file of the frames')
parse.add_argument('frame_meta_path', help='Path to the meta of the tiff (i.e. camera parameters)')
parse.add_argument('model_path', help='Path to the model file')
parse.add_argument('param_path', help='Path to the parameters of the training')
parse.add_argument('device', help='Device on which to do inference (e.g. "cpu" or "cuda:0"')
parse.add_argument('-o', '--online', action='store_true')
args = parse.parse_args()
online = args.o
"""Load the model"""
param = decode.utils.param_io.load_params(args.param_path)
model = decode.neuralfitter.models.SigmaMUNet.parse(param)
model = decode.utils.model_io.LoadSaveModel(
model, input_file=args.model_path, output_file=None).load_init(args.device)
"""Load the frame"""
if not online:
frames = decode.utils.frames_io.load_tif(args.frame_path)
else:
frames = decode.utils.frames_io.TiffTensor(args.frame_path)
# load meta
with open(args.frame_meta_path) as meta:
meta = yaml.safe_load(meta)
param = decode.utils.param_io.autofill_dict(meta['Camera'], param.to_dict(), mode_missing='include')
param = decode.utils.param_io.RecursiveNamespace(**param)
camera = decode.simulation.camera.Photon2Camera.parse(param)
camera.device = 'cpu'
"""Prepare Pre and post-processing"""
"""Fit"""
"""Return"""