Can this code be nicer?

if module.bias is not None:
    if len(out.shape) == 4:
        out = out + module.bias.view(1, -1, 1, 1)
    if len(out.shape) == 3:
        out = out + module.bias.view(1, -1, 1)
    if len(out.shape) == 2:
        out = out + module.bias.view(1, -1)

I don’t know what the original shape of model.bias is, but if you need to unsqueeze the dimensions, this code should work:

bias = torch.randn(3)

out = torch.randn(2, 3)
view_shape = [1, -1] + [1] * (out.dim() - 2)
print(bias.view(view_shape).shape)
> torch.Size([1, 3])

out = torch.randn(2, 3, 4)
view_shape = [1, -1] + [1] * (out.dim() - 2)
print(bias.view(view_shape).shape)
> torch.Size([1, 3, 1])

out = torch.randn(2, 3, 4, 5)
view_shape = [1, -1] + [1] * (out.dim() - 2)
print(bias.view(view_shape).shape)
> torch.Size([1, 3, 1, 1])
1 Like

Cool! it is nicer, tkx.