The issue of registering parameters using torch.nn.Parameter

I encountered a problem while reading and studying torch code, which I have simplified to the following model. I am attempting to train a model with only one parameter to fit y = 1000x. I tried two ways to register my parameter in the model. a1 can be trained normally, but obviously does not consider torch.exp() during forward propagation, while a2 causes the following error:

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

When I add retain_graph=True to backward() based on the error, the model can be trained normally, but I found that the model’s parameter a2 has not been updated, only a1 has. Here are some of the outputs:

Parameter containing:
tensor([2.7183], requires_grad=True) tensor([2.7183], grad_fn=<ExpBackward0>)
x * a1 = tensor([2.7183], grad_fn=<MulBackward0>), x * a2 = tensor([2.7183], grad_fn=<MulBackward0>)
a1 = Parameter containing:
tensor([22.6639], requires_grad=True), a2 = tensor([2.7183], grad_fn=<ExpBackward0>)
x * a1 = tensor([22.6639], grad_fn=<MulBackward0>), x * a2 = tensor([2.7183], grad_fn=<MulBackward0>)
a1 = Parameter containing:
tensor([42.2106], requires_grad=True), a2 = tensor([2.7183], grad_fn=<ExpBackward0>)
x * a1 = tensor([42.2106], grad_fn=<MulBackward0>), x * a2 = tensor([2.7183], grad_fn=<MulBackward0>)
a1 = Parameter containing:
tensor([61.3664], requires_grad=True), a2 = tensor([2.7183], grad_fn=<ExpBackward0>)
x * a1 = tensor([61.3664], grad_fn=<MulBackward0>), x * a2 = tensor([2.7183], grad_fn=<MulBackward0>)
a1 = Parameter containing:
tensor([80.1391], requires_grad=True), a2 = tensor([2.7183], grad_fn=<ExpBackward0>)
x * a1 = tensor([80.1391], grad_fn=<MulBackward0>), x * a2 = tensor([2.7183], grad_fn=<MulBackward0>)
a1 = Parameter containing:
tensor([98.5363], requires_grad=True), a2 = tensor([2.7183], grad_fn=<ExpBackward0>)
x * a1 = tensor([98.5363], grad_fn=<MulBackward0>), x * a2 = tensor([2.7183], grad_fn=<MulBackward0>)
a1 = Parameter containing:
tensor([116.5656], requires_grad=True), a2 = tensor([2.7183], grad_fn=<ExpBackward0>)
x * a1 = tensor([116.5656], grad_fn=<MulBackward0>), x * a2 = tensor([2.7183], grad_fn=<MulBackward0>)

I would like to know what the fundamental differences are between these two methods? Here is the model code:

import torch
class my_model(torch.nn.Module):
    def __init__(self):
        super(my_model, self).__init__()
        self.a1 = torch.nn.Parameter(torch.exp(torch.Tensor([1.0])), requires_grad=True)
        self.a2 = torch.exp(torch.nn.Parameter(torch.Tensor([1.0]), requires_grad=True))
    def forward(self, x):
        return x * self.a1, x * self.a2

# y = 1000 * x
x, y = torch.Tensor([1]), torch.Tensor([1000])
model = my_model()
model.train()
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
for i in range(1000):
    optimizer.zero_grad()
    y1, y2 = model(x)
    loss = criterion(y1, y) + criterion(y2, y)
    loss.backward(retain_graph=True)
    optimizer.step()
    print(f"x * a1 = {y1}, x * a2 = {y2}")
    print(f"a1 = {model.a1}, a2 = {model.a2}")

Actually, the code I read is as follows:"

if self.learned_temperature is True:
    self.temperature = (self.model_dim * torch.exp(torch.clamp(nn.Parameter(
        torch.Tensor([1]), requires_grad=True), min=0.005, max=5))).cuda()

from ATTEMPT/attempt/third_party/models/t5/modeling_t5.py at main · AkariAsai/ATTEMPT · GitHub line 916

have you tried using

torch.nn.Parameter(torch.exp(torch.Tensor([1.0])), requires_grad=True)

once you call torch.exp on a nn.Parameter I don’t think it’s a parameter anymore