Channels last question

I’m looking at https://pytorch.org/tutorials/intermediate/memory_format_tutorial.html

Specifically, this example:

    input = torch.randint(1, 10, (2, 8, 4, 4), dtype=torch.float32, device="cuda", requires_grad=True)
    model = torch.nn.Conv2d(8, 5, 3).cuda().float()

    input = input.contiguous(memory_format=torch.channels_last)
    model = model.to(memory_format=torch.channels_last) # Module parameters need to be Channels Last

    out = model(input)
    print(out.is_contiguous(memory_format=torch.channels_last)) # Outputs: True

If I comment out the line that converts model to channels_last format, I expect it to fail, but it does not.
How come input which are in channel_last format can be convolved with weight filters in channel_first format without dim mismatch error:

>>> model.weight.shape
torch.Size([5, 8, 3, 3])
>>> model.weight.is_contiguous(memory_format=torch.channels_last)
False
>>> input.shape
torch.Size([2, 8, 4, 4])
>>> input.is_contiguous(memory_format=torch.channels_last)
True

My understanding is input actual shape in this case is [2, 4, 4, 8], and therefore the weight input channels dim (8) should not match the input channels dim (4).

@VitalyFedyunin ?

I think the input would be transformed to channels_last internally (I asked the same question recently and discussed it with one of the original authors), but let’s see if Vitaly can confirm it.

This is my understanding as well. It looks like a bug to me, because if input is being permuted to channels_last, but weight remains channels_first then the code effectively becomes:

input = torch.randint(1, 10, (2, 4, 4, 8), dtype=torch.float32, device="cuda", requires_grad=True)
model = torch.nn.Conv2d(8, 5, 3).cuda().float()
out = model(input)

And it obviously fails because of the dim mismatch, but my example above runs fine even when weight remains in channels_first format. Something is not right here.

No, I don’t think it’s a bug, if PyTorch internally makes sure that the input and parameter uses the same memory layout.
Also, the result is the same:

input = torch.randint(1, 10, (2, 8, 4, 4), dtype=torch.float32, device="cuda", requires_grad=True)
model = torch.nn.Conv2d(8, 5, 3).cuda().float()

input = input.contiguous(memory_format=torch.channels_last)
model = model.to(memory_format=torch.channels_last) # Module parameters need to be Channels Last

out_ref = model(input)
print(out_ref.is_contiguous(memory_format=torch.channels_last))
> True

model.to(memory_format=torch.contiguous_format)
out = model(input)

print(out.is_contiguous(memory_format=torch.channels_last))
> True

print((out_ref - out).abs().max())
> tensor(0., device='cuda:0', grad_fn=<MaxBackward1>)

What does it mean? To me it seems when the layout is different, the op should fail with a “memory layout mismatch” error message. Otherwise, how does the Pytorch know which layout is the correct one? Does it convert weight to channel_last, to match input, or does it convert input to channel_first to match weight? If it’s always the former, what’s the point of model=model.to(memory_format=torch.channels_last) line?

This might have been one possible approach, but would potentially break ambiguous memory layout checks, such as 1x1 kernels.
The current implementations checks, if the suggested memory format of the input or weight is channels_last (code) and should use it, if applicable.

I see. So does this mean that, in the tutorial example, it’s not necessary to set input to channels_last - it’s redundant because setting the model to channels_last forces it anyway. Is there any situation where we would want to set the inputs in addition to setting the model?

I would always set the memory layout explicitly, not rely on the internal workflow of fixing ambiguous memory layouts, and would stick to the tutorial.

Makes sense, thank you Piotr. I have to say, this behavior will cause confusion, because people will forget to set one or the other, see that it still works, and will wonder whether it actually worked as intended. At the very least I’d mention something about this in the tutorial.

2 Likes

Is it possible to run a Conv2d on a NHWC tensor (not NCHW)? it seems that if I put channels explicitely at the end of @michaelklachko example and keep “channels_last” behavior it does not work.

