Can I explicitly set dtype for torch.nn.Conv2d?

How do I set dtype for torch.nn.Conv2d?
It seems that default dtype for torcn.nn.Conv2d is float.

I want to explicitly set dtype for torch.nn.Conv2d ,
as I set dtype for tensor as below:

a = torch.tensor([1,2,3], dtype=torch.int8)


You can use your_mod.type(dtype) to change its type. Or your_mod.double() to change to double precision.
Be careful that the conv module is not implemented for every types :wink:

1 Like

Super simple