Source code for decode.neuralfitter.frame_processing
from abc import ABC, abstractmethod
import torch
from typing import Tuple, List
[docs]class FrameProcessing(ABC):
[docs] @abstractmethod
def forward(self, frame: torch.Tensor) -> torch.Tensor:
"""
Forward frame through processing implementation.
Args:
frame:
"""
raise NotImplementedError
[docs]class Mirror2D(FrameProcessing):
def __init__(self, dims: Tuple):
"""
Mirror the specified dimensions. Providing dim index in negative format is recommended.
Given format N x C x H x W and you want to mirror H and W set dims=(-2, -1).
Args:
dims: dimensions
"""
super().__init__()
self.dims = dims
[docs] def forward(self, frame: torch.Tensor) -> torch.Tensor:
return frame.flip(self.dims)
[docs]class AutoCenterCrop(FrameProcessing):
def __init__(self, px_fold: int):
"""
Automatic cropping in centre. Specify pixel_fold which the target frame size must satistfy
and the frame will be center-cropped to this size.
Args:
px_fold: integer in which multiple the frame must dimensioned (H, W dimension)
"""
super().__init__()
self.px_fold = px_fold
if not isinstance(self.px_fold, int):
raise ValueError
[docs] def forward(self, frame: torch.Tensor) -> torch.Tensor:
"""
Process frames
Args:
frame: size [*, H, W]
"""
if self.px_fold == 1:
return frame
size_is = torch.tensor(frame.size())[-2:]
size_tar = torch.div(size_is, self.px_fold, rounding_mode="trunc") * self.px_fold
if (size_tar <= 0).any():
raise ValueError
"""Crop"""
ix_front = ((size_is - size_tar).float() / 2).ceil().long()
ix_back = ix_front + size_tar
return frame[..., ix_front[0]:ix_back[0], ix_front[1]:ix_back[1]]
[docs]class AutoPad(AutoCenterCrop):
def __init__(self, px_fold: int, mode:str = 'constant'):
"""
Pad frame to a size that is divisible by px_fold. Useful to prepare
an experimental frame for forwarding through network.
Args:
px_fold: number of pixels the resulting frame size should be divisible by
mode: torch mode for padding. refer to docs of `torch.nn.functional.pad`
"""
super().__init__(px_fold=px_fold)
self.mode = mode
[docs] def forward(self, frame: torch.Tensor) -> torch.Tensor:
if self.px_fold == 1:
return frame
size_is = torch.tensor(frame.size())[-2:]
size_tar = torch.ceil(size_is / self.px_fold) * self.px_fold
size_tar = size_tar.long()
size_pad = size_tar - size_is
size_pad_div = size_pad // 2
size_residual = size_pad - size_pad_div
size_pad_lr_ud = [size_pad_div[1].item(), size_residual[1].item(),
size_pad_div[0].item(), size_residual[0].item()]
return torch.nn.functional.pad(frame, size_pad_lr_ud, mode=self.mode)
[docs]def get_frame_extent(size, func) -> torch.Size:
"""
Get frame extent after processing pipeline
Args:
size:
func:
Returns:
"""
if len(size) == 4: # avoid to forward large batches just to get the output extent
n_batch = size[0]
size_out = func(torch.zeros(2, *size[1:])).size()
return torch.Size([n_batch, *size_out[1:]])
else:
return func(torch.zeros(*size)).size()