How to cast a tensor to another type?

If I understand your use case correctly, you are transforming the input data and model parameters to FP16 manually without using automatic mixed-precision training?
If so, could you post the error message which shows which operation is raising the issue?

That is correct. This is how I did it. Here comes the error message:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-11-1d79e07828d2> in <module>
     34             with torch.no_grad():
     35                 inference_start_time = time.time() * 1000
---> 36                 prediction = model(frame)
     37                 inference_end_time = time.time() * 1000
     38 

~/anaconda3/envs/enet-pytorch/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

~/PyTorch-ENet/models/enet.py in forward(self, x)
    590         # Stage 1 - Encoder
    591         stage1_input_size = x.size()
--> 592         x, max_indices1_0 = self.downsample1_0(x)
    593         x = self.regular1_1(x)
    594         x = self.regular1_2(x)

~/anaconda3/envs/enet-pytorch/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

~/PyTorch-ENet/models/enet.py in forward(self, x)
    345         out = main + ext
    346 
--> 347         return self.out_activation(out), max_indices
    348 
    349 

~/anaconda3/envs/enet-pytorch/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

~/anaconda3/envs/enet-pytorch/lib/python3.7/site-packages/torch/nn/modules/activation.py in forward(self, input)
    986 
    987     def forward(self, input: Tensor) -> Tensor:
--> 988         return F.prelu(input, self.weight)
    989 
    990     def extra_repr(self) -> str:

~/anaconda3/envs/enet-pytorch/lib/python3.7/site-packages/torch/nn/functional.py in prelu(input, weight)
   1317         if type(input) is not Tensor and has_torch_function((input,)):
   1318             return handle_torch_function(prelu, (input,), input, weight)
-> 1319     return torch.prelu(input, weight)
   1320 
   1321 

RuntimeError: expected scalar type Float but found Half

PyTorch version: 1.6
Python: 3.7.8

Thanks for the stacktrace.
It seems nn.PReLU() is causing the error, which I cannot reproduce using 1.7.0.dev20200830:

prelu = nn.PReLU().cuda().half()
x = torch.randn(10, 10).cuda().half()
out = prelu(x)

Could you update to the latest nightly binary and rerun the code?

1 Like

I was able to run the code snippet of yours in both 1.6 and 1.7.0.dev20200830. Unfortunately remains my problem the same with the nighty version.

Just for reconfirmation. The following output shows the that the conversion to float16 was successful, right? Is there something else that could be checked?

