Tensor Cores and mixed precision *matrix multiplication* - output in float32

https://devblogs.nvidia.com/programming-tensor-cores-cuda-9/ states that “Each Tensor Core performs 64 floating point FMA mixed-precision operations per clock (FP16 input multiply with full-precision product and FP32 accumulate, as Figure 2 shows)”.

Is it possible to multiply two fp16 tensors but get output in fp32? torch.mm with out argument of type fp32 produces an error: RuntimeError: Expected object of scalar type Half but got scalar type Float for argument #0 'result'

According to nvidia accumulation is done in fp32 so seems wasteful (in terms of performance) to return just fp16 as an output?

I’m aware of apex. Question here isn’t about end-to-end model training, but simply about matrix multiplication

1 Like

Hi Yuri!

It’s well known in the numerical analysis community that a
significant increase in accuracy can be achieve for vector
dot-products (i.e., for matrix multiplication) when extended
precision (that is, for your case, fp32 is “extended precision”
relative to fp16) is used for the “accumuland” of the dot-product
multiply-accumulate chain. (This is especially true for large
matrices.)

(I don’t have a good reference for this off hand, but if you
search for the rationale behind the 80-bit internal precision
of the Intel 8087 math coprocessor, you should be able to
find lots of good stuff.)

So nVidia is making a potentially sensible trade-off here. You
(potentially) get the most bang for the buck by performing the
tensor multiplication (multiply-accumulate) in extended precision
(fp32), even if you only return a regular precision (fp16) result.

Best regards.

K. Frank

Thanks K. Frank,

sorry, i might have been not clear. I’m aware of trade-off, but thought that nvidia returns result in fp32 (as per link above)

. So i’d guess (might be wrongly?) it’s just Pytorch which doesn’t support returning fp32 result

Hello Yuri!

A quick comment on nVidia’s terminology: I do not see anything
that makes clear what nVidia means by “full precision.” I’d say
that this is either sloppy language or some subtle marketing
hype. Because, bear in mind, that (other than possible exponent
underflow or overflow) multiplying two fp16’s as fp32’s gives
exactly the same result as multiplying them as fp64’s. So (if
only for sanity’s sake) nVidia should have labelled their diagram
as “FP32 product”, rather than the potentially ambiguous “Full
precision product”.

(Also, it’s a little odd that their last step is “Convert to FP32
result”. Last time I looked, the output of a “FP32 accumulator”
is a FP32 result.)

All that being said, it is true that you lose some information by
truncating your fp32 matrix multiplication back down to fp16.
It may be preferable not to. However, the lesson of the numerical
analysts is that you get a lot of benefit (in certain realistically
common cases) from performing the multiply-accumulates in
fp32, and keep most of that benefit even after truncating back
down to fp16. That is, performing the matrix multiplication in
fp16 gives you a fp16 result that has much less accuracy than
its fp16 precision might suggest, whereas multiplying in fp32
(and truncating back to fp16) can give you a result with near
full (or full) fp16 accuracy.

(The fp32 multiplication result may, in fact, be more accurate
than fp16 accuracy (or it could be less, in some cases), so
whether, in such a case, the loss of accuracy of truncating
back down to fp16 is outweighed by the speed up, is up to
you. But broad experience teaches us that in many cases
it is.)

(I don’t know whether pytorch supports returning a fp32 result.
Your nVidia link suggests that the gpu gives you a choice.
If pytorch doesn’t give you a choice, it probably should, but
in many realistic cases it won’t be wasteful to return the result
as fp16.)

Best.

K. Frank

Thanks.

Yeh, my point/question is exactly that nvidia gives fp32, but looks like pytorch doesn’t have an option to return with that precision (allowing only fp16 as output for fp16 product). Wonder if i’m missing smth, or it’s indeed the case?

On “full precision” agree it could only mean that it’s fp32 in this context

Yuri, you are correct. Cublas allows returning result of the multiplication of 2 fp16 matrices in fp32, however, pytorch does not support this option. Pytorch could allow this if there is a compelling use case.

Many thanks @ngimel!

Sorry, have one more question on this… Am i right that implementation is here and what should be changed is c, CUDA_R_16F to c, CUDA_R_32F?

What i unfortunately don’t know (despite reading tutorial) is how to create a simple extension, which accepts just 3 tensors (easy bit) and would call cublasGemmEx at the end. Basically missing a link between these two parts (high-level and low-level). Seems that one step back (from THCBlas.cu) is here, but don’t see atm how to ‘adapt’ it. Might be you can kindly provide a link to some simple (or not so) example… Many thanks in advance!

Guess alternative would ‘simply’ be calling data_ptr() on 3 tensors (inputs and output) and passing it directly to cublasGemmEx?

You can use something like this https://github.com/ngimel/rnn_ext, it’s more complicated than you need but it shows how to call cublas directly from extension.

1 Like

Thanks a lot! That’s pretty much what i’m after! And was easy to change, once you read that cuBlas assumes F-order and so (AB)’ = B’A’ comes to help :slight_smile:

Last question on this. Am i right, that cuBlas does not support multiplication of fp32 on fp16 (which can be useful if you don’t want to covert fp16 to fp32 allocating memory)?

I’ve tried providing CUDA_R_32F and CUDA_R_16F for inputs and got CUBLAS_STATUS_NOT_SUPPORTED

Both inputs have to be of the same type, please refer to the table listing supported type combinations https://docs.nvidia.com/cuda/cublas/index.html#cublas-GemmEx

Thanks! Alas… would have been useful in my case