Yes it will work.
Be careful though that if you do bias=nn.Parameter(torch.zeros(512,)).repeat(batch,1), the python vairable bias will not contain the original Parameter object and is not the tensor that will be optimized.
You would need to do:
Because in the one liner, the variable that you get is not a leaf tensor and so will not have it’s gradient saved. Meaning you won’t be able to optimize it:
import torch
from torch import nn
a = nn.Parameter(torch.rand(10))
b = a.repeat(2)
b.sum().backward()
print("b is leaf: ",b.is_leaf) # False
print("b.grad: ", b.grad) # None
print("a is leaf: ",a.is_leaf) # True
print("a.grad: ", a.grad) # some gradients
a = nn.Parameter(torch.rand(10)).repeat(2)
a.sum().backward()
print("a is leaf: ",a.is_leaf) # False
print("a.grad: ", a.grad) # None
That’s my point,
You should register bias_for_optim in the module and for the optimization.
The bias=bias_for_optim.repeat(batch,1) should only be done during the forward pass.
Ah okei, I wanted to avoid calling repeat in each forward step for efficience. However I suppose .repeat() is good optimized (at least would be much quicker than looping at python level)
Yes, I was talking more about memory reservation, however pytorch pool memory allocator should not have problems at this level, my intention was to have all memory allocated to avoid the typical problems that arise when you allocate and deallocate dynamically.