for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].dtype)
initial_block.main_branch.weight 	 torch.float16
initial_block.batch_norm.weight 	 torch.float16
initial_block.batch_norm.bias 	 torch.float16
initial_block.batch_norm.running_mean 	 torch.float16
initial_block.batch_norm.running_var 	 torch.float16
initial_block.batch_norm.num_batches_tracked 	 torch.int64
initial_block.out_activation.weight 	 torch.float16
downsample1_0.ext_conv1.0.weight 	 torch.float16
downsample1_0.ext_conv1.1.weight 	 torch.float16
downsample1_0.ext_conv1.1.bias 	 torch.float16
downsample1_0.ext_conv1.1.running_mean 	 torch.float16
downsample1_0.ext_conv1.1.running_var 	 torch.float16
downsample1_0.ext_conv1.1.num_batches_tracked 	 torch.int64
downsample1_0.ext_conv1.2.weight 	 torch.float16
downsample1_0.ext_conv2.0.weight 	 torch.float16
downsample1_0.ext_conv2.1.weight 	 torch.float16
downsample1_0.ext_conv2.1.bias 	 torch.float16
downsample1_0.ext_conv2.1.running_mean 	 torch.float16
downsample1_0.ext_conv2.1.running_var 	 torch.float16
downsample1_0.ext_conv2.1.num_batches_tracked 	 torch.int64
downsample1_0.ext_conv2.2.weight 	 torch.float16
downsample1_0.ext_conv3.0.weight 	 torch.float16
downsample1_0.ext_conv3.1.weight 	 torch.float16
downsample1_0.ext_conv3.1.bias 	 torch.float16
downsample1_0.ext_conv3.1.running_mean 	 torch.float16
downsample1_0.ext_conv3.1.running_var 	 torch.float16
downsample1_0.ext_conv3.1.num_batches_tracked 	 torch.int64
downsample1_0.ext_conv3.2.weight 	 torch.float16
downsample1_0.out_activation.weight 	 torch.float16
regular1_1.ext_conv1.0.weight 	 torch.float16
regular1_1.ext_conv1.1.weight 	 torch.float16
regular1_1.ext_conv1.1.bias 	 torch.float16
regular1_1.ext_conv1.1.running_mean 	 torch.float16
regular1_1.ext_conv1.1.running_var 	 torch.float16
regular1_1.ext_conv1.1.num_batches_tracked 	 torch.int64
regular1_1.ext_conv1.2.weight 	 torch.float16
regular1_1.ext_conv2.0.weight 	 torch.float16
regular1_1.ext_conv2.1.weight 	 torch.float16
regular1_1.ext_conv2.1.bias 	 torch.float16
regular1_1.ext_conv2.1.running_mean 	 torch.float16
regular1_1.ext_conv2.1.running_var 	 torch.float16
regular1_1.ext_conv2.1.num_batches_tracked 	 torch.int64
regular1_1.ext_conv2.2.weight 	 torch.float16
regular1_1.ext_conv3.0.weight 	 torch.float16
regular1_1.ext_conv3.1.weight 	 torch.float16
regular1_1.ext_conv3.1.bias 	 torch.float16
regular1_1.ext_conv3.1.running_mean 	 torch.float16
regular1_1.ext_conv3.1.running_var 	 torch.float16
regular1_1.ext_conv3.1.num_batches_tracked 	 torch.int64
regular1_1.ext_conv3.2.weight 	 torch.float16
regular1_1.out_activation.weight 	 torch.float16
regular1_2.ext_conv1.0.weight 	 torch.float16
regular1_2.ext_conv1.1.weight 	 torch.float16
regular1_2.ext_conv1.1.bias 	 torch.float16
regular1_2.ext_conv1.1.running_mean 	 torch.float16
regular1_2.ext_conv1.1.running_var 	 torch.float16
regular1_2.ext_conv1.1.num_batches_tracked 	 torch.int64
regular1_2.ext_conv1.2.weight 	 torch.float16
regular1_2.ext_conv2.0.weight 	 torch.float16
regular1_2.ext_conv2.1.weight 	 torch.float16
regular1_2.ext_conv2.1.bias 	 torch.float16
regular1_2.ext_conv2.1.running_mean 	 torch.float16
regular1_2.ext_conv2.1.running_var 	 torch.float16
regular1_2.ext_conv2.1.num_batches_tracked 	 torch.int64
regular1_2.ext_conv2.2.weight 	 torch.float16
regular1_2.ext_conv3.0.weight 	 torch.float16
regular1_2.ext_conv3.1.weight 	 torch.float16
regular1_2.ext_conv3.1.bias 	 torch.float16
regular1_2.ext_conv3.1.running_mean 	 torch.float16
regular1_2.ext_conv3.1.running_var 	 torch.float16
regular1_2.ext_conv3.1.num_batches_tracked 	 torch.int64
regular1_2.ext_conv3.2.weight 	 torch.float16
regular1_2.out_activation.weight 	 torch.float16
regular1_3.ext_conv1.0.weight 	 torch.float16
regular1_3.ext_conv1.1.weight 	 torch.float16
regular1_3.ext_conv1.1.bias 	 torch.float16
regular1_3.ext_conv1.1.running_mean 	 torch.float16
regular1_3.ext_conv1.1.running_var 	 torch.float16
regular1_3.ext_conv1.1.num_batches_tracked 	 torch.int64
regular1_3.ext_conv1.2.weight 	 torch.float16
regular1_3.ext_conv2.0.weight 	 torch.float16
regular1_3.ext_conv2.1.weight 	 torch.float16
regular1_3.ext_conv2.1.bias 	 torch.float16
regular1_3.ext_conv2.1.running_mean 	 torch.float16
regular1_3.ext_conv2.1.running_var 	 torch.float16
regular1_3.ext_conv2.1.num_batches_tracked 	 torch.int64
regular1_3.ext_conv2.2.weight 	 torch.float16
regular1_3.ext_conv3.0.weight 	 torch.float16
regular1_3.ext_conv3.1.weight 	 torch.float16
regular1_3.ext_conv3.1.bias 	 torch.float16
regular1_3.ext_conv3.1.running_mean 	 torch.float16
regular1_3.ext_conv3.1.running_var 	 torch.float16
regular1_3.ext_conv3.1.num_batches_tracked 	 torch.int64
regular1_3.ext_conv3.2.weight 	 torch.float16
regular1_3.out_activation.weight 	 torch.float16
regular1_4.ext_conv1.0.weight 	 torch.float16
regular1_4.ext_conv1.1.weight 	 torch.float16
regular1_4.ext_conv1.1.bias 	 torch.float16
regular1_4.ext_conv1.1.running_mean 	 torch.float16
regular1_4.ext_conv1.1.running_var 	 torch.float16
regular1_4.ext_conv1.1.num_batches_tracked 	 torch.int64
regular1_4.ext_conv1.2.weight 	 torch.float16
regular1_4.ext_conv2.0.weight 	 torch.float16
regular1_4.ext_conv2.1.weight 	 torch.float16
regular1_4.ext_conv2.1.bias 	 torch.float16
regular1_4.ext_conv2.1.running_mean 	 torch.float16
regular1_4.ext_conv2.1.running_var 	 torch.float16
regular1_4.ext_conv2.1.num_batches_tracked 	 torch.int64
regular1_4.ext_conv2.2.weight 	 torch.float16
regular1_4.ext_conv3.0.weight 	 torch.float16
regular1_4.ext_conv3.1.weight 	 torch.float16
regular1_4.ext_conv3.1.bias 	 torch.float16
regular1_4.ext_conv3.1.running_mean 	 torch.float16
regular1_4.ext_conv3.1.running_var 	 torch.float16
regular1_4.ext_conv3.1.num_batches_tracked 	 torch.int64
regular1_4.ext_conv3.2.weight 	 torch.float16
regular1_4.out_activation.weight 	 torch.float16
downsample2_0.ext_conv1.0.weight 	 torch.float16
downsample2_0.ext_conv1.1.weight 	 torch.float16
downsample2_0.ext_conv1.1.bias 	 torch.float16
downsample2_0.ext_conv1.1.running_mean 	 torch.float16
downsample2_0.ext_conv1.1.running_var 	 torch.float16
downsample2_0.ext_conv1.1.num_batches_tracked 	 torch.int64
downsample2_0.ext_conv1.2.weight 	 torch.float16
downsample2_0.ext_conv2.0.weight 	 torch.float16
downsample2_0.ext_conv2.1.weight 	 torch.float16
downsample2_0.ext_conv2.1.bias 	 torch.float16
downsample2_0.ext_conv2.1.running_mean 	 torch.float16
downsample2_0.ext_conv2.1.running_var 	 torch.float16
downsample2_0.ext_conv2.1.num_batches_tracked 	 torch.int64
downsample2_0.ext_conv2.2.weight 	 torch.float16
downsample2_0.ext_conv3.0.weight 	 torch.float16
downsample2_0.ext_conv3.1.weight 	 torch.float16
downsample2_0.ext_conv3.1.bias 	 torch.float16
downsample2_0.ext_conv3.1.running_mean 	 torch.float16
downsample2_0.ext_conv3.1.running_var 	 torch.float16
downsample2_0.ext_conv3.1.num_batches_tracked 	 torch.int64
downsample2_0.ext_conv3.2.weight 	 torch.float16
downsample2_0.out_activation.weight 	 torch.float16
regular2_1.ext_conv1.0.weight 	 torch.float16
regular2_1.ext_conv1.1.weight 	 torch.float16
regular2_1.ext_conv1.1.bias 	 torch.float16
regular2_1.ext_conv1.1.running_mean 	 torch.float16
regular2_1.ext_conv1.1.running_var 	 torch.float16
regular2_1.ext_conv1.1.num_batches_tracked 	 torch.int64
regular2_1.ext_conv1.2.weight 	 torch.float16
regular2_1.ext_conv2.0.weight 	 torch.float16
regular2_1.ext_conv2.1.weight 	 torch.float16
regular2_1.ext_conv2.1.bias 	 torch.float16
regular2_1.ext_conv2.1.running_mean 	 torch.float16
regular2_1.ext_conv2.1.running_var 	 torch.float16
regular2_1.ext_conv2.1.num_batches_tracked 	 torch.int64
regular2_1.ext_conv2.2.weight 	 torch.float16
regular2_1.ext_conv3.0.weight 	 torch.float16
regular2_1.ext_conv3.1.weight 	 torch.float16
regular2_1.ext_conv3.1.bias 	 torch.float16
regular2_1.ext_conv3.1.running_mean 	 torch.float16
regular2_1.ext_conv3.1.running_var 	 torch.float16
regular2_1.ext_conv3.1.num_batches_tracked 	 torch.int64
regular2_1.ext_conv3.2.weight 	 torch.float16
regular2_1.out_activation.weight 	 torch.float16
dilated2_2.ext_conv1.0.weight 	 torch.float16
dilated2_2.ext_conv1.1.weight 	 torch.float16
dilated2_2.ext_conv1.1.bias 	 torch.float16
dilated2_2.ext_conv1.1.running_mean 	 torch.float16
dilated2_2.ext_conv1.1.running_var 	 torch.float16
dilated2_2.ext_conv1.1.num_batches_tracked 	 torch.int64
dilated2_2.ext_conv1.2.weight 	 torch.float16
dilated2_2.ext_conv2.0.weight 	 torch.float16
dilated2_2.ext_conv2.1.weight 	 torch.float16
dilated2_2.ext_conv2.1.bias 	 torch.float16
dilated2_2.ext_conv2.1.running_mean 	 torch.float16
dilated2_2.ext_conv2.1.running_var 	 torch.float16
dilated2_2.ext_conv2.1.num_batches_tracked 	 torch.int64
dilated2_2.ext_conv2.2.weight 	 torch.float16
dilated2_2.ext_conv3.0.weight 	 torch.float16
dilated2_2.ext_conv3.1.weight 	 torch.float16
dilated2_2.ext_conv3.1.bias 	 torch.float16
dilated2_2.ext_conv3.1.running_mean 	 torch.float16
dilated2_2.ext_conv3.1.running_var 	 torch.float16
dilated2_2.ext_conv3.1.num_batches_tracked 	 torch.int64
dilated2_2.ext_conv3.2.weight 	 torch.float16
dilated2_2.out_activation.weight 	 torch.float16
asymmetric2_3.ext_conv1.0.weight 	 torch.float16
asymmetric2_3.ext_conv1.1.weight 	 torch.float16
asymmetric2_3.ext_conv1.1.bias 	 torch.float16
asymmetric2_3.ext_conv1.1.running_mean 	 torch.float16
asymmetric2_3.ext_conv1.1.running_var 	 torch.float16
asymmetric2_3.ext_conv1.1.num_batches_tracked 	 torch.int64
asymmetric2_3.ext_conv1.2.weight 	 torch.float16
asymmetric2_3.ext_conv2.0.weight 	 torch.float16
asymmetric2_3.ext_conv2.1.weight 	 torch.float16
asymmetric2_3.ext_conv2.1.bias 	 torch.float16
asymmetric2_3.ext_conv2.1.running_mean 	 torch.float16
asymmetric2_3.ext_conv2.1.running_var 	 torch.float16
asymmetric2_3.ext_conv2.1.num_batches_tracked 	 torch.int64
asymmetric2_3.ext_conv2.2.weight 	 torch.float16
asymmetric2_3.ext_conv2.3.weight 	 torch.float16
asymmetric2_3.ext_conv2.4.weight 	 torch.float16
asymmetric2_3.ext_conv2.4.bias 	 torch.float16
asymmetric2_3.ext_conv2.4.running_mean 	 torch.float16
asymmetric2_3.ext_conv2.4.running_var 	 torch.float16
asymmetric2_3.ext_conv2.4.num_batches_tracked 	 torch.int64
asymmetric2_3.ext_conv2.5.weight 	 torch.float16
asymmetric2_3.ext_conv3.0.weight 	 torch.float16
asymmetric2_3.ext_conv3.1.weight 	 torch.float16
asymmetric2_3.ext_conv3.1.bias 	 torch.float16
asymmetric2_3.ext_conv3.1.running_mean 	 torch.float16
asymmetric2_3.ext_conv3.1.running_var 	 torch.float16
asymmetric2_3.ext_conv3.1.num_batches_tracked 	 torch.int64
asymmetric2_3.ext_conv3.2.weight 	 torch.float16
asymmetric2_3.out_activation.weight 	 torch.float16
dilated2_4.ext_conv1.0.weight 	 torch.float16
dilated2_4.ext_conv1.1.weight 	 torch.float16
dilated2_4.ext_conv1.1.bias 	 torch.float16
dilated2_4.ext_conv1.1.running_mean 	 torch.float16
dilated2_4.ext_conv1.1.running_var 	 torch.float16
dilated2_4.ext_conv1.1.num_batches_tracked 	 torch.int64
dilated2_4.ext_conv1.2.weight 	 torch.float16
dilated2_4.ext_conv2.0.weight 	 torch.float16
dilated2_4.ext_conv2.1.weight 	 torch.float16
dilated2_4.ext_conv2.1.bias 	 torch.float16
dilated2_4.ext_conv2.1.running_mean 	 torch.float16
dilated2_4.ext_conv2.1.running_var 	 torch.float16
dilated2_4.ext_conv2.1.num_batches_tracked 	 torch.int64
dilated2_4.ext_conv2.2.weight 	 torch.float16
dilated2_4.ext_conv3.0.weight 	 torch.float16
dilated2_4.ext_conv3.1.weight 	 torch.float16
dilated2_4.ext_conv3.1.bias 	 torch.float16
dilated2_4.ext_conv3.1.running_mean 	 torch.float16
dilated2_4.ext_conv3.1.running_var 	 torch.float16
dilated2_4.ext_conv3.1.num_batches_tracked 	 torch.int64
dilated2_4.ext_conv3.2.weight 	 torch.float16
dilated2_4.out_activation.weight 	 torch.float16
regular2_5.ext_conv1.0.weight 	 torch.float16
regular2_5.ext_conv1.1.weight 	 torch.float16
regular2_5.ext_conv1.1.bias 	 torch.float16
regular2_5.ext_conv1.1.running_mean 	 torch.float16
regular2_5.ext_conv1.1.running_var 	 torch.float16
regular2_5.ext_conv1.1.num_batches_tracked 	 torch.int64
regular2_5.ext_conv1.2.weight 	 torch.float16
regular2_5.ext_conv2.0.weight 	 torch.float16
regular2_5.ext_conv2.1.weight 	 torch.float16
regular2_5.ext_conv2.1.bias 	 torch.float16
regular2_5.ext_conv2.1.running_mean 	 torch.float16
regular2_5.ext_conv2.1.running_var 	 torch.float16
regular2_5.ext_conv2.1.num_batches_tracked 	 torch.int64
regular2_5.ext_conv2.2.weight 	 torch.float16
regular2_5.ext_conv3.0.weight 	 torch.float16
regular2_5.ext_conv3.1.weight 	 torch.float16
regular2_5.ext_conv3.1.bias 	 torch.float16
regular2_5.ext_conv3.1.running_mean 	 torch.float16
regular2_5.ext_conv3.1.running_var 	 torch.float16
regular2_5.ext_conv3.1.num_batches_tracked 	 torch.int64
regular2_5.ext_conv3.2.weight 	 torch.float16
regular2_5.out_activation.weight 	 torch.float16
dilated2_6.ext_conv1.0.weight 	 torch.float16
dilated2_6.ext_conv1.1.weight 	 torch.float16
dilated2_6.ext_conv1.1.bias 	 torch.float16
dilated2_6.ext_conv1.1.running_mean 	 torch.float16
dilated2_6.ext_conv1.1.running_var 	 torch.float16
dilated2_6.ext_conv1.1.num_batches_tracked 	 torch.int64
dilated2_6.ext_conv1.2.weight 	 torch.float16
dilated2_6.ext_conv2.0.weight 	 torch.float16
dilated2_6.ext_conv2.1.weight 	 torch.float16
dilated2_6.ext_conv2.1.bias 	 torch.float16
dilated2_6.ext_conv2.1.running_mean 	 torch.float16
dilated2_6.ext_conv2.1.running_var 	 torch.float16
dilated2_6.ext_conv2.1.num_batches_tracked 	 torch.int64
dilated2_6.ext_conv2.2.weight 	 torch.float16
dilated2_6.ext_conv3.0.weight 	 torch.float16
dilated2_6.ext_conv3.1.weight 	 torch.float16
dilated2_6.ext_conv3.1.bias 	 torch.float16
dilated2_6.ext_conv3.1.running_mean 	 torch.float16
dilated2_6.ext_conv3.1.running_var 	 torch.float16
dilated2_6.ext_conv3.1.num_batches_tracked 	 torch.int64
dilated2_6.ext_conv3.2.weight 	 torch.float16
dilated2_6.out_activation.weight 	 torch.float16
asymmetric2_7.ext_conv1.0.weight 	 torch.float16
asymmetric2_7.ext_conv1.1.weight 	 torch.float16
asymmetric2_7.ext_conv1.1.bias 	 torch.float16
asymmetric2_7.ext_conv1.1.running_mean 	 torch.float16
asymmetric2_7.ext_conv1.1.running_var 	 torch.float16
asymmetric2_7.ext_conv1.1.num_batches_tracked 	 torch.int64
asymmetric2_7.ext_conv1.2.weight 	 torch.float16
asymmetric2_7.ext_conv2.0.weight 	 torch.float16
asymmetric2_7.ext_conv2.1.weight 	 torch.float16
asymmetric2_7.ext_conv2.1.bias 	 torch.float16
asymmetric2_7.ext_conv2.1.running_mean 	 torch.float16
asymmetric2_7.ext_conv2.1.running_var 	 torch.float16
asymmetric2_7.ext_conv2.1.num_batches_tracked 	 torch.int64
asymmetric2_7.ext_conv2.2.weight 	 torch.float16
asymmetric2_7.ext_conv2.3.weight 	 torch.float16
asymmetric2_7.ext_conv2.4.weight 	 torch.float16
asymmetric2_7.ext_conv2.4.bias 	 torch.float16
asymmetric2_7.ext_conv2.4.running_mean 	 torch.float16
asymmetric2_7.ext_conv2.4.running_var 	 torch.float16
asymmetric2_7.ext_conv2.4.num_batches_tracked 	 torch.int64
asymmetric2_7.ext_conv2.5.weight 	 torch.float16
asymmetric2_7.ext_conv3.0.weight 	 torch.float16
asymmetric2_7.ext_conv3.1.weight 	 torch.float16
asymmetric2_7.ext_conv3.1.bias 	 torch.float16
asymmetric2_7.ext_conv3.1.running_mean 	 torch.float16
asymmetric2_7.ext_conv3.1.running_var 	 torch.float16
asymmetric2_7.ext_conv3.1.num_batches_tracked 	 torch.int64
asymmetric2_7.ext_conv3.2.weight 	 torch.float16
asymmetric2_7.out_activation.weight 	 torch.float16
dilated2_8.ext_conv1.0.weight 	 torch.float16
dilated2_8.ext_conv1.1.weight 	 torch.float16
dilated2_8.ext_conv1.1.bias 	 torch.float16
dilated2_8.ext_conv1.1.running_mean 	 torch.float16
dilated2_8.ext_conv1.1.running_var 	 torch.float16

