How to customize data types that can be transferred between CPU and GPU

When I design my network I need to work with “complex” data whose leaf attributes are tensors. For example, CameraSpecification contains attributes:

  • position, a tensor
  • transform, a tensor
  • intrinsics, data of another type CameraIntrinsics containing:
    • fov, a tensor
    • pixel_coordinates, a tensor

Then these two classes are nothing but groups of tensors, so in theory they can be safely transferred between CPU and GPU and can be generated during forward(). However, if I generate them during forward(), I will have to move the tensor attributes of them using to(). This is especially inconvenient when I use PytorchLightning. Is there a more elegant way to group tensors and to group groups of tensors? Thank you!

I’m not sure, if the posted object is used as the input or inside the model as e.g. a custom parameter.
In the former case, I would guess you could push the data to the GPU before executing the forward pass, while the parameters would be pushed to the device after initializing the model, so I’m unsure why the device transfer is used in the forward.
In any case, I don’t know what limitations Lightning has, so could you explain why pushing the data to the device won’t work?

I mean, how can I get rid of .to()? Is it a way to tell PyTorch or Lightning that my class is just an aggregation of tensors so then they can just automatically implement and call .to() for me? I was expecting something like this

class TensorGroup(torch.TensorAggregation):
     def __init__(self):
        self.tensor1 = torch.rand()
        self.tensor2 = torch.rand()

# in a Lightning training step
def training_step(self,batch, batch_idx):
    some_tensor_group = TensorGroup().to(self.device) # which is the same as some_tensor = torch.rand(1).to(self.device)
    # or even better
    some_tensor_group = TensorGroup() # so I can get rid of .to()

I don’t know about Lightning, but that’s not possible in PyTorch, but you could create a new custom object and implement the to() method, if you think it’s making the code cleaner somehow.

I still don’t understand the use case, but given your current TensorGroup code snippet, you could try to use a custom nn.Module instead and register the tensors as buffers, which would then push them to the desired device via the to() operation.