Sparse bmm cause CUDA misaligned address error

Hi everyone,
I’m new to pytorch, cuda and sparse memory format. I’m doing computation on batch sparse tensor, in this code:

import torch
from torch import Tensor

SEED = 42
# torch.random.manual_seed(SEED)

def generate_random_dataset(
min_num_categorical: int,
max_num_categorical: int,
min_groups: int,
max_groups: int,
min_rows: int,
max_rows: int,
shuffle_rows: bool,
dtype=torch.float64,
) → torch.Tensor:
def randn_scalar(low=0.0, high=1.0):
return torch.normal(low, high, size=())

def randint_scalar(low, high):
    return torch.randint(low, high, size=()).item()

# --- Covariance Matrix Setup (Numerical Columns X and Y) ---
cov_scalar = randn_scalar()
number_of_groups = randint_scalar(min_groups, max_groups)

means = torch.tensor(
    [
        randint_scalar(-5, 6),
        randint_scalar(-5, 6),
    ],
    dtype=dtype,
)
var_X = randn_scalar() * randint_scalar(1, 6)
var_Y = randn_scalar() * randint_scalar(1, 6)

# Create and "square" the matrix to ensure it's positive semi-definite
A = torch.tensor([[var_X, cov_scalar], [cov_scalar, var_Y]], dtype=dtype)
cov_matrix = A.T @ A

groups = []

for shift in range(number_of_groups):
    group_size = randint_scalar(min_rows, max_rows)
    group_xy = (
        torch.distributions.MultivariateNormal(means, cov_matrix).sample(
            (group_size,)
        )
        + shift * 0.5
    )

    # Create the Kth column (key/group ID)
    group_k = torch.full((group_size, 1), fill_value=shift, dtype=dtype)

    # Concatenate K, X, Y: [K | X | Y]
    group = torch.hstack([group_k, group_xy])
    groups.append(group)

data = torch.cat(groups, dim=0)

if max_num_categorical >= min_num_categorical > 0:
    N = data.shape[0]

    # randomly define how many categorical columns we will append
    # this number consider the basic one created above
    num_categorical = (
        randint_scalar(min_num_categorical, max_num_categorical + 1) - 1
    )

    # Generate random number of categories for each column
    # ensuring they're sorted in ascending order
    num_categories_list = sorted(
        [randint_scalar(2, number_of_groups) for _ in range(num_categorical)]
    )

    # Ensure last categorical column has <= distinct values than K column
    num_categories_list[-1] = int(
        min(
            torch.tensor(num_categories_list[-1]),
            torch.tensor(number_of_groups),
        ).item()
    )

    categorical_cols = []

    # Get the categorical data from a normal distribution
    # combined with a multinomial one
    for num_categories in num_categories_list:
        y = (
            torch.distributions.Normal(
                loc=torch.tensor([10.0]), scale=torch.tensor([5.0])
            )
            .sample((num_categories,))
            .reshape((1, -1))
        )
        y = y * torch.sign(y)
        y, _ = torch.sort(y)
        y = y / torch.norm(y)

        d = torch.multinomial(y, num_samples=N, replacement=True).reshape((-1, 1))
        categorical_cols.append(d)

    # Prepend categorical columns to data
    categorical_data = torch.hstack(categorical_cols)
    categorical_data = categorical_data.to(dtype=dtype)
    data = torch.hstack([categorical_data, data])

if shuffle_rows:
    indices = torch.randperm(data.shape[0])
    data = data[indices]
return data

def t_create_batch_index_matrix_sparse(D: Tensor, dtype=torch.float64) → Tensor:
# B: number of categorical columns
# N: number of records
# K: number of groups (max. number of unique elements among all categorical columns)
N, B = D.shape
K = D.unique(sorted=False).shape\[0\]

batch_idx = torch.arange(B, device=D.device).repeat_interleave(N)
row_idx = torch.arange(N, device=D.device).repeat(B)
column_idx = D.T.flatten()

indices = torch.stack([batch_idx, row_idx, column_idx])
values = torch.ones(B * N, device=D.device)
size = torch.Size([B, N, K])

G = torch.sparse_coo_tensor(
    indices=indices, values=values, size=size, dtype=dtype, device=D.device
).coalesce()

return G

def proc_batch_matrix_sparse(G: Tensor, X: Tensor, Y: Tensor) → Tensor:
B, N, K = G.shape

Xb = X.unsqueeze(0).expand(B, -1, -1).transpose(1, 2)
Yb = Y.unsqueeze(0).expand(B, -1, -1).transpose(1, 2)

