How to do epoch dependent calculation in loss function for CUDAGraph

Hi,

I have a loss function to calculate loss dependent on epoch, e.g.

optimizer.zero_grad()
forward_func()
loss = loss_func( epoch )
loss.backward()

When trying to use CUDAGraph, epoch is never changed during replaying after stream capture.
(No use of CUDAGraph works well.)
Any idea about how to pass epoch during CUDAGraph replaying?

Thanks,
Joe

Try to pass it as a tensor.

Thank you. It works.
I realized that the argument of the function in captured stream shouldn’t be changed during replaying. So, the static address like torch tensor (during replaying) must be given.