How to decrease gpu usage in F.grid_sample?

I have a ground thruth tensor of [n_imgs, classes, H, W], and I expand it to [batch_size, n_imgs, classes, H, W] (expand does not repeat it in memory). Then for each img, I got [batch_size, classes, H, W] tensor, and I use it for the input of F.grid_sample. Forward computing works well while in backward, the gpu usage is around 76 GB. I try to compute it each bach while it costs too much time. So how to construct the input tensor to backward gradient fast?