How to cast a tensor to another type?

Cast seems not to work in pytorch 1.0. Please see the output below. How can I fix that, please?

d.type(torch.DoubleTensor)
tensor([[[ 1.5446, 0.3419, 0.1070, -0.6632, 0.5054, 0.7074],
[-0.5460, -0.0041, -0.6613, -1.5072, 0.4836, 3.1626],
[-0.9564, 1.8512, -0.6912, -1.0977, 0.4808, -0.5918],
[-1.3628, 2.2673, -0.9875, 1.0004, 0.1614, -0.4596],
[-2.0670, 1.4336, -1.1763, 0.1440, -0.5740, 0.2190]],

    [[ 1.5446,  0.3419,  0.1070, -0.6632,  0.5054,  0.7074],
     [-0.5460, -0.0041, -0.6613, -1.5072,  0.4836,  3.1626],
     [-0.9564,  1.8512, -0.6912, -1.0977,  0.4808, -0.5918],
     [-1.3628,  2.2673, -0.9875,  1.0004,  0.1614, -0.4596],
     [-2.0670,  1.4336, -1.1763,  0.1440, -0.5740,  0.2190]]],
   dtype=torch.float64)

In modern PyTorch, you just say float_tensor.double() to cast a float tensor to double tensor. There are methods for each type you want to cast to. If, instead, you have a dtype and want to cast to that, say float_tensor.to(dtype=your_dtype) (e.g., your_dtype = torch.float64)

6 Likes

@alan_ayu @ezyang
Isn’t there a method to change dtype of a model?

The .to() method will also work on models and dtypes, e.g. model.to(torch.double) will convert all parameters to float64.

26 Likes

Just an on the go solution

tensor_one.float() : converts the tensor_one type to torch.float32
tensor_one.double() : converts the tensor_one type to torch.float64
tensor_one.int() : converts the tensor_one type to torch.int32

6 Likes

cast your tensors using .long()

This worked for me.

1 Like

how to do the above conversion in libtorch?

Hi Ptrblck,

I am computing this commnad

 PP3=(P1*P2).view(-1).sum().item()

error is (I run this code on cpu does not work. I change it to GPU and )

It give me this error. I try all option to change the (P1*P2) to the double but it gave me float again., Would you please help me with that?

Could you post the error message you are seeing as well as the workaround you are trying to use, please?

@ptrblck thanks for pointing to the dtype conversation for the whole model.
After applying it to my model I first received an error that was due to the fact that I did not change the dtype of the input to what the model is now expecting:

RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.cuda.HalfTensor) should be the same

This made sense to me and I then switched the dtype of the input accordingly: input.to(dtype=torch.float16) but then I receive the following error which causes me trouble:

RuntimeError: expected scalar type Float but found Half

Any help would be much appreciated. P.S. I searched similar issues but they did not help in my case.

Also: Iterating over the model states and printing their dtypes confirms that the conversion from float32 to float16 was successful.

If I understand your use case correctly, you are transforming the input data and model parameters to FP16 manually without using automatic mixed-precision training?
If so, could you post the error message which shows which operation is raising the issue?

That is correct. This is how I did it. Here comes the error message:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-11-1d79e07828d2> in <module>
     34             with torch.no_grad():
     35                 inference_start_time = time.time() * 1000
---> 36                 prediction = model(frame)
     37                 inference_end_time = time.time() * 1000
     38 

~/anaconda3/envs/enet-pytorch/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

~/PyTorch-ENet/models/enet.py in forward(self, x)
    590         # Stage 1 - Encoder
    591         stage1_input_size = x.size()
--> 592         x, max_indices1_0 = self.downsample1_0(x)
    593         x = self.regular1_1(x)
    594         x = self.regular1_2(x)

~/anaconda3/envs/enet-pytorch/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

~/PyTorch-ENet/models/enet.py in forward(self, x)
    345         out = main + ext
    346 
--> 347         return self.out_activation(out), max_indices
    348 
    349 

~/anaconda3/envs/enet-pytorch/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

~/anaconda3/envs/enet-pytorch/lib/python3.7/site-packages/torch/nn/modules/activation.py in forward(self, input)
    986 
    987     def forward(self, input: Tensor) -> Tensor:
--> 988         return F.prelu(input, self.weight)
    989 
    990     def extra_repr(self) -> str:

~/anaconda3/envs/enet-pytorch/lib/python3.7/site-packages/torch/nn/functional.py in prelu(input, weight)
   1317         if type(input) is not Tensor and has_torch_function((input,)):
   1318             return handle_torch_function(prelu, (input,), input, weight)
-> 1319     return torch.prelu(input, weight)
   1320 
   1321 

RuntimeError: expected scalar type Float but found Half

PyTorch version: 1.6
Python: 3.7.8

Thanks for the stacktrace.
It seems nn.PReLU() is causing the error, which I cannot reproduce using 1.7.0.dev20200830:

prelu = nn.PReLU().cuda().half()
x = torch.randn(10, 10).cuda().half()
out = prelu(x)

Could you update to the latest nightly binary and rerun the code?

I was able to run the code snippet of yours in both 1.6 and 1.7.0.dev20200830. Unfortunately remains my problem the same with the nighty version.

Just for reconfirmation. The following output shows the that the conversion to float16 was successful, right? Is there something else that could be checked?

