Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead

You might be looking for
DeviceCastTransform — torchrl main documentation?