CUDA error: CUBLAS_STATUS_NOT_SUPPORTED when calling cublasLtMatmul from torch.nn.functional.linear()

When moving from cuda-11.3 to cuda-11.6+, a call to torch.nn.functional.linear() began to fail with a CUBLAS_STATUS_NOT_SUPPORTED error. I was able to reproduce the error using the following script, which aligns one of the input tensors involved in the linear() operation on a “torch.half” boundary.

import torch
import torch.nn.functional as F
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors

pad = torch.rand((1), requires_grad=True, dtype=torch.half, device="cuda")
A = torch.rand((5120, 2560), requires_grad=True, dtype=torch.half, device="cuda")
all_tensors = [pad, A]
new_tensors = _unflatten_dense_tensors(_flatten_dense_tensors([p.clone().detach() for p in all_tensors]), all_tensors)
pad, A = new_tensors

X = torch.rand((26, 1, 2560), requires_grad=True, dtype=torch.half, device="cuda")
B = torch.rand((5120), requires_grad=True, dtype=torch.half, device="cuda")
out = F.linear(X, A, B)
print(out)

The following trace is produced on the nightly pytorch against cuda-11.6, but from what I can tell it affects pytorch-1.12+ and cuda-11.6+. I ran the test script above as CUBLASLT_LOG_LEVEL=5 python test.py

[2023-01-13 18:46:05][cublasLt][265705][Api][cublasLtMatmulDescCreate] matmulDesc=0X7FFC66EBCDD8 computeType=COMPUTE_32F scaleType=0
[2023-01-13 18:46:05][cublasLt][265705][Api][cublasLtMatmulDescSetAttribute] matmulDesc=0X55C8AC9AC430 attr=MATMUL_DESC_TRANSA buf=0X7FFC66EBCDB8 sizeInBytes=4
[2023-01-13 18:46:05][cublasLt][265705][Api][cublasLtMatmulDescSetAttribute] matmulDesc=0X55C8AC9AC430 attr=MATMUL_DESC_TRANSB buf=0X7FFC66EBCDBC sizeInBytes=4
[2023-01-13 18:46:05][cublasLt][265705][Api][cublasLtMatmulDescSetAttribute] matmulDesc=0X55C8AC9AC430 attr=MATMUL_DESC_EPILOGUE buf=0X7FFC66EBCDC0 sizeInBytes=4
[2023-01-13 18:46:05][cublasLt][265705][Api][cublasLtMatmulDescSetAttribute] matmulDesc=0X55C8AC9AC430 attr=MATMUL_DESC_BIAS_POINTER buf=0X7FFC66EBCFB8 sizeInBytes=8
[2023-01-13 18:46:05][cublasLt][265705][Api][cublasLtMatrixLayoutCreate] matLayout=0X7FFC66EBCDD8 type=R_16F rows=2560 cols=5120 ld=2560
[2023-01-13 18:46:05][cublasLt][265705][Api][cublasLtMatrixLayoutCreate] matLayout=0X7FFC66EBCDD8 type=R_16F rows=2560 cols=26 ld=2560
[2023-01-13 18:46:05][cublasLt][265705][Api][cublasLtMatrixLayoutCreate] matLayout=0X7FFC66EBCDD8 type=R_16F rows=5120 cols=26 ld=5120
[2023-01-13 18:46:05][cublasLt][265705][Api][cublasLtMatmulPreferenceCreate] matmulPref=0X7FFC66EBCDD8
[2023-01-13 18:46:05][cublasLt][265705][Api][cublasLtMatmulPreferenceSetAttribute] pref=0X55C8AC9ADF00 attr=MATMUL_PREF_MAX_WORKSPACE_BYTES buf=0X7FFC66EBCDC8 sizeInBytes=8
[2023-01-13 18:46:06][cublasLt][265705][Api][cublasLtMatmulAlgoGetHeuristic] Adesc=[type=R_16F rows=2560 cols=5120 ld=2560] Bdesc=[type=R_16F rows=2560 cols=26 ld=2560] Cdesc=[type=R_16F rows=5120 cols=26 ld=5120] preference=[maxWavesCount=0.0 maxWorkspaceSizeinBytes=1048576] computeDesc=[computeType=COMPUTE_32F scaleType=R_32F transa=OP_T epilogue=EPILOGUE_BIAS biasPointer=0x7fe48bc20a00]
[2023-01-13 18:46:06][cublasLt][265705][Info][cublasLtMatmulAlgoGetHeuristic]  heuristicResults=[6]
[2023-01-13 18:46:06][cublasLt][265705][Api][cublasLtMatmul] A=0X7FE452000002 Adesc=0X55C8AC9AD160 B=0X7FE48BC00200 Bdesc=0X55C8AC9AD5F0 C=0X7FE48BC23200 Cdesc=0X55C8AC9AD630 D=0X7FE48BC23200 Ddesc=0X55C8AC9AD630 computeDesc=0X55C8AC9AC430 algo=0X7FFC66EBCE00 workSpace=0X7FE48BC64200 workSpaceSizeInBytes=1048576 stream=0X0
[2023-01-13 18:46:06][cublasLt][265705][Trace][cublasLtMatmul] A=0X7FE452000002 Adesc=[type=R_16F rows=2560 cols=5120 ld=2560] B=0X7FE48BC00200 Bdesc=[type=R_16F rows=2560 cols=26 ld=2560] C=0X7FE48BC23200 Cdesc=[type=R_16F rows=5120 cols=26 ld=5120] D=0X7FE48BC23200 Ddesc=[type=R_16F rows=5120 cols=26 ld=5120] computeDesc=[computeType=COMPUTE_32F scaleType=R_32F transa=OP_T epilogue=EPILOGUE_BIAS biasPointer=0x7fe48bc20a00] algo=[algoId=6 tile=MATMUL_TILE_64x64 stages=MATMUL_STAGES_64x6] workSpace=0X7FE48BC64200 workSpaceSizeInBytes=1048576 beta=0 outOfPlace=0 stream=0X0
[2023-01-13 18:46:06][cublasLt][265705][Api][cublasLtMatmulPreferenceDestroy] matmulPref=0X55C8AC9ADF00
[2023-01-13 18:46:06][cublasLt][265705][Api][cublasLtMatrixLayoutDestroy] matLayout=0X55C8AC9AD630
[2023-01-13 18:46:06][cublasLt][265705][Api][cublasLtMatrixLayoutDestroy] matLayout=0X55C8AC9AD5F0
[2023-01-13 18:46:06][cublasLt][265705][Api][cublasLtMatrixLayoutDestroy] matLayout=0X55C8AC9AD160
[2023-01-13 18:46:06][cublasLt][265705][Api][cublasLtMatmulDescDestroy] matmulDesc=0X55C8AC9AC430
Traceback (most recent call last):
  File "/home/ubuntu/src/augment/models/gpt-neox/mytest.py", line 13, in <module>
    out = F.linear(X, A, B)
