Implementing custom backward for banded system solver

I try to implement a solver for a banded system (using torch.linalg.solve is not possible because of memory issue).
I write a custom autograd function like this:

"""Implement banded linear solver using torch autograd"""

import torch
from torch.autograd import Function
from einops import rearrange, repeat, einsum

if torch.cuda.is_available():
    from hpfilter.cholesky_banded import (
        cholesky_banded_solver_cpu,
        cholesky_banded_solver_cuda,
    )
else:
    from hpfilter.cholesky_banded import cholesky_banded_solver_cpu


class BandedLinearSolver(Function):  # pylint: disable=W0223
    """Class implementing the BandedLinearSolver with autograd"""

    @staticmethod
    def forward(
        ctx, alpha: torch.Tensor, DTD: torch.Tensor, W: torch.tensor, X: torch.Tensor
    ):
        assert X.dtype == torch.float64
        X_ = X.detach().clone()

        # Solve Banded Problem
        # cholesky_bande_solver solves in place: X_ is modified after the function call
        if X.is_cuda:
            cholesky_banded_solver_cuda(DTD, W, X_, alpha.item())
        else:
            cholesky_banded_solver_cpu(DTD, W, X_, alpha.item())

        # Save for backward
        ctx.mark_non_differentiable(W, DTD, alpha)
        ctx.save_for_backward(X_, W, DTD, alpha)
        return X_

    @staticmethod
    def backward(ctx, grad: torch.Tensor):
        # Load save tensor
        X_, W, DTD, alpha = ctx.saved_tensors
        Z = torch.bmm(repeat(DTD, "u v->b u v", b=X_.shape[0]), X_)

        if X_.is_cuda:
            cholesky_banded_solver_cuda(DTD, W, Z, alpha.item())
        else:
            cholesky_banded_solver_cpu(DTD, W, Z, alpha.item())

        grad_alpha = rearrange(
            -einsum(grad, Z, "b tu c, b tu c -> "), "->1"
        )  # need 1 dim tensor
        return (grad_alpha, None, None, None)

the cholesky_banded_solver_cpu/_cuda function are implemeted here and have been tested w.r.t standard solver (test can be seen from the link above): similar results are obtained.

The forward pass is ok for both cpu and gpu. However, when I check the backward pass using the gradcheck pytorch function it works well for CPU, but for GPU I have errors that I don’t understand:

