This is my understanding as well. It looks like a bug to me, because if input
is being permuted to channels_last
, but weight
remains channels_first
then the code effectively becomes:
input = torch.randint(1, 10, (2, 4, 4, 8), dtype=torch.float32, device="cuda", requires_grad=True)
model = torch.nn.Conv2d(8, 5, 3).cuda().float()
out = model(input)
And it obviously fails because of the dim mismatch, but my example above runs fine even when weight
remains in channels_first
format. Something is not right here.