I understand that one can write custom backward() override, I read thru docs and some discussion and I feel a little stupid, the question is not hard, so someone is bound to give me a single line asnwer

I was reading Veit 2016 Residual Networks Behave Like Ensembles of Relatively Shallow Networks
And I wanted to test the claims (page 6) on other models. But first I need to modify resnet.

> To sample a path of length k, we first feed a batch forward through the whole network. During the backward pass, we randomly sample k residual blocks. For those k blocks, we only propagate through the residual module; for the remaining n−k blocks, we only propagate through the skip connection.

So I dont really need to modify any low level computation, just the block, conditionally,
the forward will not change but in some cases the backward should appear to correspond to a shorter forward

``````    def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out

def run_backwards_on_sometimes_appearing_as_forward(self, x):
out = self.relu(x)
return out
``````

Seems simple but im blanking out. thanks

Maybe as simple as adding a flag to the class, `self.mode=0` then

``````   def forward(self, x):
identity = x
if self.mode:
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
else:
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu1(out)
return out
``````

the modified resnet does not fail; i havent checked what gradients are doing yet, but seems should work… if anyone sees any issue flag me. thanks.