How does automatic mixed precision handle input downscaling

Hello,

I was wondering whether for automatic mixed precision, Pytorch would benefit from/requires the input to be in the most precise format (e.g. float32) rather than in the smaller one (e.g. float16). The reason I am asking is that storing the input to my model can be very expensive in terms of storage and if I could downscale the input to float16 and operate on that I would. Of course, I could just run the whole model in fp16 but that might give less accuracy and that is why I am interested in automatic mixed precision.

Therefore, does Pytorch downscale the input to fp16 before applying e.g. a linear layer on it in torch.autocast()? In this case, if Pytorch already does that then could I store the fp16 representation myself and use automatic mixed precision to still have fp32 accumulation and train as if the input was fp32? Otherwise, would Pytorch require that the input be fp32 or would it benefit (in terms of performance, perhaps in the backpropagation) from it?

Thank you in advance!

It depends on your actual use case and the used layers. In your particular use case the first layer would be interesting. If it’s accepting float16 inputs, autocast will cast the inputs down for you and keep the outputs also in float16. The next layer could then reuse them or cast them back to float32 if needed.

A manual cast could work, but you should double check it with your actual model.

Here is a small example which shows the logged operations:

import torch
import torch.nn as nn
from torch.testing._internal.logging_tensor import LoggingTensorMode, capture_logs

# setup
lin = nn.Linear(16, 16).cuda()
x = torch.randn(1, 16).cuda()

# default
with capture_logs(is_mode=True) as logs, LoggingTensorMode():
    with torch.cuda.amp.autocast():
        out = lin(x)
for l in logs:
    print(l)
# $1 = torch._ops.aten._to_copy.default($0, dtype=torch.float16)
# $3 = torch._ops.aten._to_copy.default($2, dtype=torch.float16)
# $5 = torch._ops.aten._to_copy.default($4, dtype=torch.float16)
# $6 = torch._ops.aten.t.default($3)
# $7 = torch._ops.aten.addmm.default($1, $5, $6)

# manual cast
x = x.half()
with capture_logs(is_mode=True) as logs, LoggingTensorMode():
    with torch.cuda.amp.autocast():
        out = lin(x)
for l in logs:
    print(l)
# $0 = torch._ops.aten._to_copy.default($0, dtype=torch.float16)
# $1 = torch._ops.aten._to_copy.default($2, dtype=torch.float16)
# $2 = torch._ops.aten.t.default($1)
# $4 = torch._ops.aten.addmm.default($0, $3, $2)

As you can see, the second approach works and avoids one cast.