Conditional cusom autograd question

I understand that one can write custom backward() override, I read thru docs and some discussion and I feel a little stupid, the question is not hard, so someone is bound to give me a single line asnwer

I was reading Veit 2016 Residual Networks Behave Like Ensembles of Relatively Shallow Networks
And I wanted to test the claims (page 6) on other models. But first I need to modify resnet.

> To sample a path of length k, we first feed a batch forward through the whole network. During the backward pass, we randomly sample k residual blocks. For those k blocks, we only propagate through the residual module; for the remaining n−k blocks, we only propagate through the skip connection.

So I dont really need to modify any low level computation, just the block, conditionally,
the forward will not change but in some cases the backward should appear to correspond to a shorter forward

    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
        out = self.relu(out)
        return out

    def run_backwards_on_sometimes_appearing_as_forward(self, x):  
        out = self.relu(x)
        return out

Seems simple but im blanking out. thanks

Maybe as simple as adding a flag to the class, self.mode=0 then

   def forward(self, x):
       identity = x
        if self.mode:
            with torch.no_grad():
                out = self.conv1(x)
                out = self.bn1(out)
                out = self.relu(out)
                out = self.conv2(out)
                out = self.bn2(out)
        else:
            out = self.conv1(x)
            out = self.bn1(out)
            out = self.relu(out)
            out = self.conv2(out)
            out = self.bn2(out)
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
        out = self.relu1(out)
        return out

the modified resnet does not fail; i havent checked what gradients are doing yet, but seems should work… if anyone sees any issue flag me. thanks.