dilated2_8.ext_conv1.1.num_batches_tracked 	 torch.int64
dilated2_8.ext_conv1.2.weight 	 torch.float16
dilated2_8.ext_conv2.0.weight 	 torch.float16
dilated2_8.ext_conv2.1.weight 	 torch.float16
dilated2_8.ext_conv2.1.bias 	 torch.float16
dilated2_8.ext_conv2.1.running_mean 	 torch.float16
dilated2_8.ext_conv2.1.running_var 	 torch.float16
dilated2_8.ext_conv2.1.num_batches_tracked 	 torch.int64
dilated2_8.ext_conv2.2.weight 	 torch.float16
dilated2_8.ext_conv3.0.weight 	 torch.float16
dilated2_8.ext_conv3.1.weight 	 torch.float16
dilated2_8.ext_conv3.1.bias 	 torch.float16
dilated2_8.ext_conv3.1.running_mean 	 torch.float16
dilated2_8.ext_conv3.1.running_var 	 torch.float16
dilated2_8.ext_conv3.1.num_batches_tracked 	 torch.int64
dilated2_8.ext_conv3.2.weight 	 torch.float16
dilated2_8.out_activation.weight 	 torch.float16
regular3_0.ext_conv1.0.weight 	 torch.float16
regular3_0.ext_conv1.1.weight 	 torch.float16
regular3_0.ext_conv1.1.bias 	 torch.float16
regular3_0.ext_conv1.1.running_mean 	 torch.float16
regular3_0.ext_conv1.1.running_var 	 torch.float16
regular3_0.ext_conv1.1.num_batches_tracked 	 torch.int64
regular3_0.ext_conv1.2.weight 	 torch.float16
regular3_0.ext_conv2.0.weight 	 torch.float16
regular3_0.ext_conv2.1.weight 	 torch.float16
regular3_0.ext_conv2.1.bias 	 torch.float16
regular3_0.ext_conv2.1.running_mean 	 torch.float16
regular3_0.ext_conv2.1.running_var 	 torch.float16
regular3_0.ext_conv2.1.num_batches_tracked 	 torch.int64
regular3_0.ext_conv2.2.weight 	 torch.float16
regular3_0.ext_conv3.0.weight 	 torch.float16
regular3_0.ext_conv3.1.weight 	 torch.float16
regular3_0.ext_conv3.1.bias 	 torch.float16
regular3_0.ext_conv3.1.running_mean 	 torch.float16
regular3_0.ext_conv3.1.running_var 	 torch.float16
regular3_0.ext_conv3.1.num_batches_tracked 	 torch.int64
regular3_0.ext_conv3.2.weight 	 torch.float16
regular3_0.out_activation.weight 	 torch.float16
dilated3_1.ext_conv1.0.weight 	 torch.float16
dilated3_1.ext_conv1.1.weight 	 torch.float16
dilated3_1.ext_conv1.1.bias 	 torch.float16
dilated3_1.ext_conv1.1.running_mean 	 torch.float16
dilated3_1.ext_conv1.1.running_var 	 torch.float16
dilated3_1.ext_conv1.1.num_batches_tracked 	 torch.int64
dilated3_1.ext_conv1.2.weight 	 torch.float16
dilated3_1.ext_conv2.0.weight 	 torch.float16
dilated3_1.ext_conv2.1.weight 	 torch.float16
dilated3_1.ext_conv2.1.bias 	 torch.float16
dilated3_1.ext_conv2.1.running_mean 	 torch.float16
dilated3_1.ext_conv2.1.running_var 	 torch.float16
dilated3_1.ext_conv2.1.num_batches_tracked 	 torch.int64
dilated3_1.ext_conv2.2.weight 	 torch.float16
dilated3_1.ext_conv3.0.weight 	 torch.float16
dilated3_1.ext_conv3.1.weight 	 torch.float16
dilated3_1.ext_conv3.1.bias 	 torch.float16
dilated3_1.ext_conv3.1.running_mean 	 torch.float16
dilated3_1.ext_conv3.1.running_var 	 torch.float16
dilated3_1.ext_conv3.1.num_batches_tracked 	 torch.int64
dilated3_1.ext_conv3.2.weight 	 torch.float16
dilated3_1.out_activation.weight 	 torch.float16
asymmetric3_2.ext_conv1.0.weight 	 torch.float16
asymmetric3_2.ext_conv1.1.weight 	 torch.float16
asymmetric3_2.ext_conv1.1.bias 	 torch.float16
asymmetric3_2.ext_conv1.1.running_mean 	 torch.float16
asymmetric3_2.ext_conv1.1.running_var 	 torch.float16
asymmetric3_2.ext_conv1.1.num_batches_tracked 	 torch.int64
asymmetric3_2.ext_conv1.2.weight 	 torch.float16
asymmetric3_2.ext_conv2.0.weight 	 torch.float16
asymmetric3_2.ext_conv2.1.weight 	 torch.float16
asymmetric3_2.ext_conv2.1.bias 	 torch.float16
asymmetric3_2.ext_conv2.1.running_mean 	 torch.float16
asymmetric3_2.ext_conv2.1.running_var 	 torch.float16
asymmetric3_2.ext_conv2.1.num_batches_tracked 	 torch.int64
asymmetric3_2.ext_conv2.2.weight 	 torch.float16
asymmetric3_2.ext_conv2.3.weight 	 torch.float16
asymmetric3_2.ext_conv2.4.weight 	 torch.float16
asymmetric3_2.ext_conv2.4.bias 	 torch.float16
asymmetric3_2.ext_conv2.4.running_mean 	 torch.float16
asymmetric3_2.ext_conv2.4.running_var 	 torch.float16
asymmetric3_2.ext_conv2.4.num_batches_tracked 	 torch.int64
asymmetric3_2.ext_conv2.5.weight 	 torch.float16
asymmetric3_2.ext_conv3.0.weight 	 torch.float16
asymmetric3_2.ext_conv3.1.weight 	 torch.float16
asymmetric3_2.ext_conv3.1.bias 	 torch.float16
asymmetric3_2.ext_conv3.1.running_mean 	 torch.float16
asymmetric3_2.ext_conv3.1.running_var 	 torch.float16
asymmetric3_2.ext_conv3.1.num_batches_tracked 	 torch.int64
asymmetric3_2.ext_conv3.2.weight 	 torch.float16
asymmetric3_2.out_activation.weight 	 torch.float16
dilated3_3.ext_conv1.0.weight 	 torch.float16
dilated3_3.ext_conv1.1.weight 	 torch.float16
dilated3_3.ext_conv1.1.bias 	 torch.float16
dilated3_3.ext_conv1.1.running_mean 	 torch.float16
dilated3_3.ext_conv1.1.running_var 	 torch.float16
dilated3_3.ext_conv1.1.num_batches_tracked 	 torch.int64
dilated3_3.ext_conv1.2.weight 	 torch.float16
dilated3_3.ext_conv2.0.weight 	 torch.float16
dilated3_3.ext_conv2.1.weight 	 torch.float16
dilated3_3.ext_conv2.1.bias 	 torch.float16
dilated3_3.ext_conv2.1.running_mean 	 torch.float16
dilated3_3.ext_conv2.1.running_var 	 torch.float16
dilated3_3.ext_conv2.1.num_batches_tracked 	 torch.int64
dilated3_3.ext_conv2.2.weight 	 torch.float16
dilated3_3.ext_conv3.0.weight 	 torch.float16
dilated3_3.ext_conv3.1.weight 	 torch.float16
dilated3_3.ext_conv3.1.bias 	 torch.float16
dilated3_3.ext_conv3.1.running_mean 	 torch.float16
dilated3_3.ext_conv3.1.running_var 	 torch.float16
dilated3_3.ext_conv3.1.num_batches_tracked 	 torch.int64
dilated3_3.ext_conv3.2.weight 	 torch.float16
dilated3_3.out_activation.weight 	 torch.float16
regular3_4.ext_conv1.0.weight 	 torch.float16
regular3_4.ext_conv1.1.weight 	 torch.float16
regular3_4.ext_conv1.1.bias 	 torch.float16
regular3_4.ext_conv1.1.running_mean 	 torch.float16
regular3_4.ext_conv1.1.running_var 	 torch.float16
regular3_4.ext_conv1.1.num_batches_tracked 	 torch.int64
regular3_4.ext_conv1.2.weight 	 torch.float16
regular3_4.ext_conv2.0.weight 	 torch.float16
regular3_4.ext_conv2.1.weight 	 torch.float16
regular3_4.ext_conv2.1.bias 	 torch.float16
regular3_4.ext_conv2.1.running_mean 	 torch.float16
regular3_4.ext_conv2.1.running_var 	 torch.float16
regular3_4.ext_conv2.1.num_batches_tracked 	 torch.int64
regular3_4.ext_conv2.2.weight 	 torch.float16
regular3_4.ext_conv3.0.weight 	 torch.float16
regular3_4.ext_conv3.1.weight 	 torch.float16
regular3_4.ext_conv3.1.bias 	 torch.float16
regular3_4.ext_conv3.1.running_mean 	 torch.float16
regular3_4.ext_conv3.1.running_var 	 torch.float16
regular3_4.ext_conv3.1.num_batches_tracked 	 torch.int64
regular3_4.ext_conv3.2.weight 	 torch.float16
regular3_4.out_activation.weight 	 torch.float16
dilated3_5.ext_conv1.0.weight 	 torch.float16
dilated3_5.ext_conv1.1.weight 	 torch.float16
dilated3_5.ext_conv1.1.bias 	 torch.float16
dilated3_5.ext_conv1.1.running_mean 	 torch.float16
dilated3_5.ext_conv1.1.running_var 	 torch.float16
dilated3_5.ext_conv1.1.num_batches_tracked 	 torch.int64
dilated3_5.ext_conv1.2.weight 	 torch.float16
dilated3_5.ext_conv2.0.weight 	 torch.float16
dilated3_5.ext_conv2.1.weight 	 torch.float16
dilated3_5.ext_conv2.1.bias 	 torch.float16
dilated3_5.ext_conv2.1.running_mean 	 torch.float16
dilated3_5.ext_conv2.1.running_var 	 torch.float16
dilated3_5.ext_conv2.1.num_batches_tracked 	 torch.int64
dilated3_5.ext_conv2.2.weight 	 torch.float16
dilated3_5.ext_conv3.0.weight 	 torch.float16
dilated3_5.ext_conv3.1.weight 	 torch.float16
dilated3_5.ext_conv3.1.bias 	 torch.float16
dilated3_5.ext_conv3.1.running_mean 	 torch.float16
dilated3_5.ext_conv3.1.running_var 	 torch.float16
dilated3_5.ext_conv3.1.num_batches_tracked 	 torch.int64
dilated3_5.ext_conv3.2.weight 	 torch.float16
dilated3_5.out_activation.weight 	 torch.float16
asymmetric3_6.ext_conv1.0.weight 	 torch.float16
asymmetric3_6.ext_conv1.1.weight 	 torch.float16
asymmetric3_6.ext_conv1.1.bias 	 torch.float16
asymmetric3_6.ext_conv1.1.running_mean 	 torch.float16
asymmetric3_6.ext_conv1.1.running_var 	 torch.float16
asymmetric3_6.ext_conv1.1.num_batches_tracked 	 torch.int64
asymmetric3_6.ext_conv1.2.weight 	 torch.float16
asymmetric3_6.ext_conv2.0.weight 	 torch.float16
asymmetric3_6.ext_conv2.1.weight 	 torch.float16
asymmetric3_6.ext_conv2.1.bias 	 torch.float16
asymmetric3_6.ext_conv2.1.running_mean 	 torch.float16
asymmetric3_6.ext_conv2.1.running_var 	 torch.float16
asymmetric3_6.ext_conv2.1.num_batches_tracked 	 torch.int64
asymmetric3_6.ext_conv2.2.weight 	 torch.float16
asymmetric3_6.ext_conv2.3.weight 	 torch.float16
asymmetric3_6.ext_conv2.4.weight 	 torch.float16
asymmetric3_6.ext_conv2.4.bias 	 torch.float16
asymmetric3_6.ext_conv2.4.running_mean 	 torch.float16
asymmetric3_6.ext_conv2.4.running_var 	 torch.float16
asymmetric3_6.ext_conv2.4.num_batches_tracked 	 torch.int64
asymmetric3_6.ext_conv2.5.weight 	 torch.float16
asymmetric3_6.ext_conv3.0.weight 	 torch.float16
asymmetric3_6.ext_conv3.1.weight 	 torch.float16
asymmetric3_6.ext_conv3.1.bias 	 torch.float16
asymmetric3_6.ext_conv3.1.running_mean 	 torch.float16
asymmetric3_6.ext_conv3.1.running_var 	 torch.float16
asymmetric3_6.ext_conv3.1.num_batches_tracked 	 torch.int64
asymmetric3_6.ext_conv3.2.weight 	 torch.float16
asymmetric3_6.out_activation.weight 	 torch.float16
dilated3_7.ext_conv1.0.weight 	 torch.float16
dilated3_7.ext_conv1.1.weight 	 torch.float16
dilated3_7.ext_conv1.1.bias 	 torch.float16
dilated3_7.ext_conv1.1.running_mean 	 torch.float16
dilated3_7.ext_conv1.1.running_var 	 torch.float16
dilated3_7.ext_conv1.1.num_batches_tracked 	 torch.int64
dilated3_7.ext_conv1.2.weight 	 torch.float16
dilated3_7.ext_conv2.0.weight 	 torch.float16
dilated3_7.ext_conv2.1.weight 	 torch.float16
dilated3_7.ext_conv2.1.bias 	 torch.float16
dilated3_7.ext_conv2.1.running_mean 	 torch.float16
dilated3_7.ext_conv2.1.running_var 	 torch.float16
dilated3_7.ext_conv2.1.num_batches_tracked 	 torch.int64
dilated3_7.ext_conv2.2.weight 	 torch.float16
dilated3_7.ext_conv3.0.weight 	 torch.float16
dilated3_7.ext_conv3.1.weight 	 torch.float16
dilated3_7.ext_conv3.1.bias 	 torch.float16
dilated3_7.ext_conv3.1.running_mean 	 torch.float16
dilated3_7.ext_conv3.1.running_var 	 torch.float16
dilated3_7.ext_conv3.1.num_batches_tracked 	 torch.int64
dilated3_7.ext_conv3.2.weight 	 torch.float16
dilated3_7.out_activation.weight 	 torch.float16
upsample4_0.main_conv1.0.weight 	 torch.float16
upsample4_0.main_conv1.1.weight 	 torch.float16
upsample4_0.main_conv1.1.bias 	 torch.float16
upsample4_0.main_conv1.1.running_mean 	 torch.float16
upsample4_0.main_conv1.1.running_var 	 torch.float16
upsample4_0.main_conv1.1.num_batches_tracked 	 torch.int64
upsample4_0.ext_conv1.0.weight 	 torch.float16
upsample4_0.ext_conv1.1.weight 	 torch.float16
upsample4_0.ext_conv1.1.bias 	 torch.float16
upsample4_0.ext_conv1.1.running_mean 	 torch.float16
upsample4_0.ext_conv1.1.running_var 	 torch.float16
upsample4_0.ext_conv1.1.num_batches_tracked 	 torch.int64
upsample4_0.ext_tconv1.weight 	 torch.float16
upsample4_0.ext_tconv1_bnorm.weight 	 torch.float16
upsample4_0.ext_tconv1_bnorm.bias 	 torch.float16
upsample4_0.ext_tconv1_bnorm.running_mean 	 torch.float16
upsample4_0.ext_tconv1_bnorm.running_var 	 torch.float16
upsample4_0.ext_tconv1_bnorm.num_batches_tracked 	 torch.int64
upsample4_0.ext_conv2.0.weight 	 torch.float16
upsample4_0.ext_conv2.1.weight 	 torch.float16
upsample4_0.ext_conv2.1.bias 	 torch.float16
upsample4_0.ext_conv2.1.running_mean 	 torch.float16
upsample4_0.ext_conv2.1.running_var 	 torch.float16
upsample4_0.ext_conv2.1.num_batches_tracked 	 torch.int64
regular4_1.ext_conv1.0.weight 	 torch.float16
regular4_1.ext_conv1.1.weight 	 torch.float16
regular4_1.ext_conv1.1.bias 	 torch.float16
regular4_1.ext_conv1.1.running_mean 	 torch.float16
regular4_1.ext_conv1.1.running_var 	 torch.float16
regular4_1.ext_conv1.1.num_batches_tracked 	 torch.int64
regular4_1.ext_conv2.0.weight 	 torch.float16
regular4_1.ext_conv2.1.weight 	 torch.float16
regular4_1.ext_conv2.1.bias 	 torch.float16
regular4_1.ext_conv2.1.running_mean 	 torch.float16
regular4_1.ext_conv2.1.running_var 	 torch.float16
regular4_1.ext_conv2.1.num_batches_tracked 	 torch.int64
regular4_1.ext_conv3.0.weight 	 torch.float16
regular4_1.ext_conv3.1.weight 	 torch.float16
regular4_1.ext_conv3.1.bias 	 torch.float16
regular4_1.ext_conv3.1.running_mean 	 torch.float16
regular4_1.ext_conv3.1.running_var 	 torch.float16
regular4_1.ext_conv3.1.num_batches_tracked 	 torch.int64
regular4_2.ext_conv1.0.weight 	 torch.float16
regular4_2.ext_conv1.1.weight 	 torch.float16
regular4_2.ext_conv1.1.bias 	 torch.float16
regular4_2.ext_conv1.1.running_mean 	 torch.float16
regular4_2.ext_conv1.1.running_var 	 torch.float16
regular4_2.ext_conv1.1.num_batches_tracked 	 torch.int64
regular4_2.ext_conv2.0.weight 	 torch.float16
regular4_2.ext_conv2.1.weight 	 torch.float16
regular4_2.ext_conv2.1.bias 	 torch.float16
regular4_2.ext_conv2.1.running_mean 	 torch.float16
regular4_2.ext_conv2.1.running_var 	 torch.float16
regular4_2.ext_conv2.1.num_batches_tracked 	 torch.int64
regular4_2.ext_conv3.0.weight 	 torch.float16
regular4_2.ext_conv3.1.weight 	 torch.float16
regular4_2.ext_conv3.1.bias 	 torch.float16
regular4_2.ext_conv3.1.running_mean 	 torch.float16
regular4_2.ext_conv3.1.running_var 	 torch.float16
regular4_2.ext_conv3.1.num_batches_tracked 	 torch.int64
upsample5_0.main_conv1.0.weight 	 torch.float16
upsample5_0.main_conv1.1.weight 	 torch.float16
upsample5_0.main_conv1.1.bias 	 torch.float16
upsample5_0.main_conv1.1.running_mean 	 torch.float16
upsample5_0.main_conv1.1.running_var 	 torch.float16
upsample5_0.main_conv1.1.num_batches_tracked 	 torch.int64
upsample5_0.ext_conv1.0.weight 	 torch.float16
upsample5_0.ext_conv1.1.weight 	 torch.float16
upsample5_0.ext_conv1.1.bias 	 torch.float16
upsample5_0.ext_conv1.1.running_mean 	 torch.float16
upsample5_0.ext_conv1.1.running_var 	 torch.float16
upsample5_0.ext_conv1.1.num_batches_tracked 	 torch.int64
upsample5_0.ext_tconv1.weight 	 torch.float16
upsample5_0.ext_tconv1_bnorm.weight 	 torch.float16
upsample5_0.ext_tconv1_bnorm.bias 	 torch.float16
upsample5_0.ext_tconv1_bnorm.running_mean 	 torch.float16
upsample5_0.ext_tconv1_bnorm.running_var 	 torch.float16
upsample5_0.ext_tconv1_bnorm.num_batches_tracked 	 torch.int64
upsample5_0.ext_conv2.0.weight 	 torch.float16
upsample5_0.ext_conv2.1.weight 	 torch.float16
upsample5_0.ext_conv2.1.bias 	 torch.float16
upsample5_0.ext_conv2.1.running_mean 	 torch.float16
upsample5_0.ext_conv2.1.running_var 	 torch.float16
upsample5_0.ext_conv2.1.num_batches_tracked 	 torch.int64
regular5_1.ext_conv1.0.weight 	 torch.float16
regular5_1.ext_conv1.1.weight 	 torch.float16
regular5_1.ext_conv1.1.bias 	 torch.float16
regular5_1.ext_conv1.1.running_mean 	 torch.float16
regular5_1.ext_conv1.1.running_var 	 torch.float16
regular5_1.ext_conv1.1.num_batches_tracked 	 torch.int64
regular5_1.ext_conv2.0.weight 	 torch.float16
regular5_1.ext_conv2.1.weight 	 torch.float16

