Utils.checkpoint and cuda.amp, save memory

Hi, I was using cuda.amp.autocast to save memory during training.

But if I use checkpoint in the middle of the network forward pass,
x = checkpoint.checkpoint(self.layer2, x)
feat = checkpoint.checkpoint(self.layer3, x)

the error comes out like below.

RuntimeError: Input type (torch.cuda.HalfTensor) and weight type (torch.cuda.FloatTensor) should be the same

Is it not possible to use both cuda amp and checkpoint?

Yes, it should work, so could you post a minimal, executable code snippet to reproduce the issue, please?

Okay. Below is my model code with layers wrapped with checkpoint.

Below is the code for training, wrapped with amp.autocast() and calculating loss with

The error comes out in the layer3 of the model where checkpoint is used.
The weight type seems to be not matched when using with amp.

Could you post the code snippet by wrapping it into three backticks ```, please?