I’m in a situation where I want to use a function that accepts and returns a NumPy array in a PyTorch Net, and I’m not sure the best way to do this without disrupting the computational graph. The function doesn’t have any learnable parameters, so it doesn’t need to be part of the graph. In my example, I’m trying to use the norm.ppf function from scipy.stats. My solution right now is to make a copy of my tensor, do the operation that returns a NumPy array on the copy, and then override the data in the original tensor, but this feels kinda hacky. Here’s the code I have now:
# x is a tensor of mini-batch inputs x_copy = x.clone().detach().cpu() x_copy = (torch.argsort(x_copy, dim=0)+0.99) / x_copy.shape x_copy = norm.ppf(x_copy) x.data = torch.Tensor(x_copy).to(get_device())
I’m doing it this way because if I don’t make a copy then the computational graph gets wiped out after the tensor gets converted to a NumPy array in line 3.