I try to implement a solver for a banded system (using torch.linalg.solve is not possible because of memory issue).
I write a custom autograd function like this:
"""Implement banded linear solver using torch autograd"""
import torch
from torch.autograd import Function
from einops import rearrange, repeat, einsum
if torch.cuda.is_available():
from hpfilter.cholesky_banded import (
cholesky_banded_solver_cpu,
cholesky_banded_solver_cuda,
)
else:
from hpfilter.cholesky_banded import cholesky_banded_solver_cpu
class BandedLinearSolver(Function): # pylint: disable=W0223
"""Class implementing the BandedLinearSolver with autograd"""
@staticmethod
def forward(
ctx, alpha: torch.Tensor, DTD: torch.Tensor, W: torch.tensor, X: torch.Tensor
):
assert X.dtype == torch.float64
X_ = X.detach().clone()
# Solve Banded Problem
# cholesky_bande_solver solves in place: X_ is modified after the function call
if X.is_cuda:
cholesky_banded_solver_cuda(DTD, W, X_, alpha.item())
else:
cholesky_banded_solver_cpu(DTD, W, X_, alpha.item())
# Save for backward
ctx.mark_non_differentiable(W, DTD, alpha)
ctx.save_for_backward(X_, W, DTD, alpha)
return X_
@staticmethod
def backward(ctx, grad: torch.Tensor):
# Load save tensor
X_, W, DTD, alpha = ctx.saved_tensors
Z = torch.bmm(repeat(DTD, "u v->b u v", b=X_.shape[0]), X_)
if X_.is_cuda:
cholesky_banded_solver_cuda(DTD, W, Z, alpha.item())
else:
cholesky_banded_solver_cpu(DTD, W, Z, alpha.item())
grad_alpha = rearrange(
-einsum(grad, Z, "b tu c, b tu c -> "), "->1"
) # need 1 dim tensor
return (grad_alpha, None, None, None)
the cholesky_banded_solver_cpu/_cuda
function are implemeted here and have been tested w.r.t standard solver (test can be seen from the link above): similar results are obtained.
The forward pass is ok for both cpu and gpu. However, when I check the backward pass using the gradcheck
pytorch function it works well for CPU, but for GPU I have errors that I don’t understand:
Traceback (most recent call last):
File "/datalocal1/share/fauvelm/hp-filter/simu_solve.py", line 75, in <module>
loss.backward()
File "/home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/_tensor.py", line 487, in backward
torch.autograd.backward(
File "/home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/autograd/__init__.py", line 200, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/autograd/function.py", line 274, in apply
return user_fn(self, *args)
File "/datalocal1/share/fauvelm/hp-filter/hpfilter/banded_linear_solver.py", line 50, in backward
-einsum(grad, Z, "b tu c, b tu c -> "), "->1"
RuntimeError: CUDA error: invalid argument
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
Exception raised from c10_cuda_check_implementation at ../c10/cuda/CUDAException.cpp:44 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x57 (0x7fe5faccf4d7 in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x64 (0x7fe5fac9936b in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #2: c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, int, bool) + 0x118 (0x7fe5fad73fa8 in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libc10_cuda.so)
frame #3: void at::native::gpu_kernel_impl<__nv_hdl_wrapper_t<false, true, __nv_dl_tag<void (*)(at::TensorIteratorBase&), &at::native::neg_kernel_cuda, 6u>, double (double)> >(at::TensorIteratorBase&, __nv_hdl_wrapper_t<false, true, __nv_dl_tag<void (*)(at::TensorIteratorBase&), &at::native::neg_kernel_cuda, 6u>, double (double)> const&) + 0xc98 (0x7fe5fd4e9488 in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #4: void at::native::gpu_kernel<__nv_hdl_wrapper_t<false, true, __nv_dl_tag<void (*)(at::TensorIteratorBase&), &at::native::neg_kernel_cuda, 6u>, double (double)> >(at::TensorIteratorBase&, __nv_hdl_wrapper_t<false, true, __nv_dl_tag<void (*)(at::TensorIteratorBase&), &at::native::neg_kernel_cuda, 6u>, double (double)> const&) + 0x11b (0x7fe5fd4e98eb in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #5: at::native::neg_kernel_cuda(at::TensorIteratorBase&) + 0x200 (0x7fe5fd4c8800 in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #6: <unknown function> + 0x2b8a3f8 (0x7fe5fd9293f8 in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #7: <unknown function> + 0x2b8a49d (0x7fe5fd92949d in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #8: at::_ops::neg::redispatch(c10::DispatchKeySet, at::Tensor const&) + 0x6b (0x7fe62325337b in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so)
frame #9: <unknown function> + 0x3f481dc (0x7fe624da71dc in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so)
frame #10: <unknown function> + 0x3f486a0 (0x7fe624da76a0 in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so)
frame #11: at::_ops::neg::call(at::Tensor const&) + 0x12b (0x7fe6232a15eb in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so)
frame #12: <unknown function> + 0x48a342 (0x7fe63a1c2342 in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_python.so)
<omitting python frames>
frame #20: torch::autograd::PyNode::apply(std::vector<at::Tensor, std::allocator<at::Tensor> >&&) + 0x1f3 (0x7fe63a46da63 in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_python.so)
frame #21: <unknown function> + 0x49e847b (0x7fe62584747b in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so)
frame #22: torch::autograd::Engine::evaluate_function(std::shared_ptr<torch::autograd::GraphTask>&, torch::autograd::Node*, torch::autograd::InputBuffer&, std::shared_ptr<torch::autograd::ReadyQueue> const&) + 0xe8d (0x7fe62584088d in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so)
frame #23: torch::autograd::Engine::thread_main(std::shared_ptr<torch::autograd::GraphTask> const&) + 0x6b0 (0x7fe625841c00 in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so)
frame #24: torch::autograd::Engine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool) + 0x8b (0x7fe62583893b in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so)
frame #25: torch::autograd::python::PythonEngine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool) + 0x5c (0x7fe63a4678ec in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_python.so)
frame #26: <unknown function> + 0xd3e95 (0x7fe671ae6e95 in /home/fauvelm/.conda/envs/mtan_classif/bin/../lib/libstdc++.so.6)
frame #27: <unknown function> + 0x7ea5 (0x7fe6f2edcea5 in /lib64/libpthread.so.0)
frame #28: clone + 0x6d (0x7fe6f22f4b0d in /lib64/libc.so.6)
It seems that the tensors Z
is problematic but I do not understand why it goes well for CPU and not for GPU.
Here some additional info about my configuration:
(mtan_classif) [fauvelm@cesbiocalc2 hp-filter]$ python -m torch.utils.collect_env
Collecting environment information...
PyTorch version: 2.0.1+cu117
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A
OS: CentOS Linux 7 (Core) (x86_64)
GCC version: (conda-forge gcc 9.5.0-19) 9.5.0
Clang version: Could not collect
CMake version: version 3.28.1
Libc version: glibc-2.17
Python version: 3.10.9 | packaged by conda-forge | (main, Feb 2 2023, 20:20:04) [GCC 11.3.0] (64-bit runtime)
Python platform: Linux-3.10.0-1160.88.1.el7.x86_64-x86_64-with-glibc2.17
Is CUDA available: True
CUDA runtime version: 11.6.55
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: Tesla V100-PCIE-32GB
Nvidia driver version: 460.106.00
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture : x86_64
Mode(s) opératoire(s) des processeurs : 32-bit, 64-bit
Boutisme : Little Endian
Processeur(s) : 48
Liste de processeur(s) en ligne : 0-47
Thread(s) par cœur : 2
Cœur(s) par socket : 12
Socket(s) : 2
Nœud(s) NUMA : 2
Identifiant constructeur : GenuineIntel
Famille de processeur : 6
Modèle : 85
Nom de modèle : Intel(R) Xeon(R) Gold 6136 CPU @ 3.00GHz
Révision : 4
Vitesse du processeur en MHz : 1199.890
CPU max MHz: 3700,0000
CPU min MHz: 1200,0000
BogoMIPS : 6000.00
Virtualisation : VT-x
Cache L1d : 32K
Cache L1i : 32K
Cache L2 : 1024K
Cache L3 : 25344K
Nœud NUMA 0 de processeur(s) : 0,2,4,6,8,10,12,14,16,18,20,22,24,26,28,30,32,34,36,38,40,42,44,46
Nœud NUMA 1 de processeur(s) : 1,3,5,7,9,11,13,15,17,19,21,23,25,27,29,31,33,35,37,39,41,43,45,47
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc aperfmperf eagerfpu pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch epb cat_l3 cdp_l3 invpcid_single intel_ppin ssbd mba rsb_ctxsw ibrs ibpb stibp tpr_shadow vnmi flexpriority ept vpid fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm cqm mpx rdt_a avx512f avx512dq rdseed adx smap clflushopt clwb intel_pt avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local dtherm ida arat pln pts pku ospke md_clear spec_ctrl intel_stibp flush_l1d arch_capabilities
Versions of relevant libraries:
[pip3] mypy==1.8.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.3
[pip3] pylsp-mypy==0.6.8
[pip3] pytorch-lightning==2.0.6
[pip3] torch==2.0.1
[pip3] torchmetrics==1.3.0.post0
[pip3] torchmuntan==0.0.0
[pip3] torchvision==0.15.2
[conda] numpy 1.26.3 pypi_0 pypi
[conda] pytorch-lightning 2.0.6 pypi_0 pypi
[conda] torch 2.0.1 pypi_0 pypi
[conda] torchmetrics 1.3.0.post0 pypi_0 pypi
[conda] torchmuntan 0.0.0 pypi_0 pypi
[conda] torchvision 0.15.2 pypi_0 pypi
and
(mtan_classif) [fauvelm@cesbiocalc2 hp-filter]$ nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2021 NVIDIA Corporation
Built on Fri_Dec_17_18:16:03_PST_2021
Cuda compilation tools, release 11.6, V11.6.55
Build cuda_11.6.r11.6/compiler.30794723_0