regular5_1.ext_conv2.1.bias 	 torch.float16
regular5_1.ext_conv2.1.running_mean 	 torch.float16
regular5_1.ext_conv2.1.running_var 	 torch.float16
regular5_1.ext_conv2.1.num_batches_tracked 	 torch.int64
regular5_1.ext_conv3.0.weight 	 torch.float16
regular5_1.ext_conv3.1.weight 	 torch.float16
regular5_1.ext_conv3.1.bias 	 torch.float16
regular5_1.ext_conv3.1.running_mean 	 torch.float16
regular5_1.ext_conv3.1.running_var 	 torch.float16
regular5_1.ext_conv3.1.num_batches_tracked 	 torch.int64
transposed_conv.weight 	 torch.float16

Could you post an executable code snippet using random inputs so that we could reproduce this issue?

@ptrblck here is everything you need to reproduce the issue. The repo is very minimalistic (2.4mb) and specifily created to reproduce the issue. I assume you have all the requirements allready fulfilled (pillow, torch and numpy).

I am going to remove the repo as I found the solution to my issue.

I found the what caused my problem. The default dtype of my model must have been float. Setting the following line above my model definition solved the problem:

torch.set_default_dtype(torch.float16)

It was a little tricky to notice it, but it does make sense. Thank you @ptrblck for the discussion and your questions. In the end they were responsible for me finding the solution.

