Hi,
tl;dr: is it possible to make a checkpoint for a whole sub-network and learn its weights with a nice api?
I’ve learned that i can replace
result = self.net(x, y)[0]
by
result = torch.utils.checkpoint.checkpoint(self.net, x, y)[0]
where in my case self. net is a sub-model with some logic and several layers, and x and y are inputs from some earlier layer. As far as I understand, x and y will be stored during forward, but no other value inside self. net. On the backward pass, the gradient coming from result will see the checkpoint and recalculate everything to get the gradient for x and y. This trades computation for memory.
Previously i didn’t need the gradients in self. net (it was pre-learned), but now I do. So I would like to have something like
result = torch.utils.checkpoint.checkpoint(self.net, self.net.parameters(), x, y)[0]
As far as I understand, technically this is very similar to the first case, but the API doesn’t allow it as self. net. parameters() is an iterator.
I’m not very experienced with python and pytorch (only started last summer), but I managed to do the following (and it seems to work):
class _NetCheckpointWrapper:
def __init__(self, net, x, y):
self.net = net
self.x = x
self.y = y
def __call__(self, *args, **kwargs):
return self.net(self.x, self.bias)
# forward:
wrapper = _NetCheckpointWrapper(self.net, x, y)
net_params = tuple(self.net.parameters())
result = torch.utils.checkpoint.checkpoint(wrapper, *net_params)[0]
This doesn’t look very good to me as a programmer, but maybe it is idiomatic in python/pytorch? I’m not sure about the usage of net_params – coming from c++ i’m afraid net_params might refer to different objects than self. net. parameters(), or that it might change in the future. It is quite verbose as well.
Is there a better way to achieve that?