for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].dtype)
initial_block.main_branch.weight 	 torch.float16
initial_block.batch_norm.weight 	 torch.float16
initial_block.batch_norm.bias 	 torch.float16
initial_block.batch_norm.running_mean 	 torch.float16
initial_block.batch_norm.running_var 	 torch.float16
initial_block.batch_norm.num_batches_tracked 	 torch.int64
initial_block.out_activation.weight 	 torch.float16
downsample1_0.ext_conv1.0.weight 	 torch.float16
downsample1_0.ext_conv1.1.weight 	 torch.float16
downsample1_0.ext_conv1.1.bias 	 torch.float16
downsample1_0.ext_conv1.1.running_mean 	 torch.float16
downsample1_0.ext_conv1.1.running_var 	 torch.float16
downsample1_0.ext_conv1.1.num_batches_tracked 	 torch.int64
downsample1_0.ext_conv1.2.weight 	 torch.float16
downsample1_0.ext_conv2.0.weight 	 torch.float16
downsample1_0.ext_conv2.1.weight 	 torch.float16
downsample1_0.ext_conv2.1.bias 	 torch.float16
downsample1_0.ext_conv2.1.running_mean 	 torch.float16
downsample1_0.ext_conv2.1.running_var 	 torch.float16
downsample1_0.ext_conv2.1.num_batches_tracked 	 torch.int64
downsample1_0.ext_conv2.2.weight 	 torch.float16
downsample1_0.ext_conv3.0.weight 	 torch.float16
downsample1_0.ext_conv3.1.weight 	 torch.float16
downsample1_0.ext_conv3.1.bias 	 torch.float16
downsample1_0.ext_conv3.1.running_mean 	 torch.float16
downsample1_0.ext_conv3.1.running_var 	 torch.float16
downsample1_0.ext_conv3.1.num_batches_tracked 	 torch.int64
downsample1_0.ext_conv3.2.weight 	 torch.float16
downsample1_0.out_activation.weight 	 torch.float16
regular1_1.ext_conv1.0.weight 	 torch.float16
regular1_1.ext_conv1.1.weight 	 torch.float16
regular1_1.ext_conv1.1.bias 	 torch.float16
regular1_1.ext_conv1.1.running_mean 	 torch.float16
regular1_1.ext_conv1.1.running_var 	 torch.float16
regular1_1.ext_conv1.1.num_batches_tracked 	 torch.int64
regular1_1.ext_conv1.2.weight 	 torch.float16
regular1_1.ext_conv2.0.weight 	 torch.float16
regular1_1.ext_conv2.1.weight 	 torch.float16
regular1_1.ext_conv2.1.bias 	 torch.float16
regular1_1.ext_conv2.1.running_mean 	 torch.float16
regular1_1.ext_conv2.1.running_var 	 torch.float16
regular1_1.ext_conv2.1.num_batches_tracked 	 torch.int64
regular1_1.ext_conv2.2.weight 	 torch.float16
regular1_1.ext_conv3.0.weight 	 torch.float16
regular1_1.ext_conv3.1.weight 	 torch.float16
regular1_1.ext_conv3.1.bias 	 torch.float16
regular1_1.ext_conv3.1.running_mean 	 torch.float16
regular1_1.ext_conv3.1.running_var 	 torch.float16
regular1_1.ext_conv3.1.num_batches_tracked 	 torch.int64
regular1_1.ext_conv3.2.weight 	 torch.float16
regular1_1.out_activation.weight 	 torch.float16
regular1_2.ext_conv1.0.weight 	 torch.float16
regular1_2.ext_conv1.1.weight 	 torch.float16
regular1_2.ext_conv1.1.bias 	 torch.float16
regular1_2.ext_conv1.1.running_mean 	 torch.float16
regular1_2.ext_conv1.1.running_var 	 torch.float16
regular1_2.ext_conv1.1.num_batches_tracked 	 torch.int64
regular1_2.ext_conv1.2.weight 	 torch.float16
regular1_2.ext_conv2.0.weight 	 torch.float16
regular1_2.ext_conv2.1.weight 	 torch.float16
regular1_2.ext_conv2.1.bias 	 torch.float16
regular1_2.ext_conv2.1.running_mean 	 torch.float16
regular1_2.ext_conv2.1.running_var 	 torch.float16
regular1_2.ext_conv2.1.num_batches_tracked 	 torch.int64
regular1_2.ext_conv2.2.weight 	 torch.float16
regular1_2.ext_conv3.0.weight 	 torch.float16
regular1_2.ext_conv3.1.weight 	 torch.float16
regular1_2.ext_conv3.1.bias 	 torch.float16
regular1_2.ext_conv3.1.running_mean 	 torch.float16
regular1_2.ext_conv3.1.running_var 	 torch.float16
regular1_2.ext_conv3.1.num_batches_tracked 	 torch.int64
regular1_2.ext_conv3.2.weight 	 torch.float16
regular1_2.out_activation.weight 	 torch.float16
regular1_3.ext_conv1.0.weight 	 torch.float16
regular1_3.ext_conv1.1.weight 	 torch.float16
regular1_3.ext_conv1.1.bias 	 torch.float16
regular1_3.ext_conv1.1.running_mean 	 torch.float16
regular1_3.ext_conv1.1.running_var 	 torch.float16
regular1_3.ext_conv1.1.num_batches_tracked 	 torch.int64
regular1_3.ext_conv1.2.weight 	 torch.float16
regular1_3.ext_conv2.0.weight 	 torch.float16
regular1_3.ext_conv2.1.weight 	 torch.float16
regular1_3.ext_conv2.1.bias 	 torch.float16
regular1_3.ext_conv2.1.running_mean 	 torch.float16
regular1_3.ext_conv2.1.running_var 	 torch.float16
regular1_3.ext_conv2.1.num_batches_tracked 	 torch.int64
regular1_3.ext_conv2.2.weight 	 torch.float16
regular1_3.ext_conv3.0.weight 	 torch.float16
regular1_3.ext_conv3.1.weight 	 torch.float16
regular1_3.ext_conv3.1.bias 	 torch.float16
regular1_3.ext_conv3.1.running_mean 	 torch.float16
regular1_3.ext_conv3.1.running_var 	 torch.float16
regular1_3.ext_conv3.1.num_batches_tracked 	 torch.int64
regular1_3.ext_conv3.2.weight 	 torch.float16
regular1_3.out_activation.weight 	 torch.float16
regular1_4.ext_conv1.0.weight 	 torch.float16
regular1_4.ext_conv1.1.weight 	 torch.float16
regular1_4.ext_conv1.1.bias 	 torch.float16
regular1_4.ext_conv1.1.running_mean 	 torch.float16
regular1_4.ext_conv1.1.running_var 	 torch.float16
regular1_4.ext_conv1.1.num_batches_tracked 	 torch.int64
regular1_4.ext_conv1.2.weight 	 torch.float16
regular1_4.ext_conv2.0.weight 	 torch.float16
regular1_4.ext_conv2.1.weight 	 torch.float16
regular1_4.ext_conv2.1.bias 	 torch.float16
regular1_4.ext_conv2.1.running_mean 	 torch.float16
regular1_4.ext_conv2.1.running_var 	 torch.float16
regular1_4.ext_conv2.1.num_batches_tracked 	 torch.int64
regular1_4.ext_conv2.2.weight 	 torch.float16
regular1_4.ext_conv3.0.weight 	 torch.float16
regular1_4.ext_conv3.1.weight 	 torch.float16
regular1_4.ext_conv3.1.bias 	 torch.float16
regular1_4.ext_conv3.1.running_mean 	 torch.float16
regular1_4.ext_conv3.1.running_var 	 torch.float16
regular1_4.ext_conv3.1.num_batches_tracked 	 torch.int64
regular1_4.ext_conv3.2.weight 	 torch.float16
regular1_4.out_activation.weight 	 torch.float16
downsample2_0.ext_conv1.0.weight 	 torch.float16
downsample2_0.ext_conv1.1.weight 	 torch.float16
downsample2_0.ext_conv1.1.bias 	 torch.float16
downsample2_0.ext_conv1.1.running_mean 	 torch.float16
downsample2_0.ext_conv1.1.running_var 	 torch.float16
downsample2_0.ext_conv1.1.num_batches_tracked 	 torch.int64
downsample2_0.ext_conv1.2.weight 	 torch.float16
downsample2_0.ext_conv2.0.weight 	 torch.float16
downsample2_0.ext_conv2.1.weight 	 torch.float16
downsample2_0.ext_conv2.1.bias 	 torch.float16
downsample2_0.ext_conv2.1.running_mean 	 torch.float16
downsample2_0.ext_conv2.1.running_var 	 torch.float16
downsample2_0.ext_conv2.1.num_batches_tracked 	 torch.int64
downsample2_0.ext_conv2.2.weight 	 torch.float16
downsample2_0.ext_conv3.0.weight 	 torch.float16
downsample2_0.ext_conv3.1.weight 	 torch.float16
downsample2_0.ext_conv3.1.bias 	 torch.float16
downsample2_0.ext_conv3.1.running_mean 	 torch.float16
downsample2_0.ext_conv3.1.running_var 	 torch.float16
downsample2_0.ext_conv3.1.num_batches_tracked 	 torch.int64
downsample2_0.ext_conv3.2.weight 	 torch.float16
downsample2_0.out_activation.weight 	 torch.float16
regular2_1.ext_conv1.0.weight 	 torch.float16
regular2_1.ext_conv1.1.weight 	 torch.float16
regular2_1.ext_conv1.1.bias 	 torch.float16
regular2_1.ext_conv1.1.running_mean 	 torch.float16
regular2_1.ext_conv1.1.running_var 	 torch.float16
regular2_1.ext_conv1.1.num_batches_tracked 	 torch.int64
regular2_1.ext_conv1.2.weight 	 torch.float16
regular2_1.ext_conv2.0.weight 	 torch.float16
regular2_1.ext_conv2.1.weight 	 torch.float16
regular2_1.ext_conv2.1.bias 	 torch.float16
regular2_1.ext_conv2.1.running_mean 	 torch.float16
regular2_1.ext_conv2.1.running_var 	 torch.float16
regular2_1.ext_conv2.1.num_batches_tracked 	 torch.int64
regular2_1.ext_conv2.2.weight 	 torch.float16
regular2_1.ext_conv3.0.weight 	 torch.float16
regular2_1.ext_conv3.1.weight 	 torch.float16
regular2_1.ext_conv3.1.bias 	 torch.float16
regular2_1.ext_conv3.1.running_mean 	 torch.float16
regular2_1.ext_conv3.1.running_var 	 torch.float16
regular2_1.ext_conv3.1.num_batches_tracked 	 torch.int64
regular2_1.ext_conv3.2.weight 	 torch.float16
regular2_1.out_activation.weight 	 torch.float16
dilated2_2.ext_conv1.0.weight 	 torch.float16
dilated2_2.ext_conv1.1.weight 	 torch.float16
dilated2_2.ext_conv1.1.bias 	 torch.float16
dilated2_2.ext_conv1.1.running_mean 	 torch.float16
dilated2_2.ext_conv1.1.running_var 	 torch.float16
dilated2_2.ext_conv1.1.num_batches_tracked 	 torch.int64
dilated2_2.ext_conv1.2.weight 	 torch.float16
dilated2_2.ext_conv2.0.weight 	 torch.float16
dilated2_2.ext_conv2.1.weight 	 torch.float16
dilated2_2.ext_conv2.1.bias 	 torch.float16
dilated2_2.ext_conv2.1.running_mean 	 torch.float16
dilated2_2.ext_conv2.1.running_var 	 torch.float16
dilated2_2.ext_conv2.1.num_batches_tracked 	 torch.int64
dilated2_2.ext_conv2.2.weight 	 torch.float16
dilated2_2.ext_conv3.0.weight 	 torch.float16
dilated2_2.ext_conv3.1.weight 	 torch.float16
dilated2_2.ext_conv3.1.bias 	 torch.float16
dilated2_2.ext_conv3.1.running_mean 	 torch.float16
dilated2_2.ext_conv3.1.running_var 	 torch.float16
dilated2_2.ext_conv3.1.num_batches_tracked 	 torch.int64
dilated2_2.ext_conv3.2.weight 	 torch.float16
dilated2_2.out_activation.weight 	 torch.float16
asymmetric2_3.ext_conv1.0.weight 	 torch.float16
asymmetric2_3.ext_conv1.1.weight 	 torch.float16
asymmetric2_3.ext_conv1.1.bias 	 torch.float16
asymmetric2_3.ext_conv1.1.running_mean 	 torch.float16
asymmetric2_3.ext_conv1.1.running_var 	 torch.float16
asymmetric2_3.ext_conv1.1.num_batches_tracked 	 torch.int64
asymmetric2_3.ext_conv1.2.weight 	 torch.float16
asymmetric2_3.ext_conv2.0.weight 	 torch.float16
asymmetric2_3.ext_conv2.1.weight 	 torch.float16
asymmetric2_3.ext_conv2.1.bias 	 torch.float16
asymmetric2_3.ext_conv2.1.running_mean 	 torch.float16
asymmetric2_3.ext_conv2.1.running_var 	 torch.float16
asymmetric2_3.ext_conv2.1.num_batches_tracked 	 torch.int64
asymmetric2_3.ext_conv2.2.weight 	 torch.float16
asymmetric2_3.ext_conv2.3.weight 	 torch.float16
asymmetric2_3.ext_conv2.4.weight 	 torch.float16
asymmetric2_3.ext_conv2.4.bias 	 torch.float16
asymmetric2_3.ext_conv2.4.running_mean 	 torch.float16
asymmetric2_3.ext_conv2.4.running_var 	 torch.float16
asymmetric2_3.ext_conv2.4.num_batches_tracked 	 torch.int64
asymmetric2_3.ext_conv2.5.weight 	 torch.float16
asymmetric2_3.ext_conv3.0.weight 	 torch.float16
asymmetric2_3.ext_conv3.1.weight 	 torch.float16
asymmetric2_3.ext_conv3.1.bias 	 torch.float16
asymmetric2_3.ext_conv3.1.running_mean 	 torch.float16
asymmetric2_3.ext_conv3.1.running_var 	 torch.float16
asymmetric2_3.ext_conv3.1.num_batches_tracked 	 torch.int64
asymmetric2_3.ext_conv3.2.weight 	 torch.float16
asymmetric2_3.out_activation.weight 	 torch.float16
dilated2_4.ext_conv1.0.weight 	 torch.float16
dilated2_4.ext_conv1.1.weight 	 torch.float16
dilated2_4.ext_conv1.1.bias 	 torch.float16
dilated2_4.ext_conv1.1.running_mean 	 torch.float16
dilated2_4.ext_conv1.1.running_var 	 torch.float16
dilated2_4.ext_conv1.1.num_batches_tracked 	 torch.int64
dilated2_4.ext_conv1.2.weight 	 torch.float16
dilated2_4.ext_conv2.0.weight 	 torch.float16
dilated2_4.ext_conv2.1.weight 	 torch.float16
dilated2_4.ext_conv2.1.bias 	 torch.float16
dilated2_4.ext_conv2.1.running_mean 	 torch.float16
dilated2_4.ext_conv2.1.running_var 	 torch.float16
dilated2_4.ext_conv2.1.num_batches_tracked 	 torch.int64
dilated2_4.ext_conv2.2.weight 	 torch.float16
dilated2_4.ext_conv3.0.weight 	 torch.float16
dilated2_4.ext_conv3.1.weight 	 torch.float16
dilated2_4.ext_conv3.1.bias 	 torch.float16
dilated2_4.ext_conv3.1.running_mean 	 torch.float16
dilated2_4.ext_conv3.1.running_var 	 torch.float16
dilated2_4.ext_conv3.1.num_batches_tracked 	 torch.int64
dilated2_4.ext_conv3.2.weight 	 torch.float16
dilated2_4.out_activation.weight 	 torch.float16
regular2_5.ext_conv1.0.weight 	 torch.float16
regular2_5.ext_conv1.1.weight 	 torch.float16
regular2_5.ext_conv1.1.bias 	 torch.float16
regular2_5.ext_conv1.1.running_mean 	 torch.float16
regular2_5.ext_conv1.1.running_var 	 torch.float16
regular2_5.ext_conv1.1.num_batches_tracked 	 torch.int64
regular2_5.ext_conv1.2.weight 	 torch.float16
regular2_5.ext_conv2.0.weight 	 torch.float16
regular2_5.ext_conv2.1.weight 	 torch.float16
regular2_5.ext_conv2.1.bias 	 torch.float16
regular2_5.ext_conv2.1.running_mean 	 torch.float16
regular2_5.ext_conv2.1.running_var 	 torch.float16
regular2_5.ext_conv2.1.num_batches_tracked 	 torch.int64
regular2_5.ext_conv2.2.weight 	 torch.float16
regular2_5.ext_conv3.0.weight 	 torch.float16
regular2_5.ext_conv3.1.weight 	 torch.float16
regular2_5.ext_conv3.1.bias 	 torch.float16
regular2_5.ext_conv3.1.running_mean 	 torch.float16
regular2_5.ext_conv3.1.running_var 	 torch.float16
regular2_5.ext_conv3.1.num_batches_tracked 	 torch.int64
regular2_5.ext_conv3.2.weight 	 torch.float16
regular2_5.out_activation.weight 	 torch.float16
dilated2_6.ext_conv1.0.weight 	 torch.float16
dilated2_6.ext_conv1.1.weight 	 torch.float16
dilated2_6.ext_conv1.1.bias 	 torch.float16
dilated2_6.ext_conv1.1.running_mean 	 torch.float16
dilated2_6.ext_conv1.1.running_var 	 torch.float16
dilated2_6.ext_conv1.1.num_batches_tracked 	 torch.int64
dilated2_6.ext_conv1.2.weight 	 torch.float16
dilated2_6.ext_conv2.0.weight 	 torch.float16
dilated2_6.ext_conv2.1.weight 	 torch.float16
dilated2_6.ext_conv2.1.bias 	 torch.float16
dilated2_6.ext_conv2.1.running_mean 	 torch.float16
dilated2_6.ext_conv2.1.running_var 	 torch.float16
dilated2_6.ext_conv2.1.num_batches_tracked 	 torch.int64
dilated2_6.ext_conv2.2.weight 	 torch.float16
dilated2_6.ext_conv3.0.weight 	 torch.float16
dilated2_6.ext_conv3.1.weight 	 torch.float16
dilated2_6.ext_conv3.1.bias 	 torch.float16
dilated2_6.ext_conv3.1.running_mean 	 torch.float16
dilated2_6.ext_conv3.1.running_var 	 torch.float16
dilated2_6.ext_conv3.1.num_batches_tracked 	 torch.int64
dilated2_6.ext_conv3.2.weight 	 torch.float16
dilated2_6.out_activation.weight 	 torch.float16
asymmetric2_7.ext_conv1.0.weight 	 torch.float16
asymmetric2_7.ext_conv1.1.weight 	 torch.float16
asymmetric2_7.ext_conv1.1.bias 	 torch.float16
asymmetric2_7.ext_conv1.1.running_mean 	 torch.float16
asymmetric2_7.ext_conv1.1.running_var 	 torch.float16
asymmetric2_7.ext_conv1.1.num_batches_tracked 	 torch.int64
asymmetric2_7.ext_conv1.2.weight 	 torch.float16
asymmetric2_7.ext_conv2.0.weight 	 torch.float16
asymmetric2_7.ext_conv2.1.weight 	 torch.float16
asymmetric2_7.ext_conv2.1.bias 	 torch.float16
asymmetric2_7.ext_conv2.1.running_mean 	 torch.float16
asymmetric2_7.ext_conv2.1.running_var 	 torch.float16
asymmetric2_7.ext_conv2.1.num_batches_tracked 	 torch.int64
asymmetric2_7.ext_conv2.2.weight 	 torch.float16
asymmetric2_7.ext_conv2.3.weight 	 torch.float16
asymmetric2_7.ext_conv2.4.weight 	 torch.float16
asymmetric2_7.ext_conv2.4.bias 	 torch.float16
asymmetric2_7.ext_conv2.4.running_mean 	 torch.float16
asymmetric2_7.ext_conv2.4.running_var 	 torch.float16
asymmetric2_7.ext_conv2.4.num_batches_tracked 	 torch.int64
asymmetric2_7.ext_conv2.5.weight 	 torch.float16
asymmetric2_7.ext_conv3.0.weight 	 torch.float16
asymmetric2_7.ext_conv3.1.weight 	 torch.float16
asymmetric2_7.ext_conv3.1.bias 	 torch.float16
asymmetric2_7.ext_conv3.1.running_mean 	 torch.float16
asymmetric2_7.ext_conv3.1.running_var 	 torch.float16
asymmetric2_7.ext_conv3.1.num_batches_tracked 	 torch.int64
asymmetric2_7.ext_conv3.2.weight 	 torch.float16
asymmetric2_7.out_activation.weight 	 torch.float16
dilated2_8.ext_conv1.0.weight 	 torch.float16
dilated2_8.ext_conv1.1.weight 	 torch.float16
dilated2_8.ext_conv1.1.bias 	 torch.float16
dilated2_8.ext_conv1.1.running_mean 	 torch.float16
dilated2_8.ext_conv1.1.running_var 	 torch.float16

