from typing import Optional
import matplotlib.pyplot as plt
import torch
"""
Convention:
x to the right, y down.
—--—x --——>
y
|
v
"""
[docs]def connect_point_set(set0, set1, threeD=False, ax=None):
"""
Plots the connecting lines between the set0 and set1 in 2D.
Args:
set0: torch.Tensor / np.array of dim N x 2
set1: torch.Tensor / np.array of dim N x 2
threeD (bool): plot / connect in 3D
ax: axis where to plot
Returns:
"""
if ax is None:
ax = plt.gca()
if threeD:
for i in range(set0.size(0)):
ax.plot3D([set0[i, 0], set1[i, 0]], [set0[i, 1], set1[i, 1]], [set0[i, 2], set1[i, 2]],
'orange')
else:
for i in range(set0.size(0)):
ax.plot([set0[i, 0], set1[i, 0]], [set0[i, 1], set1[i, 1]], 'orange')
[docs]class PlotFrame:
def __init__(self, frame: torch.Tensor, extent: Optional[tuple] = None, clim=None,
plot_colorbar: bool = False, axes_order: Optional[str] = None):
"""
Plots a frame.
Args:
frame: frame to be plotted
extent: specify frame extent, tuple ((x0, x1), (y0, y1))
clim: clim values
plot_colorbar: plot the colorbar
axes_order: order of axis. Either default order (None) or 'future'
(i.e. future version of decode in which we will swap axes).
This is only a visual effect and does not change the storage scheme of the EmitterSet
"""
self.frame = frame.detach().squeeze()
self.extent = extent
self.clim = clim
self.plot_colorbar = plot_colorbar
self._axes_order = axes_order
assert self._axes_order is None or self._axes_order == 'future'
if self._axes_order is None:
self.frame.transpose_(-1, -2)
[docs] def plot(self) -> plt.axis:
"""
Plot the frame. Note that according to convention we need to transpose the last two axis.
"""
if self.extent is None:
plt.imshow(self.frame.numpy(), cmap='gray')
else:
plt.imshow(self.frame.numpy(), cmap='gray', extent=(
self.extent[0][0],
self.extent[0][1],
self.extent[1][1],
self.extent[1][0]))
plt.gca().set_aspect('equal', adjustable='box')
if self.clim is not None:
plt.clim(self.clim[0], self.clim[1])
# safety measure
if self.plot_colorbar:
plt.colorbar(fraction=0.046, pad=0.04)
plt.xlabel('x')
plt.ylabel('y')
return plt.gca()
[docs]class PlotCoordinates:
_labels_default = ('Target', 'Output', 'Init')
def __init__(self,
pos_tar=None, phot_tar=None,
pos_out=None, phot_out=None,
pos_ini=None, phot_ini=None,
extent_limit=None,
match_lines=False,
labels=None,
axes_order: Optional[str] = None):
"""
Plots points in 2D projection.
Args:
pos_tar:
phot_tar:
pos_out:
phot_out:
pos_ini:
phot_ini:
extent_limit:
match_lines: plots
axes_order: order of axis. Either default order (None) or 'future'
(i.e. future version of decode in which we will swap axes).
This is only a visual effect and does not change the storage scheme of the EmitterSet
"""
self.extent_limit = extent_limit
self.pos_tar = pos_tar
self.phot_tar = phot_tar
self.pos_out = pos_out
self.phot_out = phot_out
self.pos_ini = pos_ini
self.phot_ini = phot_ini
self.match_lines = match_lines
self.labels = labels if labels is not None else self._labels_default
self._axes_order = axes_order
self.tar_marker = 'ro'
self.tar_cmap = 'winter'
self.out_marker = 'bx'
self.out_cmap = 'viridis'
self.ini_marker = 'g+'
self.ini_cmap = 'copper'
assert self._axes_order is None or self._axes_order == 'future'
[docs] def plot(self):
def plot_xyz(pos, marker, color, label):
if self._axes_order == 'future':
pos = pos[:, [1, 0, 2]]
plt.scatter(pos[:, 0].numpy(), pos[:, 1].numpy(),
marker=marker, c=color, facecolors='none', label=label)
def plot_xyz_phot(pos, phot, marker, cmap, label):
if self._axes_order == 'decode_future':
pos = pos[:, [1, 0, 2]]
plt.scatter(pos[:, 0].numpy(), pos[:, 1].numpy(), c=phot.numpy(),
marker=marker, facecolors='none', cmap=cmap, label=label)
if self.pos_tar is not None:
if self.phot_tar is not None:
plot_xyz_phot(self.pos_tar, self.phot_tar, self.tar_marker[1], self.tar_cmap,
self.labels[0])
else:
plot_xyz(self.pos_tar, self.tar_marker[1], self.tar_marker[0], self.labels[0])
if self.pos_out is not None:
if self.phot_out is not None:
plot_xyz_phot(self.pos_out, self.phot_out, self.out_marker[1], self.out_cmap,
self.labels[1])
else:
plot_xyz(self.pos_out, self.out_marker[1], self.out_marker[0], self.labels[1])
if self.pos_ini is not None:
if self.phot_ini is not None:
plot_xyz_phot(self.pos_ini, self.phot_ini, self.ini_marker[1], self.ini_cmap,
self.labels[2])
else:
plot_xyz(self.pos_ini, self.ini_marker[1], self.ini_marker[0], self.labels[2])
if self.pos_tar is not None and self.pos_out is not None and self.match_lines:
connect_point_set(self.pos_tar, self.pos_out, threeD=False)
ax = plt.gca()
ax.set_aspect('equal', adjustable='box')
ax_ylimits = ax.get_ylim()
if ax_ylimits[0] <= ax_ylimits[1]:
ax.set_ylim(ax_ylimits[::-1]) # invert the axis
if self._axes_order is None:
plt.xlabel('x')
plt.ylabel('y')
else:
plt.xlabel('y')
plt.ylabel('x')
if self.extent_limit is not None:
plt.xlim(*self.extent_limit[0])
plt.ylim(*self.extent_limit[1][::-1]) # reverse tuple order
return plt.gca()
[docs]class PlotCoordinates3D:
_labels_default = ('Target', 'Output', 'Init')
def __init__(self, pos_tar=None, pos_out=None, phot_out=None, match_lines=False, labels=None):
self.pos_tar = pos_tar
self.pos_out = pos_out
self.phot_out = phot_out
self.match_lines = match_lines
self.labels = labels if labels is not None else self._labels_default
self.fig = plt.gcf()
self.ax = self.fig.add_subplot(111, projection='3d')
[docs] def plot(self):
if self.pos_tar is not None:
xyz = self.pos_tar
self.ax.scatter(xyz[:, 0], xyz[:, 1], xyz[:, 2], c='red', marker='o',
label=self.labels[0])
if self.pos_out is not None:
xyz = self.pos_out
rgba_colors = torch.zeros((xyz.shape[0], 4))
rgba_colors[:, 2] = 1.0
rgba_colors[:, 3] = 1.0
self.ax.scatter(xyz[:, 0], xyz[:, 1], xyz[:, 2], marker='^', color=rgba_colors.numpy(),
label=self.labels[1])
plt.xlabel('x')
plt.ylabel('y')
plt.gca().invert_yaxis()
if self.pos_tar is not None and self.pos_out is not None and self.match_lines:
connect_point_set(self.pos_tar, self.pos_out, threeD=True)
[docs]class PlotFrameCoord(PlotCoordinates, PlotFrame):
def __init__(self, frame,
pos_tar=None, phot_tar=None,
pos_out=None, phot_out=None,
pos_ini=None, phot_ini=None,
extent=None, coord_limit=None,
norm=None, clim=None,
match_lines=False, labels=None,
plot_colorbar_frame: bool = False,
axes_order: Optional[str] = None):
PlotCoordinates.__init__(self,
pos_tar=pos_tar,
phot_tar=phot_tar,
pos_out=pos_out,
phot_out=phot_out,
pos_ini=pos_ini,
phot_ini=phot_ini,
extent_limit=coord_limit,
match_lines=match_lines,
labels=labels,
axes_order=axes_order)
PlotFrame.__init__(self, frame, extent, clim,
plot_colorbar=plot_colorbar_frame, axes_order=axes_order)
[docs] def plot(self):
PlotFrame.plot(self)
PlotCoordinates.plot(self)