Understanding torch auto-grad

Hello.
I’m trying to understand the following torch code.

ran_model = Random()
ran_optim = torch.optim.SGD(
    ran_model.parameters(),0.01
)

model = Model()
optim = torch.optim.SGD(
    model.parameters(),0.01
)
model_params = list(model.parameters())
for i in range(5):
    loss_mod  = torch.mean(torch.log(model.forward(x)[:,i]))
    loss_rand = torch.mean(torch.log(model.forward(y)[:,i]))
    
    model_grad = torch.autograd.grad(loss_mod, model_params)
    rand_grad  = torch.autograd.grad(
        loss_rand, model_params, create_graph=True
    )
    
    loss = 0
    for j in range(len(model_grad)):
        a = model_grad[j]
        b = rand_grad[j]
        loss = loss + torch.stack(
            [a,b], dim=0
        ).sum(dim=0).sum(dim=0).mean(0)
        
    ran_model.zero_grad()
    loss.backward()
    ran_optim.step()
    
    new_loss = -torch.mean(torch.log(model.forward(y))) 
    optim.zero_grad()
    new_loss.backward()
    optim.step()
    print('loss value  ', loss, new_loss)

Here,

  1. Get losses after forward pass, loss_mod and loss_rand.
loss_mod  = torch.mean(torch.log(model.forward(x)[:,i]))
loss_rand = torch.mean(torch.log(model.forward(y)[:,i]))
  1. Compute grads of loss_mod and loss_rand w.r.t list(model.parameters()).
model_grad = torch.autograd.grad(loss_mod, model_params)
rand_grad  = torch.autograd.grad(
  loss_rand, model_params, create_graph=True
)
  1. Update the above grads, just modify them in some ways.
loss = some_method(model_grad, rand_grad)
  1. Make grad zero to the ran_model = Random() model. And compute the gradients from the above-updated

loss (comes from two gradients) w.r.t ran_model model’s params. And next, update the state of rand_model with its optimizer.

ran_model.zero_grad()
loss.backward()
ran_optim.step()

I’m confused about step 4.
How the gradients are calculated here (loss.backward()) w.r.t ran_model? The loss is not achieved from ran_model. So, there should be no connection. How does Torch handle these cases? These loss values which come from two gradients ('model_grad and rand_grad) are achieved from model = Model() and not from ran_model = Random().

If there is no connection, then gradients wouldn’t be calculated wrt ran_model. Unlike .grad(), .backward() doesn’t complain if there is no connection, it would just accumulate the gradients to the inputs that are connected in the graph.

Could you please elaborate?

It would just accumulate the gradients to the inputs that are connected in the graph.

In the code, while computing the grad of loss_rand w.r.t. model_params, we set create_graph=True. I think that means, here somehow loss_rand and model_params getting connected, thus allowing us to compute grad of grads in the later part.

rand_grad  = torch.autograd.grad(
  loss_rand, model_params, create_graph=True
)

Yes that means params.grad is now a function of loss_rand potentially, hence connected

It would just accumulate the gradients to the inputs that are connected in the graph.

Ah, all I mean here is that the preferred way of figuring out what is connected to what is with .grad() rather than .backward()

Thank you.
Could you please suggest any blogs/docs that describe the behavior of such cases of grad calculation in the torch?

If you are trying to understand higher order graphs, e.g. double backward, there’s this one that may be helpful Double Backward with Custom Functions — PyTorch Tutorials 2.0.1+cu117 documentation, technically it is describing how one would implement custom functions that can support autograd, but it has some visualizations of what the graph would look like