Hi everybody,
I am following the torchvision documentation for this function (uint8 data type for the image input, (N,4)-sized boxes tensor) and I think the following minimal example should work:
from torchvision.utils import draw_bounding_boxes
import torch
image = torch.ones(1,1000,1000).type(torch.uint8)
boxes = torch.Tensor([[250,250,500,500]])
print("image.dtype: {}".format(image.dtype))
print("boxes.size(): {}".format(boxes.size()))
print("type(boxes): {}".format(type(boxes)))
bboxes = draw_bounding_boxes(image, boxes)
but instead I get the following stacktrace:
image.dtype: torch.uint8
boxes.size(): torch.Size([1, 4])
type(boxes): <class 'torch.Tensor'>
Traceback (most recent call last):
File "pytorch/1.8.1-py3.7/lib/python3.7/site-packages/PIL/Image.py", line 2772, in fromarray
mode, rawmode = _fromarray_typemap[typekey]
KeyError: ((1, 1, 1), '|u1')
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "minimal_example.py", line 9, in <module>
bboxes = draw_bounding_boxes(image, boxes)
File "pytorch/1.8.1-py3.7/lib/python3.7/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "pytorch/1.8.1-py3.7/lib/python3.7/site-packages/torchvision/utils.py", line 179, in draw_bounding_boxes
img_to_draw = Image.fromarray(ndarr)
File "pytorch/1.8.1-py3.7/lib/python3.7/site-packages/PIL/Image.py", line 2774, in fromarray
raise TypeError("Cannot handle this data type: %s, %s" % typekey) from e
TypeError: Cannot handle this data type: (1, 1, 1), |u1
This issue discusses my problem, but as far as I can see the boxes tensor is actually of type torch.Tensor: `TypeError: Cannot handle this data type: (1, 1, 1), |u1` when using `torchvision.utils.draw_bounding_boxes` - #2 by ptrblck
I am using torchvision 0.9.1 if that changes anything, but the documentation is identical to the most up-to-date version.
Any help would be greatly appreciated.
EDIT:
As far as I know, PIL.fromarray
expects arrays valued between 0 and 1, for example torchvision.utils.save_image
will spit out an error for images not valued in this range, but torchvision.utils.draw_bounding_boxes
expect 8 bit images, so there seem to be a bit of an issue there ?