How do I set dtype for torch.nn.Conv2d?
It seems that default dtype for torcn.nn.Conv2d is float.
I want to explicitly set dtype for torch.nn.Conv2d ,
as I set dtype for tensor as below:
a = torch.tensor([1,2,3], dtype=torch.int8)
How do I set dtype for torch.nn.Conv2d?
It seems that default dtype for torcn.nn.Conv2d is float.
I want to explicitly set dtype for torch.nn.Conv2d ,
as I set dtype for tensor as below:
a = torch.tensor([1,2,3], dtype=torch.int8)
Hi,
You can use your_mod.type(dtype)
to change its type. Or your_mod.double()
to change to double precision.
Be careful that the conv module is not implemented for every types
Super simple
Thanks!
Hello thank you for this solution but can I specify which layers? since in my application I need some layers in int8 and the others I want to keep in float
I found the solution if some one need it later on:
with torch.no_grad():
for layer_name, param in model.named_parameters()():
if "weights" in layer_name: # you can apply any condition you want
param.data = param.type(torch.uint8)