Gt = G.transpose(1, 2)
GtX = torch.bmm(Gt, Xb)
return GtX.to("cpu")

if name == “main”:
DTYPE = torch.float64
GPU = True
NUMBER_OF_TESTS = 10

MIN_NUM_CATEGORICAL, MAX_NUM_CATEGORICAL = 50, 50
MIN_GROUPS, MAX_GROUPS = 10, 100
MIN_GROUP_ROWS, MAX_GROUP_ROWS = 50, 100

device = "cuda" if GPU and torch.cuda.is_available() else "cpu"

for i in range(NUMBER_OF_TESTS):
    print(f" Run {i} ".center(100, "="))
    data = generate_random_dataset(
        MIN_NUM_CATEGORICAL,
        MAX_NUM_CATEGORICAL,
        MIN_GROUPS,
        MAX_GROUPS,
        MIN_GROUP_ROWS,
        MAX_GROUP_ROWS,
        shuffle_rows=True,
        dtype=DTYPE,
    ).to(device)

    D = data[:, :-2]  # batch of "categorical" columns [NxB]
    X = data[:, -2].reshape((1, -1))
    Y = data[:, -1].reshape((1, -1))

    G = t_create_batch_index_matrix_sparse(D, dtype=DTYPE)
    proc_batch_matrix_sparse(G, X, Y)

I create a random dataset (generate_random_dataset), take the last two columns as X and Y and the others are transformed into a sparse batch coo tensor of one hot encoded matrices. (create_batch_matrix_index_sparse) and pass these data to actual computation (proc_batch_matrix_sparse).

I encounter this error:

torch.AcceleratorError: CUDA error: misaligned address
Search for cudaErrorMisalignedAddress' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information. Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.

when computing batch matrix-matrix in proc_batch_matrix_sparse.
I’ve checked the torch.sparse doc, and both Gt and Xb should satisfy the desired shapes and layouts. The error doesn’t occur at every run, and I have not detected specific conditions that may cause it, except that it happens more often with higher number of dataset rows. Moving G to dense seems to solve, but this is not desired (and feasible) for large inputs.

Spec

I’ve ran tests with these two systems:

  • GeForce RTX 4090, CUDA 12.2, Driver 535.104.05, torch 2.9;
  • Tesla T4, CUDA 13.0, Driver 580.95.05, torch 2.9.

Update

Running this on single matrices in the batch (with torch.sparse.mm) and then stacking results works fine, but a loop on batch index is required;
Output of compute-sanitizer is a long list of:

