As a workaround, I would try to use opencv and draw labels on images. Resulting image is barely presentable as a business unit, but still is informative
import torch
import math
from torchvision.transforms import transforms
import cv2
irange = range
def make_grid_with_labels(tensor, labels, nrow=8, limit=20, padding=2,
normalize=False, range=None, scale_each=False, pad_value=0):
"""Make a grid of images.
Args:
tensor (Tensor or list): 4D mini-batch Tensor of shape (B x C x H x W)
or a list of images all of the same size.
labels (list): ( [labels_1,labels_2,labels_3,...labels_n]) where labels is Bx1 vector of some labels
limit ( int, optional): Limits number of images and labels to make grid of
nrow (int, optional): Number of images displayed in each row of the grid.
The final grid size is ``(B / nrow, nrow)``. Default: ``8``.
padding (int, optional): amount of padding. Default: ``2``.
normalize (bool, optional): If True, shift the image to the range (0, 1),
by the min and max values specified by :attr:`range`. Default: ``False``.
range (tuple, optional): tuple (min, max) where min and max are numbers,
then these numbers are used to normalize the image. By default, min and max
are computed from the tensor.
scale_each (bool, optional): If ``True``, scale each image in the batch of
images separately rather than the (min, max) over all images. Default: ``False``.
pad_value (float, optional): Value for the padded pixels. Default: ``0``.
Example:
See this notebook `here <https://gist.github.com/anonymous/bf16430f7750c023141c562f3e9f2a91>`_
"""
# Opencv configs
if not isinstance(labels, list):
raise ValueError
else:
labels = np.asarray(labels).T[0]
if limit is not None:
tensor = tensor[:limit, ::]
labels = labels[:limit, ::]
import cv2
font = 1
fontScale = 2
color = (255, 0, 0)
thickness = 1
if not (torch.is_tensor(tensor) or
(isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
raise TypeError('tensor or list of tensors expected, got {}'.format(type(tensor)))
# if list of tensors, convert to a 4D mini-batch Tensor
if isinstance(tensor, list):
tensor = torch.stack(tensor, dim=0)
if tensor.dim() == 2: # single image H x W
tensor = tensor.unsqueeze(0)
if tensor.dim() == 3: # single image
if tensor.size(0) == 1: # if single-channel, convert to 3-channel
tensor = torch.cat((tensor, tensor, tensor), 0)
tensor = tensor.unsqueeze(0)
if tensor.dim() == 4 and tensor.size(1) == 1: # single-channel images
tensor = torch.cat((tensor, tensor, tensor), 1)
if normalize is True:
tensor = tensor.clone() # avoid modifying tensor in-place
if range is not None:
assert isinstance(range, tuple), \
"range has to be a tuple (min, max) if specified. min and max are numbers"
def norm_ip(img, min, max):
img.clamp_(min=min, max=max)
img.add_(-min).div_(max - min + 1e-5)
def norm_range(t, range):
if range is not None:
norm_ip(t, range[0], range[1])
else:
norm_ip(t, float(t.min()), float(t.max()))
if scale_each is True:
for t in tensor: # loop over mini-batch dimension
norm_range(t, range)
else:
norm_range(tensor, range)
if tensor.size(0) == 1:
return tensor.squeeze(0)
# make the mini-batch of images into a grid
nmaps = tensor.size(0)
xmaps = min(nrow, nmaps)
ymaps = int(math.ceil(float(nmaps) / xmaps))
height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding)
num_channels = tensor.size(1)
grid = tensor.new_full((num_channels, height * ymaps + padding, width * xmaps + padding), pad_value)
k = 0
for y in irange(ymaps):
for x in irange(xmaps):
if k >= nmaps:
break
working_tensor = tensor[k]
if labels is not None:
org = (0, int(tensor[k].shape[1] * 0.3))
working_image = cv2.UMat(
np.asarray(np.transpose(working_tensor.numpy(), (1, 2, 0)) * 255).astype('uint8'))
image = cv2.putText(working_image, f'{str(labels[k])}', org, font,
fontScale, color, thickness, cv2.LINE_AA)
working_tensor = transforms.ToTensor()(image.get())
grid.narrow(1, y * height + padding, height - padding) \
.narrow(2, x * width + padding, width - padding) \
.copy_(working_tensor)
k = k + 1
return grid
def show(img):
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)), interpolation='nearest')