Not sure what you mean by “put channels explicitly”. If you mean you manually transposed the tensor so that channels dim is the last one, then yes, it will fail, because “channels_last” feature is designed to work with tensors in their default shapes. The transformation happens internally.

Like I said, this is going to be confusing to people until the docs and error messages are improved.

yes i agree and yes this is what i meant. It would be easier to allow conv2d operator to run on NHWC tensors.

hello, I know this discussion is more than a year old but I just came across this as I’m facing an issue that I think is relevant to this.

I couldn’t reproduce some output and when I dug into it found out that the reason was because in one case the model.to(memory_format=torch.channels_last) is set but in another case it’s not. Now, according to this discussion it shouldn’t matter and pytorch should internally take care of it, but that’s not what I’m observing.

Here’s a sample code

input = torch.randn(1, 1, 100, 100)
conv = torch.nn.Conv2d(1, 64, kernel_size=(3, 3), padding=(1, 1), bias=False)
conv.load_state_dict({"weight": torch.randn(64, 1, 3, 3)})
with torch.no_grad():
    out_ref = conv(input)

conv.to(memory_format=torch.channels_last)
with torch.no_grad():
    out = conv(input)

print(torch.mean(torch.abs(out - out_ref)))

it should print some non-zero value.

Interestingly, if the input channel is 3 instead of 1, it will be correct. Just wondering if this could be some bug? I know the proper way is to set the input memory format to also be channels_last, but this is some online repo that I’m looking at instead of my own code, so I’m just wondering if this behavior is intended by pytorch / the author or if it’s a bug that the author isn’t aware of, in which case I should raise it with them.

@emilyfy how large are the differences?

@michaelklachko it’s pretty significant, as in it’s definitely not a precision issue.
error mean: tensor(3.1527), and
error sum: tensor(2017749.7500)

also, I noticed this only happens in CPU, if I add .cuda() to both input and conv, the error is 0.

Looks like a bug. Can you please create an issue on github?

1 Like

done, Wrong output of single-channel channels_last convolution with channels_first input · Issue #82060 · pytorch/pytorch · GitHub thanks!

2 Likes

I tried to use channels last memory format but seems it will give a different result. I am using cudnn 8.6.0, cuda 11.7, python 3.8.16, and pytorch 2.0.1+cu117. Here are the results:

import copy

import torch
import torch.nn as nn


x = torch.randn(64, 4, 64, 64).cuda()
layer = nn.Conv2d(4, 128, 3).cuda()

y = layer(x)

x_mem = (x + 0).to(memory_format=torch.channels_last)
layer_mem = copy.deepcopy(layer).to(memory_format=torch.channels_last)
y_mem = layer_mem(x_mem)

x_half = (x + 0).half()
layer_half = copy.deepcopy(layer).half()
y_half = layer_half(x_half)

x_half_mem = (x_half + 0).to(memory_format=torch.channels_last)
layer_half_mem = copy.deepcopy(layer_half).to(memory_format=torch.channels_last)
y_half_mem = layer_half_mem(x_half_mem)

print('x - x_half: ', (x - x_half).abs().max().item())
print('x - x_mem: ', (x - x_mem).abs().max().item())
print('x - x_half_mem: ', (x - x_half_mem).abs().max().item())
print('y - y_half: ', (y - y_half).abs().max().item())
print('y - y_mem: ', (y - y_mem).abs().max().item())
print('y_half - y_half_mem: ', (y_half - y_half_mem).abs().max().item())

Output:

x - x_half: 0.0018982887268066406
x - x_mem: 0.0
x_half - x_half_mem: 0.0
y - y_half: 0.0024781227111816406
y - y_mem: 0.0012029409408569336
y_half - y_half_mem: 0.001953125

The differences between y and y_mem or y_half and y_half_mem are not negligible.

The posted differences are expected.
In your code you are directly calling half() on the inputs and layer to compare y_half against y_half_mem, which shows a mismatch in the expected range for float16.
Also, you are not disallowing TF32 and are then comparing y vs. y_mem.
Disable it via torch.backends.cudnn.allow_tf32 = False and the numerical mismatch will reduce.