Everywhere I read, people say just pass retain_graph=True
to solve my issue, but I’d like to know what is under the hood. For example, the loss.backward()
below doesn’t need to pass in any retain_graph
class Mnist_Logistic(nn.Module):
def __init__(self):
super().__init__()
self.lin = nn.Linear(784, 10)
def forward(self, xb):
return self.lin(xb)
for epoch in range(epochs):
pred = model(input)
loss = loss_func(pred, output)
loss.backward()
For my model below, pytorch complains that it deletes the computational graph already, so if I run backward()
in the next epoch, it could not access states variable. It seems that it can in the code above.
class Ode(nn.Module):
def __init__(self, len_data, alpha=0.57, beta=0.11, ...):
super().__init__()
self.I, self.E, self.H = torch.zeros(1), torch.zeros(1), torch.zeros(1)
self.S = nn.Parameter(torch.tensor([0.5]).to(device))
self.sigma = nn.Parameter(torch.tensor([sigma])).to(device)
# many nn.Parameter later
def f(self, t, y):
# only + and * parameters together
def forward(self, I0, E0, H0):
return torchdiffeq.odeint(self.f, t=time_range, y0=self.y0, method='rk4')
for data, y_exact in tqdm(train_dataloader):
optimizer.zero_grad()
y_approx = model(data[0,0,0], data[0,0,1], data[0,0,2])
loss = loss_fun(y_approx[1:, [1, 6, 7]], y_exact.squeeze())
loss.backward()
optimizer.step()
scheduler.step()