Bfloat16 + transformers

Context

In huggingface transformers, the pegasus and t5 models overflow during beam search in half precision.

Models that were originally trained in fairseq work well in half precision, which leads to be believe that models trained in bfloat16 (on TPUS with tensorflow) will often fail to generate with less dynamic range.

I was considering starting a project to further train the models with a penalty for having large activations (to discourage overflow in fp16), but was wondering whether this is duplicative with the pytorch team’s current efforts.

Specific Questions

  • (a) is the snippet below likely to work in a current nightly version?
  • (b) are the various kernel implementations “in the works” (and my proposed project won’t be useful in a few months)?
  • © Is bfloat16 + cuda a possibility?

Failing Snippet

The following snippet with tries to run a transformer forward pass in bfloat16:

from transformers import BartForConditionalGeneration
import torch
model = BartForConditionalGeneration.from_pretrained("sshleifer/distilbart-xsum-12-3")
model = model.to(torch.bfloat16)
input_ids = torch.tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]], dtype=torch.long)
model(input_ids)
# RuntimeError: "LayerNormKernelImpl" not implemented for 'BFloat16'
1 Like

Currently quantization supports int8/uint8/int32 and bfloat16 is not covered by pytorch quantization.

@izdeby can you answer the above questions?