Hi everyone,
I’m new to pytorch, cuda and sparse memory format. I’m doing computation on batch sparse tensor, in this code:
import torch
from torch import Tensor
SEED = 42
# torch.random.manual_seed(SEED)
def generate_random_dataset(
min_num_categorical: int,
max_num_categorical: int,
min_groups: int,
max_groups: int,
min_rows: int,
max_rows: int,
shuffle_rows: bool,
dtype=torch.float64,
) → torch.Tensor:
def randn_scalar(low=0.0, high=1.0):
return torch.normal(low, high, size=())
def randint_scalar(low, high):
return torch.randint(low, high, size=()).item()
# --- Covariance Matrix Setup (Numerical Columns X and Y) ---
cov_scalar = randn_scalar()
number_of_groups = randint_scalar(min_groups, max_groups)
means = torch.tensor(
[
randint_scalar(-5, 6),
randint_scalar(-5, 6),
],
dtype=dtype,
)
var_X = randn_scalar() * randint_scalar(1, 6)
var_Y = randn_scalar() * randint_scalar(1, 6)
# Create and "square" the matrix to ensure it's positive semi-definite
A = torch.tensor([[var_X, cov_scalar], [cov_scalar, var_Y]], dtype=dtype)
cov_matrix = A.T @ A
groups = []
for shift in range(number_of_groups):
group_size = randint_scalar(min_rows, max_rows)
group_xy = (
torch.distributions.MultivariateNormal(means, cov_matrix).sample(
(group_size,)
)
+ shift * 0.5
)
# Create the Kth column (key/group ID)
group_k = torch.full((group_size, 1), fill_value=shift, dtype=dtype)
# Concatenate K, X, Y: [K | X | Y]
group = torch.hstack([group_k, group_xy])
groups.append(group)
data = torch.cat(groups, dim=0)
if max_num_categorical >= min_num_categorical > 0:
N = data.shape[0]
# randomly define how many categorical columns we will append
# this number consider the basic one created above
num_categorical = (
randint_scalar(min_num_categorical, max_num_categorical + 1) - 1
)
# Generate random number of categories for each column
# ensuring they're sorted in ascending order
num_categories_list = sorted(
[randint_scalar(2, number_of_groups) for _ in range(num_categorical)]
)
# Ensure last categorical column has <= distinct values than K column
num_categories_list[-1] = int(
min(
torch.tensor(num_categories_list[-1]),
torch.tensor(number_of_groups),
).item()
)
categorical_cols = []
# Get the categorical data from a normal distribution
# combined with a multinomial one
for num_categories in num_categories_list:
y = (
torch.distributions.Normal(
loc=torch.tensor([10.0]), scale=torch.tensor([5.0])
)
.sample((num_categories,))
.reshape((1, -1))
)
y = y * torch.sign(y)
y, _ = torch.sort(y)
y = y / torch.norm(y)
d = torch.multinomial(y, num_samples=N, replacement=True).reshape((-1, 1))
categorical_cols.append(d)
# Prepend categorical columns to data
categorical_data = torch.hstack(categorical_cols)
categorical_data = categorical_data.to(dtype=dtype)
data = torch.hstack([categorical_data, data])
if shuffle_rows:
indices = torch.randperm(data.shape[0])
data = data[indices]
return data
def t_create_batch_index_matrix_sparse(D: Tensor, dtype=torch.float64) → Tensor:
# B: number of categorical columns
# N: number of records
# K: number of groups (max. number of unique elements among all categorical columns)
N, B = D.shape
K = D.unique(sorted=False).shape\[0\]
batch_idx = torch.arange(B, device=D.device).repeat_interleave(N)
row_idx = torch.arange(N, device=D.device).repeat(B)
column_idx = D.T.flatten()
indices = torch.stack([batch_idx, row_idx, column_idx])
values = torch.ones(B * N, device=D.device)
size = torch.Size([B, N, K])
G = torch.sparse_coo_tensor(
indices=indices, values=values, size=size, dtype=dtype, device=D.device
).coalesce()
return G
def proc_batch_matrix_sparse(G: Tensor, X: Tensor, Y: Tensor) → Tensor:
B, N, K = G.shape
Xb = X.unsqueeze(0).expand(B, -1, -1).transpose(1, 2)
Yb = Y.unsqueeze(0).expand(B, -1, -1).transpose(1, 2)
Gt = G.transpose(1, 2)
GtX = torch.bmm(Gt, Xb)
return GtX.to("cpu")
if name == “main”:
DTYPE = torch.float64
GPU = True
NUMBER_OF_TESTS = 10
MIN_NUM_CATEGORICAL, MAX_NUM_CATEGORICAL = 50, 50
MIN_GROUPS, MAX_GROUPS = 10, 100
MIN_GROUP_ROWS, MAX_GROUP_ROWS = 50, 100
device = "cuda" if GPU and torch.cuda.is_available() else "cpu"
for i in range(NUMBER_OF_TESTS):
print(f" Run {i} ".center(100, "="))
data = generate_random_dataset(
MIN_NUM_CATEGORICAL,
MAX_NUM_CATEGORICAL,
MIN_GROUPS,
MAX_GROUPS,
MIN_GROUP_ROWS,
MAX_GROUP_ROWS,
shuffle_rows=True,
dtype=DTYPE,
).to(device)
D = data[:, :-2] # batch of "categorical" columns [NxB]
X = data[:, -2].reshape((1, -1))
Y = data[:, -1].reshape((1, -1))
G = t_create_batch_index_matrix_sparse(D, dtype=DTYPE)
proc_batch_matrix_sparse(G, X, Y)
I create a random dataset (generate_random_dataset), take the last two columns as X and Y and the others are transformed into a sparse batch coo tensor of one hot encoded matrices. (create_batch_matrix_index_sparse) and pass these data to actual computation (proc_batch_matrix_sparse).
I encounter this error:
torch.AcceleratorError: CUDA error: misaligned address
Search for cudaErrorMisalignedAddress' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information. Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.
when computing batch matrix-matrix in proc_batch_matrix_sparse.
I’ve checked the torch.sparse doc, and both Gt and Xb should satisfy the desired shapes and layouts. The error doesn’t occur at every run, and I have not detected specific conditions that may cause it, except that it happens more often with higher number of dataset rows. Moving G to dense seems to solve, but this is not desired (and feasible) for large inputs.
Spec
I’ve ran tests with these two systems:
- GeForce RTX 4090, CUDA 12.2, Driver 535.104.05, torch 2.9;
- Tesla T4, CUDA 13.0, Driver 580.95.05, torch 2.9.
Update
Running this on single matrices in the batch (with torch.sparse.mm) and then stacking results works fine, but a loop on batch index is required;
Output of compute-sanitizer is a long list of:
========= Invalid __global__ read of size 16 bytes
========= at void cusparse::coomv_kernel<(bool)0, int, double, double, double, double>(cusparse::KernelCoeffs<T6>, T2, const T2 *, const T2 *, const T3 *, const T4 *, T5 *, T2 *, T6 *)+0x2b0
========= by thread (32,0,0) in block (0,0,0)
========= Access to 0x7f1fa52e2f48 is misaligned
========= and is inside the nearest allocation at 0x7f1fa4000000 of size 20,971,520 bytes
========= Saved host backtrace up to driver entry point at kernel launch time
========= Host Frame: [0xa0e735] in libcusparse.so.12
========= Host Frame: [0xa74c77] in libcusparse.so.12
========= Host Frame: [0x1b4d59] in libcusparse.so.12
========= Host Frame: [0x1c5044] in libcusparse.so.12
========= Host Frame: cusparseSpMM [0xfb023] in libcusparse.so.12
========= Host Frame: at::native::bmm_out_sparse_cuda(at::Tensor const&, at::Tensor const&, at::Tensor&)::{lambda()#1}::operator()() const::{lambda()#1}::operator()() const [0x2f49e33] in libtorch_cuda.so
========= Host Frame: at::native::bmm_out_sparse_cuda(at::Tensor const&, at::Tensor const&, at::Tensor&) [0x2f4b373] in libtorch_cuda.so
========= Host Frame: at::native::bmm_sparse_cuda(at::Tensor const&, at::Tensor const&) [0x2f4d36f] in libtorch_cuda.so
========= Host Frame: at::(anonymous namespace)::(anonymous namespace)::wrapper_SparseCUDA__bmm(at::Tensor const&, at::Tensor const&) [0x3536c1b] in libtorch_cuda.so
========= Host Frame: c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (at::Tensor const&, at::Tensor const&), &at::(anonymous namespace)::(anonymous namespace)::wrapper_SparseCUDA__bmm>, at::Tensor, c10::guts::typelist::typelist<at::Tensor const&, at::Tensor const&> >, at::Tensor (at::Tensor const&, at::Tensor const&)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, at::Tensor const&) [0x3536c9e] in libtorch_cuda.so
========= Host Frame: at::_ops::bmm::redispatch(c10::DispatchKeySet, at::Tensor const&, at::Tensor const&) [0x27e8e88] in libtorch_cpu.so
========= Host Frame: torch::autograd::VariableType::(anonymous namespace)::bmm(c10::DispatchKeySet, at::Tensor const&, at::Tensor const&) [0x4d5de6a] in libtorch_cpu.so
========= Host Frame: c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (c10::DispatchKeySet, at::Tensor const&, at::Tensor const&), &torch::autograd::VariableType::(anonymous namespace)::bmm>, at::Tensor, c10::guts::typelist::typelist<c10::DispatchKeySet, at::Tensor const&, at::Tensor const&> >, at::Tensor (c10::DispatchKeySet, at::Tensor const&, at::Tensor const&)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, at::Tensor const&) [0x4d5e421] in libtorch_cpu.so
========= Host Frame: at::_ops::bmm::call(at::Tensor const&, at::Tensor const&) [0x2829c6b] in libtorch_cpu.so
========= Host Frame: torch::autograd::THPVariable_bmm(_object*, _object*, _object*) [0x59918e] in libtorch_python.so
========= Host Frame: cfunction_call in methodobject.c:537 [0x143943] in python
========= Host Frame: _PyObject_MakeTpCall in call.c:240 [0x11778b] in python
========= Host Frame: _PyEval_EvalFrameDefault in bytecodes.c:2715 [0x121951] in python
========= Host Frame: PyEval_EvalCode in ceval.c:580 [0x1de5cd] in python
========= Host Frame: run_eval_code_obj in pythonrun.c:1757 [0x21b7b6] in python
========= Host Frame: run_mod in pythonrun.c:1778 [0x216306] in python
========= Host Frame: pyrun_file in pythonrun.c:1674 [0x2131c1] in python
========= Host Frame: _PyRun_SimpleFileObject in pythonrun.c:459 [0x212d7f] in python
========= Host Frame: _PyRun_AnyFileObject in pythonrun.c:78 [0x212882] in python
========= Host Frame: Py_RunMain in main.c:714 [0x20f6c6] in python
========= Host Frame: Py_BytesMain in main.c:768 [0x1c6bb8] in python
========= Host Frame: [0x27249] in libc.so.6
========= Host Frame: __libc_start_main [0x27304] in libc.so.6
========= Host Frame: [0x1c69e8] in python
========= Host Frame: proc_batch_matrix_sparse in myfile.py:148
========= Host Frame: <module> in myfile.py:191