RuntimeError: CUDA error: CUBLAS_STATUS_NOT_SUPPORTED when calling cublasLtMatmul with transpose_mat1 1 transpose_mat2 0 m 5120 n 26 k 2560 mat1_ld 2560 mat2_ld 2560 result_ld 5120 abcType 2 computeType 68 scaleType 0

To summarize, it appears the cublasLtMatmul is not happy with a 2-byte aligned matrix when bias is included and the other two matrices are aligned on a larger boundary.
It seems the heuristic chosen to satisfy this Matmul is possibly incorrect, and I can’t tell if the error is in pytorch asking for the heuristic, or cuda-11.6 choosing the incorrect heursitic.

I’m happy to provide additional information or reproductions if needed. Thanks!

1 Like

Thank you for providing this great minimal and executable code snippet!
We’ll look into the issue.

1 Like

Yes, it appears that the heuristics are incorrect, and the reason the failure was not observed previously was that older builds versions of PyTorch did not have a cuBlasLt path for addmm but rather relied on an unfused implementation backed by cuBlas.

I’ve opened [cublas][cublasLt] Fall back to unfused addmm for 2-byte-aligned inputs by eqy · Pull Request #92201 · pytorch/pytorch (github.com) to fix this and in the meantime you can use the environment variable DISABLE_ADDMM_CUDA_LT=1 to work around the issue.

We’ll also follow up with cuBlasLt to see if the underlying heuristics can be fixed here.

3 Likes

I think I got the same error, but I’m not totally sure: I launched my program with CUBLASLT_LOG_LEVEL=5 and got

[2023-02-19 16:40:17][cublasLt][8536][Api][cublasLtMatmulPreferenceDestroy] matmulPref=0X55BFB2C3A410
[2023-02-19 16:40:17][cublasLt][8536][Api][cublasLtMatrixLayoutDestroy] matLayout=0X55BFB2A31BD0
[2023-02-19 16:40:17][cublasLt][8536][Api][cublasLtMatrixLayoutDestroy] matLayout=0X55BFB2A31C90
[2023-02-19 16:40:17][cublasLt][8536][Api][cublasLtMatrixLayoutDestroy] matLayout=0X55BFB2A331B0
[2023-02-19 16:40:17][cublasLt][8536][Api][cublasLtMatmulDescDestroy] matmulDesc=0X55BFB2B39D70

but the end error is

RuntimeError: CUDA error: CUBLAS_STATUS_NOT_SUPPORTED when calling `cublasGemmStridedBatchedExFix(handle, opa, opb, (int)m, (int)n, (int)k, (void*)&falpha, a, CUDA_R_16BF, (int)lda, stridea, b, CUDA_R_16BF, (int)ldb, strideb, (void*)&fbeta, c, CUDA_R_16BF, (int)ldc, stridec, (int)num_batches, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)`

torch.version = 1.13.1+cu117

I tried with DISABLE_ADDMM_CUDA_LT=1 but it does not help.
Is it the same error? What shall I do?