Pytorch tensor arguments with dtype=pytorch.bool caused the CrossEntropyLoss function to throw an error. Neither the .long() casting method nor the other suggestions here helped:
RuntimeError: “log_softmax_lastdim_kernel_impl” not implemented for ‘Long’
But adding 0 to the arguments is a workaround, which is strange.

hi @ptrblck , Just wanted to know what is the ideal way to do this casting between model params. I keep getting this error in backprop although the loss is in float64 format.

tot_loss = reconstruct_loss + args.w_reg * reg_loss + args.a_reg * k_loss
tot_loss = tot_loss.double()

TOT LOSS: tensor(-3.3673e+15, grad_fn=)
Traceback (most recent call last):
File “AE.py”, line 305, in
tot_loss.backward()
File “/usr/local/lib/python3.7/dist-packages/torch/_tensor.py”, line 363, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
File “/usr/local/lib/python3.7/dist-packages/torch/autograd/init.py”, line 175, in backward
allow_unreachable=True, accumulate_grad=True) # Calls into the C++ engine to run the backward pass
RuntimeError: Found dtype Float but expected Double

That’s strange as the error seems to be raised in the backward pass (while the forward works). Could you post a minimal, executable code snippet to reproduce the issue, please?

Sure, thanks.


import torch
import torch.nn as nn
from torchvision.utils import save_image
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import argparse
import time
import numpy as np
import matplotlib.pyplot as plt
import math
from scipy import stats
import scipy
import os
import datetime
from scipy.stats import gaussian_kde
from math import sqrt
from math import log
from torch import optim
from torch.autograd import Variable
from math import sqrt
from math import log
from sklearn.neighbors import KernelDensity


