Background: When constructing a model in the development, I prefer to use high numerical precision, i.e., double, e.g., when initializing constant tensor attributes from numpy variables. For the actual use, the model can then be converted to a lower precision by
.float(). Now I’m constructing a model that has a pre-computed complex-valued tensor attribute. Earlier I’ve used the split (real/imag) complex representation, but starting a new project with a new torch version I decided to give the complex dtypes a test.
This triggers the need for clarification:
model.float(), only the floating point attributes are converted to the given precision, but the complex-valued ones are not touched. Calling
model.to(torch.float) all attribute buffers are converted to
float, even the complex-valued. What I was expecting was that the
complex128 buffers were converted to
Reading the documentation of
torch.nn.Module, this is what is claimed to happen. Just a bit confusing, when assuming
to(torch.float) to be the same, similar to what happens with
torch.Tensor. Assuming that this difference in the behaviour is intended, what would be the best way to handle the situation?
- Construct the model using the correct precision from the start? This is probably the best option with the only drawback of minimal loss of development flexibility.
.to()in the model to handle the precision of the complex-valued buffers correctly?
- Forget the complex data type and continue using the split representation? Which in turn makes using functions like FFT more difficult.