Hello, I’ve defined a simple network given by:
class RBM_Network(nn.Module):
def __init__(self, n_visible, n_hidden):
super(RBM_Network, self).__init__()
self.n_visible = n_visible
self.n_hidden = n_hidden
# define network layers
self.hidden = nn.Linear(n_visible, n_hidden)
self.output = nn.Linear(n_hidden, 1)
self.sigmoid = nn.Sigmoid()
# forward function is our feed forward algorithm
def forward(self, s):
s = self.hidden(s)
s = self.sigmoid(s)
s = self.output(s)
s = self.sigmoid(s)
return s
The network’s input is a spin state given by either [0,0], [0,1], [1,0], or [1,1]. It outputs a single value representing a wavefunction coefficient. Based on the way I’ve implemented the network so far, it has 9 parameters. When passing one of these states through the network and runing .backward()
on the output, I get a unique set of parameter gradients for the given state (which makes sense). I’m trying to store each of these unique parameter gradients in a list for each spin state. This seems like it should be relatively easy, but it appears the the parameter gradients are being overwritten when I try and append them to a simple list, param_grads
. My implementation code is given below. I’m expecting param_grads
to be a list of 4 sub-lists (one sub-list for each spin state), each sub-list containing 9 unique tensors that store the gradients of the network output wrt each input. However, param_grads
just ends up being a list of 4 identical sub-lists (the parameter gradients given by the most recent state passed through the network). I appreciate any help here, I’m relatively new to Pytorch and might be missing something simple.
import torch
import numpy as np
from RBM_Network import RBM_Network
inputs = 2
hidden_nodes = 2
spin1 = torch.tensor([0,0], dtype=torch.float)
spin2 = torch.tensor([0,1], dtype=torch.float)
spin3 = torch.tensor([1,0], dtype=torch.float)
spin4 = torch.tensor([1,1], dtype=torch.float)
spins = [spin1, spin2, spin3, spin4]
# initialize network
network = RBM_Network(inputs, hidden_nodes)
param_grads = []
psi_omega = []
# pass spin states through network to obtain coefficients for psi_omega
# track parameter gradients for each spin state in param_grads
for spin in spins:
network.zero_grad()
psi = network(spin)
psi.backward()
psi_omega.append(psi)
current_grads = []
for param in network.parameters():
current_grads.append(param.grad)
param_grads.append(current_grads)
print(param_grads)
print("\n")