# parse input data
parser = argparse.ArgumentParser()
parser.add_argument("--code_size", default=20, help="size of the code", type=int)
parser.add_argument("--w_reg", default=0.001, help="weight of the regularization in the loss function", type=float)
parser.add_argument("--a_reg", default=0.2, help="weight of the kernel alignment", type=float)
parser.add_argument("--num_epochs", default=5000, help="number of epochs in training", type=int)
parser.add_argument("--batch_size", default=25, help="number of samples in each batch", type=int)
parser.add_argument("--max_gradient_norm", default=1.0, help="max gradient norm for gradient clipping", type=float)
parser.add_argument("--learning_rate", default=0.001, help="Adam initial learning rate", type=float)
parser.add_argument("--hidden_size", default=30, help="size of the code", type=int)
args = parser.parse_args()
print(args)

# ================= DATASET =================
# (train_data, train_labels, train_len, _, K_tr,
#  valid_data, _, valid_len, _, K_vs,
#  test_data_orig, test_labels, test_len, _, K_ts) = getBlood(kernel='TCK',
#                                                             inp='zero')  # data shape is [T, N, V] = [time_steps, num_elements, num_var]

train_data = np.random.rand(9000,6)
train_labels = np.ones([9000,1])
train_len = 9000

valid_data = np.random.rand(9000,6)
valid_len = 9000

