How to cache intermediate interactions of model parameters that are not input-dependent?

Assume that I have a model that has two parameters A and B, takes as input a matrix C, and returns f(g(A, B), C). I could compute this during every forward fun, but we really should be able to cache g(A, B). I tried to compute self.D = g(A, B) in __init__, but the device doesn’t match when I do f(self.D, C). Alternatively, if I do f(, C), the forward pass works but the gradient device is wrong during the optimizer step. I think I could also somehow always keep self.D in GPU memory at all times, but is that necessary? What’s the best way to accomplish this?

What kind of error are you getting?
The to() operation is differentiable, thus the parameters of different models can be on different devices.

My apologies. The error is actually coming from a 3rd party transformer library. I saw the error was coming from optimizer.step() so I didn’t look in more detail. The error stack trace is pasted below, but I will ask in their repo. Thanks!

File "", line 71, in optimizer_step
File ".../lib/python3.7/site-packages/torch/optim/", line 67, in wrapper
return wrapped(*args, **kwargs)
File ".../lib/python3.7/site-packages/transformers/", line 155, in step
exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
RuntimeError: expected device cuda:0 but got device cpu

Sure, let us know, if you figured out the issue.
Based on the error message my best guess is that some internal buffers using in the optimizer are not pushed to the right device, so maybe an optimizer.cuda() method was implemented in this repository?
If not, try to push the model to the device before passing it to the optimizer.