from abc import ABC
from typing import Tuple
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.colors import hsv_to_rgb
from matplotlib.colors import rgb_to_hsv
from mpl_toolkits.axes_grid1 import make_axes_locatable
from scipy.ndimage import gaussian_filter
from tqdm import tqdm
from ..generic import emitter
[docs]class Renderer(ABC):
def __init__(self, plot_axis: tuple, xextent: tuple, yextent: tuple, zextent: tuple,
px_size: float, abs_clip: float, rel_clip: float, contrast: float):
"""Renderer. Takes emitters and outputs a rendered image."""
super().__init__()
self.xextent = xextent
self.yextent = yextent
self.zextent = zextent
self.px_size = px_size
self.plot_axis = plot_axis
self.abs_clip = abs_clip
self.rel_clip = rel_clip
self.contrast = contrast
assert (
self.abs_clip is None or self.rel_clip is None
), "Define either an absolute or a relative value for clipping, but not both"
[docs] def forward(self, em: emitter.EmitterSet) -> torch.Tensor:
"""
Forward emitterset through rendering and output rendered data.
Args:
em: emitter set
"""
raise NotImplementedError
[docs] def render(self, em: emitter.EmitterSet, ax=None):
"""
Render emitters
Args:
em: emitter set
ax: plot axis
"""
raise NotImplementedError
[docs]class Renderer2D(Renderer):
def __init__(self, px_size, sigma_blur, plot_axis=(0, 1, 2),
xextent=None, yextent=None, zextent=None, colextent=None,
abs_clip=None, rel_clip=None, contrast=1):
"""
2D histogram renderer with constant gaussian blur.
Args:
px_size: pixel size of the output image in nm
sigma_blur: sigma of the gaussian blur applied in nm
plot_axis: determines which dimensions get plotted. 0,1,2 = x,y,z. (0,1) is x over y.
xextent: extent in x in nm
yextent: extent in y in nm
zextent: extent in z in nm.
cextent: extent of the color variable. Values outside of this range get clipped.
abs_clip: absolute clipping value of the histogram in counts
rel_clip: clipping value relative to the maximum count. i.e. rel_clip = 0.8 clips at 0.8*hist.max()
contrast: scaling factor to increase contrast
"""
super().__init__(
plot_axis=plot_axis,
xextent=xextent,
yextent=yextent,
zextent=zextent,
px_size=px_size,
abs_clip=abs_clip,
rel_clip=rel_clip,
contrast=contrast,
)
self.sigma_blur = sigma_blur
self.colextent = colextent
self.jet_hue = self._get_jet_cmap()
[docs] def render(self, em: emitter.EmitterSet, col_vec=None, ax=None):
"""
Forward emitterset through rendering and output rendered data.
Args:
em: emitter set
col_vec: torch tensor (1 dim) with the same length as em
ax: plot axis
"""
hist = self.forward(em, col_vec).numpy()
if ax is None:
ax = plt.gca()
if col_vec is not None:
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size=0.25, pad=-0.25)
colb = mpl.colorbar.ColorbarBase(
cax, cmap=plt.get_cmap("jet"), values=np.linspace(0, 1.0, 101),
norm=mpl.colors.Normalize(0.0, 1.0)
)
colb.outline.set_visible(False)
cax.text(
0.12, 0.04, f"{self.colextent[0]}", rotation=90, color="white", fontsize=15,
transform=cax.transAxes
)
cax.text(
0.12, 0.88, f"{self.colextent[1]}", rotation=90, color="white", fontsize=15,
transform=cax.transAxes
)
cax.axis("off")
ax.imshow(np.transpose(hist, [1, 0, 2]))
else:
# because imshow use different ordering
ax.imshow(np.transpose(hist), cmap="gray")
return ax
[docs] def forward(self, em: emitter.EmitterSet, col_vec=None) -> torch.Tensor:
"""
Forward emitterset through rendering and output rendered data.
Args:
em: emitter set
col_vec: torch tensor (1 dim) with the same length as em
"""
xyz_extent = self.get_extent(em)
ind_mask = (
(em.xyz_nm[:, 0] >= xyz_extent[0][0])
* (em.xyz_nm[:, 0] <= xyz_extent[0][1])
* (em.xyz_nm[:, 1] >= xyz_extent[1][0])
* (em.xyz_nm[:, 1] <= xyz_extent[1][1])
* (em.xyz_nm[:, 2] >= xyz_extent[2][0])
* (em.xyz_nm[:, 2] <= xyz_extent[2][1])
)
em_sub = em[ind_mask]
if col_vec is not None:
col_vec = col_vec[ind_mask]
self.colextent = (
col_vec.min(), col_vec.max()) if self.colextent is None else self.colextent
int_hist, col_hist = self._hist2d(
em_sub, col_vec, xyz_extent[self.plot_axis[0]], xyz_extent[self.plot_axis[1]],
self.colextent
)
with np.errstate(divide="ignore", invalid="ignore"):
c_avg = col_hist / int_hist
if self.rel_clip is not None:
int_hist = np.clip(int_hist * self.contrast, 0.0, int_hist.max() * self.rel_clip)
val = int_hist / int_hist.max()
elif self.abs_clip is not None:
int_hist = np.clip(int_hist, 0.0, self.abs_clip)
val = int_hist / self.abs_clip
else:
val = int_hist / int_hist.max()
val *= self.contrast
c_avg[np.isnan(c_avg)] = 0
sat = np.ones(int_hist.shape)
hue = np.interp(c_avg, np.linspace(0, 1, 256), self.jet_hue)
HSV = np.concatenate((hue[:, :, None], sat[:, :, None], val[:, :, None]), -1)
RGB = hsv_to_rgb(HSV)
if self.sigma_blur:
RGB = np.array(
[
gaussian_filter(
RGB[:, :, i],
sigma=[self.sigma_blur / self.px_size, self.sigma_blur / self.px_size]
)
for i in range(3)
]
).transpose(1, 2, 0)
RGB = np.clip(RGB, 0, 1)
return torch.from_numpy(RGB)
else:
hist = self._hist2d(em_sub, None, xyz_extent[self.plot_axis[0]],
xyz_extent[self.plot_axis[1]])
if self.rel_clip is not None:
hist = np.clip(hist, 0.0, hist.max() * self.rel_clip)
if self.abs_clip is not None:
hist = np.clip(hist, 0.0, self.abs_clip)
if self.sigma_blur is not None:
hist = gaussian_filter(hist, sigma=[self.sigma_blur / self.px_size,
self.sigma_blur / self.px_size])
hist = np.clip(hist, 0, hist.max() / self.contrast)
return torch.from_numpy(hist)
[docs] def get_extent(self, em) -> Tuple[tuple, tuple, tuple]:
xextent = (
em.xyz_nm[:, 0].min(), em.xyz_nm[:, 0].max()) if self.xextent is None else self.xextent
yextent = (
em.xyz_nm[:, 1].min(), em.xyz_nm[:, 1].max()) if self.yextent is None else self.yextent
zextent = (
em.xyz_nm[:, 2].min(), em.xyz_nm[:, 2].max()) if self.zextent is None else self.zextent
return xextent, yextent, zextent
def _hist2d(self, em: emitter.EmitterSet, col_vec, x_hist_ext, y_hist_ext, c_range=None):
xy = em.xyz_nm[:, self.plot_axis].numpy()
hist_bins_x = np.arange(x_hist_ext[0], x_hist_ext[1] + self.px_size, self.px_size)
hist_bins_y = np.arange(y_hist_ext[0], y_hist_ext[1] + self.px_size, self.px_size)
int_hist, _, _ = np.histogram2d(xy[:, 0], xy[:, 1], bins=(hist_bins_x, hist_bins_y))
if col_vec is not None:
c_pos = np.clip(col_vec, c_range[0], c_range[1])
c_weight = (c_pos - c_range[0]) / (c_range[1] - c_range[0])
col_hist, _, _ = np.histogram2d(xy[:, 0], xy[:, 1], bins=(hist_bins_x, hist_bins_y),
weights=c_weight)
return int_hist, col_hist
else:
return int_hist
@staticmethod
def _get_jet_cmap():
lin_hue = np.linspace(0, 1, 256)
cmap = plt.get_cmap("jet", lut=256)
cmap = cmap(lin_hue)
cmap_hsv = rgb_to_hsv(cmap[:, :3])
jet_hue = cmap_hsv[:, 0]
_, b = np.unique(jet_hue, return_index=True)
jet_hue = [jet_hue[index] for index in sorted(b)]
jet_hue = np.interp(np.linspace(0, len(jet_hue), 256), np.arange(len(jet_hue)), jet_hue)
return jet_hue
[docs]class RendererIndividual2D(Renderer2D):
def __init__(self, px_size, batch_size=1000, filt_size=10, plot_axis=(0, 1),
xextent=None, yextent=None, zextent=None, colextent=None,
abs_clip=None, rel_clip=None, contrast=1, intensity_field="sigma", device="cpu"):
"""
2D histogram renderer. Each localization is individually rendered as a 2D Gaussian corresponding to a
respective field.
Args:
px_size: pixel size of the output image in nm
batch_size: number of localization processed in parallel
filt_size: each gaussian is calculated as a patch with size filt_size*filt_size (in pixels)
plot_axis: determines which dimensions get plotted. 0,1,2 = x,y,z. (0,1) is x over y.
xextent: extent in x in nm
yextent: extent in y in nm
zextent: extent in z in nm.
cextent: extent of the color variable. Values outside of this range get clipped.
abs_clip: absolute clipping value of the histogram in counts
rel_clip: clipping value relative to the maximum count. i.e. rel_clip = 0.8 clips at 0.8*hist.max()
contrast: scaling factor to increase contrast
intensity_field: field of emitter that should be used for rendering
device: render on cpu or cuda
"""
super().__init__(
px_size=px_size,
sigma_blur=None,
plot_axis=plot_axis,
xextent=xextent,
yextent=yextent,
zextent=zextent,
colextent=colextent,
abs_clip=abs_clip,
rel_clip=rel_clip,
contrast=contrast,
)
self.bs = batch_size
self.fs = filt_size
self.device = device
self.intensity_field = intensity_field
[docs] def calc_gaussians(self, xy_mu, xy_sig, mesh):
xy_mu = xy_mu[:, :2] % self.px_size / self.px_size
xy_sig = xy_sig[:, :2] / self.px_size
dist = torch.distributions.Normal(xy_mu, xy_sig)
W = torch.exp(dist.log_prob(mesh[:, :, None]).sum(-1)).permute(2, 0, 1)
return W / torch.clamp_min(W.sum(-1).sum(-1), 1.0)[:, None, None]
@torch.jit.script
def _place_gaussians(int_hist, inds, W, fs):
for i in range(len(W)):
int_hist[inds[i, 1]: inds[i, 1] + fs, inds[i, 0]: inds[i, 0] + fs] += W[i]
return int_hist
@torch.jit.script
def _place_gaussians_weighted(comb_hist, inds, weights, W, fs):
for i in range(len(W)):
comb_hist[inds[i, 1]: inds[i, 1] + fs, inds[i, 0]: inds[i, 0] + fs] \
+= torch.stack([W[i], W[i] * weights[i]], -1)
return comb_hist
def _hist2d(self, em: emitter.EmitterSet, col_vec, x_hist_ext, y_hist_ext, c_range=None):
ym, xm = torch.meshgrid(
torch.linspace(-(self.fs // 2), self.fs // 2, self.fs, device=self.device),
torch.linspace(-(self.fs // 2), self.fs // 2, self.fs, device=self.device),
)
mesh = torch.cat([(xm)[..., None], (ym)[..., None]], -1)
xy_mus = em.xyz_nm[:, self.plot_axis].to(self.device)
xy_sigs = em.xyz_sig_nm[:, self.plot_axis].to(self.device)
w = int((x_hist_ext[1] - x_hist_ext[0]) // self.px_size + 1)
h = int((y_hist_ext[1] - y_hist_ext[0]) // self.px_size + 1)
s_inds = xy_mus - torch.Tensor([x_hist_ext[0], y_hist_ext[0]], device=self.device)
s_inds = torch.div(s_inds, self.px_size, rounding_mode="trunc").long()
if col_vec is not None:
c_pos = torch.clip(col_vec, c_range[0], c_range[1])
c_weight = ((c_pos - c_range[0]) / (c_range[1] - c_range[0])).to(self.device)
comb_hist = torch.zeros([h + self.fs, w + self.fs, 2], device=self.device,
dtype=torch.float)
for i in tqdm(range(len(xy_mus) // self.bs + 1)):
sl = np.s_[i * self.bs: (i + 1) * self.bs]
sub_inds = s_inds[sl]
W = self.calc_gaussians(xy_mus[sl], xy_sigs[sl], mesh)
c_ws = c_weight[sl]
comb_hist = self._place_gaussians_weighted(comb_hist, sub_inds, c_ws, W,
torch.tensor(self.fs))
comb_hist = comb_hist[self.fs // 2: -(self.fs // 2 + 1),
self.fs // 2: -(self.fs // 2 + 1)]
int_hist = comb_hist[:, :, 0]
col_hist = comb_hist[:, :, 1]
return int_hist.T.cpu().numpy(), col_hist.T.cpu().numpy()
else:
int_hist = torch.zeros([h + self.fs, w + self.fs], device=self.device,
dtype=torch.float)
for i in tqdm(range(len(xy_mus) // self.bs + 1)):
sl = np.s_[i * self.bs: (i + 1) * self.bs]
sub_inds = s_inds[sl]
W = self.calc_gaussians(xy_mus[sl], xy_sigs[sl], mesh)
int_hist = self._place_gaussians(int_hist, sub_inds, W, torch.tensor(self.fs))
int_hist = int_hist[self.fs // 2: -(self.fs // 2 + 1),
self.fs // 2: -(self.fs // 2 + 1)]
return int_hist.T.cpu().numpy()