Tracking the parameter gradients of different inputs

Hello, I’ve defined a simple network given by:

class RBM_Network(nn.Module):

    def __init__(self, n_visible, n_hidden):
        super(RBM_Network, self).__init__()

        self.n_visible = n_visible
        self.n_hidden = n_hidden

        # define network layers 
        self.hidden = nn.Linear(n_visible, n_hidden)
        self.output = nn.Linear(n_hidden, 1)
        self.sigmoid = nn.Sigmoid()


    # forward function is our feed forward algorithm
    def forward(self, s):

        s = self.hidden(s)
        s = self.sigmoid(s)
        s = self.output(s)
        s = self.sigmoid(s)

        return s

The network’s input is a spin state given by either [0,0], [0,1], [1,0], or [1,1]. It outputs a single value representing a wavefunction coefficient. Based on the way I’ve implemented the network so far, it has 9 parameters. When passing one of these states through the network and runing .backward() on the output, I get a unique set of parameter gradients for the given state (which makes sense). I’m trying to store each of these unique parameter gradients in a list for each spin state. This seems like it should be relatively easy, but it appears the the parameter gradients are being overwritten when I try and append them to a simple list, param_grads. My implementation code is given below. I’m expecting param_grads to be a list of 4 sub-lists (one sub-list for each spin state), each sub-list containing 9 unique tensors that store the gradients of the network output wrt each input. However, param_grads just ends up being a list of 4 identical sub-lists (the parameter gradients given by the most recent state passed through the network). I appreciate any help here, I’m relatively new to Pytorch and might be missing something simple.

import torch
import numpy as np 
from RBM_Network import RBM_Network

inputs = 2
hidden_nodes = 2
spin1 = torch.tensor([0,0], dtype=torch.float)
spin2 = torch.tensor([0,1], dtype=torch.float)
spin3 = torch.tensor([1,0], dtype=torch.float)
spin4 = torch.tensor([1,1], dtype=torch.float)
spins = [spin1, spin2, spin3, spin4]

# initialize network
network = RBM_Network(inputs, hidden_nodes)
param_grads = []
psi_omega = []

# pass spin states through network to obtain coefficients for psi_omega
# track parameter gradients for each spin state in param_grads
for spin in spins:
    
    network.zero_grad()
    psi = network(spin)
    psi.backward()
    psi_omega.append(psi)
    
    current_grads = []
    for param in network.parameters():
        current_grads.append(param.grad)
        
    param_grads.append(current_grads)
    print(param_grads)
    print("\n")

these tensors are reused, so you need to do param.grad.clone()

1 Like