How should I detach a channels_last format tensor?

Hi,

I am using channels_last format when using amp for training.
From pytorch 1.8, when I detach a tensor that is in channels_last format and feed it to another model, pytorch breaks with error.

Here’s an example snippet.

import torch
import torch.nn as nn
import torch.cuda.amp as amp

device = torch.device('cuda')
dtype = torch.float32
memory_format = torch.channels_last
# memory_format = torch.contiguous_format

model1 = nn.Conv2d(3,3,1,1).to(device=device, dtype=dtype, non_blocking=True, memory_format=memory_format)
model2 = nn.Conv2d(3,3,1,1).to(device=device, dtype=dtype, non_blocking=True, memory_format=memory_format)

input = torch.randn(1,3,4,4).to(device, dtype=dtype, memory_format=memory_format)

with amp.autocast():
    out1 = model1(input)
    out2 = model2(out1.detach())

Here’s the error message:

RuntimeError: set_sizes_and_strides is not allowed on a Tensor created from .data or .detach().
If your intent is to change the metadata of a Tensor (such as sizes / strides / storage / storage_offset)
without autograd tracking the change, remove the .data / .detach() call and wrap the change in a `with torch.no_grad():` block.
For example, change:
    x.data.set_(y)
to:
    with torch.no_grad():
        x.set_(y)

If I change the memory format to torch.contiguous_format, it works fine.
Is detaching a channels_last tensor forbidden?
As there is no resizing/reshaping operations in the second model, the error message seems to be irrelevant.

A related issue is here: https://github.com/pytorch/pytorch/issues/55301

PyTorch version: 1.8.1