[nonissue] Autograd fails when using half-precision - overflow on matrix size

I got Titan V and have been experimenting with half-precision.

In half-precision mode I can’t backpropagate through matmul of two all-zeros matrices, because the number of elements in the resulting matrix is outside of half-precision range.

I am getting the same error if I use Conv1d or Conv2d or bmm.

This minimal computation graph replicates the problem:

import torch, torch.autograd, torch.nn,numpy
with torch.cuda.device(0):

    test_input = torch.autograd.Variable(torch.zeros(257, 509)).cuda().half()
    test_w = torch.nn.Parameter(torch.zeros(509,263)).cuda().half()

    matmul_result = torch.matmul(test_input, test_w)
    print(matmul_result.size())
    print(numpy.prod(matmul_result.size()))

    test_output = matmul_result.abs().mean()
    test_output.backward()

And the result is:

torch.Size([257, 263])
67591
Traceback (most recent call last):
  File "<stdin>", line 11, in <module>
  File "/home/dzmitry/miniconda3/lib/python3.6/site-packages/torch/autograd/variable.py", line 167, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, retain_variables)
  File "/home/dzmitry/miniconda3/lib/python3.6/site-packages/torch/autograd/__init__.py", line 99, in backward
    variables, grad_variables, retain_graph)
RuntimeError: value cannot be converted to type Half without overflow: 67591

I am using PyTorch 0.3 and could replicate the issue with CUDA 8 and CUDA 9.

1 Like

I found my mistake. It actually fails because I use .mean(), which of course has the total number of elements in it.

So what is the solution to this problem? I’m still stuck! Would you recommend any changes to mean() code?

Reductions are sensitive to overflow if you are using FP16.
You should perform all reductions in FP32 just to make sure to get a valid result.
Just call .float() on your tensor before passing it to torch.mean().
Autograd will rewind this operation in the backward pass so that your model will still be in half precision.

Thanks. It works but the loss goes to NaN. :sweat_smile: I think I gotta read some low-precision training papers before trying to do my own experiments with the handling of these NaNs. I find that it’s super unstable to train with low precision. The value get washed out easily.

Have a look at this post for some info on FP16 training.

Hello, I have a question about FP16, I found if I apply FP 16 in forward process, it will very slow.
e.g. FP32 forward time 3.27s
FP16 forward time 13s

there are nn.Linear and torch.matmul in my nework. I want to ask does FP16 optimized for special ops not for all ops?

FP16 operations can be accelerated using Tensor cores (or Volta cores), which were introduced in Volta and Turing GPUs (e.g. GTX 2080).

Linear layers use GEMMs to compute the output. The matrices involved in this computation must have sizes of multiples of 8 to use Tensor cores:

a = torch.randn(8, 16)
b = torch.randn(16, 32)
torch.matmul(a, b)

This requirement exists for all cublas and cudnn versions, as far as I know.

There were also other requirements for convolution layers for cudnn 7.2 and older.
The current PyTorch binaries ship with cudnn 7.6, so you don’t have to worry about it.

That being said, the posted times look quite bad.
Could you post the code you’ve used to profile your model?
Also, which GPU are you using?

Hi, code is too much, I just try in my code, seems problem occurred in torch.matmul() function.
below is my training process
FP32 each batch
Time: 0.00418s

Time: 0.00555s

Time: 0.00513s

Time: 0.00510s

Time: 0.00475s

Time: 0.00477s

Time: 0.00465s

Time: 0.00469s

in case of FP16 each batch:
Time: 0.46559s

Time: 0.46569s

Time: 0.46576s

Time: 0.46529s

Time: 0.46545s

Time: 0.46568s

Time: 0.46566s

Time: 0.46571s

Time: 0.46584s

my GPU is GTX1080ti

Your GTX 1080Ti does not use Tensor / Volta cores, so you won’t see any speedup.
However, a performance degradation of x100 shouldn’t occur.
Without code, it’s hard to debug the issue. :confused:

test code:

torch.matmul(col, weight)
col shape[10, 22500, 1, 1728]
weight shape[22500, 1728, 3]

that will slow using precision half

Unfortunately, only 1728 is a multiple of 8, so that this operation won’t be sped up in a GEMM using FP16.
On my Titan V, the operation in FP16 is approx. 10x slower than in FP32.
When I adapt the shapes to multiples of 8, I get approx. a 2x speedup.

I improve my code:
start = time.process_time()
a = torch.randn(8, 8000, 1, 1728)
b = torch.randn(8000, 1728, 3)
#a = a.half()
#b = b.half()
if True:
device = torch.device(“cuda”)
else:
device = torch.device(“cpu”)
a = a.to(device)
b = b.to(device)
torch.matmul(a, b)
end = time.process_time()
print(“Elapsed time of {}: {} s”.format(‘main’, np.round(end - start, decimals=5)))

FP16: 4.7057 s
FP32: 3.43783 s

As I said, your GPU does not have any Volta or Tensor cores so the performance won’t be improved.
However, you’ll see a smaller memory usage using FP16.

If you would like to train your model using mixed precision, I would recommend to have a look at apex/amp.

thank you for your reply, I have another problem.
test code:
a = torch.randn(1, 49920, 512)
c = torch.randn(1, 512, 49920)
#b = torch.randn(153000, 288, 3, 3)
#a = a.half()
#b = b.half()
if True:
device = torch.device(“cuda”)
else:
device = torch.device(“cpu”)
a = a.to(device)
#b = b.to(device)
c = c.to(device)
torch.cuda.synchronize()
start = time.process_time()
#torch.matmul(a, b)
#F.conv1d(a, c)
torch.matmul(a,c)

as I checked the input feature map size should be 49920x512x4 /1024/1024 = 97.5M
but when I run it in GPU (GTX1080ti), the top memory will cost more than 10G.
do you why for that, and how to avoid it.

Since the result of torch.matmul(a, c) will have the shape [1, 49920, 49920], it will use approx. 9.5GB of memory.
Given that the CUDA context uses a bit of memory as well as a and b, the usage of 10GB seems reasonable.

yes, it is. thank you!