If I understand your use case correctly, you are transforming the input data and model parameters to FP16 manually without using automatic mixed-precision training?
If so, could you post the error message which shows which operation is raising the issue?
That is correct. This is how I did it. Here comes the error message:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-11-1d79e07828d2> in <module>
34 with torch.no_grad():
35 inference_start_time = time.time() * 1000
---> 36 prediction = model(frame)
37 inference_end_time = time.time() * 1000
38
~/anaconda3/envs/enet-pytorch/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
720 result = self._slow_forward(*input, **kwargs)
721 else:
--> 722 result = self.forward(*input, **kwargs)
723 for hook in itertools.chain(
724 _global_forward_hooks.values(),
~/PyTorch-ENet/models/enet.py in forward(self, x)
590 # Stage 1 - Encoder
591 stage1_input_size = x.size()
--> 592 x, max_indices1_0 = self.downsample1_0(x)
593 x = self.regular1_1(x)
594 x = self.regular1_2(x)
~/anaconda3/envs/enet-pytorch/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
720 result = self._slow_forward(*input, **kwargs)
721 else:
--> 722 result = self.forward(*input, **kwargs)
723 for hook in itertools.chain(
724 _global_forward_hooks.values(),
~/PyTorch-ENet/models/enet.py in forward(self, x)
345 out = main + ext
346
--> 347 return self.out_activation(out), max_indices
348
349
~/anaconda3/envs/enet-pytorch/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
720 result = self._slow_forward(*input, **kwargs)
721 else:
--> 722 result = self.forward(*input, **kwargs)
723 for hook in itertools.chain(
724 _global_forward_hooks.values(),
~/anaconda3/envs/enet-pytorch/lib/python3.7/site-packages/torch/nn/modules/activation.py in forward(self, input)
986
987 def forward(self, input: Tensor) -> Tensor:
--> 988 return F.prelu(input, self.weight)
989
990 def extra_repr(self) -> str:
~/anaconda3/envs/enet-pytorch/lib/python3.7/site-packages/torch/nn/functional.py in prelu(input, weight)
1317 if type(input) is not Tensor and has_torch_function((input,)):
1318 return handle_torch_function(prelu, (input,), input, weight)
-> 1319 return torch.prelu(input, weight)
1320
1321
RuntimeError: expected scalar type Float but found Half
PyTorch version: 1.6
Python: 3.7.8
Thanks for the stacktrace.
It seems nn.PReLU()
is causing the error, which I cannot reproduce using 1.7.0.dev20200830
:
prelu = nn.PReLU().cuda().half()
x = torch.randn(10, 10).cuda().half()
out = prelu(x)
Could you update to the latest nightly binary and rerun the code?
I was able to run the code snippet of yours in both 1.6
and 1.7.0.dev20200830
. Unfortunately remains my problem the same with the nighty version.
Just for reconfirmation. The following output shows the that the conversion to float16 was successful, right? Is there something else that could be checked?
for param_tensor in model.state_dict():
print(param_tensor, "\t", model.state_dict()[param_tensor].dtype)
initial_block.main_branch.weight torch.float16
initial_block.batch_norm.weight torch.float16
initial_block.batch_norm.bias torch.float16
initial_block.batch_norm.running_mean torch.float16
initial_block.batch_norm.running_var torch.float16
initial_block.batch_norm.num_batches_tracked torch.int64
initial_block.out_activation.weight torch.float16
downsample1_0.ext_conv1.0.weight torch.float16
downsample1_0.ext_conv1.1.weight torch.float16
downsample1_0.ext_conv1.1.bias torch.float16
downsample1_0.ext_conv1.1.running_mean torch.float16
downsample1_0.ext_conv1.1.running_var torch.float16
downsample1_0.ext_conv1.1.num_batches_tracked torch.int64
downsample1_0.ext_conv1.2.weight torch.float16
downsample1_0.ext_conv2.0.weight torch.float16
downsample1_0.ext_conv2.1.weight torch.float16
downsample1_0.ext_conv2.1.bias torch.float16
downsample1_0.ext_conv2.1.running_mean torch.float16
downsample1_0.ext_conv2.1.running_var torch.float16
downsample1_0.ext_conv2.1.num_batches_tracked torch.int64
downsample1_0.ext_conv2.2.weight torch.float16
downsample1_0.ext_conv3.0.weight torch.float16
downsample1_0.ext_conv3.1.weight torch.float16
downsample1_0.ext_conv3.1.bias torch.float16
downsample1_0.ext_conv3.1.running_mean torch.float16
downsample1_0.ext_conv3.1.running_var torch.float16
downsample1_0.ext_conv3.1.num_batches_tracked torch.int64
downsample1_0.ext_conv3.2.weight torch.float16
downsample1_0.out_activation.weight torch.float16
regular1_1.ext_conv1.0.weight torch.float16
regular1_1.ext_conv1.1.weight torch.float16
regular1_1.ext_conv1.1.bias torch.float16
regular1_1.ext_conv1.1.running_mean torch.float16
regular1_1.ext_conv1.1.running_var torch.float16
regular1_1.ext_conv1.1.num_batches_tracked torch.int64
regular1_1.ext_conv1.2.weight torch.float16
regular1_1.ext_conv2.0.weight torch.float16
regular1_1.ext_conv2.1.weight torch.float16
regular1_1.ext_conv2.1.bias torch.float16
regular1_1.ext_conv2.1.running_mean torch.float16
regular1_1.ext_conv2.1.running_var torch.float16
regular1_1.ext_conv2.1.num_batches_tracked torch.int64
regular1_1.ext_conv2.2.weight torch.float16
regular1_1.ext_conv3.0.weight torch.float16
regular1_1.ext_conv3.1.weight torch.float16
regular1_1.ext_conv3.1.bias torch.float16
regular1_1.ext_conv3.1.running_mean torch.float16
regular1_1.ext_conv3.1.running_var torch.float16
regular1_1.ext_conv3.1.num_batches_tracked torch.int64
regular1_1.ext_conv3.2.weight torch.float16
regular1_1.out_activation.weight torch.float16
regular1_2.ext_conv1.0.weight torch.float16
regular1_2.ext_conv1.1.weight torch.float16
regular1_2.ext_conv1.1.bias torch.float16
regular1_2.ext_conv1.1.running_mean torch.float16
regular1_2.ext_conv1.1.running_var torch.float16
regular1_2.ext_conv1.1.num_batches_tracked torch.int64
regular1_2.ext_conv1.2.weight torch.float16
regular1_2.ext_conv2.0.weight torch.float16
regular1_2.ext_conv2.1.weight torch.float16
regular1_2.ext_conv2.1.bias torch.float16
regular1_2.ext_conv2.1.running_mean torch.float16
regular1_2.ext_conv2.1.running_var torch.float16
regular1_2.ext_conv2.1.num_batches_tracked torch.int64
regular1_2.ext_conv2.2.weight torch.float16
regular1_2.ext_conv3.0.weight torch.float16
regular1_2.ext_conv3.1.weight torch.float16
regular1_2.ext_conv3.1.bias torch.float16
regular1_2.ext_conv3.1.running_mean torch.float16
regular1_2.ext_conv3.1.running_var torch.float16
regular1_2.ext_conv3.1.num_batches_tracked torch.int64
regular1_2.ext_conv3.2.weight torch.float16
regular1_2.out_activation.weight torch.float16
regular1_3.ext_conv1.0.weight torch.float16
regular1_3.ext_conv1.1.weight torch.float16
regular1_3.ext_conv1.1.bias torch.float16
regular1_3.ext_conv1.1.running_mean torch.float16
regular1_3.ext_conv1.1.running_var torch.float16
regular1_3.ext_conv1.1.num_batches_tracked torch.int64
regular1_3.ext_conv1.2.weight torch.float16
regular1_3.ext_conv2.0.weight torch.float16
regular1_3.ext_conv2.1.weight torch.float16
regular1_3.ext_conv2.1.bias torch.float16
regular1_3.ext_conv2.1.running_mean torch.float16
regular1_3.ext_conv2.1.running_var torch.float16
regular1_3.ext_conv2.1.num_batches_tracked torch.int64
regular1_3.ext_conv2.2.weight torch.float16
regular1_3.ext_conv3.0.weight torch.float16
regular1_3.ext_conv3.1.weight torch.float16
regular1_3.ext_conv3.1.bias torch.float16
regular1_3.ext_conv3.1.running_mean torch.float16
regular1_3.ext_conv3.1.running_var torch.float16
regular1_3.ext_conv3.1.num_batches_tracked torch.int64
regular1_3.ext_conv3.2.weight torch.float16
regular1_3.out_activation.weight torch.float16
regular1_4.ext_conv1.0.weight torch.float16
regular1_4.ext_conv1.1.weight torch.float16
regular1_4.ext_conv1.1.bias torch.float16
regular1_4.ext_conv1.1.running_mean torch.float16
regular1_4.ext_conv1.1.running_var torch.float16
regular1_4.ext_conv1.1.num_batches_tracked torch.int64
regular1_4.ext_conv1.2.weight torch.float16
regular1_4.ext_conv2.0.weight torch.float16
regular1_4.ext_conv2.1.weight torch.float16
regular1_4.ext_conv2.1.bias torch.float16
regular1_4.ext_conv2.1.running_mean torch.float16
regular1_4.ext_conv2.1.running_var torch.float16
regular1_4.ext_conv2.1.num_batches_tracked torch.int64
regular1_4.ext_conv2.2.weight torch.float16
regular1_4.ext_conv3.0.weight torch.float16
regular1_4.ext_conv3.1.weight torch.float16
regular1_4.ext_conv3.1.bias torch.float16
regular1_4.ext_conv3.1.running_mean torch.float16
regular1_4.ext_conv3.1.running_var torch.float16
regular1_4.ext_conv3.1.num_batches_tracked torch.int64
regular1_4.ext_conv3.2.weight torch.float16
regular1_4.out_activation.weight torch.float16
downsample2_0.ext_conv1.0.weight torch.float16
downsample2_0.ext_conv1.1.weight torch.float16
downsample2_0.ext_conv1.1.bias torch.float16
downsample2_0.ext_conv1.1.running_mean torch.float16
downsample2_0.ext_conv1.1.running_var torch.float16
downsample2_0.ext_conv1.1.num_batches_tracked torch.int64
downsample2_0.ext_conv1.2.weight torch.float16
downsample2_0.ext_conv2.0.weight torch.float16
downsample2_0.ext_conv2.1.weight torch.float16
downsample2_0.ext_conv2.1.bias torch.float16
downsample2_0.ext_conv2.1.running_mean torch.float16
downsample2_0.ext_conv2.1.running_var torch.float16
downsample2_0.ext_conv2.1.num_batches_tracked torch.int64
downsample2_0.ext_conv2.2.weight torch.float16
downsample2_0.ext_conv3.0.weight torch.float16
downsample2_0.ext_conv3.1.weight torch.float16
downsample2_0.ext_conv3.1.bias torch.float16
downsample2_0.ext_conv3.1.running_mean torch.float16
downsample2_0.ext_conv3.1.running_var torch.float16
downsample2_0.ext_conv3.1.num_batches_tracked torch.int64
downsample2_0.ext_conv3.2.weight torch.float16
downsample2_0.out_activation.weight torch.float16
regular2_1.ext_conv1.0.weight torch.float16
regular2_1.ext_conv1.1.weight torch.float16
regular2_1.ext_conv1.1.bias torch.float16
regular2_1.ext_conv1.1.running_mean torch.float16
regular2_1.ext_conv1.1.running_var torch.float16
regular2_1.ext_conv1.1.num_batches_tracked torch.int64
regular2_1.ext_conv1.2.weight torch.float16
regular2_1.ext_conv2.0.weight torch.float16
regular2_1.ext_conv2.1.weight torch.float16
regular2_1.ext_conv2.1.bias torch.float16
regular2_1.ext_conv2.1.running_mean torch.float16
regular2_1.ext_conv2.1.running_var torch.float16
regular2_1.ext_conv2.1.num_batches_tracked torch.int64
regular2_1.ext_conv2.2.weight torch.float16
regular2_1.ext_conv3.0.weight torch.float16
regular2_1.ext_conv3.1.weight torch.float16
regular2_1.ext_conv3.1.bias torch.float16
regular2_1.ext_conv3.1.running_mean torch.float16
regular2_1.ext_conv3.1.running_var torch.float16
regular2_1.ext_conv3.1.num_batches_tracked torch.int64
regular2_1.ext_conv3.2.weight torch.float16
regular2_1.out_activation.weight torch.float16
dilated2_2.ext_conv1.0.weight torch.float16
dilated2_2.ext_conv1.1.weight torch.float16
dilated2_2.ext_conv1.1.bias torch.float16
dilated2_2.ext_conv1.1.running_mean torch.float16
dilated2_2.ext_conv1.1.running_var torch.float16
dilated2_2.ext_conv1.1.num_batches_tracked torch.int64
dilated2_2.ext_conv1.2.weight torch.float16
dilated2_2.ext_conv2.0.weight torch.float16
dilated2_2.ext_conv2.1.weight torch.float16
dilated2_2.ext_conv2.1.bias torch.float16
dilated2_2.ext_conv2.1.running_mean torch.float16
dilated2_2.ext_conv2.1.running_var torch.float16
dilated2_2.ext_conv2.1.num_batches_tracked torch.int64
dilated2_2.ext_conv2.2.weight torch.float16
dilated2_2.ext_conv3.0.weight torch.float16
dilated2_2.ext_conv3.1.weight torch.float16
dilated2_2.ext_conv3.1.bias torch.float16
dilated2_2.ext_conv3.1.running_mean torch.float16
dilated2_2.ext_conv3.1.running_var torch.float16
dilated2_2.ext_conv3.1.num_batches_tracked torch.int64
dilated2_2.ext_conv3.2.weight torch.float16
dilated2_2.out_activation.weight torch.float16
asymmetric2_3.ext_conv1.0.weight torch.float16
asymmetric2_3.ext_conv1.1.weight torch.float16
asymmetric2_3.ext_conv1.1.bias torch.float16
asymmetric2_3.ext_conv1.1.running_mean torch.float16
asymmetric2_3.ext_conv1.1.running_var torch.float16
asymmetric2_3.ext_conv1.1.num_batches_tracked torch.int64
asymmetric2_3.ext_conv1.2.weight torch.float16
asymmetric2_3.ext_conv2.0.weight torch.float16
asymmetric2_3.ext_conv2.1.weight torch.float16
asymmetric2_3.ext_conv2.1.bias torch.float16
asymmetric2_3.ext_conv2.1.running_mean torch.float16
asymmetric2_3.ext_conv2.1.running_var torch.float16
asymmetric2_3.ext_conv2.1.num_batches_tracked torch.int64
asymmetric2_3.ext_conv2.2.weight torch.float16
asymmetric2_3.ext_conv2.3.weight torch.float16
asymmetric2_3.ext_conv2.4.weight torch.float16
asymmetric2_3.ext_conv2.4.bias torch.float16
asymmetric2_3.ext_conv2.4.running_mean torch.float16
asymmetric2_3.ext_conv2.4.running_var torch.float16
asymmetric2_3.ext_conv2.4.num_batches_tracked torch.int64
asymmetric2_3.ext_conv2.5.weight torch.float16
asymmetric2_3.ext_conv3.0.weight torch.float16
asymmetric2_3.ext_conv3.1.weight torch.float16
asymmetric2_3.ext_conv3.1.bias torch.float16
asymmetric2_3.ext_conv3.1.running_mean torch.float16
asymmetric2_3.ext_conv3.1.running_var torch.float16
asymmetric2_3.ext_conv3.1.num_batches_tracked torch.int64
asymmetric2_3.ext_conv3.2.weight torch.float16
asymmetric2_3.out_activation.weight torch.float16
dilated2_4.ext_conv1.0.weight torch.float16
dilated2_4.ext_conv1.1.weight torch.float16
dilated2_4.ext_conv1.1.bias torch.float16
dilated2_4.ext_conv1.1.running_mean torch.float16
dilated2_4.ext_conv1.1.running_var torch.float16
dilated2_4.ext_conv1.1.num_batches_tracked torch.int64
dilated2_4.ext_conv1.2.weight torch.float16
dilated2_4.ext_conv2.0.weight torch.float16
dilated2_4.ext_conv2.1.weight torch.float16
dilated2_4.ext_conv2.1.bias torch.float16
dilated2_4.ext_conv2.1.running_mean torch.float16
dilated2_4.ext_conv2.1.running_var torch.float16
dilated2_4.ext_conv2.1.num_batches_tracked torch.int64
dilated2_4.ext_conv2.2.weight torch.float16
dilated2_4.ext_conv3.0.weight torch.float16
dilated2_4.ext_conv3.1.weight torch.float16
dilated2_4.ext_conv3.1.bias torch.float16
dilated2_4.ext_conv3.1.running_mean torch.float16
dilated2_4.ext_conv3.1.running_var torch.float16
dilated2_4.ext_conv3.1.num_batches_tracked torch.int64
dilated2_4.ext_conv3.2.weight torch.float16
dilated2_4.out_activation.weight torch.float16
regular2_5.ext_conv1.0.weight torch.float16
regular2_5.ext_conv1.1.weight torch.float16
regular2_5.ext_conv1.1.bias torch.float16
regular2_5.ext_conv1.1.running_mean torch.float16
regular2_5.ext_conv1.1.running_var torch.float16
regular2_5.ext_conv1.1.num_batches_tracked torch.int64
regular2_5.ext_conv1.2.weight torch.float16
regular2_5.ext_conv2.0.weight torch.float16
regular2_5.ext_conv2.1.weight torch.float16
regular2_5.ext_conv2.1.bias torch.float16
regular2_5.ext_conv2.1.running_mean torch.float16
regular2_5.ext_conv2.1.running_var torch.float16
regular2_5.ext_conv2.1.num_batches_tracked torch.int64
regular2_5.ext_conv2.2.weight torch.float16
regular2_5.ext_conv3.0.weight torch.float16
regular2_5.ext_conv3.1.weight torch.float16
regular2_5.ext_conv3.1.bias torch.float16
regular2_5.ext_conv3.1.running_mean torch.float16
regular2_5.ext_conv3.1.running_var torch.float16
regular2_5.ext_conv3.1.num_batches_tracked torch.int64
regular2_5.ext_conv3.2.weight torch.float16
regular2_5.out_activation.weight torch.float16
dilated2_6.ext_conv1.0.weight torch.float16
dilated2_6.ext_conv1.1.weight torch.float16
dilated2_6.ext_conv1.1.bias torch.float16
dilated2_6.ext_conv1.1.running_mean torch.float16
dilated2_6.ext_conv1.1.running_var torch.float16
dilated2_6.ext_conv1.1.num_batches_tracked torch.int64
dilated2_6.ext_conv1.2.weight torch.float16
dilated2_6.ext_conv2.0.weight torch.float16
dilated2_6.ext_conv2.1.weight torch.float16
dilated2_6.ext_conv2.1.bias torch.float16
dilated2_6.ext_conv2.1.running_mean torch.float16
dilated2_6.ext_conv2.1.running_var torch.float16
dilated2_6.ext_conv2.1.num_batches_tracked torch.int64
dilated2_6.ext_conv2.2.weight torch.float16
dilated2_6.ext_conv3.0.weight torch.float16
dilated2_6.ext_conv3.1.weight torch.float16
dilated2_6.ext_conv3.1.bias torch.float16
dilated2_6.ext_conv3.1.running_mean torch.float16
dilated2_6.ext_conv3.1.running_var torch.float16
dilated2_6.ext_conv3.1.num_batches_tracked torch.int64
dilated2_6.ext_conv3.2.weight torch.float16
dilated2_6.out_activation.weight torch.float16
asymmetric2_7.ext_conv1.0.weight torch.float16
asymmetric2_7.ext_conv1.1.weight torch.float16
asymmetric2_7.ext_conv1.1.bias torch.float16
asymmetric2_7.ext_conv1.1.running_mean torch.float16
asymmetric2_7.ext_conv1.1.running_var torch.float16
asymmetric2_7.ext_conv1.1.num_batches_tracked torch.int64
asymmetric2_7.ext_conv1.2.weight torch.float16
asymmetric2_7.ext_conv2.0.weight torch.float16
asymmetric2_7.ext_conv2.1.weight torch.float16
asymmetric2_7.ext_conv2.1.bias torch.float16
asymmetric2_7.ext_conv2.1.running_mean torch.float16
asymmetric2_7.ext_conv2.1.running_var torch.float16
asymmetric2_7.ext_conv2.1.num_batches_tracked torch.int64
asymmetric2_7.ext_conv2.2.weight torch.float16
asymmetric2_7.ext_conv2.3.weight torch.float16
asymmetric2_7.ext_conv2.4.weight torch.float16
asymmetric2_7.ext_conv2.4.bias torch.float16
asymmetric2_7.ext_conv2.4.running_mean torch.float16
asymmetric2_7.ext_conv2.4.running_var torch.float16
asymmetric2_7.ext_conv2.4.num_batches_tracked torch.int64
asymmetric2_7.ext_conv2.5.weight torch.float16
asymmetric2_7.ext_conv3.0.weight torch.float16
asymmetric2_7.ext_conv3.1.weight torch.float16
asymmetric2_7.ext_conv3.1.bias torch.float16
asymmetric2_7.ext_conv3.1.running_mean torch.float16
asymmetric2_7.ext_conv3.1.running_var torch.float16
asymmetric2_7.ext_conv3.1.num_batches_tracked torch.int64
asymmetric2_7.ext_conv3.2.weight torch.float16
asymmetric2_7.out_activation.weight torch.float16
dilated2_8.ext_conv1.0.weight torch.float16
dilated2_8.ext_conv1.1.weight torch.float16
dilated2_8.ext_conv1.1.bias torch.float16
dilated2_8.ext_conv1.1.running_mean torch.float16
dilated2_8.ext_conv1.1.running_var torch.float16
dilated2_8.ext_conv1.1.num_batches_tracked torch.int64
dilated2_8.ext_conv1.2.weight torch.float16
dilated2_8.ext_conv2.0.weight torch.float16
dilated2_8.ext_conv2.1.weight torch.float16
dilated2_8.ext_conv2.1.bias torch.float16
dilated2_8.ext_conv2.1.running_mean torch.float16
dilated2_8.ext_conv2.1.running_var torch.float16
dilated2_8.ext_conv2.1.num_batches_tracked torch.int64
dilated2_8.ext_conv2.2.weight torch.float16
dilated2_8.ext_conv3.0.weight torch.float16
dilated2_8.ext_conv3.1.weight torch.float16
dilated2_8.ext_conv3.1.bias torch.float16
dilated2_8.ext_conv3.1.running_mean torch.float16
dilated2_8.ext_conv3.1.running_var torch.float16
dilated2_8.ext_conv3.1.num_batches_tracked torch.int64
dilated2_8.ext_conv3.2.weight torch.float16
dilated2_8.out_activation.weight torch.float16
regular3_0.ext_conv1.0.weight torch.float16
regular3_0.ext_conv1.1.weight torch.float16
regular3_0.ext_conv1.1.bias torch.float16
regular3_0.ext_conv1.1.running_mean torch.float16
regular3_0.ext_conv1.1.running_var torch.float16
regular3_0.ext_conv1.1.num_batches_tracked torch.int64
regular3_0.ext_conv1.2.weight torch.float16
regular3_0.ext_conv2.0.weight torch.float16
regular3_0.ext_conv2.1.weight torch.float16
regular3_0.ext_conv2.1.bias torch.float16
regular3_0.ext_conv2.1.running_mean torch.float16
regular3_0.ext_conv2.1.running_var torch.float16
regular3_0.ext_conv2.1.num_batches_tracked torch.int64
regular3_0.ext_conv2.2.weight torch.float16
regular3_0.ext_conv3.0.weight torch.float16
regular3_0.ext_conv3.1.weight torch.float16
regular3_0.ext_conv3.1.bias torch.float16
regular3_0.ext_conv3.1.running_mean torch.float16
regular3_0.ext_conv3.1.running_var torch.float16
regular3_0.ext_conv3.1.num_batches_tracked torch.int64
regular3_0.ext_conv3.2.weight torch.float16
regular3_0.out_activation.weight torch.float16
dilated3_1.ext_conv1.0.weight torch.float16
dilated3_1.ext_conv1.1.weight torch.float16
dilated3_1.ext_conv1.1.bias torch.float16
dilated3_1.ext_conv1.1.running_mean torch.float16
dilated3_1.ext_conv1.1.running_var torch.float16
dilated3_1.ext_conv1.1.num_batches_tracked torch.int64
dilated3_1.ext_conv1.2.weight torch.float16
dilated3_1.ext_conv2.0.weight torch.float16
dilated3_1.ext_conv2.1.weight torch.float16
dilated3_1.ext_conv2.1.bias torch.float16
dilated3_1.ext_conv2.1.running_mean torch.float16
dilated3_1.ext_conv2.1.running_var torch.float16
dilated3_1.ext_conv2.1.num_batches_tracked torch.int64
dilated3_1.ext_conv2.2.weight torch.float16
dilated3_1.ext_conv3.0.weight torch.float16
dilated3_1.ext_conv3.1.weight torch.float16
dilated3_1.ext_conv3.1.bias torch.float16
dilated3_1.ext_conv3.1.running_mean torch.float16
dilated3_1.ext_conv3.1.running_var torch.float16
dilated3_1.ext_conv3.1.num_batches_tracked torch.int64
dilated3_1.ext_conv3.2.weight torch.float16
dilated3_1.out_activation.weight torch.float16
asymmetric3_2.ext_conv1.0.weight torch.float16
asymmetric3_2.ext_conv1.1.weight torch.float16
asymmetric3_2.ext_conv1.1.bias torch.float16
asymmetric3_2.ext_conv1.1.running_mean torch.float16
asymmetric3_2.ext_conv1.1.running_var torch.float16
asymmetric3_2.ext_conv1.1.num_batches_tracked torch.int64
asymmetric3_2.ext_conv1.2.weight torch.float16
asymmetric3_2.ext_conv2.0.weight torch.float16
asymmetric3_2.ext_conv2.1.weight torch.float16
asymmetric3_2.ext_conv2.1.bias torch.float16
asymmetric3_2.ext_conv2.1.running_mean torch.float16
asymmetric3_2.ext_conv2.1.running_var torch.float16
asymmetric3_2.ext_conv2.1.num_batches_tracked torch.int64
asymmetric3_2.ext_conv2.2.weight torch.float16
asymmetric3_2.ext_conv2.3.weight torch.float16
asymmetric3_2.ext_conv2.4.weight torch.float16
asymmetric3_2.ext_conv2.4.bias torch.float16
asymmetric3_2.ext_conv2.4.running_mean torch.float16
asymmetric3_2.ext_conv2.4.running_var torch.float16
asymmetric3_2.ext_conv2.4.num_batches_tracked torch.int64
asymmetric3_2.ext_conv2.5.weight torch.float16
asymmetric3_2.ext_conv3.0.weight torch.float16
asymmetric3_2.ext_conv3.1.weight torch.float16
asymmetric3_2.ext_conv3.1.bias torch.float16
asymmetric3_2.ext_conv3.1.running_mean torch.float16
asymmetric3_2.ext_conv3.1.running_var torch.float16
asymmetric3_2.ext_conv3.1.num_batches_tracked torch.int64
asymmetric3_2.ext_conv3.2.weight torch.float16
asymmetric3_2.out_activation.weight torch.float16
dilated3_3.ext_conv1.0.weight torch.float16
dilated3_3.ext_conv1.1.weight torch.float16
dilated3_3.ext_conv1.1.bias torch.float16
dilated3_3.ext_conv1.1.running_mean torch.float16
dilated3_3.ext_conv1.1.running_var torch.float16
dilated3_3.ext_conv1.1.num_batches_tracked torch.int64
dilated3_3.ext_conv1.2.weight torch.float16
dilated3_3.ext_conv2.0.weight torch.float16
dilated3_3.ext_conv2.1.weight torch.float16
dilated3_3.ext_conv2.1.bias torch.float16
dilated3_3.ext_conv2.1.running_mean torch.float16
dilated3_3.ext_conv2.1.running_var torch.float16
dilated3_3.ext_conv2.1.num_batches_tracked torch.int64
dilated3_3.ext_conv2.2.weight torch.float16
dilated3_3.ext_conv3.0.weight torch.float16
dilated3_3.ext_conv3.1.weight torch.float16
dilated3_3.ext_conv3.1.bias torch.float16
dilated3_3.ext_conv3.1.running_mean torch.float16
dilated3_3.ext_conv3.1.running_var torch.float16
dilated3_3.ext_conv3.1.num_batches_tracked torch.int64
dilated3_3.ext_conv3.2.weight torch.float16
dilated3_3.out_activation.weight torch.float16
regular3_4.ext_conv1.0.weight torch.float16
regular3_4.ext_conv1.1.weight torch.float16
regular3_4.ext_conv1.1.bias torch.float16
regular3_4.ext_conv1.1.running_mean torch.float16
regular3_4.ext_conv1.1.running_var torch.float16
regular3_4.ext_conv1.1.num_batches_tracked torch.int64
regular3_4.ext_conv1.2.weight torch.float16
regular3_4.ext_conv2.0.weight torch.float16
regular3_4.ext_conv2.1.weight torch.float16
regular3_4.ext_conv2.1.bias torch.float16
regular3_4.ext_conv2.1.running_mean torch.float16
regular3_4.ext_conv2.1.running_var torch.float16
regular3_4.ext_conv2.1.num_batches_tracked torch.int64
regular3_4.ext_conv2.2.weight torch.float16
regular3_4.ext_conv3.0.weight torch.float16
regular3_4.ext_conv3.1.weight torch.float16
regular3_4.ext_conv3.1.bias torch.float16
regular3_4.ext_conv3.1.running_mean torch.float16
regular3_4.ext_conv3.1.running_var torch.float16
regular3_4.ext_conv3.1.num_batches_tracked torch.int64
regular3_4.ext_conv3.2.weight torch.float16
regular3_4.out_activation.weight torch.float16
dilated3_5.ext_conv1.0.weight torch.float16
dilated3_5.ext_conv1.1.weight torch.float16
dilated3_5.ext_conv1.1.bias torch.float16
dilated3_5.ext_conv1.1.running_mean torch.float16
dilated3_5.ext_conv1.1.running_var torch.float16
dilated3_5.ext_conv1.1.num_batches_tracked torch.int64
dilated3_5.ext_conv1.2.weight torch.float16
dilated3_5.ext_conv2.0.weight torch.float16
dilated3_5.ext_conv2.1.weight torch.float16
dilated3_5.ext_conv2.1.bias torch.float16
dilated3_5.ext_conv2.1.running_mean torch.float16
dilated3_5.ext_conv2.1.running_var torch.float16
dilated3_5.ext_conv2.1.num_batches_tracked torch.int64
dilated3_5.ext_conv2.2.weight torch.float16
dilated3_5.ext_conv3.0.weight torch.float16
dilated3_5.ext_conv3.1.weight torch.float16
dilated3_5.ext_conv3.1.bias torch.float16
dilated3_5.ext_conv3.1.running_mean torch.float16
dilated3_5.ext_conv3.1.running_var torch.float16
dilated3_5.ext_conv3.1.num_batches_tracked torch.int64
dilated3_5.ext_conv3.2.weight torch.float16
dilated3_5.out_activation.weight torch.float16
asymmetric3_6.ext_conv1.0.weight torch.float16
asymmetric3_6.ext_conv1.1.weight torch.float16
asymmetric3_6.ext_conv1.1.bias torch.float16
asymmetric3_6.ext_conv1.1.running_mean torch.float16
asymmetric3_6.ext_conv1.1.running_var torch.float16
asymmetric3_6.ext_conv1.1.num_batches_tracked torch.int64
asymmetric3_6.ext_conv1.2.weight torch.float16
asymmetric3_6.ext_conv2.0.weight torch.float16
asymmetric3_6.ext_conv2.1.weight torch.float16
asymmetric3_6.ext_conv2.1.bias torch.float16
asymmetric3_6.ext_conv2.1.running_mean torch.float16
asymmetric3_6.ext_conv2.1.running_var torch.float16
asymmetric3_6.ext_conv2.1.num_batches_tracked torch.int64
asymmetric3_6.ext_conv2.2.weight torch.float16
asymmetric3_6.ext_conv2.3.weight torch.float16
asymmetric3_6.ext_conv2.4.weight torch.float16
asymmetric3_6.ext_conv2.4.bias torch.float16
asymmetric3_6.ext_conv2.4.running_mean torch.float16
asymmetric3_6.ext_conv2.4.running_var torch.float16
asymmetric3_6.ext_conv2.4.num_batches_tracked torch.int64
asymmetric3_6.ext_conv2.5.weight torch.float16
asymmetric3_6.ext_conv3.0.weight torch.float16
asymmetric3_6.ext_conv3.1.weight torch.float16
asymmetric3_6.ext_conv3.1.bias torch.float16
asymmetric3_6.ext_conv3.1.running_mean torch.float16
asymmetric3_6.ext_conv3.1.running_var torch.float16
asymmetric3_6.ext_conv3.1.num_batches_tracked torch.int64
asymmetric3_6.ext_conv3.2.weight torch.float16
asymmetric3_6.out_activation.weight torch.float16
dilated3_7.ext_conv1.0.weight torch.float16
dilated3_7.ext_conv1.1.weight torch.float16
dilated3_7.ext_conv1.1.bias torch.float16
dilated3_7.ext_conv1.1.running_mean torch.float16
dilated3_7.ext_conv1.1.running_var torch.float16
dilated3_7.ext_conv1.1.num_batches_tracked torch.int64
dilated3_7.ext_conv1.2.weight torch.float16
dilated3_7.ext_conv2.0.weight torch.float16
dilated3_7.ext_conv2.1.weight torch.float16
dilated3_7.ext_conv2.1.bias torch.float16
dilated3_7.ext_conv2.1.running_mean torch.float16
dilated3_7.ext_conv2.1.running_var torch.float16
dilated3_7.ext_conv2.1.num_batches_tracked torch.int64
dilated3_7.ext_conv2.2.weight torch.float16
dilated3_7.ext_conv3.0.weight torch.float16
dilated3_7.ext_conv3.1.weight torch.float16
dilated3_7.ext_conv3.1.bias torch.float16
dilated3_7.ext_conv3.1.running_mean torch.float16
dilated3_7.ext_conv3.1.running_var torch.float16
dilated3_7.ext_conv3.1.num_batches_tracked torch.int64
dilated3_7.ext_conv3.2.weight torch.float16
dilated3_7.out_activation.weight torch.float16
upsample4_0.main_conv1.0.weight torch.float16
upsample4_0.main_conv1.1.weight torch.float16
upsample4_0.main_conv1.1.bias torch.float16
upsample4_0.main_conv1.1.running_mean torch.float16
upsample4_0.main_conv1.1.running_var torch.float16
upsample4_0.main_conv1.1.num_batches_tracked torch.int64
upsample4_0.ext_conv1.0.weight torch.float16
upsample4_0.ext_conv1.1.weight torch.float16
upsample4_0.ext_conv1.1.bias torch.float16
upsample4_0.ext_conv1.1.running_mean torch.float16
upsample4_0.ext_conv1.1.running_var torch.float16
upsample4_0.ext_conv1.1.num_batches_tracked torch.int64
upsample4_0.ext_tconv1.weight torch.float16
upsample4_0.ext_tconv1_bnorm.weight torch.float16
upsample4_0.ext_tconv1_bnorm.bias torch.float16
upsample4_0.ext_tconv1_bnorm.running_mean torch.float16
upsample4_0.ext_tconv1_bnorm.running_var torch.float16
upsample4_0.ext_tconv1_bnorm.num_batches_tracked torch.int64
upsample4_0.ext_conv2.0.weight torch.float16
upsample4_0.ext_conv2.1.weight torch.float16
upsample4_0.ext_conv2.1.bias torch.float16
upsample4_0.ext_conv2.1.running_mean torch.float16
upsample4_0.ext_conv2.1.running_var torch.float16
upsample4_0.ext_conv2.1.num_batches_tracked torch.int64
regular4_1.ext_conv1.0.weight torch.float16
regular4_1.ext_conv1.1.weight torch.float16
regular4_1.ext_conv1.1.bias torch.float16
regular4_1.ext_conv1.1.running_mean torch.float16
regular4_1.ext_conv1.1.running_var torch.float16
regular4_1.ext_conv1.1.num_batches_tracked torch.int64
regular4_1.ext_conv2.0.weight torch.float16
regular4_1.ext_conv2.1.weight torch.float16
regular4_1.ext_conv2.1.bias torch.float16
regular4_1.ext_conv2.1.running_mean torch.float16
regular4_1.ext_conv2.1.running_var torch.float16
regular4_1.ext_conv2.1.num_batches_tracked torch.int64
regular4_1.ext_conv3.0.weight torch.float16
regular4_1.ext_conv3.1.weight torch.float16
regular4_1.ext_conv3.1.bias torch.float16
regular4_1.ext_conv3.1.running_mean torch.float16
regular4_1.ext_conv3.1.running_var torch.float16
regular4_1.ext_conv3.1.num_batches_tracked torch.int64
regular4_2.ext_conv1.0.weight torch.float16
regular4_2.ext_conv1.1.weight torch.float16
regular4_2.ext_conv1.1.bias torch.float16
regular4_2.ext_conv1.1.running_mean torch.float16
regular4_2.ext_conv1.1.running_var torch.float16
regular4_2.ext_conv1.1.num_batches_tracked torch.int64
regular4_2.ext_conv2.0.weight torch.float16
regular4_2.ext_conv2.1.weight torch.float16
regular4_2.ext_conv2.1.bias torch.float16
regular4_2.ext_conv2.1.running_mean torch.float16
regular4_2.ext_conv2.1.running_var torch.float16
regular4_2.ext_conv2.1.num_batches_tracked torch.int64
regular4_2.ext_conv3.0.weight torch.float16
regular4_2.ext_conv3.1.weight torch.float16
regular4_2.ext_conv3.1.bias torch.float16
regular4_2.ext_conv3.1.running_mean torch.float16
regular4_2.ext_conv3.1.running_var torch.float16
regular4_2.ext_conv3.1.num_batches_tracked torch.int64
upsample5_0.main_conv1.0.weight torch.float16
upsample5_0.main_conv1.1.weight torch.float16
upsample5_0.main_conv1.1.bias torch.float16
upsample5_0.main_conv1.1.running_mean torch.float16
upsample5_0.main_conv1.1.running_var torch.float16
upsample5_0.main_conv1.1.num_batches_tracked torch.int64
upsample5_0.ext_conv1.0.weight torch.float16
upsample5_0.ext_conv1.1.weight torch.float16
upsample5_0.ext_conv1.1.bias torch.float16
upsample5_0.ext_conv1.1.running_mean torch.float16
upsample5_0.ext_conv1.1.running_var torch.float16
upsample5_0.ext_conv1.1.num_batches_tracked torch.int64
upsample5_0.ext_tconv1.weight torch.float16
upsample5_0.ext_tconv1_bnorm.weight torch.float16
upsample5_0.ext_tconv1_bnorm.bias torch.float16
upsample5_0.ext_tconv1_bnorm.running_mean torch.float16
upsample5_0.ext_tconv1_bnorm.running_var torch.float16
upsample5_0.ext_tconv1_bnorm.num_batches_tracked torch.int64
upsample5_0.ext_conv2.0.weight torch.float16
upsample5_0.ext_conv2.1.weight torch.float16
upsample5_0.ext_conv2.1.bias torch.float16
upsample5_0.ext_conv2.1.running_mean torch.float16
upsample5_0.ext_conv2.1.running_var torch.float16
upsample5_0.ext_conv2.1.num_batches_tracked torch.int64
regular5_1.ext_conv1.0.weight torch.float16
regular5_1.ext_conv1.1.weight torch.float16
regular5_1.ext_conv1.1.bias torch.float16
regular5_1.ext_conv1.1.running_mean torch.float16
regular5_1.ext_conv1.1.running_var torch.float16
regular5_1.ext_conv1.1.num_batches_tracked torch.int64
regular5_1.ext_conv2.0.weight torch.float16
regular5_1.ext_conv2.1.weight torch.float16
regular5_1.ext_conv2.1.bias torch.float16
regular5_1.ext_conv2.1.running_mean torch.float16
regular5_1.ext_conv2.1.running_var torch.float16
regular5_1.ext_conv2.1.num_batches_tracked torch.int64
regular5_1.ext_conv3.0.weight torch.float16
regular5_1.ext_conv3.1.weight torch.float16
regular5_1.ext_conv3.1.bias torch.float16
regular5_1.ext_conv3.1.running_mean torch.float16
regular5_1.ext_conv3.1.running_var torch.float16
regular5_1.ext_conv3.1.num_batches_tracked torch.int64
transposed_conv.weight torch.float16
Could you post an executable code snippet using random inputs so that we could reproduce this issue?
@ptrblck here is everything you need to reproduce the issue. The repo is very minimalistic (2.4mb) and specifily created to reproduce the issue. I assume you have all the requirements allready fulfilled (pillow, torch and numpy).
I am going to remove the repo as I found the solution to my issue.
I found the what caused my problem. The default dtype of my model must have been float. Setting the following line above my model definition solved the problem:
torch.set_default_dtype(torch.float16)
It was a little tricky to notice it, but it does make sense. Thank you @ptrblck for the discussion and your questions. In the end they were responsible for me finding the solution.
Pytorch tensor arguments with dtype=pytorch.bool caused the CrossEntropyLoss function to throw an error. Neither the .long() casting method nor the other suggestions here helped:
RuntimeError: “log_softmax_lastdim_kernel_impl” not implemented for ‘Long’
But adding 0 to the arguments is a workaround, which is strange.
hi @ptrblck , Just wanted to know what is the ideal way to do this casting between model params. I keep getting this error in backprop although the loss is in float64 format.
tot_loss = reconstruct_loss + args.w_reg * reg_loss + args.a_reg * k_loss
tot_loss = tot_loss.double()
TOT LOSS: tensor(-3.3673e+15, grad_fn=)
Traceback (most recent call last):
File “AE.py”, line 305, in
tot_loss.backward()
File “/usr/local/lib/python3.7/dist-packages/torch/_tensor.py”, line 363, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
File “/usr/local/lib/python3.7/dist-packages/torch/autograd/init.py”, line 175, in backward
allow_unreachable=True, accumulate_grad=True) # Calls into the C++ engine to run the backward pass
RuntimeError: Found dtype Float but expected Double
That’s strange as the error seems to be raised in the backward pass (while the forward works). Could you post a minimal, executable code snippet to reproduce the issue, please?
Sure, thanks.
import torch
import torch.nn as nn
from torchvision.utils import save_image
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import argparse
import time
import numpy as np
import matplotlib.pyplot as plt
import math
from scipy import stats
import scipy
import os
import datetime
from scipy.stats import gaussian_kde
from math import sqrt
from math import log
from torch import optim
from torch.autograd import Variable
from math import sqrt
from math import log
from sklearn.neighbors import KernelDensity
# parse input data
parser = argparse.ArgumentParser()
parser.add_argument("--code_size", default=20, help="size of the code", type=int)
parser.add_argument("--w_reg", default=0.001, help="weight of the regularization in the loss function", type=float)
parser.add_argument("--a_reg", default=0.2, help="weight of the kernel alignment", type=float)
parser.add_argument("--num_epochs", default=5000, help="number of epochs in training", type=int)
parser.add_argument("--batch_size", default=25, help="number of samples in each batch", type=int)
parser.add_argument("--max_gradient_norm", default=1.0, help="max gradient norm for gradient clipping", type=float)
parser.add_argument("--learning_rate", default=0.001, help="Adam initial learning rate", type=float)
parser.add_argument("--hidden_size", default=30, help="size of the code", type=int)
args = parser.parse_args()
print(args)
# ================= DATASET =================
# (train_data, train_labels, train_len, _, K_tr,
# valid_data, _, valid_len, _, K_vs,
# test_data_orig, test_labels, test_len, _, K_ts) = getBlood(kernel='TCK',
# inp='zero') # data shape is [T, N, V] = [time_steps, num_elements, num_var]
train_data = np.random.rand(9000,6)
train_labels = np.ones([9000,1])
train_len = 9000
valid_data = np.random.rand(9000,6)
valid_len = 9000
test_data = np.random.rand(1500,6)
test_labels = np.ones([1500,1])
K_tr = np.random.rand(9000,9000)
K_ts = np.random.rand(1500,1500)
K_vs = np.random.rand(9000,9000)
#test_data = test_data_orig
print(
'\n**** Processing Blood data: Tr{}, Vs{}, Ts{} ****\n'.format(train_data.shape, valid_data.shape, test_data.shape))
input_length = train_data.shape[1] # same for all inputs
# ================= GRAPH =================
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")
encoder_inputs = train_data
prior_k = K_tr
# ============= TENSORBOARD =============
writer = SummaryWriter()
# # ----- ENCODER -----
input_length = encoder_inputs.shape[1]
print("INPUT ")
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.We1 = torch.nn.Parameter(
torch.Tensor(input_length, args.hidden_size).uniform_(-1.0 / math.sqrt(input_length),
1.0 / math.sqrt(input_length)))
self.We2 = torch.nn.Parameter(
torch.Tensor(args.hidden_size, args.code_size).uniform_(-1.0 / math.sqrt(args.hidden_size),
1.0 / math.sqrt(args.hidden_size)))
self.be1 = torch.nn.Parameter(torch.zeros([args.hidden_size]))
self.be2 = torch.nn.Parameter(torch.zeros([args.code_size]))
def encoder(self, encoder_inputs):
hidden_1 = torch.tanh(torch.matmul(encoder_inputs.float(), self.We1) + self.be1)
code = torch.tanh(torch.matmul(hidden_1, self.We2) + self.be2)
# print ("CODE ENCODER SHAPE:", code.size())
return code
def decoder(code):
Wd1 = torch.nn.Parameter(
torch.Tensor(args.code_size, args.hidden_size).uniform_(-1.0 / math.sqrt(args.code_size),
1.0 / math.sqrt(args.code_size)))
Wd2 = torch.nn.Parameter(
torch.Tensor(args.hidden_size, input_length).uniform_(-1.0 / math.sqrt(args.hidden_size),
1.0 / math.sqrt(args.hidden_size)))
bd1 = torch.nn.Parameter(torch.zeros([args.hidden_size]))
bd2 = torch.nn.Parameter(torch.zeros([input_length]))
hidden_2 = torch.tanh(torch.matmul(code, Wd1) + bd1)
dec_out = torch.matmul(hidden_2, Wd2) + bd2
return dec_out
def kernel_loss(code, prior_K):
# kernel on codes
code_K = torch.mm(code, torch.t(code))
# ----- LOSS -----
# kernel alignment loss with KL divergence loss
kl_loss = torch.nn.KLDivLoss(reduction="batchmean")
k_loss = kl_loss(code_K, prior_K)
return k_loss
# Initialize model
model = Model()
# trainable parameters count
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('Total parameters: {}'.format(total_params))
# Optimizer
optimizer = torch.optim.Adam(model.parameters(), args.learning_rate)
# ================= TRAINING =================
# initialize training variables
time_tr_start = time.time()
batch_size = args.batch_size
max_batches = train_data.shape[0] // batch_size
loss_track = []
kloss_track = []
min_vs_loss = np.infty
model_dir = "logs/m_0.ckpt"
logdir = os.path.join("logs", datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
###############################################################################
# Training code
###############################################################################
try:
for ep in range(args.num_epochs):
# shuffle training data
idx = np.random.permutation(train_data.shape[0])
train_data_s = train_data[idx, :]
K_tr_s = K_tr[idx, :][:, idx]
for batch in range(max_batches):
fdtr = {}
fdtr["encoder_inputs"] = train_data_s[(batch) * batch_size:(batch + 1) * batch_size, :]
fdtr["prior_K"] = K_tr_s[(batch) * batch_size:(batch + 1) * batch_size,
(batch) * batch_size:(batch + 1) * batch_size]
encoder_inputs = (fdtr["encoder_inputs"].astype(float))
encoder_inputs = torch.from_numpy(encoder_inputs)
prior_K = (fdtr["prior_K"].astype(float))
prior_K = torch.from_numpy(prior_K)
code = model.encoder(encoder_inputs)
dec_out = decoder(code)
reconstruct_loss = torch.mean((dec_out - encoder_inputs) ** 2)
reconstruct_loss = reconstruct_loss.float()
# print("RECONS LOSS TRAIN:", reconstruct_loss)
k_loss = kernel_loss(code, prior_K)
k_loss = k_loss.float()
# Regularization L2 loss
reg_loss = 0
parameters = torch.nn.utils.parameters_to_vector(model.parameters())
# print ("PARAMS:", (parameters))
for tf_var in parameters:
reg_loss += torch.mean(torch.linalg.norm(tf_var))
tot_loss = reconstruct_loss + args.w_reg * reg_loss + args.a_reg * k_loss
tot_loss = tot_loss.double()
# Backpropagation
optimizer.zero_grad()
tot_loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_gradient_norm)
optimizer.step()
loss_track.append(reconstruct_loss)
kloss_track.append(k_loss)
print('VS r_loss=%.3f, k_loss=%.3f -- TR r_loss=%.3f, k_loss=%.3f' % (
reconstruct_loss, k_loss, torch.mean(torch.stack(loss_track[-10:])),
torch.mean(torch.stack(kloss_track[-10:]))))
except KeyboardInterrupt:
print('training interrupted')
time_tr_end = time.time()
print('Tot training time: {}'.format((time_tr_end - time_tr_start) // 60))
writer.close()
Please run the code as:
!python3 file.py --code_size 9 --w_reg 0.001 --a_reg 0.1 --num_epochs 100 --max_gradient_norm 0.5 --learning_rate 0.001 --hidden_size 30
Thanks for the code snippet!
The issue is raised in KLDivLoss
as it doesn’t seem to accept mixed dtype
s:
kl_loss = nn.KLDivLoss(reduction="batchmean")
input = torch.log_softmax(torch.randn(3, 5, requires_grad=True), dim=1)
target = torch.softmax(torch.rand(3, 5), dim=1).double()
output = kl_loss(input, target)
output.backward()
# RuntimeError: Found dtype Float but expected Double
Could you create an issue on GitHub so that we can track and fix it, please?
Sure. The issue is raised.
Backpropagation not working on KL divergence loss function due to data type mismatch #80158
Thanks for the explanation! helped me a lot