Hello.

I’m trying to understand the following torch code.

```
ran_model = Random()
ran_optim = torch.optim.SGD(
ran_model.parameters(),0.01
)
model = Model()
optim = torch.optim.SGD(
model.parameters(),0.01
)
model_params = list(model.parameters())
```

```
for i in range(5):
loss_mod = torch.mean(torch.log(model.forward(x)[:,i]))
loss_rand = torch.mean(torch.log(model.forward(y)[:,i]))
model_grad = torch.autograd.grad(loss_mod, model_params)
rand_grad = torch.autograd.grad(
loss_rand, model_params, create_graph=True
)
loss = 0
for j in range(len(model_grad)):
a = model_grad[j]
b = rand_grad[j]
loss = loss + torch.stack(
[a,b], dim=0
).sum(dim=0).sum(dim=0).mean(0)
ran_model.zero_grad()
loss.backward()
ran_optim.step()
new_loss = -torch.mean(torch.log(model.forward(y)))
optim.zero_grad()
new_loss.backward()
optim.step()
print('loss value ', loss, new_loss)
```

Here,

- Get losses after forward pass,
`loss_mod`

and`loss_rand`

.

```
loss_mod = torch.mean(torch.log(model.forward(x)[:,i]))
loss_rand = torch.mean(torch.log(model.forward(y)[:,i]))
```

- Compute grads of
`loss_mod`

and`loss_rand`

w.r.t`list(model.parameters())`

.

```
model_grad = torch.autograd.grad(loss_mod, model_params)
rand_grad = torch.autograd.grad(
loss_rand, model_params, create_graph=True
)
```

- Update the above grads, just modify them in some ways.

```
loss = some_method(model_grad, rand_grad)
```

- Make grad zero to the
`ran_model = Random()`

model. And compute the gradients from the above-updated

loss (comes from two gradients) w.r.t `ran_model`

model’s params. And next, update the state of `rand_model`

with its optimizer.

```
ran_model.zero_grad()
loss.backward()
ran_optim.step()
```

I’m confused about step 4.

How the gradients are calculated here (`loss.backward()`

) w.r.t `ran_model`

? The loss is not achieved from `ran_model`

. So, there should be no connection. How does Torch handle these cases? These loss values which come from two gradients (`'model_grad`

and `rand_grad`

) are achieved from `model = Model()`

and not from `ran_model = Random()`

.