I’m looking if there is a better or easier way to do the following. I’m creating an empty tensor, and only want to fill in the indices with the output of a network if the mask condition is true.
...
with torch.cuda.amp.autocast():
next_state_values = torch.zeros(len(non_final_mask), device=self.device)
next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0].detach()
However, this will throw the error
RuntimeError: Index put requires the source and destination dtypes match, got Float for the destination and Half for the source.
I can do the following as a workaround, but this feels a bit hacky. Is there a better way to create the initial empty tensor while using amp?
...
with torch.cuda.amp.autocast():
next_state_values = torch.zeros(len(non_final_mask), device=self.device)
net_out = target_net(non_final_next_states).max(1)[0].detach()
next_state_values = next_state_values.to(net_out.dtype)
next_state_values[non_final_mask] = net_out