dilated2_8.ext_conv1.1.num_batches_tracked 	 torch.int64
dilated2_8.ext_conv1.2.weight 	 torch.float16
dilated2_8.ext_conv2.0.weight 	 torch.float16
dilated2_8.ext_conv2.1.weight 	 torch.float16
dilated2_8.ext_conv2.1.bias 	 torch.float16
dilated2_8.ext_conv2.1.running_mean 	 torch.float16
dilated2_8.ext_conv2.1.running_var 	 torch.float16
dilated2_8.ext_conv2.1.num_batches_tracked 	 torch.int64
dilated2_8.ext_conv2.2.weight 	 torch.float16
dilated2_8.ext_conv3.0.weight 	 torch.float16
dilated2_8.ext_conv3.1.weight 	 torch.float16
dilated2_8.ext_conv3.1.bias 	 torch.float16
dilated2_8.ext_conv3.1.running_mean 	 torch.float16
dilated2_8.ext_conv3.1.running_var 	 torch.float16
dilated2_8.ext_conv3.1.num_batches_tracked 	 torch.int64
dilated2_8.ext_conv3.2.weight 	 torch.float16
dilated2_8.out_activation.weight 	 torch.float16
regular3_0.ext_conv1.0.weight 	 torch.float16
regular3_0.ext_conv1.1.weight 	 torch.float16
regular3_0.ext_conv1.1.bias 	 torch.float16
regular3_0.ext_conv1.1.running_mean 	 torch.float16
regular3_0.ext_conv1.1.running_var 	 torch.float16
regular3_0.ext_conv1.1.num_batches_tracked 	 torch.int64
regular3_0.ext_conv1.2.weight 	 torch.float16
regular3_0.ext_conv2.0.weight 	 torch.float16
regular3_0.ext_conv2.1.weight 	 torch.float16
regular3_0.ext_conv2.1.bias 	 torch.float16
regular3_0.ext_conv2.1.running_mean 	 torch.float16
regular3_0.ext_conv2.1.running_var 	 torch.float16
regular3_0.ext_conv2.1.num_batches_tracked 	 torch.int64
regular3_0.ext_conv2.2.weight 	 torch.float16
regular3_0.ext_conv3.0.weight 	 torch.float16
regular3_0.ext_conv3.1.weight 	 torch.float16
regular3_0.ext_conv3.1.bias 	 torch.float16
regular3_0.ext_conv3.1.running_mean 	 torch.float16
regular3_0.ext_conv3.1.running_var 	 torch.float16
regular3_0.ext_conv3.1.num_batches_tracked 	 torch.int64
regular3_0.ext_conv3.2.weight 	 torch.float16
regular3_0.out_activation.weight 	 torch.float16
dilated3_1.ext_conv1.0.weight 	 torch.float16
dilated3_1.ext_conv1.1.weight 	 torch.float16
dilated3_1.ext_conv1.1.bias 	 torch.float16
dilated3_1.ext_conv1.1.running_mean 	 torch.float16
dilated3_1.ext_conv1.1.running_var 	 torch.float16
dilated3_1.ext_conv1.1.num_batches_tracked 	 torch.int64
dilated3_1.ext_conv1.2.weight 	 torch.float16
dilated3_1.ext_conv2.0.weight 	 torch.float16
dilated3_1.ext_conv2.1.weight 	 torch.float16
dilated3_1.ext_conv2.1.bias 	 torch.float16
dilated3_1.ext_conv2.1.running_mean 	 torch.float16
dilated3_1.ext_conv2.1.running_var 	 torch.float16
dilated3_1.ext_conv2.1.num_batches_tracked 	 torch.int64
dilated3_1.ext_conv2.2.weight 	 torch.float16
dilated3_1.ext_conv3.0.weight 	 torch.float16
dilated3_1.ext_conv3.1.weight 	 torch.float16
dilated3_1.ext_conv3.1.bias 	 torch.float16
dilated3_1.ext_conv3.1.running_mean 	 torch.float16
dilated3_1.ext_conv3.1.running_var 	 torch.float16
dilated3_1.ext_conv3.1.num_batches_tracked 	 torch.int64
dilated3_1.ext_conv3.2.weight 	 torch.float16
dilated3_1.out_activation.weight 	 torch.float16
asymmetric3_2.ext_conv1.0.weight 	 torch.float16
asymmetric3_2.ext_conv1.1.weight 	 torch.float16
asymmetric3_2.ext_conv1.1.bias 	 torch.float16
asymmetric3_2.ext_conv1.1.running_mean 	 torch.float16
asymmetric3_2.ext_conv1.1.running_var 	 torch.float16
asymmetric3_2.ext_conv1.1.num_batches_tracked 	 torch.int64
asymmetric3_2.ext_conv1.2.weight 	 torch.float16
asymmetric3_2.ext_conv2.0.weight 	 torch.float16
asymmetric3_2.ext_conv2.1.weight 	 torch.float16
asymmetric3_2.ext_conv2.1.bias 	 torch.float16
asymmetric3_2.ext_conv2.1.running_mean 	 torch.float16
asymmetric3_2.ext_conv2.1.running_var 	 torch.float16
asymmetric3_2.ext_conv2.1.num_batches_tracked 	 torch.int64
asymmetric3_2.ext_conv2.2.weight 	 torch.float16
asymmetric3_2.ext_conv2.3.weight 	 torch.float16
asymmetric3_2.ext_conv2.4.weight 	 torch.float16
asymmetric3_2.ext_conv2.4.bias 	 torch.float16
asymmetric3_2.ext_conv2.4.running_mean 	 torch.float16
asymmetric3_2.ext_conv2.4.running_var 	 torch.float16
asymmetric3_2.ext_conv2.4.num_batches_tracked 	 torch.int64
asymmetric3_2.ext_conv2.5.weight 	 torch.float16
asymmetric3_2.ext_conv3.0.weight 	 torch.float16
asymmetric3_2.ext_conv3.1.weight 	 torch.float16
asymmetric3_2.ext_conv3.1.bias 	 torch.float16
asymmetric3_2.ext_conv3.1.running_mean 	 torch.float16
asymmetric3_2.ext_conv3.1.running_var 	 torch.float16
asymmetric3_2.ext_conv3.1.num_batches_tracked 	 torch.int64
asymmetric3_2.ext_conv3.2.weight 	 torch.float16
asymmetric3_2.out_activation.weight 	 torch.float16
dilated3_3.ext_conv1.0.weight 	 torch.float16
dilated3_3.ext_conv1.1.weight 	 torch.float16
dilated3_3.ext_conv1.1.bias 	 torch.float16
dilated3_3.ext_conv1.1.running_mean 	 torch.float16
dilated3_3.ext_conv1.1.running_var 	 torch.float16
dilated3_3.ext_conv1.1.num_batches_tracked 	 torch.int64
dilated3_3.ext_conv1.2.weight 	 torch.float16
dilated3_3.ext_conv2.0.weight 	 torch.float16
dilated3_3.ext_conv2.1.weight 	 torch.float16
dilated3_3.ext_conv2.1.bias 	 torch.float16
dilated3_3.ext_conv2.1.running_mean 	 torch.float16
dilated3_3.ext_conv2.1.running_var 	 torch.float16
dilated3_3.ext_conv2.1.num_batches_tracked 	 torch.int64
dilated3_3.ext_conv2.2.weight 	 torch.float16
dilated3_3.ext_conv3.0.weight 	 torch.float16
dilated3_3.ext_conv3.1.weight 	 torch.float16
dilated3_3.ext_conv3.1.bias 	 torch.float16
dilated3_3.ext_conv3.1.running_mean 	 torch.float16
dilated3_3.ext_conv3.1.running_var 	 torch.float16
dilated3_3.ext_conv3.1.num_batches_tracked 	 torch.int64
dilated3_3.ext_conv3.2.weight 	 torch.float16
dilated3_3.out_activation.weight 	 torch.float16
regular3_4.ext_conv1.0.weight 	 torch.float16
regular3_4.ext_conv1.1.weight 	 torch.float16
regular3_4.ext_conv1.1.bias 	 torch.float16
regular3_4.ext_conv1.1.running_mean 	 torch.float16
regular3_4.ext_conv1.1.running_var 	 torch.float16
regular3_4.ext_conv1.1.num_batches_tracked 	 torch.int64
regular3_4.ext_conv1.2.weight 	 torch.float16
regular3_4.ext_conv2.0.weight 	 torch.float16
regular3_4.ext_conv2.1.weight 	 torch.float16
regular3_4.ext_conv2.1.bias 	 torch.float16
regular3_4.ext_conv2.1.running_mean 	 torch.float16
regular3_4.ext_conv2.1.running_var 	 torch.float16
regular3_4.ext_conv2.1.num_batches_tracked 	 torch.int64
regular3_4.ext_conv2.2.weight 	 torch.float16
regular3_4.ext_conv3.0.weight 	 torch.float16
regular3_4.ext_conv3.1.weight 	 torch.float16
regular3_4.ext_conv3.1.bias 	 torch.float16
regular3_4.ext_conv3.1.running_mean 	 torch.float16
regular3_4.ext_conv3.1.running_var 	 torch.float16
regular3_4.ext_conv3.1.num_batches_tracked 	 torch.int64
regular3_4.ext_conv3.2.weight 	 torch.float16
regular3_4.out_activation.weight 	 torch.float16
dilated3_5.ext_conv1.0.weight 	 torch.float16
dilated3_5.ext_conv1.1.weight 	 torch.float16
dilated3_5.ext_conv1.1.bias 	 torch.float16
dilated3_5.ext_conv1.1.running_mean 	 torch.float16
dilated3_5.ext_conv1.1.running_var 	 torch.float16
dilated3_5.ext_conv1.1.num_batches_tracked 	 torch.int64
dilated3_5.ext_conv1.2.weight 	 torch.float16
dilated3_5.ext_conv2.0.weight 	 torch.float16
dilated3_5.ext_conv2.1.weight 	 torch.float16
dilated3_5.ext_conv2.1.bias 	 torch.float16
dilated3_5.ext_conv2.1.running_mean 	 torch.float16
dilated3_5.ext_conv2.1.running_var 	 torch.float16
dilated3_5.ext_conv2.1.num_batches_tracked 	 torch.int64
dilated3_5.ext_conv2.2.weight 	 torch.float16
dilated3_5.ext_conv3.0.weight 	 torch.float16
dilated3_5.ext_conv3.1.weight 	 torch.float16
dilated3_5.ext_conv3.1.bias 	 torch.float16
dilated3_5.ext_conv3.1.running_mean 	 torch.float16
dilated3_5.ext_conv3.1.running_var 	 torch.float16
dilated3_5.ext_conv3.1.num_batches_tracked 	 torch.int64
dilated3_5.ext_conv3.2.weight 	 torch.float16
dilated3_5.out_activation.weight 	 torch.float16
asymmetric3_6.ext_conv1.0.weight 	 torch.float16
asymmetric3_6.ext_conv1.1.weight 	 torch.float16
asymmetric3_6.ext_conv1.1.bias 	 torch.float16
asymmetric3_6.ext_conv1.1.running_mean 	 torch.float16
asymmetric3_6.ext_conv1.1.running_var 	 torch.float16
asymmetric3_6.ext_conv1.1.num_batches_tracked 	 torch.int64
asymmetric3_6.ext_conv1.2.weight 	 torch.float16
asymmetric3_6.ext_conv2.0.weight 	 torch.float16
asymmetric3_6.ext_conv2.1.weight 	 torch.float16
asymmetric3_6.ext_conv2.1.bias 	 torch.float16
asymmetric3_6.ext_conv2.1.running_mean 	 torch.float16
asymmetric3_6.ext_conv2.1.running_var 	 torch.float16
asymmetric3_6.ext_conv2.1.num_batches_tracked 	 torch.int64
asymmetric3_6.ext_conv2.2.weight 	 torch.float16
asymmetric3_6.ext_conv2.3.weight 	 torch.float16
asymmetric3_6.ext_conv2.4.weight 	 torch.float16
asymmetric3_6.ext_conv2.4.bias 	 torch.float16
asymmetric3_6.ext_conv2.4.running_mean 	 torch.float16
asymmetric3_6.ext_conv2.4.running_var 	 torch.float16
asymmetric3_6.ext_conv2.4.num_batches_tracked 	 torch.int64
asymmetric3_6.ext_conv2.5.weight 	 torch.float16
asymmetric3_6.ext_conv3.0.weight 	 torch.float16
asymmetric3_6.ext_conv3.1.weight 	 torch.float16
asymmetric3_6.ext_conv3.1.bias 	 torch.float16
asymmetric3_6.ext_conv3.1.running_mean 	 torch.float16
asymmetric3_6.ext_conv3.1.running_var 	 torch.float16
asymmetric3_6.ext_conv3.1.num_batches_tracked 	 torch.int64
asymmetric3_6.ext_conv3.2.weight 	 torch.float16
asymmetric3_6.out_activation.weight 	 torch.float16
dilated3_7.ext_conv1.0.weight 	 torch.float16
dilated3_7.ext_conv1.1.weight 	 torch.float16
dilated3_7.ext_conv1.1.bias 	 torch.float16
dilated3_7.ext_conv1.1.running_mean 	 torch.float16
dilated3_7.ext_conv1.1.running_var 	 torch.float16
dilated3_7.ext_conv1.1.num_batches_tracked 	 torch.int64
dilated3_7.ext_conv1.2.weight 	 torch.float16
dilated3_7.ext_conv2.0.weight 	 torch.float16
dilated3_7.ext_conv2.1.weight 	 torch.float16
dilated3_7.ext_conv2.1.bias 	 torch.float16
dilated3_7.ext_conv2.1.running_mean 	 torch.float16
dilated3_7.ext_conv2.1.running_var 	 torch.float16
dilated3_7.ext_conv2.1.num_batches_tracked 	 torch.int64
dilated3_7.ext_conv2.2.weight 	 torch.float16
dilated3_7.ext_conv3.0.weight 	 torch.float16
dilated3_7.ext_conv3.1.weight 	 torch.float16
dilated3_7.ext_conv3.1.bias 	 torch.float16
dilated3_7.ext_conv3.1.running_mean 	 torch.float16
dilated3_7.ext_conv3.1.running_var 	 torch.float16
dilated3_7.ext_conv3.1.num_batches_tracked 	 torch.int64
dilated3_7.ext_conv3.2.weight 	 torch.float16
dilated3_7.out_activation.weight 	 torch.float16
upsample4_0.main_conv1.0.weight 	 torch.float16
upsample4_0.main_conv1.1.weight 	 torch.float16
upsample4_0.main_conv1.1.bias 	 torch.float16
upsample4_0.main_conv1.1.running_mean 	 torch.float16
upsample4_0.main_conv1.1.running_var 	 torch.float16
upsample4_0.main_conv1.1.num_batches_tracked 	 torch.int64
upsample4_0.ext_conv1.0.weight 	 torch.float16
upsample4_0.ext_conv1.1.weight 	 torch.float16
upsample4_0.ext_conv1.1.bias 	 torch.float16
upsample4_0.ext_conv1.1.running_mean 	 torch.float16
upsample4_0.ext_conv1.1.running_var 	 torch.float16
upsample4_0.ext_conv1.1.num_batches_tracked 	 torch.int64
upsample4_0.ext_tconv1.weight 	 torch.float16
upsample4_0.ext_tconv1_bnorm.weight 	 torch.float16
upsample4_0.ext_tconv1_bnorm.bias 	 torch.float16
upsample4_0.ext_tconv1_bnorm.running_mean 	 torch.float16
upsample4_0.ext_tconv1_bnorm.running_var 	 torch.float16
upsample4_0.ext_tconv1_bnorm.num_batches_tracked 	 torch.int64
upsample4_0.ext_conv2.0.weight 	 torch.float16
upsample4_0.ext_conv2.1.weight 	 torch.float16
upsample4_0.ext_conv2.1.bias 	 torch.float16
upsample4_0.ext_conv2.1.running_mean 	 torch.float16
upsample4_0.ext_conv2.1.running_var 	 torch.float16
upsample4_0.ext_conv2.1.num_batches_tracked 	 torch.int64
regular4_1.ext_conv1.0.weight 	 torch.float16
regular4_1.ext_conv1.1.weight 	 torch.float16
regular4_1.ext_conv1.1.bias 	 torch.float16
regular4_1.ext_conv1.1.running_mean 	 torch.float16
regular4_1.ext_conv1.1.running_var 	 torch.float16
regular4_1.ext_conv1.1.num_batches_tracked 	 torch.int64
regular4_1.ext_conv2.0.weight 	 torch.float16
regular4_1.ext_conv2.1.weight 	 torch.float16
regular4_1.ext_conv2.1.bias 	 torch.float16
regular4_1.ext_conv2.1.running_mean 	 torch.float16
regular4_1.ext_conv2.1.running_var 	 torch.float16
regular4_1.ext_conv2.1.num_batches_tracked 	 torch.int64
regular4_1.ext_conv3.0.weight 	 torch.float16
regular4_1.ext_conv3.1.weight 	 torch.float16
regular4_1.ext_conv3.1.bias 	 torch.float16
regular4_1.ext_conv3.1.running_mean 	 torch.float16
regular4_1.ext_conv3.1.running_var 	 torch.float16
regular4_1.ext_conv3.1.num_batches_tracked 	 torch.int64
regular4_2.ext_conv1.0.weight 	 torch.float16
regular4_2.ext_conv1.1.weight 	 torch.float16
regular4_2.ext_conv1.1.bias 	 torch.float16
regular4_2.ext_conv1.1.running_mean 	 torch.float16
regular4_2.ext_conv1.1.running_var 	 torch.float16
regular4_2.ext_conv1.1.num_batches_tracked 	 torch.int64
regular4_2.ext_conv2.0.weight 	 torch.float16
regular4_2.ext_conv2.1.weight 	 torch.float16
regular4_2.ext_conv2.1.bias 	 torch.float16
regular4_2.ext_conv2.1.running_mean 	 torch.float16
regular4_2.ext_conv2.1.running_var 	 torch.float16
regular4_2.ext_conv2.1.num_batches_tracked 	 torch.int64
regular4_2.ext_conv3.0.weight 	 torch.float16
regular4_2.ext_conv3.1.weight 	 torch.float16
regular4_2.ext_conv3.1.bias 	 torch.float16
regular4_2.ext_conv3.1.running_mean 	 torch.float16
regular4_2.ext_conv3.1.running_var 	 torch.float16
regular4_2.ext_conv3.1.num_batches_tracked 	 torch.int64
upsample5_0.main_conv1.0.weight 	 torch.float16
upsample5_0.main_conv1.1.weight 	 torch.float16
upsample5_0.main_conv1.1.bias 	 torch.float16
upsample5_0.main_conv1.1.running_mean 	 torch.float16
upsample5_0.main_conv1.1.running_var 	 torch.float16
upsample5_0.main_conv1.1.num_batches_tracked 	 torch.int64
upsample5_0.ext_conv1.0.weight 	 torch.float16
upsample5_0.ext_conv1.1.weight 	 torch.float16
upsample5_0.ext_conv1.1.bias 	 torch.float16
upsample5_0.ext_conv1.1.running_mean 	 torch.float16
upsample5_0.ext_conv1.1.running_var 	 torch.float16
upsample5_0.ext_conv1.1.num_batches_tracked 	 torch.int64
upsample5_0.ext_tconv1.weight 	 torch.float16
upsample5_0.ext_tconv1_bnorm.weight 	 torch.float16
upsample5_0.ext_tconv1_bnorm.bias 	 torch.float16
upsample5_0.ext_tconv1_bnorm.running_mean 	 torch.float16
upsample5_0.ext_tconv1_bnorm.running_var 	 torch.float16
upsample5_0.ext_tconv1_bnorm.num_batches_tracked 	 torch.int64
upsample5_0.ext_conv2.0.weight 	 torch.float16
upsample5_0.ext_conv2.1.weight 	 torch.float16
upsample5_0.ext_conv2.1.bias 	 torch.float16
upsample5_0.ext_conv2.1.running_mean 	 torch.float16
upsample5_0.ext_conv2.1.running_var 	 torch.float16
upsample5_0.ext_conv2.1.num_batches_tracked 	 torch.int64
regular5_1.ext_conv1.0.weight 	 torch.float16
regular5_1.ext_conv1.1.weight 	 torch.float16
regular5_1.ext_conv1.1.bias 	 torch.float16
regular5_1.ext_conv1.1.running_mean 	 torch.float16
regular5_1.ext_conv1.1.running_var 	 torch.float16
regular5_1.ext_conv1.1.num_batches_tracked 	 torch.int64
regular5_1.ext_conv2.0.weight 	 torch.float16
regular5_1.ext_conv2.1.weight 	 torch.float16

