Autocast with indexing incorrect types

I’m looking if there is a better or easier way to do the following. I’m creating an empty tensor, and only want to fill in the indices with the output of a network if the mask condition is true.

...
with torch.cuda.amp.autocast():
    next_state_values = torch.zeros(len(non_final_mask), device=self.device)
    next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0].detach()

However, this will throw the error

RuntimeError: Index put requires the source and destination dtypes match, got Float for the destination and Half for the source.

I can do the following as a workaround, but this feels a bit hacky. Is there a better way to create the initial empty tensor while using amp?

...
with torch.cuda.amp.autocast():
    next_state_values = torch.zeros(len(non_final_mask), device=self.device)
    net_out = target_net(non_final_next_states).max(1)[0].detach()
    next_state_values = next_state_values.to(net_out.dtype)
    next_state_values[non_final_mask] = net_out