How to convert a pytorch nn module to float 64

I want a simple technique that will convert a pytorch nn.module to a float 64 model.

To transform all parameters and buffers of a module to float64 tensors, use model.double().

1 Like

What is the Pytorch buffer?

Buffers are tensors, which are registered to the parent module but don’t require gradients. E.g. the running stats in batchnorm layers are buffers. You can register buffers via self.register_buffer inside an nn.Module.