Can I explicitly set dtype for torch.nn.Conv2d?

How do I set dtype for torch.nn.Conv2d?
It seems that default dtype for torcn.nn.Conv2d is float.

I want to explicitly set dtype for torch.nn.Conv2d ,
as I set dtype for tensor as below:

a = torch.tensor([1,2,3], dtype=torch.int8)

Hi,

You can use your_mod.type(dtype) to change its type. Or your_mod.double() to change to double precision.
Be careful that the conv module is not implemented for every types :wink:

1 Like

Super simple
Thanks!

Hello thank you for this solution but can I specify which layers? since in my application I need some layers in int8 and the others I want to keep in float

I found the solution if some one need it later on:

    with torch.no_grad():
        for layer_name, param in model.named_parameters()():
            if "weights" in layer_name: # you can apply any condition you want 
                param.data = param.type(torch.uint8)