Hi,
I am trying to implement a CPU loss that applies parallelization across batch.
When running the backward function I receive a very strange error:
Traceback (most recent call last):
File "test.py", line 8, in <module>
torch.sum(loss(a, b)).backward()
File "/home/.local/lib/python3.7/site-packages/torch/_tensor.py", line 255, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
File "/home/.local/lib/python3.7/site-packages/torch/autograd/__init__.py", line 149, in backward
allow_unreachable=True, accumulate_grad=True) # allow_unreachable flag
RuntimeError: res[i].defined()INTERNAL ASSERT FAILED at "../torch/csrc/autograd/functions/tensor.cpp":97, please report a bug to PyTorch.
I’ve managed to reduce my complicated function to something simple in order to provide you with a reproducible example.
Here are the instructions on how to set it up (I use linux/ubuntu 18.04 and pytorch 1.5.0)
To use this create a dir named ‘loss’ and add the following files below it:
setup.py
from setuptools import setup, Extension
from torch.utils import cpp_extension
from __init__ import __version__
cpp_module = cpp_extension.CppExtension(
'loss',
sources=['loss.cpp'],
extra_compile_args=['-fopenmp'],
extra_link_args=['-lgomp']
)
setup(name='loss',
ext_modules=[cpp_module],
cmdclass={'build_ext': cpp_extension.BuildExtension},
)
test.py
import torch
torch.set_num_threads(3)
from loss import loss
a = torch.nn.Parameter(torch.ones((3, 1000)))
b = torch.zeros((3, 1000))
torch.sum(loss(a, b)).backward()
loss.cpp
#include <torch/extension.h>
#include <iostream>
#include <ATen/Parallel.h>
torch::Tensor loss(torch::Tensor z, torch::Tensor x)
{
torch::Tensor z_out = at::empty({z.size(0), z.size(1)}, z.options());
int64_t batch_size = z.size(0);
at::parallel_for(0, batch_size, 0, [&](int64_t start, int64_t end) {
for (int64_t b = start; b < end; b++)
{
z_out[b] = z[b] - x[b];
}
});
return z_out;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("loss", &loss, "loss");
}
test.sh
rm -rf build dist loss.egg-info __pycache__
pip uninstall loss -y
python setup.py install --user
python test.py
To run it move inside the ‘loss’ directory and run: sh test.sh
Note: the error doesn’t always appear. You may need to run the script multiple times to see it. Let me know if this is a version problem.