Hi, I am new to the Pytorch C++ extension. I wanted to translate my python code into C++ but got a “Runtime Error: CUDA Out of Memory” while training. Here are my python codes and C++ codes. I wonder if there are any obvious mistakes in my C++ codes leading to an OOM.
The below coding block (in pure python) can run without errors.
class TopKLoss(object):
def __init__(self, c=5., T=10000, clamp_min=-10):
self.c = c
self.T = T
self.clamp_min = clamp_min
def __call__(self, outputs, labels):
loss = []
for w, A in zip(outputs, labels):
A = torch.from_numpy(np.array(A)).to(w.device)
w = w.view(1, -1)
A = A.view(1, -1)
loss.append(
nll_topk_one_sample(
w, A, self.c, self.T, self.clamp_min))
return torch.stack(loss).mean()
Then I translate it into corresponding C++ implementations and use pybind to create the extension.
class TopKLoss(object):
def __init__(self, c=5., T=10000, clamp_min=-10):
self.c = c
self.T = T
self.clamp_min = clamp_min
def __call__(self, outputs, labels):
# use c++ extension here
loss = TopKLoss_call_cimpl(outputs, labels, self.c, self.T, self.clamp_min)
return loss
torch::Tensor TopKLoss_call_cimpl(torch::Tensor outputs,
std::vector<std::vector<int>> labels,
double c, double T, double clamp_min) {
std::vector<torch::Tensor> loss;
int64_t size = (int64_t)labels.size();
loss.reserve(size);
at::parallel_for(0, size, 0, [&](int64_t begin, int64_t end) {
for (auto i = begin; i < end; i++) {
auto w = outputs.index({i}).view({1, -1});
auto A = torch::tensor(labels[i], w.device()).view({1, -1});
loss.push_back(nll_topk_one_sample_cimpl(w, A, c, T, clamp_min));
}
});
return torch::stack(loss).mean();
}
Then I got
Traceback (most recent call last):
File "cmain.py", line 283, in <module>
loss = criterion(outputs, labels) # loss
File "/home/jimmyzxj/Research/ListwiseLTR/closs.py", line 79, in __call__
loss = TopKLoss_call_cimpl(outputs, labels, self.c, self.T, self.clamp_min)
RuntimeError: CUDA out of memory. Tried to allocate 504.00 MiB (GPU 0; 15.78 GiB total capacity; 13.10 GiB already allocated; 443.50 MiB free; 14.23 GiB reserved in total by PyTorch)
Please let me know if you can spot any reasons that might lead to an OOM. Great thanks!