class Fun(torch.autograd.Function):
@staticmethod
def forward(ctx, *args, **kwargs):
@staticmethod
def backward(ctx, *grad_outputs):
class NN(nn.Module):
def __init__(self):
super(NN, self).__init__()
self.fc1 = nn.Linear(784, 800)
self.fc2 = nn.Linear(800, 10)
def forward(self, inpu, L=5):
for step in rang(L):
x = input.view(batch,-1)
x = self.fc1(x)
x = self.fc2(x)
return x
nn = NN()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(nn.parameters(), lr=learning_rate,)
for epoch in range(e):
for i, (x, y) in enumerate(train_loader):
outputs = nn(x)
y_ = torch.zeros(batch_size, 10).scatter_(1, y.view(-1, 1), 1)# ont-hot encoding
loss = criterion(outputs.cpu(), y)
loss.backward()
As shown above, I need to loop multiple times in the forward stage, but I want to update parameters only once in the Backward process, not all forward loops need to be updated.
What should I do, please.Thank you!