test_data = np.random.rand(1500,6)
test_labels = np.ones([1500,1])

K_tr = np.random.rand(9000,9000)
K_ts = np.random.rand(1500,1500)
K_vs =  np.random.rand(9000,9000)

#test_data = test_data_orig

print(
    '\n**** Processing Blood data: Tr{}, Vs{}, Ts{} ****\n'.format(train_data.shape, valid_data.shape, test_data.shape))

input_length = train_data.shape[1]  # same for all inputs

# ================= GRAPH =================

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

encoder_inputs = train_data
prior_k = K_tr

# ============= TENSORBOARD =============
writer = SummaryWriter()

# # ----- ENCODER -----

input_length = encoder_inputs.shape[1]
print("INPUT ")


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()

        self.We1 = torch.nn.Parameter(
            torch.Tensor(input_length, args.hidden_size).uniform_(-1.0 / math.sqrt(input_length),
                                                                  1.0 / math.sqrt(input_length)))
        self.We2 = torch.nn.Parameter(
            torch.Tensor(args.hidden_size, args.code_size).uniform_(-1.0 / math.sqrt(args.hidden_size),
                                                                    1.0 / math.sqrt(args.hidden_size)))

        self.be1 = torch.nn.Parameter(torch.zeros([args.hidden_size]))
        self.be2 = torch.nn.Parameter(torch.zeros([args.code_size]))

    def encoder(self, encoder_inputs):
        hidden_1 = torch.tanh(torch.matmul(encoder_inputs.float(), self.We1) + self.be1)
        code = torch.tanh(torch.matmul(hidden_1, self.We2) + self.be2)
        # print ("CODE ENCODER SHAPE:", code.size())
        return code


