Visualizing weight updates where loss itself has gradients

Hi, I’ve written a NN where the loss function consists of gradients of the output with respect to the inputs. The code works fine, but I want to inspect how the gradients flow when I’m updating the NN weights with respect to the loss function. I can’t figure out how the computation graph is laid out here. I hope someone can help me understand how exactly are the gradients computed.

Here’s my code

dtype = torch.float
device = torch.device("cuda:0") # Uncomment this to run on GPU
x = torch.tensor([[1., 1.]], device=device, dtype=dtype, requires_grad=True)

w1 = torch.tensor([[ 0.6561,  0.6202, -1.5620],
        [ 1.1547, -1.3227,  0.4719]], device=device, dtype=dtype, requires_grad=True)
w2 = torch.tensor([[-0.4917],
        [-0.3886],
        [-1.4218]], device=device, dtype=dtype, requires_grad=True)

learning_rate = 0.01

z1 = x.mm(w1)
a1 = torch.sin(z1)
z2 = a1.mm(w2)
p = torch.sin(z2)
grads, = torch.autograd.grad(p, x, 
                grad_outputs=p.data.new(p.shape).fill_(1),
                create_graph=True, only_inputs=True)
dpdx, dpdt = grads[:,0], grads[:,1]
pde = dpdt + dpdx
loss = pde.pow(2).mean()

loss.backward()
with torch.no_grad():
    w1 -= learning_rate * w1.grad
    w2 -= learning_rate * w2.grad
    # Manually zero the gradients after updating weights
    w1.grad.zero_()
    w2.grad.zero_()       

I just can’t understand how I get w1.grad and w2.grad (i.e. what are the connections in the computation graph) when I run loss.backward().