Torchvision model with task-specific BN layers

Hi,

Is it possible to create task-specific BN layers starting from a torchvision model? Let’s say I have two tasks for the same data and want them to share all the convolutional layers of a torchvision model, but have separate BN parameters for each task. How can that be implemented?

Thanks!

Can you provide some more info? What do you mean by ‘task’? Are you going to train the same model using two sets of hyperparameters?

Hi @Karthik_Ganesan,

Thank you for your answer!

By “task” I mean a multi-task learning scenario, where typically there is a shared backbone consisting of convolutional blocks shared across all tasks, followed by task-specific fully-connected heads (one head per task).

What I want to do, inspired by this paper https://openaccess.thecvf.com/content_CVPR_2019/papers/Chang_Domain-Specific_Batch_Normalization_for_Unsupervised_Domain_Adaptation_CVPR_2019_paper.pdf, is to not only have task specific FC heads but also make the batch norm layers task specific, while sharing all the remaining layer across all tasks.

I hope it’s clearer now. Thanks!

Hi @SofiaCP , Thank you for clarifying. I took a quick look at the paper and it seems like building the model should be straightforward. My understanding is they pass one mini-batch from each domain at a time when training.

The easiest option then would be to write your own domain specific BN(DSBN) layer which stores two sets of BN parameters. You would also have an additional input to this layer which indicates which task the current mini-batch is for. Based on this you can set the task for the DSBN layers in your network.

I’m not familiar with unsupervised domain adaption but I believe that the backward pass will work the same; you just set the layer’s parameters based on the task and AutoGrad will generate the necessary gradients.

Dear @Karthik_Ganesan,

Thank you for replying. Based on your answer, I am guessing that it is not possible to simply import a torchvision model (e.g. densenet121) and incorporate the DSBN by replacing the model’s BN layers. Is that so?

Could you possibly provide some more details on how to achieve this?

Thanks!
Sofia

Yes, basically that is what you need to do. But in PyTorch its a bit hard to change the layers of a model once it is already built. So you can import a model like densenet121 but then you will have to iterate over all the layers to replace the BN layers with DSBN layers.

Another way would be to just create the model directly by specifying all the layers. And then in that file, replace all the BN layers with the custom DSBN layers. Since the network will have to be trained with the DSBN layers anyway, this shouldn’t be much more difficult than directly importing a model.

1 Like

I’ve managed to iterate over the layers of the model and replace the BN layers with DSBN layers, which I have defined the following way:


class DSBN(nn.Module):
    def __init__(self, num_channels):
        super(DSBN, self).__init__()
        
        self.norm1 = nn.BatchNorm2d(num_channels)
        self.norm2 = nn.BatchNorm2d(num_channels)
        self.norm3 = nn.BatchNorm2d(num_channels)
    
    def forward(self, x, group="1"):
        assert (group in ["1","2","3"])
        if group=="1":
            print("using group 1 BN")
            x = self.norm1(x)
        elif group=="2":
            print("using group 2 BN")
            x = self.norm2(x)
        elif group=="3":
            print("using group 3 BN")
            x = self.norm3(x)
        return x

How can I insert a switch to decide which bn layer of DSBN to use depending on the input? My goal is to pass the group as an argument to the model:

model= model()
out_domain1= model(x,group=1)
out_domain2=model(x,group=2)
out_domain3=mode(x,group=3)

Thanks!

You should be able to just add the group to the DSBN class. Since this isn’t a trainable parameter, you can just add it like this:

class DSBN(nn.Module):
    def __init__(self, num_channels):
        super(DSBN, self).__init__()
        
        self.group = None
        self.norm1 = nn.BatchNorm2d(num_channels)
        self.norm2 = nn.BatchNorm2d(num_channels)
        self.norm3 = nn.BatchNorm2d(num_channels)
    
    def forward(self, x):
        assert (self.group in ["1","2","3"])
        if self.group=="1":
            print("using group 1 BN")
            x = self.norm1(x)
        elif self.group=="2":
            print("using group 2 BN")
            x = self.norm2(x)
        elif self.group=="3":
            print("using group 3 BN")
            x = self.norm3(x)
        return x

Then, when you run each task, go through and set the group value for each DSBN layer to be the group you want. That should be very similar to how you accessed all the layers and replaced BatchNorm with DSBN.

1 Like

That is exactly what I ended up doing. Initially, I was trying to set the group argument in the forward of the DSBN module but ended up going with this because it is easier to implement.

Thank you!

1 Like