Hello,
I am new to pytorch (I’m using 1.6.0), and I know that this topic or similar has a number of entries, but after studying them I can’t yet see the problem with my code, and would appreciate help with this. I define the following model:
import torch
“”" Model definition “”"
class NNModel( torch.nn.Module ):
def __init__( self, nFeatures, nNeurons ):
"""
The model consists of two hidden layers with tanh activation and a single neuron output
from the third layer. The input to the first layer is a tensor containing the input features
(nFeatures); the output of the third layer is a single number.
"""
super( NNModel, self).__init__()
self.linear1 = torch.nn.Linear( nFeatures, nNeurons )
self.activn1 = torch.nn.Tanh()
self.linear2 = torch.nn.Linear( nNeurons, 1 )
self.activn2 = torch.nn.Tanh()
def forward( self, x ):
"""
x is a tensor containing all symmetry functions for the present configuration; therefore
it has dimensions (nObservations, nFeatures). The model must loop over each observation,
calculating the contribution of each one to the output (the sum of them).
"""
nObservations, _ = x.shape
z = torch.zeros( nObservations, requires_grad = True )
for n in range( nObservations ):
y = self.linear1( x[n,:] )
y = self.activn1( y )
y = self.linear2( y )
z[n] = self.activn2( y )
addition = z.sum()
return addition
My loss functions and optimizer are:
lossFunction = torch.nn.MSELoss( reduction = ‘sum’ )
optimizer = torch.optim.SGD( model.parameters(), lr=1.0e-4 )
and I run this in a loop like so:
for t in range( 500 ):
# forward pass
for n in range( nCases ):
y_pred[n] = model( sym[n] )
# compute and print loss
loss = lossFunction( y_pred, energy )
print( t, loss.item() )
# zero gradients, perform a backward pass and update weights
optimizer.zero_grad()
loss.backward( )
optimizer.step()
The first pass through the loop prints a loss value, but on the next iteration the program crashes with the known RuntimeError: leaf variable has been moved into the graph interior problem.
I guess this is to do with the loop over nObservations in the forward function definition, but I do not understand why nor what can I do to solve this problem. Any help would be appreciated. Thanks!