How to call specific layer index during training?

I am tying to modify ResNet34 for Task Incremental learning as part of my academic work. It done by replacing every BatchNorm2d locations with Sequential block of multiple BatchNorm2d layers. Let say have a Sequential block of 4 BN layers at each location of BN in ResNet34. For training, I will do separate training on each BN blocks (here, 4 trainig) while using all other CNN layers (with task specific masks) during all training. Issue is, how can I call or set specific BN sublayer (like as index) in the Sequential block prior to its training inorder to use that BN layer. Ex; set model to use BN sublayer 2 (2 index) at every Sequential() block for 2nd training.

I am able to modify model like what I mentioned. But not able to train likey. What I noticed is; I have set an attribute self.bn_nr initialized with 0, so it able to train for 1st index (or 0), and for subsequent training model.bn_nr is getting updated. but not find any scope to update this attribute for MultiBatchNorm() forward() method for next training. Is there anyway to get this updation on self.bn_nr within forward() of MultiBatchNorm(). Can anyone help me out to solve this?

Here I skipped Incremental learning part and trying to summarizing my problem with least code because it might be redundant with the question I asked.

class BaseModel(nn.Module, metaclass=abc.ABCMeta):
    """Abstract module to add CL capabilities to multitask classifier"""

    def __init__(self):
        super().__init__()
        #self.scenario = None
        self.num_bn_layers = 0
        self.bn_nr = 0

    @abc.abstractmethod
    def forward(self, x):
        pass

class MultiBatchNorm(BaseModel):
    def __init__(self, num_features=None, num_bn_layers=4):
        super().__init__()
        self.batchnorms = nn.ModuleList([nn.BatchNorm2d(num_features) for _ in range(num_bn_layers)])
        self.num_bn_layers = num_bn_layers

    def forward(self, x):
        x = self.batchnorms[self.bn_nr](x)
        return x

class ResidualBlock(BaseModel):
    def __init__(self, in_channels, out_channels, num_bn_layers, stride = 1, downsample = None):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Sequential(
                        nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = stride, padding = 1),
                        MultiBatchNorm(out_channels, num_bn_layers),
                        nn.ReLU())
        self.conv2 = nn.Sequential(
                        nn.Conv2d(out_channels, out_channels, kernel_size = 3, stride = 1, padding = 1),
                        MultiBatchNorm(out_channels, num_bn_layers))
        self.downsample = downsample
        self.relu = nn.ReLU()
        self.out_channels = out_channels

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.conv2(out)
        if self.downsample:
            residual = self.downsample(x)
        out += residual
        out = self.relu(out)
        return out

class ResNet(BaseModel):
    def __init__(self, block, layers, num_classes=10, classes=10, num_bn_layers=2):
        super(ResNet, self).__init__()
        self.inplanes = 64
        self.conv1 = nn.Sequential(
                        nn.Conv2d(3, 64, kernel_size = 7, stride = 2, padding = 3),
                        MultiBatchNorm(64, num_bn_layers),
                        nn.ReLU())
        self.maxpool = nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1)
        self.layer0 = self._make_layer(block, 64, layers[0], num_bn_layers, stride = 1)
        self.layer1 = self._make_layer(block, 128, layers[1], num_bn_layers, stride = 2)
        self.layer2 = self._make_layer(block, 256, layers[2], num_bn_layers, stride = 2)
        self.layer3 = self._make_layer(block, 512, layers[3], num_bn_layers, stride = 2)
        self.avgpool = nn.AvgPool2d(7, stride=1)
        self.fc = nn.Linear(512, classes)

        self.num_bn_layers = num_bn_layers

    def _make_layer(self, block, planes, blocks, num_bn_layers, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes:

            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes, kernel_size=1, stride=stride),
                MultiBatchNorm(planes, num_bn_layers),
            )
        layers = []
        layers.append(block(self.inplanes, planes, num_bn_layers, stride, downsample))
        self.inplanes = planes
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes, num_bn_layers))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = nn.functional.interpolate(x, (224, 224))
        x = self.conv1(x)
        x = self.maxpool(x)
        x = self.layer0(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x
        

num_bn_layers = 4
model = ResNet(ResidualBlock,
               layers=[3, 4, 6, 3],
               num_bn_layers=num_bn_layers,
               ).to(device)

training (pseudo code)

epochs = 10
for task_id in range(num_bn_layers):
  model.bn_nr = task_id
  for epoch in range(epochs):
    for b, data in enumerate(train_loader):
      X, y = data
      loss = model.train_epoch(X)