[Mixed precision] nn.Parameter autocast problem

Hello I’m trying mixed precision training with pretrained model.

But when I load pretrained_model[1], nn.Parameter module[2] in WavLM(pretrained model) makes trouble.

### Error Message
  File "/data/leecho/xi-stt/xi-stt/model/WavLM.py", line 286, in apply_mask
    x[mask_indices] = self.mask_emb
RuntimeError: Index put requires the source and destination dtypes match, got Half for the destination and Float for the source.

Is there any way to fix this without breaking loaded data?

### trainer.py
model = ASRModel(...)

with autocast(enabled=config.fp16_run):
    predictions = model(inputs, apply_mask=True)

### ASRmodel.py
class ASRModel(nn.Module):
    def __init__(
        self, cfg: ASRConfig, ...) -> None:
        super().__init__()
        
        self.cfg = cfg
        
        ### [1] load pretrained_model
        if load_from_pretrained_model:
            checkpoint = torch.load(cfg.pretrained_model_path)
            wavlm_cfg = WavLMConfig(checkpoint['cfg'])
            wavlm_cfg.feature_grad_mult = 0.
            self.acoustic_model = WavLM(wavlm_cfg)
            self.acoustic_model.load_state_dict(checkpoint['model'])
                    
            self.cfg.wavlm_cfg = wavlm_cfg


### WavLM.py, 
### [2] self.mask_emb
self.mask_emb = nn.Parameter(
    torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
)

def apply_mask(self, x, padding_mask):
    B, T, C = x.shape
    ....
        mask_indices = torch.from_numpy(mask_indices).to(x.device)
        x[mask_indices] = self.mask_emb
    else:
        mask_indices = None

It seems you are trying to call index_put_ with mixed dtypes, which creates the error. Cast e.g. the value tensor to the source dtype and it should work:

lin = nn.Linear(1, 1).cuda()
x = torch.randn(1, 1).cuda()

with torch.cuda.amp.autocast():
    out = lin(x)
    
print(out.dtype)
# torch.float16

val = torch.randn(1, dtype=torch.float32, device='cuda')

out.index_put_((torch.tensor(0),), val)
# RuntimeError: Index put requires the source and destination dtypes match, got Half for the destination and Float for the source.

out.index_put_((torch.tensor(0),), val.to(out.dtype)) # works

Thank you for your reply.

The problem solved by your suggestion.

Thank you.