Repeat a nn.Parameter for efficient computation

Hi everyone.

I am implementing local reparameretization (https://papers.nips.cc/paper/5666-variational-dropout-and-the-local-reparameterization-trick.pdf) and realized that I need somehow a matrix that has the same vector parameter row-wise. Supposing a layer with 512 neurons.

If I code this:

bias=nn.Parameter(torch.zeros(512,)).repeat(batch,1)

If i now sample from this bias matrix, does pytorch (when performing backward) know that each row is the same parameter?

What I want to do is avoid this:

bias=nn.Parameter(torch.zeros(512,))
for i in range(batch):
      bias.sample()
1 Like

Hi,

Yes it will work.
Be careful though that if you do bias=nn.Parameter(torch.zeros(512,)).repeat(batch,1), the python vairable bias will not contain the original Parameter object and is not the tensor that will be optimized.
You would need to do:

bias_for_optimizer = nn.Parameter(torch.zeros(512,))
bias = bias_for_optimizer.repeat(batch, 1)
4 Likes

thanks.

What is the exact reason of why I cannot code it in one line?

Because in the one liner, the variable that you get is not a leaf tensor and so will not have it’s gradient saved. Meaning you won’t be able to optimize it:

import torch 
from torch import nn

a = nn.Parameter(torch.rand(10))
b = a.repeat(2)

b.sum().backward()

print("b is leaf: ",b.is_leaf) # False
print("b.grad: ", b.grad) # None
print("a is leaf: ",a.is_leaf) # True
print("a.grad: ", a.grad) # some gradients

a = nn.Parameter(torch.rand(10)).repeat(2)

a.sum().backward()
print("a is leaf: ",a.is_leaf) # False
print("a.grad: ", a.grad) # None


The problem with this approach is that the resulting tensor bias is not a parameter:

bias_for_optim=nn.Parameter(torch.zeros(topology[idx+1],).cuda())
bias=bias_for_optim.repeat(batch,1)
print type(bias)

and cannot be registered in the module. Shoud I put into nn.Parameter again after repeat?

That’s my point,
You should register bias_for_optim in the module and for the optimization.
The bias=bias_for_optim.repeat(batch,1) should only be done during the forward pass.

1 Like

Ah okei, I wanted to avoid calling repeat in each forward step for efficience. However I suppose .repeat() is good optimized (at least would be much quicker than looping at python level)

Thanks @albanD

Hi,

You can consider that repeat is for free litterally. It changes 2 numbers on the cpu memory. You should not worry about it :wink:

Yes, I was talking more about memory reservation, however pytorch pool memory allocator should not have problems at this level, my intention was to have all memory allocated to avoid the typical problems that arise when you allocate and deallocate dynamically.