regular5_1.ext_conv2.1.bias 	 torch.float16
regular5_1.ext_conv2.1.running_mean 	 torch.float16
regular5_1.ext_conv2.1.running_var 	 torch.float16
regular5_1.ext_conv2.1.num_batches_tracked 	 torch.int64
regular5_1.ext_conv3.0.weight 	 torch.float16
regular5_1.ext_conv3.1.weight 	 torch.float16
regular5_1.ext_conv3.1.bias 	 torch.float16
regular5_1.ext_conv3.1.running_mean 	 torch.float16
regular5_1.ext_conv3.1.running_var 	 torch.float16
regular5_1.ext_conv3.1.num_batches_tracked 	 torch.int64
transposed_conv.weight 	 torch.float16

Could you post an executable code snippet using random inputs so that we could reproduce this issue?

@ptrblck here is everything you need to reproduce the issue. The repo is very minimalistic (2.4mb) and specifily created to reproduce the issue. I assume you have all the requirements allready fulfilled (pillow, torch and numpy).

I am going to remove the repo as I found the solution to my issue.

I found the what caused my problem. The default dtype of my model must have been float. Setting the following line above my model definition solved the problem:

torch.set_default_dtype(torch.float16)

It was a little tricky to notice it, but it does make sense. Thank you @ptrblck for the discussion and your questions. In the end they were responsible for me finding the solution.

Pytorch tensor arguments with dtype=pytorch.bool caused the CrossEntropyLoss function to throw an error. Neither the .long() casting method nor the other suggestions here helped:
RuntimeError: “log_softmax_lastdim_kernel_impl” not implemented for ‘Long’
But adding 0 to the arguments is a workaround, which is strange.

hi @ptrblck , Just wanted to know what is the ideal way to do this casting between model params. I keep getting this error in backprop although the loss is in float64 format.

tot_loss = reconstruct_loss + args.w_reg * reg_loss + args.a_reg * k_loss
tot_loss = tot_loss.double()

TOT LOSS: tensor(-3.3673e+15, grad_fn=)
Traceback (most recent call last):
File “AE.py”, line 305, in
tot_loss.backward()
File “/usr/local/lib/python3.7/dist-packages/torch/_tensor.py”, line 363, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
File “/usr/local/lib/python3.7/dist-packages/torch/autograd/init.py”, line 175, in backward
allow_unreachable=True, accumulate_grad=True) # Calls into the C++ engine to run the backward pass
RuntimeError: Found dtype Float but expected Double