Moving custom modules between devices

I am having trouble understanding how moving custom modules between devices works in Pytorch.
I have to implement some custom logic on a device move.

class Fred(nn.Module):
    def __init__(self):
    def to(self, device):
    def _apply(self, device):

class Bert(nn.Module):
    def __init__(self):
        self.fred = Fred()
    def to(self, device):

bert = Bert()"cpu")


I can put my logic into the Fred::_apply method; but my question is is that good practice?
I had imagined that the call would recursively call the to methods on all the member modules, ie end up calling, but this does not appear to happen.

My concern is that overriding _apply is just a “hack” using internal PyTorch code and may break in future releases. Is there a more pytorchonic (ie good style) way of handling this?

Thank you,

The internal to() method calls self._apply(convert) as seen here. _apply calls then:

for module in self.children():

so you might take a look at these methods and try to add your custom logic to them.

I also agree that overriding internal methods might easily break, but also your use case dictates to override the to() method so I’m unsure if there would be a cleaner way.
Maybe the Extening PyTorch docs migth be helpful.

Thank you for your reply.

My specific use case is also related to issue #7795 (codification API for distributions), which doesn’t seem to have gained much traction, although I recognise there can be some subtleties, especially related to parameters which are updated.

I have just read the PyTorch documentation suggests nn.Modules should support device placement on construction, so if I follow that consistently it should remove my need to get the .to function to work.

Thank you for explaining the _apply mechanism.

1 Like