The use case is a bit strange, but should generally work as seen in this small example overfitting a static target:
class DummyGenerator(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
self.dummy_tensor = nn.Parameter(torch.rand((1, 2, 1024, 128), requires_grad=True))
def forward(self):
return self.dummy_tensor
model = DummyGenerator()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()
y = torch.randn(1, 2, 1024, 128)
for epoch in range(10000):
optimizer.zero_grad()
out = model()
loss = criterion(out, y)
loss.backward()
optimizer.step()
print('epoch {}, loss {}'.format(epoch, loss.item()))
# ...
# epoch 9998, loss 1.416539715387577e-11
# epoch 9999, loss 1.4069788736859046e-11
Check if your parameter has a valid .grad
attribute after the backward
call and is indeed being updated.