Thank you!

Could you post a minimal and executable code snippet which would reproduce the issue, please?

I’m trying, but when I extract the call to baddbmm() that generates the
crash, with the same parameters, it does not creash any more.
So it’s likely related to the precise context that occurs when this call
is made. May it be because of the GPU memory that is full or close to full?

I’ll keep on investigating when I’ll have time, but so far, no success
to reproduce outside the full program.

Thank you!

The cublas error might indeed be the victim of another failure. Could you update to the latest nightly release and check if you are still seeing the same error?

Using torch==1.13.1 (sorry, I don’t have the rights to update or use github’s master right now, I am running this on a cluster.), I get the error using torch’s tutorial for dataParallel. Hope it helps.

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

# SEE HERE  https://pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html

# Parameters
input_size = 5
output_size = 2

batch_size = 30
data_size = 100

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Generate random tensors as data
class RandomDataset(Dataset):
    def __init__(self, nb_features, nb_data):
        self.len = nb_data
        self.data = torch.randn(nb_data, nb_features)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len

# simple linear model
class Model(nn.Module):
    def __init__(self, input_size, output_size):
        super(Model, self).__init__()
        self.fc = nn.Linear(input_size, output_size)

    def forward(self, input: torch.Tensor):
        output = self.fc(input)
        print("\tIn Model: input size", input.size(),
              "output size", output.size())

        return output

rand_loader = DataLoader(dataset=RandomDataset(input_size, data_size),
                         batch_size=batch_size, shuffle=True)

model = Model(input_size, output_size)
print("Let's use", torch.cuda.device_count(), "GPUs!")
model = nn.DataParallel(model)

model.to(device)

# Run the model
for data in rand_loader:
    input = data.to(device)
    output = model(input)
    print("Outside: input size", input.size(),
          "output_size

Result:

Let's use 2 GPUs!
	In Model: input size torch.Size([15, 5]) output size torch.Size([15, 2])
	In Model: input size torch.Size([15, 5]) output size torch.Size([15, 2])
Outside: input size torch.Size([30, 5]) output_size torch.Size([30, 2])
	In Model: input size torch.Size([15, 5]) output size torch.Size([15, 2])

Traceback (most recent call last):
[...]
RuntimeError: Caught RuntimeError in replica 0 on device 0.
[...]
RuntimeError: CUDA error: CUBLAS_STATUS_NOT_SUPPORTED when calling `cublasLtMatmulAlgoGetHeuristic( ltHandle, computeDesc.descriptor(), Adesc.descriptor(), Bdesc.descriptor(), Cdesc.descriptor(), Cdesc.descriptor(), preference.descriptor(), 1, &heuristicResult, &returnedResult)`

I tried changing the environment variable (export DISABLE_ADDMM_CUDA_LT=1 in my terminal; is this the right way?), it didn’t fix the issue.

If I do not use DataParallel, the script runs fine.

Thank you!

Could you rerun your code with CUDA_LAUNCH_BLOCKING=1 and check the stacktrace, please? Also which GPIs are you using?

CUDA_LAUNCH_BLOCKING=1 does not seem to fix the issue either.
(I tried both using export and lauching script with CUDA_LAUNCH_BLOCKING=1 python my_script.py).

The GPU is V100-SXM2.

I ran tutorial of DistributedDataParallel with success, however.

https://pytorch.org/tutorials/intermediate/ddp_tutorial.html

I cannot reproduce the issue using your code andtorch==1.13.1+cu117 on 8x V100 16GB.
This is the output I get after removing the last "output_size comment which looks like a copy/paste error in your code:

Let's use 8 GPUs!
	In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2])
	In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2])
	In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2])
	In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2])
	In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2])
	In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2])
	In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2])
	In Model: input size torch.Size([2, 5]) output size torch.Size([2, 2])
Outside: input size torch.Size([30, 5])
	In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2])
	In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2])
	In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2])
	In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2])
	In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2])
	In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2])
	In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2])
	In Model: input size torch.Size([2, 5]) output size torch.Size([2, 2])
Outside: input size torch.Size([30, 5])
	In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2])
	In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2])
	In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2])
	In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2])
	In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2])
	In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2])
	In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2])
	In Model: input size torch.Size([2, 5]) output size torch.Size([2, 2])
Outside: input size torch.Size([30, 5])
	In Model: input size torch.Size([2, 5]) output size torch.Size([2, 2])
	In Model: input size torch.Size([2, 5]) output size torch.Size([2, 2])
	In Model: input size torch.Size([2, 5]) output size torch.Size([2, 2])
	In Model: input size torch.Size([2, 5]) output size torch.Size([2, 2])
	In Model: input size torch.Size([2, 5]) output size torch.Size([2, 2])
Outside: input size torch.Size([10, 5])

I had the same issue. This issue happens when your CUDA runtime version is different (older?) than the CUDA you used for building PyTorch.

I meet same problem, how to solve this? thx

hi, do u solve this problem?