========= Invalid __global__ read of size 16 bytes
=========     at void cusparse::coomv_kernel<(bool)0, int, double, double, double, double>(cusparse::KernelCoeffs<T6>, T2, const T2 *, const T2 *, const T3 *, const T4 *, T5 *, T2 *, T6 *)+0x2b0
=========     by thread (32,0,0) in block (0,0,0)
=========     Access to 0x7f1fa52e2f48 is misaligned
=========     and is inside the nearest allocation at 0x7f1fa4000000 of size 20,971,520 bytes
=========     Saved host backtrace up to driver entry point at kernel launch time
=========         Host Frame:  [0xa0e735] in libcusparse.so.12
=========         Host Frame:  [0xa74c77] in libcusparse.so.12
=========         Host Frame:  [0x1b4d59] in libcusparse.so.12
=========         Host Frame:  [0x1c5044] in libcusparse.so.12
=========         Host Frame: cusparseSpMM [0xfb023] in libcusparse.so.12
=========         Host Frame: at::native::bmm_out_sparse_cuda(at::Tensor const&, at::Tensor const&, at::Tensor&)::{lambda()#1}::operator()() const::{lambda()#1}::operator()() const [0x2f49e33] in libtorch_cuda.so
=========         Host Frame: at::native::bmm_out_sparse_cuda(at::Tensor const&, at::Tensor const&, at::Tensor&) [0x2f4b373] in libtorch_cuda.so
=========         Host Frame: at::native::bmm_sparse_cuda(at::Tensor const&, at::Tensor const&) [0x2f4d36f] in libtorch_cuda.so
=========         Host Frame: at::(anonymous namespace)::(anonymous namespace)::wrapper_SparseCUDA__bmm(at::Tensor const&, at::Tensor const&) [0x3536c1b] in libtorch_cuda.so
=========         Host Frame: c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (at::Tensor const&, at::Tensor const&), &at::(anonymous namespace)::(anonymous namespace)::wrapper_SparseCUDA__bmm>, at::Tensor, c10::guts::typelist::typelist<at::Tensor const&, at::Tensor const&> >, at::Tensor (at::Tensor const&, at::Tensor const&)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, at::Tensor const&) [0x3536c9e] in libtorch_cuda.so
=========         Host Frame: at::_ops::bmm::redispatch(c10::DispatchKeySet, at::Tensor const&, at::Tensor const&) [0x27e8e88] in libtorch_cpu.so
=========         Host Frame: torch::autograd::VariableType::(anonymous namespace)::bmm(c10::DispatchKeySet, at::Tensor const&, at::Tensor const&) [0x4d5de6a] in libtorch_cpu.so
=========         Host Frame: c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (c10::DispatchKeySet, at::Tensor const&, at::Tensor const&), &torch::autograd::VariableType::(anonymous namespace)::bmm>, at::Tensor, c10::guts::typelist::typelist<c10::DispatchKeySet, at::Tensor const&, at::Tensor const&> >, at::Tensor (c10::DispatchKeySet, at::Tensor const&, at::Tensor const&)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, at::Tensor const&) [0x4d5e421] in libtorch_cpu.so
=========         Host Frame: at::_ops::bmm::call(at::Tensor const&, at::Tensor const&) [0x2829c6b] in libtorch_cpu.so
=========         Host Frame: torch::autograd::THPVariable_bmm(_object*, _object*, _object*) [0x59918e] in libtorch_python.so
=========         Host Frame: cfunction_call in methodobject.c:537 [0x143943] in python
=========         Host Frame: _PyObject_MakeTpCall in call.c:240 [0x11778b] in python
=========         Host Frame: _PyEval_EvalFrameDefault in bytecodes.c:2715 [0x121951] in python
=========         Host Frame: PyEval_EvalCode in ceval.c:580 [0x1de5cd] in python
=========         Host Frame: run_eval_code_obj in pythonrun.c:1757 [0x21b7b6] in python
=========         Host Frame: run_mod in pythonrun.c:1778 [0x216306] in python
=========         Host Frame: pyrun_file in pythonrun.c:1674 [0x2131c1] in python
=========         Host Frame: _PyRun_SimpleFileObject in pythonrun.c:459 [0x212d7f] in python
=========         Host Frame: _PyRun_AnyFileObject in pythonrun.c:78 [0x212882] in python
=========         Host Frame: Py_RunMain in main.c:714 [0x20f6c6] in python
=========         Host Frame: Py_BytesMain in main.c:768 [0x1c6bb8] in python
=========         Host Frame:  [0x27249] in libc.so.6
=========         Host Frame: __libc_start_main [0x27304] in libc.so.6
=========         Host Frame:  [0x1c69e8] in python
=========         Host Frame: proc_batch_matrix_sparse in myfile.py:148
=========         Host Frame: <module> in myfile.py:191

Thank you for reporting this issue! Do you see the same error using the latest stable or nightly binary with CUDA 13.0?

The error seems to occur with both latest stable and nightly with CUDA 13.0.
Also, I’ve just found this github issue which is related https://github.com/pytorch/pytorch/issues/119076

And if this can be useful, I’ve made another shorter example with similar workflow to reproduce the error below. For what I have tested, pad X with zeros in dim 2 as suggested in the referenced github issue helps to bypass the error.

import torch
from torch import Tensor

torch.manual_seed(42)

def proc_batch_matrix_sparse(G: Tensor, X: Tensor) -> Tensor:
    Gt = G.transpose(1, 2)
    GtX = torch.bmm(Gt, X)
    return GtX.to("cpu")

if __name__ == "__main__":
    DTYPE = torch.float64
    GPU = True
    device = "cuda" if GPU and torch.cuda.is_available() else "cpu"

    torch.set_default_dtype(DTYPE)
    torch.set_default_device(device)

    G_shape = (3, 2779, 4)
    X_shape = (3, 2779, 1)
    density = 0.25

    X = torch.distributions.Normal(100, 50).sample(X_shape)
    G = torch.rand(G_shape)
    mask = G < density
    G[mask] = 1
    G[~mask] = 0
    G_density = G.sum() / (G_shape[0] * G_shape[1] * G_shape[2])

    G: Tensor = G.to_sparse_coo()

    print(f"{G.shape=}, {G.dtype=}, {G.device=}, {G.is_sparse=}, {G.is_coalesced()=}")
    proc_batch_matrix_sparse(G, X)