Source code for decode.neuralfitter.sampling

import torch
from skimage.util.shape import view_as_windows

from typing import Sequence


[docs]def sample_crop(x_in: torch.Tensor, sample_size: Sequence[int]) -> torch.Tensor: """ Takes a 2D tensor and returns random crops Args: x_in: input tensor sample_size: size of sample, size specification (N, H, W) Returns: random crops with size sample_size """ assert x_in.dim() == 2, "Not implemented dimensionality" assert len(sample_size) == 3, "Wrong sequence dimension." windows = view_as_windows(x_in.numpy(), sample_size[-2:]) # converts array via sliding window into smaller ones n = sample_size[0] ix_max = (x_in.size(-2) - sample_size[-2], x_in.size(-1) - sample_size[-1]) x_ix = torch.randint(0, ix_max[0] + 1, size=(n,)) y_ix = torch.randint(0, ix_max[1] + 1, size=(n,)) return torch.from_numpy(windows[x_ix, y_ix])