Multiply feature map by a learnable scalar

I’d like to know if there is a way to multiply the output of a convolutional layer (a set of N feature maps) by N learnable multipliers. Or, similarly, how to multiply all feature maps in a stack by one single learnable parameter. What layer/function should I use? My case is that I have the outputs of two parallel CNN branches, A and B, with same sizes and number of feature maps, and I want to make a new output C = alphaA + betaB, where alpha and beta are learnable parameters. Thanks!

4 Likes

Make your scalar a Variable containing a 1D tensor, and use the expand_as function.

matrix = Variable(torch.rand(3,3))
scalar = Variable(torch.rand(1), requires_grad=True)
output = matrix * scalar.expand_as(matrix)
7 Likes

Thank @fmassa, I have another question though.
How do I perform training with such a multiplier?
Right now, I have this:

in my Model definition:

    self.multip = torch.autograd.Variable(torch.rand(1).cuda(), requires_grad=True)
    self.multip = self.multip.cuda()

in my Model forward:

def forward(self, x):
  x1 = self.relu(self.conv1(x))
  x2 = self.relu(self.conv2(x))
  x1 = x1 * self.multip.expand_as(x1) # multiply x1 output by learnable parameter "multip"
  x = torch.add(x1, x2)
  return x

and in my training code:

optimizer.zero_grad()
loss = criterion(model(input), target)
loss.backward()
optimizer.step()

As you see, nothing out of the ordinary. But when I print the value of model.multip.data[0], I see that the initial value of multip remains unchanged (all other params in my Model do change). I deduced that probably I messed somewhere and did not apply the gradients to self.multip
Am I right?

If you are using a nn.Module and the multiplier is inside the network, you need to make it a nn.Parameter so that it is registered as a parameter for you call model.parameters().
So instead of using a Variable to encapsulate your multiplier, you should use a nn.Parameter.

Quick example:

class Model1(nn.Module):
    def __init__(self):
        super(Model1, self).__init__()
        self.multp = Variable(torch.rand(1), requires_grad=True)

class Model2(nn.Module):
    def __init__(self):
        super(Model2, self).__init__()
        self.multp = nn.Parameter(torch.rand(1)) # requires_grad is True by default for Parameter

m1 = Model1()
m2 = Model2()

print('m1', list(m1.parameters()))
print('m2', list(m2.parameters()))
10 Likes

Thank you. Alright, I can now see nonzero values in
m2.multp.grad.data
after I do the optimizer steps, but the value of multp still doesn’t change during training. Will investigate. I think it might have something to do with the learning rate.

Now, how can I specify a special learning rate for this new param? I used to do something like this for my convolutional layers to make a specific smaller learning rate for layer ‘conv2’.

optimizer = optim.Adam([{‘params’: model.conv2.parameters(), ‘lr’: opt.lr*0.1}], lr=opt.lr)

Now if I do this:
optimizer = optim.Adam([{‘params’: model.multp.parameters(), ‘lr’: opt.lr*0.1}], lr=opt.lr)

I get an error File “/home/anaconda2/lib/python2.7/site-packages/torch/autograd/variable.py”, line 63, in getattr
raise AttributeError(name)
AttributeError: parameters

UPD: I replaced my ‘custom’ Adam (see above) that had different learning rates by a regular Adam which has the same learning rate for all the layers, and it worked - my multp started to become updated during the course of training (it also worked with SGD). However my question with setting a special learning rate for the nn.Parameter remains.

1 Like

You can simply do something like

optimizer = optim.Adam([{'params':[model.multp], 'lr':opt.lr*0.1}], lr=opt.lr)
4 Likes

Great, it worked! Thank you.

Hey, I was wondering if is possible to mutiply each feature map by a different scalar?

Maybe using a 1x1 conv with group=n_channels?

2 Likes

Hi did you work it out? I also need to do something similar.