Most efficient way to run model on gpu while optimizing input

Hello everyone,

I want to to optimize the input of a model, here the somewhat pythonic pseudocode:

var = torch.rand()
optim = Adam(var)

x= model1(var)

for i in range(10):
  with torch.autocast(device_type="cuda",dtype=torch.float16):
    x = model2(x)

with torch.autocast(device_type="cuda",dtype=torch.float16):
  x = model3(x)

loss = x.sum()
loss.backward()
otpim.step()

The problem being that (at some step of the for loop) the gpu memory overflows. When using torch.no_grad() the complete forward pass fits on the gpu .The question being, for this specific type of optimization architecture, does anybody know a blog post or has any clue on how this can be implemented efficiently while being easy on the gpu memory? Is it e.g. possible to do the forward pass on the gpu and get gradients on the cpu?

Thanks for the help :slight_smile:

You could apply CPU-offloading as described here or reduce the number of iterations, if possible.