Select subset of Conv2d layer filters in the forward method


Let’s assume I have a set of data instances of shape [1,64,8,8] that can belong to one of two groups (A or B), and a Conv2d layer with the arguments (in_channels=64, out_channels=128, kernel_size=3).

Can I use a subset of the 128 filters of the Conv2d, depending on the group of the instance? For example, can I divide the Conv2d into two parts (64 filters + 64 filters) and use the first half if the instance belongs to group A and the second part if the instance belong to group B?

My idea would be to define a strandard Conv2d in the init method and then somehow select the subset of filters that I want to use in the forward. During the backward pass I am expecting that only the filters that were used for that instance get updated.