if module.bias is not None:
if len(out.shape) == 4:
out = out + module.bias.view(1, -1, 1, 1)
if len(out.shape) == 3:
out = out + module.bias.view(1, -1, 1)
if len(out.shape) == 2:
out = out + module.bias.view(1, -1)
I don’t know what the original shape of model.bias
is, but if you need to unsqueeze the dimensions, this code should work:
bias = torch.randn(3)
out = torch.randn(2, 3)
view_shape = [1, -1] + [1] * (out.dim() - 2)
print(bias.view(view_shape).shape)
> torch.Size([1, 3])
out = torch.randn(2, 3, 4)
view_shape = [1, -1] + [1] * (out.dim() - 2)
print(bias.view(view_shape).shape)
> torch.Size([1, 3, 1])
out = torch.randn(2, 3, 4, 5)
view_shape = [1, -1] + [1] * (out.dim() - 2)
print(bias.view(view_shape).shape)
> torch.Size([1, 3, 1, 1])
1 Like
Cool! it is nicer, tkx.