RuntimeError when using load_in_8bit=True for AutoModelForSeq2SeqLM

I’m trying to train the google T5-FLAN-XL model but due to resource constraint I am having to lower the precision by setting load_in_8bit = True

model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, device_map="auto", load_in_8bit=True, return_dict=True)

I’m able to fine tune the model, but at inference time I run into the error :

     dim = _get_softmax_dim("log_softmax", input.dim(), _stacklevel)
if dtype is None:
     ret = input.log_softmax(dim)
else:
     ret = input.log_softmax(dim, dtype=dtype)

RuntimeError: "log_softmax_lastdim_kernel_impl" not implemented for 'Half'

I checked and it is because softmax expects float type tensors. I don’t know how to cast the tensors because the softmax layer in AutoModelForSeq2SeqLM is abstracted. Is there a workaround for how I could get the model trained with 8 bit lowered precision, to generate sequences.

Below is my generate function :

def generate(qamodel, passage, question):
  source_encoding = tokenizer(
        question,
        passage,
        max_length=1024,
        padding="max_length",
        truncation="only_second",
        return_attention_mask=True,
        add_special_tokens=True,
        return_tensors="pt"
    )
  
  generated_token_ids = qamodel.model.generate(
      input_ids=source_encoding["input_ids"],
      attention_mask=source_encoding["attention_mask"],
      max_length=64,
      num_beams=5,
      repetition_penalty=2.5,
      length_penalty=1.0,
      early_stopping=True
    )
  
  preds = [
      tokenizer.decode(token_id, skip_special_tokens=True, clean_up_tokenization_spaces=True)
      for token_id in generated_token_ids
  ]

  return "".join(preds)

The error is raised in the nn.LogSoftmax layer as it’s not implemented for float16 on the CPU. Use the GPU or float32 on the CPU and it should work:

sm = nn.LogSoftmax(1)

x = torch.randn(1, 10, dtype=torch.half)
out = sm(x)
# RuntimeError: "log_softmax_lastdim_kernel_impl" not implemented for 'Half'

x = x.to("cuda")
out = sm(x)
print(out.shape)
# torch.Size([1, 10])