How do I set dtype for torch.nn.Conv2d?
It seems that default dtype for torcn.nn.Conv2d is float.
I want to explicitly set dtype for torch.nn.Conv2d ,
as I set dtype for tensor as below:
a = torch.tensor([1,2,3], dtype=torch.int8)
How do I set dtype for torch.nn.Conv2d?
It seems that default dtype for torcn.nn.Conv2d is float.
I want to explicitly set dtype for torch.nn.Conv2d ,
as I set dtype for tensor as below:
a = torch.tensor([1,2,3], dtype=torch.int8)
Hi,
You can use your_mod.type(dtype)
to change its type. Or your_mod.double()
to change to double precision.
Be careful that the conv module is not implemented for every types
Super simple
Thanks!