How to perform mixed precsion on single F.linear?

Dear Community,

I was wondering whether I can further reduce memory consumption on a specific operation that I identified that is using a lot of memory. I noticed that 1) the A and B matrix both as float.32 and 2) the resulting coefficients c also float.32, lead to rather large memory usage.

While I can manually cast A and B to float.16 so that the resulting coefficients c are also float.16, leading to quite reduced memory constraint, I was wondering whether it is possible to use mixed precision instead of manually casting the tensors to float.16 ?

I am not sure whether this is necessary, since we dont need gradients on A and B but on the coefficients c.

This is my code:

self.A = torch.from_numpy(GBTA).to(device, dtype=torch.float16).requires_grad_(False)
self.B = torch.from_numpy(GBTB).to(device, dtype=torch.float16).requires_grad_(False)
c_t = F.linear(c_t.half(), self.A[t].half()) + self.B[t].squeeze(-1) * input

I am not sure how to wrap gad scalar into this, as we already do auto-cast mixed precision and pass the scalar from gradscalar into the training function:

 scaler = GradScaler()
            train_loss, train_class_acc, train_noobj_acc, train_obj_acc = (
                trainholov4_enas_vid_bptt(
                    device,
                    train_loader,
                    model,
                    optimizer,
                    scheduler,
                    loss_f,
                    scaled_anchors,
                    scaler,
                    conf_thresh=0.8,
                    mode="ciou",
                    target_batch_size=args.target_batch_size,
                    ngpus=args.ngpus,
                )
            )

So it may not be needed to add gad scalar and maybe something like this would just work ?

self.A = torch.from_numpy(GBTA).to(device, dtype=torch.float16).requires_grad_(False)
self.B = torch.from_numpy(GBTB).to(device, dtype=torch.float16).requires_grad_(False)
with autocast():
     c_t = F.linear(c_t.half(), self.A[t].half()) + self.B[t].squeeze(-1) * input

Please let me know, I would be happy to hear about suggestions or answers.

The casts should not be needed when autocast is used. Do you see a different behavior and memory usage between autocast and the explicit transformations?