I want to to optimize the input of a model, here the somewhat pythonic pseudocode:
var = torch.rand() optim = Adam(var) x= model1(var) for i in range(10): with torch.autocast(device_type="cuda",dtype=torch.float16): x = model2(x) with torch.autocast(device_type="cuda",dtype=torch.float16): x = model3(x) loss = x.sum() loss.backward() otpim.step()
The problem being that (at some step of the for loop) the gpu memory overflows. When using torch.no_grad() the complete forward pass fits on the gpu .The question being, for this specific type of optimization architecture, does anybody know a blog post or has any clue on how this can be implemented efficiently while being easy on the gpu memory? Is it e.g. possible to do the forward pass on the gpu and get gradients on the cpu?
Thanks for the help