I wrote an operator, but cannot backward when using multi gpus, with error"arguments are located on different GPUs"

I wrote a new function to flip my tensor, when training on one gpu, the model works fine. But when using multi-gpus, it crashes when loss.backward().

def flip(x, dim):
dim = x.dim() + dim if dim < 0 else dim
return x[tuple(slice(None, None) if i != dim
else torch.arange(x.size(i)-1, -1, -1).long()
for i in range(x.dim()))]

class Flip(torch.autograd.Function):
“”"
input: shape with CHW
“”"
def forward(self, input):
out = input.clone()
out = flip(out, 1)
out = flip(out, 2)
return out

def backward(self, grad_outputs):
    grad = grad_outputs.clone()
    grad = flip(grad, 1)
    grad = flip(grad, 2)
    return grad