Can't save_image of int tensor

Hi all,
I have an output tensor of torch.int64 type in the range[0,255] . I am trying to save the tensor as image using save_image() but it’s giving me an error.
My tensor

img_hr,a,b = learn.predict(fn)
img_hr.dtype
torch.int64

Error

save_image(img_hr,'img.jpg')
RuntimeError                              Traceback (most recent call last)
Input In [186], in <cell line: 1>()
----> 1 save_image(img_hr,'img.jpg')

File ~/mambaforge/lib/python3.9/site-packages/torch/autograd/grad_mode.py:27, in _DecoratorContextManager.__call__.<locals>.decorate_context(*args, **kwargs)
     24 @functools.wraps(func)
     25 def decorate_context(*args, **kwargs):
     26     with self.clone():
---> 27         return func(*args, **kwargs)

File ~/mambaforge/lib/python3.9/site-packages/torchvision/utils.py:154, in save_image(tensor, fp, format, **kwargs)
    152 grid = make_grid(tensor, **kwargs)
    153 # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
--> 154 ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
    155 im = Image.fromarray(ndarr)
    156 im.save(fp, format=format)

File ~/mambaforge/lib/python3.9/site-packages/fastai/torch_core.py:365, in TensorBase.__torch_function__(cls, func, types, args, kwargs)
    363 if cls.debug and func.__name__ not in ('__str__','__repr__'): print(func, types, args, kwargs)
    364 if _torch_handled(args, cls._opt, func): types = (torch.Tensor,)
--> 365 res = super().__torch_function__(func, types, args, ifnone(kwargs, {}))
    366 dict_objs = _find_args(args) if args else _find_args(list(kwargs.values()))
    367 if issubclass(type(res),TensorBase) and dict_objs: res.set_meta(dict_objs[0],as_copy=True)

File ~/mambaforge/lib/python3.9/site-packages/torch/_tensor.py:1142, in Tensor.__torch_function__(cls, func, types, args, kwargs)
   1139     return NotImplemented
   1141 with _C.DisableTorchFunction():
-> 1142     ret = func(*args, **kwargs)
   1143     if func in get_default_nowrap_functions():
   1144         return ret

RuntimeError: result type Float can't be cast to the desired output type Long

Any ideas on how to save it?

Thanks,
Ankit

save_image only expects [0; 1] float tensor inputs. unfortunately, this function is quite antique. please chime in Move torchvision.utils.save_image -> torchvision.io.save_image (or introduce this function from scratch) · Issue #5461 · pytorch/vision · GitHub to vote for its modernization. In general, i/o in torch is a big design/implementation mess IMHO. For example, both torchvision and torchaudio have their own separate ffmpeg wrapper with their own quirks of building from sources / linking with your own ffmpeg libraries etc :frowning: My proposal would be creating a separate torchio package grouping all this code.

Thanks @vadimkantorov I will definitely chime in. But any ideas to save the image now ?

In the meanwhile, you can pass in torchvision.utils.save_image(your_int64_tensor / 255.0) or maybe using a newer encode_jpeg: open("your_file.jpg", "wb").write(bytes(torchvision.io.encode_jpeg((your_int64_tensor - 128).to(torch.int8))))

An alternative to this conversion probably could be your_int64_tensor.to(torch.uint8).view(torch.int8) (saves the subtraction) but you need to double-check

1 Like

Thanks @vadimkantorov

torchvision.utils.save_image(your_int64_tensor / 255.0)

works pretty well.

1 Like