Background: When constructing a model in the development, I prefer to use high numerical precision, i.e., double, e.g., when initializing constant tensor attributes from numpy variables. For the actual use, the model can then be converted to a lower precision by .float()
. Now I’m constructing a model that has a pre-computed complex-valued tensor attribute. Earlier I’ve used the split (real/imag) complex representation, but starting a new project with a new torch version I decided to give the complex dtypes a test.
This triggers the need for clarification:
When calling model.float()
, only the floating point attributes are converted to the given precision, but the complex-valued ones are not touched. Calling model.to(torch.float)
all attribute buffers are converted to float
, even the complex-valued. What I was expecting was that the complex128
buffers were converted to complex64
.
Reading the documentation of torch.nn.Module
, this is what is claimed to happen. Just a bit confusing, when assuming float()
and to(torch.float)
to be the same, similar to what happens with torch.Tensor
. Assuming that this difference in the behaviour is intended, what would be the best way to handle the situation?
- Construct the model using the correct precision from the start? This is probably the best option with the only drawback of minimal loss of development flexibility.
- Overloading
.to()
in the model to handle the precision of the complex-valued buffers correctly? - Forget the complex data type and continue using the split representation? Which in turn makes using functions like FFT more difficult.