Hello I’m trying mixed precision training with pretrained model.
But when I load pretrained_model[1], nn.Parameter module[2] in WavLM(pretrained model) makes trouble.
### Error Message
File "/data/leecho/xi-stt/xi-stt/model/WavLM.py", line 286, in apply_mask
x[mask_indices] = self.mask_emb
RuntimeError: Index put requires the source and destination dtypes match, got Half for the destination and Float for the source.
Is there any way to fix this without breaking loaded data?
### trainer.py
model = ASRModel(...)
with autocast(enabled=config.fp16_run):
predictions = model(inputs, apply_mask=True)
### ASRmodel.py
class ASRModel(nn.Module):
def __init__(
self, cfg: ASRConfig, ...) -> None:
super().__init__()
self.cfg = cfg
### [1] load pretrained_model
if load_from_pretrained_model:
checkpoint = torch.load(cfg.pretrained_model_path)
wavlm_cfg = WavLMConfig(checkpoint['cfg'])
wavlm_cfg.feature_grad_mult = 0.
self.acoustic_model = WavLM(wavlm_cfg)
self.acoustic_model.load_state_dict(checkpoint['model'])
self.cfg.wavlm_cfg = wavlm_cfg
### WavLM.py,
### [2] self.mask_emb
self.mask_emb = nn.Parameter(
torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
)
def apply_mask(self, x, padding_mask):
B, T, C = x.shape
....
mask_indices = torch.from_numpy(mask_indices).to(x.device)
x[mask_indices] = self.mask_emb
else:
mask_indices = None