I am having trouble understanding how moving custom modules between devices works in Pytorch.
I have to implement some custom logic on a device move.
Example:
class Fred(nn.Module):
def __init__(self):
super().__init__()
def to(self, device):
print("Fred::to")
def _apply(self, device):
print("Fred::_apply")
class Bert(nn.Module):
def __init__(self):
super().__init__()
self.fred = Fred()
def to(self, device):
print("Bert::to")
super().to(device)
bert = Bert()
bert.to("cpu")
Bert::to
Fred::_apply
I can put my logic into the Fred::_apply method; but my question is is that good practice?
I had imagined that the Bert.to call would recursively call the to methods on all the member modules, ie end up calling Fred.to, but this does not appear to happen.
My concern is that overriding _apply is just a “hack” using internal PyTorch code and may break in future releases. Is there a more pytorchonic (ie good style) way of handling this?
Thank you,
Julian.