How could I create a module with learnable parameters

Sometimes, we need to create a module with learnable parameters. For example, when we construct a-softmax module, we need the module contains a weight W which should be learnt and updated during the process of training.

By looking at the docs, it seems that I should use it like this:

class mod(nn.Module):
    def __init__(self):
         self.W = torch.tensor(torch.random(3,4,5), requires_grad=True)
    def forward(self, x):
        w_norm = torch.norm(self.W, 2, 1, True).expand_as(self.W)
        x_norm = torch.norm(x, 2, 1, True).expand_as(x)
        theta =, self.W) / (x_norm * w_norm)
        return theta
loss = mod()
params = loss.parameters()
optim = Optimizer(params, lr = 1e-3...)

I am just trying to confirm, is this the recommanded way to do this ?

2 Likes need to run __init__ for mod's base class
2. the parameters should be wrapped within a torch.nn.Parameter
3. torch.random is a module, say you are using torch.randn for normal distribution, then putting these together:

def __init__(self):
  self.W = torch.nn.Parameter(torch.randn(3,4,5))
  self.W.requires_grad = True

then when you do loss.paramters() you can see W can verify this by doing:
hex(id(next(loss.parameters()))) , hex(id(loss.W))


Got it, thanks a great deal


I tried as you say, but when I add the parameters to an optimizer:

params = [net.parameters(), loss.parameters()]
opt = optim.Adam(params, lr = 1e-3, weight_decay = 5e-4)

It invokes the error message:

File “/home/zhangzy/.local/lib/python3.5/site-packages/torch/optim/”, line 192, in add_param_group
"but one of the params is " + torch.typename(param))
TypeError: optimizer can only optimize Tensors, but one of the params is Module.parameters

Have a look here. *.parameters() creates a generator, normally we would do torch.optim.Adam(loss.parameters(),lr = ...) when dealing with just one set of parameters, but here since you have 2 sets you will need to make a list out of one generator and extend it:

params = list(net.parameters())
opt = torch.optim.Adam(params,lr=1e-3,weight_decay=5e-4)

Your solution works for me, many thanks