DataParallel without gradient updates

I’m trying to run a non-gradient based computation-intensive optimization algorithm (say FNograd) using pytorch tensors on multiple GPUs along with my model. I have input images go through the algorithm to generate secondary inputs for my gradient-based neural network. Here’s a logical flow of what I want to achieve.

Input image batch ----> CNN —> loss1
–FNograd —> Fgradientweights —> loss2

Here, I want FNograd, Fgradientweights and CNN running DataParallel. First I just ran FNograd on 1 GPU and the CNN using DataParallel and it worked quite well, although very slow. Then, I tried manually sending every input image’s computation of FNograd to GPUs one by one in a loop. This works but I’m sure there’s a better way to do this using multiple threads as GPU usage is sequential in this case. I didn’t want to create a fake model for the sake of using DataParallel since I don’t use gradients at all in FNograd.

cuda_devices = ... #[torch devices list]

for i in range(0, input.size()[0], len(cuda_devices)):
   for j in range(len(cuda_devices)):
      Send data item [i+j] to j'th GPU

Compose transformed input data for DataParallel(Fgradientweights)
Compute DataParallel(CNN(input))
Do backward computations

Any help is appreciated. Thanks.