Hi! I am trying to create a model that makes predictions and updates the current state based on this. However, when I try keep track of the state and updating this, the computation graph seems to break. How do I address this issue/is there something I’m doing incorrectly in my code:
import torch
import torch.nn as nn
from copy import deepcopy
import math
from torch.nn import Linear, ModuleList, Sequential, ReLU
torch.manual_seed(0)
import numpy as np
class FeedForward():
def __init__(self, input_size, hidden_size):
self.input_size = input_size
self.hidden_size = hidden_size
self.mlp = Sequential(
torch.nn.Linear(self.input_size, self.hidden_size),
torch.nn.ReLU(),
torch.nn.Linear(self.hidden_size, 1),
torch.nn.Sigmoid()
)
def forward(self, x):
output = self.mlp(x)
return output
model = FeedForward(8,32)
z = torch.tensor([math.pi/2.3], dtype=torch.float64)
curr_state = torch.tensor([[0, 1, z]]) #state variables, (for ex pos, vel, accel)
curr_state = curr_state.repeat(5,1) #batch size of 5, all with same curr_state
opt = torch.optim.Adam(model.mlp.parameters(), lr=1e-3)
next_pos = curr_state[:,0] #assume this is current_pos
for ep in range(10):
with torch.autograd.set_detect_anomaly(True):
print(ep)
opt.zero_grad()
x = torch.rand(5, 8)
y = torch.rand(5, 1)
delta_pos = model.mlp(x)
#next_pos = next_pos.detach() --- note adding this fixes this issue. why is this needed? will this solve the issue?
next_pos = next_pos + delta_pos
loss = torch.sum((next_pos-y)**2)
print(loss)
loss.backward(retain_graph=True)
opt.step()
Thanks!