nn.Linear with 3 or more d input

Hi guys,

I was toying around with nn.Linear and 3d tensor inputs. When I entered the 3d tensor the grad_fn function is named <UnsafeViewBackward>. With a manually flatten 2d tensor, which I viewed back to a 3d tensor the grad_fn function is <ViewBackward>, like expected. Now I’m curious why it’s called ‘unsafe’ and if i rather should manually view the 3d tensor to 2d or enter it as 3d a tensor?

Code (not sure if nessecary):

import torch
import torch.nn as nn 

l1 = nn.Linear(4, 5, bias=False)
l1.weight = nn.Parameter(torch.ones(l1.weight.shape))

b, s, e = 2, 3, 4
a = torch.arange(0, b*s*e, dtype=torch.float).view(b, s, e)

out1 = l1(a)
out2 = l1(a.view(-1, e)).view(b, s, -1)

print(out1, out2, sep="\n")
print(torch.equal(out1, out2))

The UnsafeViewBackward shouldn’t be problematic as explained here.

However, note that the linear layer will behave differently for a 2D and 3D input.
While a 2D input should be the “standard” approach, a 3D input would be treated as a sequence (in dim1) or inputs.

1 Like