def decoder(code):
    Wd1 = torch.nn.Parameter(
        torch.Tensor(args.code_size, args.hidden_size).uniform_(-1.0 / math.sqrt(args.code_size),
                                                                1.0 / math.sqrt(args.code_size)))
    Wd2 = torch.nn.Parameter(
        torch.Tensor(args.hidden_size, input_length).uniform_(-1.0 / math.sqrt(args.hidden_size),
                                                              1.0 / math.sqrt(args.hidden_size)))

    bd1 = torch.nn.Parameter(torch.zeros([args.hidden_size]))
    bd2 = torch.nn.Parameter(torch.zeros([input_length]))

    hidden_2 = torch.tanh(torch.matmul(code, Wd1) + bd1)

    dec_out = torch.matmul(hidden_2, Wd2) + bd2

    return dec_out


def kernel_loss(code, prior_K):
    # kernel on codes
    code_K = torch.mm(code, torch.t(code))

    # ----- LOSS -----
   # kernel alignment loss with KL divergence loss

    kl_loss = torch.nn.KLDivLoss(reduction="batchmean")
    k_loss = kl_loss(code_K, prior_K)
    return k_loss


# Initialize model
model = Model()

# trainable parameters count
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('Total parameters: {}'.format(total_params))

# Optimizer
optimizer = torch.optim.Adam(model.parameters(), args.learning_rate)

# ================= TRAINING =================

# initialize training variables
time_tr_start = time.time()
batch_size = args.batch_size
max_batches = train_data.shape[0] // batch_size
loss_track = []
kloss_track = []
min_vs_loss = np.infty
model_dir = "logs/m_0.ckpt"

logdir = os.path.join("logs", datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))

###############################################################################
# Training code
###############################################################################

try:
    for ep in range(args.num_epochs):

        # shuffle training data
        idx = np.random.permutation(train_data.shape[0])
        train_data_s = train_data[idx, :]
        K_tr_s = K_tr[idx, :][:, idx]

        for batch in range(max_batches):
            fdtr = {}
            fdtr["encoder_inputs"] = train_data_s[(batch) * batch_size:(batch + 1) * batch_size, :]
            fdtr["prior_K"] = K_tr_s[(batch) * batch_size:(batch + 1) * batch_size,
                              (batch) * batch_size:(batch + 1) * batch_size]

            encoder_inputs = (fdtr["encoder_inputs"].astype(float))
            encoder_inputs = torch.from_numpy(encoder_inputs)

            prior_K = (fdtr["prior_K"].astype(float))
            prior_K = torch.from_numpy(prior_K)

            code = model.encoder(encoder_inputs)
            dec_out = decoder(code)

            reconstruct_loss = torch.mean((dec_out - encoder_inputs) ** 2)
            reconstruct_loss = reconstruct_loss.float()
            # print("RECONS LOSS TRAIN:", reconstruct_loss)

            k_loss = kernel_loss(code, prior_K)
            k_loss = k_loss.float()

            # Regularization L2 loss
            reg_loss = 0

            parameters = torch.nn.utils.parameters_to_vector(model.parameters())
            # print ("PARAMS:", (parameters))
            for tf_var in parameters:
                reg_loss += torch.mean(torch.linalg.norm(tf_var))

            tot_loss = reconstruct_loss + args.w_reg * reg_loss + args.a_reg * k_loss
            tot_loss = tot_loss.double()

            # Backpropagation
            optimizer.zero_grad()
            tot_loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_gradient_norm)
            optimizer.step()

            loss_track.append(reconstruct_loss)
            kloss_track.append(k_loss)

           print('VS r_loss=%.3f, k_loss=%.3f -- TR r_loss=%.3f, k_loss=%.3f' % (
                reconstruct_loss, k_loss, torch.mean(torch.stack(loss_track[-10:])),
                torch.mean(torch.stack(kloss_track[-10:]))))

except KeyboardInterrupt:
    print('training interrupted')

time_tr_end = time.time()
print('Tot training time: {}'.format((time_tr_end - time_tr_start) // 60))

writer.close()




Please run the code as:

!python3 file.py --code_size 9 --w_reg 0.001 --a_reg 0.1 --num_epochs 100 --max_gradient_norm 0.5 --learning_rate 0.001 --hidden_size 30 

Thanks for the code snippet!
The issue is raised in KLDivLoss as it doesn’t seem to accept mixed dtypes:

kl_loss = nn.KLDivLoss(reduction="batchmean")
input = torch.log_softmax(torch.randn(3, 5, requires_grad=True), dim=1)
target = torch.softmax(torch.rand(3, 5), dim=1).double()
output = kl_loss(input, target)
output.backward()
# RuntimeError: Found dtype Float but expected Double

Could you create an issue on GitHub so that we can track and fix it, please?

Sure. The issue is raised.

Backpropagation not working on KL divergence loss function due to data type mismatch #80158

1 Like

Thanks for the explanation! helped me a lot