I understand that one can write custom backward() override, I read thru docs and some discussion and I feel a little stupid, the question is not hard, so someone is bound to give me a single line asnwer
I was reading Veit 2016 Residual Networks Behave Like Ensembles of Relatively Shallow Networks
And I wanted to test the claims (page 6) on other models. But first I need to modify resnet.
> To sample a path of length k, we first feed a batch forward through the whole network. During the backward pass, we randomly sample k residual blocks. For those k blocks, we only propagate through the residual module; for the remaining n−k blocks, we only propagate through the skip connection.
So I dont really need to modify any low level computation, just the block, conditionally,
the forward will not change but in some cases the backward should appear to correspond to a shorter forward
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
def run_backwards_on_sometimes_appearing_as_forward(self, x):
out = self.relu(x)
return out
Seems simple but im blanking out. thanks