Traceback (most recent call last):
  File "/datalocal1/share/fauvelm/hp-filter/simu_solve.py", line 75, in <module>
    loss.backward()
  File "/home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/_tensor.py", line 487, in backward
    torch.autograd.backward(
  File "/home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/autograd/__init__.py", line 200, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/autograd/function.py", line 274, in apply
    return user_fn(self, *args)
  File "/datalocal1/share/fauvelm/hp-filter/hpfilter/banded_linear_solver.py", line 50, in backward
    -einsum(grad, Z, "b tu c, b tu c -> "), "->1"
RuntimeError: CUDA error: invalid argument
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Exception raised from c10_cuda_check_implementation at ../c10/cuda/CUDAException.cpp:44 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x57 (0x7fe5faccf4d7 in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x64 (0x7fe5fac9936b in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #2: c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, int, bool) + 0x118 (0x7fe5fad73fa8 in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libc10_cuda.so)
frame #3: void at::native::gpu_kernel_impl<__nv_hdl_wrapper_t<false, true, __nv_dl_tag<void (*)(at::TensorIteratorBase&), &at::native::neg_kernel_cuda, 6u>, double (double)> >(at::TensorIteratorBase&, __nv_hdl_wrapper_t<false, true, __nv_dl_tag<void (*)(at::TensorIteratorBase&), &at::native::neg_kernel_cuda, 6u>, double (double)> const&) + 0xc98 (0x7fe5fd4e9488 in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #4: void at::native::gpu_kernel<__nv_hdl_wrapper_t<false, true, __nv_dl_tag<void (*)(at::TensorIteratorBase&), &at::native::neg_kernel_cuda, 6u>, double (double)> >(at::TensorIteratorBase&, __nv_hdl_wrapper_t<false, true, __nv_dl_tag<void (*)(at::TensorIteratorBase&), &at::native::neg_kernel_cuda, 6u>, double (double)> const&) + 0x11b (0x7fe5fd4e98eb in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #5: at::native::neg_kernel_cuda(at::TensorIteratorBase&) + 0x200 (0x7fe5fd4c8800 in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #6: <unknown function> + 0x2b8a3f8 (0x7fe5fd9293f8 in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #7: <unknown function> + 0x2b8a49d (0x7fe5fd92949d in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #8: at::_ops::neg::redispatch(c10::DispatchKeySet, at::Tensor const&) + 0x6b (0x7fe62325337b in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so)
frame #9: <unknown function> + 0x3f481dc (0x7fe624da71dc in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so)
frame #10: <unknown function> + 0x3f486a0 (0x7fe624da76a0 in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so)
frame #11: at::_ops::neg::call(at::Tensor const&) + 0x12b (0x7fe6232a15eb in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so)
frame #12: <unknown function> + 0x48a342 (0x7fe63a1c2342 in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_python.so)
<omitting python frames>
frame #20: torch::autograd::PyNode::apply(std::vector<at::Tensor, std::allocator<at::Tensor> >&&) + 0x1f3 (0x7fe63a46da63 in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_python.so)
frame #21: <unknown function> + 0x49e847b (0x7fe62584747b in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so)
frame #22: torch::autograd::Engine::evaluate_function(std::shared_ptr<torch::autograd::GraphTask>&, torch::autograd::Node*, torch::autograd::InputBuffer&, std::shared_ptr<torch::autograd::ReadyQueue> const&) + 0xe8d (0x7fe62584088d in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so)
frame #23: torch::autograd::Engine::thread_main(std::shared_ptr<torch::autograd::GraphTask> const&) + 0x6b0 (0x7fe625841c00 in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so)
frame #24: torch::autograd::Engine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool) + 0x8b (0x7fe62583893b in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so)
frame #25: torch::autograd::python::PythonEngine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool) + 0x5c (0x7fe63a4678ec in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_python.so)
frame #26: <unknown function> + 0xd3e95 (0x7fe671ae6e95 in /home/fauvelm/.conda/envs/mtan_classif/bin/../lib/libstdc++.so.6)
frame #27: <unknown function> + 0x7ea5 (0x7fe6f2edcea5 in /lib64/libpthread.so.0)
frame #28: clone + 0x6d (0x7fe6f22f4b0d in /lib64/libc.so.6)

It seems that the tensors Z is problematic but I do not understand why it goes well for CPU and not for GPU.

Here some additional info about my configuration:

(mtan_classif) [fauvelm@cesbiocalc2 hp-filter]$ python -m torch.utils.collect_env
Collecting environment information...
PyTorch version: 2.0.1+cu117
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A

OS: CentOS Linux 7 (Core) (x86_64)
GCC version: (conda-forge gcc 9.5.0-19) 9.5.0
Clang version: Could not collect
CMake version: version 3.28.1
Libc version: glibc-2.17

Python version: 3.10.9 | packaged by conda-forge | (main, Feb  2 2023, 20:20:04) [GCC 11.3.0] (64-bit runtime)
Python platform: Linux-3.10.0-1160.88.1.el7.x86_64-x86_64-with-glibc2.17
Is CUDA available: True
CUDA runtime version: 11.6.55
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: Tesla V100-PCIE-32GB
Nvidia driver version: 460.106.00
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture :        x86_64
Mode(s) opératoire(s) des processeurs : 32-bit, 64-bit
Boutisme :            Little Endian
Processeur(s) :       48
Liste de processeur(s) en ligne : 0-47
Thread(s) par cœur : 2
Cœur(s) par socket : 12
Socket(s) :           2
Nœud(s) NUMA :       2
Identifiant constructeur : GenuineIntel
Famille de processeur : 6
Modèle :             85
Nom de modèle :      Intel(R) Xeon(R) Gold 6136 CPU @ 3.00GHz
Révision :           4
Vitesse du processeur en MHz : 1199.890
CPU max MHz:           3700,0000
CPU min MHz:           1200,0000
BogoMIPS :            6000.00
Virtualisation :      VT-x
Cache L1d :           32K
Cache L1i :           32K
Cache L2 :            1024K
Cache L3 :            25344K
Nœud NUMA 0 de processeur(s) : 0,2,4,6,8,10,12,14,16,18,20,22,24,26,28,30,32,34,36,38,40,42,44,46
Nœud NUMA 1 de processeur(s) : 1,3,5,7,9,11,13,15,17,19,21,23,25,27,29,31,33,35,37,39,41,43,45,47
Flags:                 fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc aperfmperf eagerfpu pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch epb cat_l3 cdp_l3 invpcid_single intel_ppin ssbd mba rsb_ctxsw ibrs ibpb stibp tpr_shadow vnmi flexpriority ept vpid fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm cqm mpx rdt_a avx512f avx512dq rdseed adx smap clflushopt clwb intel_pt avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local dtherm ida arat pln pts pku ospke md_clear spec_ctrl intel_stibp flush_l1d arch_capabilities

Versions of relevant libraries:
[pip3] mypy==1.8.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.3
[pip3] pylsp-mypy==0.6.8
[pip3] pytorch-lightning==2.0.6
[pip3] torch==2.0.1
[pip3] torchmetrics==1.3.0.post0
[pip3] torchmuntan==0.0.0
[pip3] torchvision==0.15.2
[conda] numpy                     1.26.3                   pypi_0    pypi
[conda] pytorch-lightning         2.0.6                    pypi_0    pypi
[conda] torch                     2.0.1                    pypi_0    pypi
[conda] torchmetrics              1.3.0.post0              pypi_0    pypi
[conda] torchmuntan               0.0.0                    pypi_0    pypi
[conda] torchvision               0.15.2                   pypi_0    pypi

and

(mtan_classif) [fauvelm@cesbiocalc2 hp-filter]$ nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2021 NVIDIA Corporation
Built on Fri_Dec_17_18:16:03_PST_2021
Cuda compilation tools, release 11.6, V11.6.55
Build cuda_11.6.r11.6/compiler.30794723_0

Could you isolate the failing call and post the tensor shapes to reproduce the issue, please?

Dear @ptrblck, thanks for your comments. Tell me if the code below is more clear

import os
import torch
from hpfilter.banded_linear_solver import BandedLinearSolver
from einops import repeat

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["TORCH_SHOW_CPP_STACKTRACES"] = "1"

# Init parameter
dtype = torch.float64
batch_size = 16
n_times = 5
n_channels = 2
cpu_device = torch.device("cpu")
gpu_device = torch.device("cuda", 0)

# Init data on CPU
dt_d = torch.Tensor(
    [
        [5.9113e-02, -9.8796e-02, 3.9683e-02, 0.0000e00, 0.0000e00],
        [-9.8796e-02, 2.1075e-01, -2.7670e-01, 1.6475e-01, 0.0000e00],
        [3.9683e-02, -2.7670e-01, 1.0305e01, -6.6333e01, 5.6265e01],
        [0.0000e00, 1.6475e-01, -6.6333e01, 4.6251e02, -3.9634e02],
        [0.0000e00, 0.0000e00, 5.6265e01, -3.9634e02, 3.4008e02],
    ],
    device=cpu_device,
).to(dtype=dtype)
alpha = torch.tensor([100.0], requires_grad=True)

mask = (torch.rand(batch_size, n_times, dtype=dtype, device=cpu_device) > 0.125).to(
    dtype
)
data = torch.randn(batch_size, n_times, n_channels, dtype=dtype, device=cpu_device)

# Solve [diag(mask) + lamba*dt_d]^{-1}\times data
# with torch.linalg solve
A = repeat(alpha * dt_d, "u v -> b u v", b=batch_size) + torch.diag_embed(mask)
C_torch = torch.linalg.solve(A, data)


# Solve same problem using BandedLinearSolver on cpu
solver = BandedLinearSolver.apply
C_cpu = solver(alpha, dt_d, mask, data)

print(torch.allclose(C_torch, C_cpu))
# Perform backward
accu_cpu = C_cpu.sum()
accu_cpu.backward()

# Solve same problem using BandedLinearSolver on gpu
C_gpu = solver(
    alpha.to(gpu_device), dt_d.to(gpu_device), mask.to(gpu_device), data.to(gpu_device)
)

print(torch.allclose(C_torch, C_gpu.to(cpu_device)))

# Perform backward
accu_gpu = C_gpu.sum()
accu_gpu.backward()

I got the following output (which is a little bit different from what I get with gradcheck):

True
True
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[1], line 59
     57 # Perform backward
     58 accu_gpu = C_gpu.sum()
---> 59 accu_gpu.backward()

File ~/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/_tensor.py:487, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
    477 if has_torch_function_unary(self):
    478     return handle_torch_function(
    479         Tensor.backward,
    480         (self,),
   (...)
    485         inputs=inputs,
    486     )
--> 487 torch.autograd.backward(
    488     self, gradient, retain_graph, create_graph, inputs=inputs
    489 )

File ~/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/autograd/__init__.py:200, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    195     retain_graph = create_graph
    197 # The reason we repeat same the comment below is that
    198 # some Python versions print out the first line of a multi-line function
    199 # calls in the traceback and some print out the last line
--> 200 Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    201     tensors, grad_tensors_, retain_graph, create_graph, inputs,
    202     allow_unreachable=True, accumulate_grad=True)

File ~/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/autograd/function.py:274, in BackwardCFunction.apply(self, *args)
    270     raise RuntimeError("Implementing both 'backward' and 'vjp' for a custom "
    271                        "Function is not allowed. You should only implement one "
    272                        "of them.")
    273 user_fn = vjp_fn if vjp_fn is not Function.vjp else backward_fn
--> 274 return user_fn(self, *args)

File /datalocal1/share/fauvelm/hp-filter/hpfilter/banded_linear_solver.py:52, in BandedLinearSolver.backward(ctx, grad)
     48 else:
     49     cholesky_banded_solver_cpu(DTD, W, Z, alpha.item())
     51 grad_alpha = rearrange(
---> 52     -einsum(grad, Z, "b tu c, b tu c -> "), "->1"
     53 )  # need 1 dim tensor
     54 return (grad_alpha, None, None, None)

File ~/.conda/envs/mtan_classif/lib/python3.10/site-packages/einops/einops.py:901, in einsum(*tensors_and_pattern)
    899 tensors = tensors_and_pattern[:-1]
    900 pattern = _compactify_pattern_for_einsum(pattern)
--> 901 return get_backend(tensors[0]).einsum(pattern, *tensors)

File ~/.conda/envs/mtan_classif/lib/python3.10/site-packages/einops/_backends.py:287, in TorchBackend.einsum(self, pattern, *x)
    286 def einsum(self, pattern, *x):
--> 287     return self.torch.einsum(pattern, *x)

File ~/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/functional.py:378, in einsum(*args)
    373     return einsum(equation, *_operands)
    375 if len(operands) <= 2 or not opt_einsum.enabled:
    376     # the path for contracting 0 or 1 time(s) is already optimized
    377     # or the user has disabled using opt_einsum
--> 378     return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
    380 path = None
    381 if opt_einsum.is_available():

RuntimeError: CUDA error: invalid argument
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Exception raised from c10_cuda_check_implementation at ../c10/cuda/CUDAException.cpp:44 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x57 (0x7fd11e8b84d7 in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x64 (0x7fd11e88236b in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #2: c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, int, bool) + 0x118 (0x7fd1240c3fa8 in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libc10_cuda.so)
frame #3: void at::native::gpu_kernel_impl<__nv_hdl_wrapper_t<false, true, __nv_dl_tag<void (*)(at::TensorIteratorBase&), &at::native::direct_copy_kernel_cuda, 9u>, double (double)> >(at::TensorIteratorBase&, __nv_hdl_wrapper_t<false, true, __nv_dl_tag<void (*)(at::TensorIteratorBase&), &at::native::direct_copy_kernel_cuda, 9u>, double (double)> const&) + 0x676 (0x7fd0a68fa286 in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #4: void at::native::gpu_kernel<__nv_hdl_wrapper_t<false, true, __nv_dl_tag<void (*)(at::TensorIteratorBase&), &at::native::direct_copy_kernel_cuda, 9u>, double (double)> >(at::TensorIteratorBase&, __nv_hdl_wrapper_t<false, true, __nv_dl_tag<void (*)(at::TensorIteratorBase&), &at::native::direct_copy_kernel_cuda, 9u>, double (double)> const&) + 0x11b (0x7fd0a68fad0b in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #5: at::native::direct_copy_kernel_cuda(at::TensorIteratorBase&) + 0x318 (0x7fd0a68e43a8 in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #6: at::native::copy_device_to_device(at::TensorIterator&, bool, bool) + 0xd25 (0x7fd0a68e5165 in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #7: <unknown function> + 0x152d526 (0x7fd0a68e6526 in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #8: <unknown function> + 0x1c53840 (0x7fd0cd0cc840 in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so)
frame #9: at::native::copy_(at::Tensor&, at::Tensor const&, bool) + 0x62 (0x7fd0cd0cd9d2 in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so)
frame #10: at::_ops::copy_::call(at::Tensor&, at::Tensor const&, bool) + 0x15f (0x7fd0cdd3c97f in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so)
frame #11: at::native::clone(at::Tensor const&, c10::optional<c10::MemoryFormat>) + 0x1c7 (0x7fd0cd3f0ac7 in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so)
frame #12: <unknown function> + 0x2c29920 (0x7fd0ce0a2920 in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so)
frame #13: at::_ops::clone::call(at::Tensor const&, c10::optional<c10::MemoryFormat>) + 0x136 (0x7fd0cda8b306 in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so)
frame #14: <unknown function> + 0x2df169e (0x7fd0a81aa69e in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #15: <unknown function> + 0x2df1c9f (0x7fd0a81aac9f in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #16: at::native::structured_bmm_out_cuda::impl(at::Tensor const&, at::Tensor const&, at::Tensor const&) + 0x62 (0x7fd0a81ac132 in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #17: <unknown function> + 0x2af6a6d (0x7fd0a7eafa6d in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #18: <unknown function> + 0x2af6af0 (0x7fd0a7eafaf0 in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #19: at::_ops::bmm::redispatch(c10::DispatchKeySet, at::Tensor const&, at::Tensor const&) + 0x6e (0x7fd0cdcd84be in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so)
frame #20: <unknown function> + 0x418c78c (0x7fd0cf60578c in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so)
frame #21: <unknown function> + 0x418d193 (0x7fd0cf606193 in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so)
frame #22: at::_ops::bmm::call(at::Tensor const&, at::Tensor const&) + 0x161 (0x7fd0cdd2bdb1 in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so)
frame #23: <unknown function> + 0x1d146b8 (0x7fd0cd18d6b8 in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so)
frame #24: at::native::einsum(c10::basic_string_view<char>, c10::ArrayRef<at::Tensor>, c10::OptionalArrayRef<long>) + 0x2450 (0x7fd0cd191820 in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so)
frame #25: <unknown function> + 0x2df46f2 (0x7fd0ce26d6f2 in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so)
frame #26: at::_ops::einsum::call(c10::basic_string_view<char>, c10::ArrayRef<at::Tensor>, c10::OptionalArrayRef<long>) + 0x211 (0x7fd0cdbea511 in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so)
frame #27: <unknown function> + 0x51a061 (0x7fd124609061 in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_python.so)
frame #28: <unknown function> + 0x13fb27 (0x556796c4db27 in /home/fauvelm/.conda/envs/mtan_classif/bin/python3.10)
frame #29: _PyObject_MakeTpCall + 0x26b (0x556796c4742b in /home/fauvelm/.conda/envs/mtan_classif/bin/python3.10)
frame #30: _PyEval_EvalFrameDefault + 0x5596 (0x556796c43386 in /home/fauvelm/.conda/envs/mtan_classif/bin/python3.10)
frame #31: _PyFunction_Vectorcall + 0x6f (0x556796c4df8f in /home/fauvelm/.conda/envs/mtan_classif/bin/python3.10)
frame #32: _PyEval_EvalFrameDefault + 0x2ec2 (0x556796c40cb2 in /home/fauvelm/.conda/envs/mtan_classif/bin/python3.10)
frame #33: <unknown function> + 0x14b7a1 (0x556796c597a1 in /home/fauvelm/.conda/envs/mtan_classif/bin/python3.10)
frame #34: _PyEval_EvalFrameDefault + 0x2ec2 (0x556796c40cb2 in /home/fauvelm/.conda/envs/mtan_classif/bin/python3.10)
frame #35: _PyFunction_Vectorcall + 0x6f (0x556796c4df8f in /home/fauvelm/.conda/envs/mtan_classif/bin/python3.10)
frame #36: _PyEval_EvalFrameDefault + 0x332 (0x556796c3e122 in /home/fauvelm/.conda/envs/mtan_classif/bin/python3.10)
frame #37: _PyFunction_Vectorcall + 0x6f (0x556796c4df8f in /home/fauvelm/.conda/envs/mtan_classif/bin/python3.10)
frame #38: _PyEval_EvalFrameDefault + 0x2ec2 (0x556796c40cb2 in /home/fauvelm/.conda/envs/mtan_classif/bin/python3.10)
frame #39: <unknown function> + 0x14b7a1 (0x556796c597a1 in /home/fauvelm/.conda/envs/mtan_classif/bin/python3.10)
frame #40: torch::autograd::PyNode::apply(std::vector<at::Tensor, std::allocator<at::Tensor> >&&) + 0x1f3 (0x7fd124824a63 in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_python.so)
frame #41: <unknown function> + 0x49e847b (0x7fd0cfe6147b in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so)
frame #42: torch::autograd::Engine::evaluate_function(std::shared_ptr<torch::autograd::GraphTask>&, torch::autograd::Node*, torch::autograd::InputBuffer&, std::shared_ptr<torch::autograd::ReadyQueue> const&) + 0xe8d (0x7fd0cfe5a88d in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so)
frame #43: torch::autograd::Engine::thread_main(std::shared_ptr<torch::autograd::GraphTask> const&) + 0x6b0 (0x7fd0cfe5bc00 in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so)
frame #44: torch::autograd::Engine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool) + 0x8b (0x7fd0cfe5293b in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so)
frame #45: torch::autograd::python::PythonEngine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool) + 0x5c (0x7fd12481e8ec in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_python.so)
frame #46: <unknown function> + 0xd3e95 (0x7fd125278e95 in /home/fauvelm/.conda/envs/mtan_classif/bin/../lib/libstdc++.so.6)
frame #47: <unknown function> + 0x7ea5 (0x7fd130452ea5 in /lib64/libpthread.so.0)
frame #48: clone + 0x6d (0x7fd12f86ab0d in /lib64/libc.so.6)

The forward pass seems to be fine for both CPU and GPU, but the backward is not working for GPU. I try to change the einsum call to grad_alpha = rearrange(torch.sum(grad * Z), "->1") # need 1 dim tensor but I got the same issue:

Cell In[1], line 59
     57 # Perform backward
     58 accu_gpu = C_gpu.sum()
---> 59 accu_gpu.backward()

File ~/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/_tensor.py:487, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
    477 if has_torch_function_unary(self):
    478     return handle_torch_function(
    479         Tensor.backward,
    480         (self,),
   (...)
    485         inputs=inputs,
    486     )
--> 487 torch.autograd.backward(
    488     self, gradient, retain_graph, create_graph, inputs=inputs
    489 )

File ~/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/autograd/__init__.py:200, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    195     retain_graph = create_graph
    197 # The reason we repeat same the comment below is that
    198 # some Python versions print out the first line of a multi-line function
    199 # calls in the traceback and some print out the last line
--> 200 Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    201     tensors, grad_tensors_, retain_graph, create_graph, inputs,
    202     allow_unreachable=True, accumulate_grad=True)

File ~/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/autograd/function.py:274, in BackwardCFunction.apply(self, *args)
    270     raise RuntimeError("Implementing both 'backward' and 'vjp' for a custom "
    271                        "Function is not allowed. You should only implement one "
    272                        "of them.")
    273 user_fn = vjp_fn if vjp_fn is not Function.vjp else backward_fn
--> 274 return user_fn(self, *args)

File /datalocal1/share/fauvelm/hp-filter/hpfilter/banded_linear_solver.py:54, in BandedLinearSolver.backward(ctx, grad)
     49     cholesky_banded_solver_cpu(DTD, W, Z, alpha.item())
     51 # grad_alpha = rearrange(
     52 #     -einsum(grad, Z, "b tu c, b tu c -> "), "->1"
     53 # )  # need 1 dim tensor
---> 54 grad_alpha = rearrange(torch.sum(grad * Z), "->1")  # need 1 dim tensor
     55 return (grad_alpha, None, None, None)

RuntimeError: CUDA error: invalid argument
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Exception raised from c10_cuda_check_implementation at ../c10/cuda/CUDAException.cpp:44 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x57 (0x7f81a03304d7 in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x64 (0x7f81a02fa36b in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #2: c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, int, bool) + 0x118 (0x7f81a03d4fa8 in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libc10_cuda.so)
frame #3: void at::native::gpu_kernel_impl<at::native::BinaryFunctor<double, double, double, at::native::binary_internal::MulFunctor<double> > >(at::TensorIteratorBase&, at::native::BinaryFunctor<double, double, double, at::native::binary_internal::MulFunctor<double> > const&) + 0xb27 (0x7f81227a89b7 in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #4: void at::native::gpu_kernel<at::native::BinaryFunctor<double, double, double, at::native::binary_internal::MulFunctor<double> > >(at::TensorIteratorBase&, at::native::BinaryFunctor<double, double, double, at::native::binary_internal::MulFunctor<double> > const&) + 0x33b (0x7f81227a936b in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #5: void at::native::opmath_symmetric_gpu_kernel_with_scalars<double, double, at::native::binary_internal::MulFunctor<double> >(at::TensorIteratorBase&, at::native::binary_internal::MulFunctor<double> const&) + 0xdd (0x7f81227bf21d in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #6: at::native::mul_kernel_cuda(at::TensorIteratorBase&) + 0x281 (0x7f812279ce21 in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #7: <unknown function> + 0x2b8d773 (0x7f8123f46773 in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #8: <unknown function> + 0x2b8d820 (0x7f8123f46820 in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #9: at::_ops::mul_Tensor::redispatch(c10::DispatchKeySet, at::Tensor const&, at::Tensor const&) + 0x6e (0x7f814987027e in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so)
frame #10: <unknown function> + 0x3f4658d (0x7f814b3bf58d in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so)
frame #11: <unknown function> + 0x3f47013 (0x7f814b3c0013 in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so)
frame #12: at::_ops::mul_Tensor::call(at::Tensor const&, at::Tensor const&) + 0x161 (0x7f81498cd8b1 in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so)
frame #13: <unknown function> + 0x4ad5a6 (0x7f81a08ad5a6 in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_python.so)
frame #14: <unknown function> + 0x4ad6f7 (0x7f81a08ad6f7 in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_python.so)
frame #15: <unknown function> + 0x141146 (0x55fd0606b146 in /home/fauvelm/.conda/envs/mtan_classif/bin/python3.10)
frame #16: <unknown function> + 0x1a9d93 (0x55fd060d3d93 in /home/fauvelm/.conda/envs/mtan_classif/bin/python3.10)
frame #17: <unknown function> + 0x2342f0 (0x55fd0615e2f0 in /home/fauvelm/.conda/envs/mtan_classif/bin/python3.10)
frame #18: PyNumber_Multiply + 0x47 (0x55fd06086477 in /home/fauvelm/.conda/envs/mtan_classif/bin/python3.10)
frame #19: _PyEval_EvalFrameDefault + 0xcb7 (0x55fd0605aaa7 in /home/fauvelm/.conda/envs/mtan_classif/bin/python3.10)
frame #20: _PyFunction_Vectorcall + 0x6f (0x55fd06069f8f in /home/fauvelm/.conda/envs/mtan_classif/bin/python3.10)
frame #21: _PyEval_EvalFrameDefault + 0x2ec2 (0x55fd0605ccb2 in /home/fauvelm/.conda/envs/mtan_classif/bin/python3.10)
frame #22: <unknown function> + 0x14b7a1 (0x55fd060757a1 in /home/fauvelm/.conda/envs/mtan_classif/bin/python3.10)
frame #23: torch::autograd::PyNode::apply(std::vector<at::Tensor, std::allocator<at::Tensor> >&&) + 0x1f3 (0x7f81a0b35a63 in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_python.so)
frame #24: <unknown function> + 0x49e847b (0x7f814be6147b in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so)
frame #25: torch::autograd::Engine::evaluate_function(std::shared_ptr<torch::autograd::GraphTask>&, torch::autograd::Node*, torch::autograd::InputBuffer&, std::shared_ptr<torch::autograd::ReadyQueue> const&) + 0xe8d (0x7f814be5a88d in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so)
frame #26: torch::autograd::Engine::thread_main(std::shared_ptr<torch::autograd::GraphTask> const&) + 0x6b0 (0x7f814be5bc00 in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so)
frame #27: torch::autograd::Engine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool) + 0x8b (0x7f814be5293b in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so)
frame #28: torch::autograd::python::PythonEngine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool) + 0x5c (0x7f81a0b2f8ec in /home/fauvelm/.conda/envs/mtan_classif/lib/python3.10/site-packages/torch/lib/libtorch_python.so)
frame #29: <unknown function> + 0xd3e95 (0x7f81a1589e95 in /home/fauvelm/.conda/envs/mtan_classif/bin/../lib/libstdc++.so.6)
frame #30: <unknown function> + 0x7ea5 (0x7f81ac763ea5 in /lib64/libpthread.so.0)
frame #31: clone + 0x6d (0x7f81abb7bb0d in /lib64/libc.so.6)


I’m trying to run your code on a 64GB GPU and am running out of memory.

Tried to allocate 9.54 GiB. GPU 0 has a total capacity of 63.29 GiB of which 3.40 GiB is free.

pointing to a conv layer.
I thus doubt your code would run on your V100 32GB.

Dear @ptrblck thank you for taking time to investigate my problem. I am suprised you get memory error since only tensors of limited size are created: the biggest one is of size A.shape= [16, 5, 5]

Also, there is no conv layer in the code too, I don’t know where it comes from ?

I did run the code on the configuration mentionned above. But maybe you can try using reduce batch_size and n_channels.

I just try, got the same error reporeted above, and nvidia-smi reports a memory usage of 1,2 Gb

I think this is a cuda issue, since if replace the calls in the forward and backard method:

cholesky_banded_solver_cuda(DTD, W, X_, alpha.item())

by moving everything to CPU, use cholesky_banded_solver_cpu and move back to GPU I got no errors and my tests are ok.

Yet I don’t understand why the forward pass is OK and the backward is not OK. I surely miss something.

You might be right as I might have executed a script from another topic here.
Let me rerun it again later.