MultiHead Models With ModuleList

I am relatively new to PyTorch, so please excuse if i got something fundamentally wrong.

I am trying to create a MultiHead network for testing purposes. In order to implement the heads i am using torch.nn.ModuleList (see the code below). In my understanding and intuition, after i run a forward pass through my network and then through one of the 3 heads, the backprob should only create gradients for my network and the one head i passed through. So after i call the optimizer step function, i expect that only the weights of my currently used head change. This is not the case. When i train this network for 20 epochs and calculate the current head with epoch%num_heads, i would expect to see only one head per epoch change its weights. But they all change their weights in each epoch. What am i missing?
Thanks in advance.

#simple testing model
class MyModel(torch.nn.Module):
    def __init__(self, layer_sizes: List[int], actf: Callable):
        super(MyModel, self).__init__()

        # ----------- please ignore this part -----------------
        self._layer_dict = OrderedDict()
        for next_idx, layer_size in enumerate(layer_sizes[:-2], 1):
            self._layer_dict[f"Linear{next_idx}"] = torch.nn.Linear(layer_size, layer_sizes[next_idx])
            self._layer_dict[f"Actf{next_idx}"] = actf()
        last_idx = len(layer_sizes) - 1
        self._layer_dict[f"Linear{last_idx}"] = torch.nn.Linear(layer_sizes[-2], layer_sizes[last_idx])
        self._model = torch.nn.Sequential(self._layer_dict)
        # -----------------------------------------------------

        # creating the heads
        self._heads = torch.nn.ModuleList()
        self._heads.append(torch.nn.Linear(layer_sizes[last_idx], 1))
        self._heads.append(torch.nn.Linear(layer_sizes[last_idx], 1))
        self._heads.append(torch.nn.Linear(layer_sizes[last_idx], 1))

    def forward(self, x):
        # retrieve the current head
        head = x["head"]
        # pass through model
        _res = self._model(x["value"])
        # pass through current head
        _res = self._heads[head](_res)
        return _res

    def print_head_weights(self):
        for idx, head in enumerate(self._heads):
            print(f"head{idx}: {head.weight}")

def main():
    # initialize stuff
    test_ds = MyDataset(nFeatures=5, nSamples=200, factor=2.5)
    test_dl = DataLoader(dataset=test_ds, batch_size=50, shuffle=True, drop_last=False)
    model = MyModel(layer_sizes=[5,5,5], actf=torch.nn.ReLU)
    loss = torch.nn.L1Loss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

    nEpochs = 20
    epochs = range(1,nEpochs+1)

    # start training
    results = []
    for epoch in epochs:
        print(f"Epoch {epoch}\{nEpochs}")
        print("before training")
        # print head weights before training
        model.print_head_weights()
        bres = []
        for batch_features, batch_targets in test_dl:
            optimizer.zero_grad()
            input = {"head" : epoch%3, "value" : batch_features}
            prediction = model(input)
            result = loss(prediction, batch_targets)
            result.backward()
            optimizer.step()
            bres.append(result.item())
        print("after training")
        # print head weights after training
        model.print_head_weights()
        results.append(np.mean(bres))

The resulting output:

Epoch 1\20
before training
head0: Parameter containing:
tensor([[ 0.3052, -0.0316, -0.3892, -0.3379,  0.2097]], requires_grad=True)
head1: Parameter containing:
tensor([[ 0.0076, -0.3838,  0.3574,  0.4275, -0.0871]], requires_grad=True)
head2: Parameter containing:
tensor([[-0.0246, -0.1727,  0.2504,  0.3131, -0.3634]], requires_grad=True)
after training
head0: Parameter containing:
tensor([[ 0.3052, -0.0316, -0.3892, -0.3379,  0.2097]], requires_grad=True)
head1: Parameter containing:
tensor([[-0.0320, -0.4229,  0.3180,  0.3881, -0.0472]], requires_grad=True)
head2: Parameter containing:
tensor([[-0.0246, -0.1727,  0.2504,  0.3131, -0.3634]], requires_grad=True)
Epoch 2\20
before training
head0: Parameter containing:
tensor([[ 0.3052, -0.0316, -0.3892, -0.3379,  0.2097]], requires_grad=True)
head1: Parameter containing:
tensor([[-0.0320, -0.4229,  0.3180,  0.3881, -0.0472]], requires_grad=True)
head2: Parameter containing:
tensor([[-0.0246, -0.1727,  0.2504,  0.3131, -0.3634]], requires_grad=True)
after training
head0: Parameter containing:
tensor([[ 0.3052, -0.0316, -0.3892, -0.3379,  0.2097]], requires_grad=True)
head1: Parameter containing:
tensor([[-0.0723, -0.4629,  0.2778,  0.3479, -0.0078]], requires_grad=True)
head2: Parameter containing:
tensor([[ 0.0144, -0.1340,  0.2894,  0.3522, -0.3248]], requires_grad=True)
Epoch 3\20
before training
head0: Parameter containing:
tensor([[ 0.3052, -0.0316, -0.3892, -0.3379,  0.2097]], requires_grad=True)
head1: Parameter containing:
tensor([[-0.0723, -0.4629,  0.2778,  0.3479, -0.0078]], requires_grad=True)
head2: Parameter containing:
tensor([[ 0.0144, -0.1340,  0.2894,  0.3522, -0.3248]], requires_grad=True)
after training
head0: Parameter containing:
tensor([[ 0.3438,  0.0072, -0.3503, -0.2989,  0.2383]], requires_grad=True)
head1: Parameter containing:
tensor([[-0.1132, -0.5037,  0.2370,  0.3071,  0.0315]], requires_grad=True)
head2: Parameter containing:
tensor([[ 0.0541, -0.0943,  0.3293,  0.3921, -0.2849]], requires_grad=True)
Epoch 4\20
before training
head0: Parameter containing:
tensor([[ 0.3438,  0.0072, -0.3503, -0.2989,  0.2383]], requires_grad=True)
head1: Parameter containing:
tensor([[-0.1132, -0.5037,  0.2370,  0.3071,  0.0315]], requires_grad=True)
head2: Parameter containing:
tensor([[ 0.0541, -0.0943,  0.3293,  0.3921, -0.2849]], requires_grad=True)
after training
head0: Parameter containing:
tensor([[ 0.3832,  0.0470, -0.3103, -0.2590,  0.2161]], requires_grad=True)
head1: Parameter containing:
tensor([[-0.1457, -0.5382,  0.1972,  0.2674,  0.0464]], requires_grad=True)
head2: Parameter containing:
tensor([[ 0.0949, -0.0535,  0.3701,  0.4329, -0.2442]], requires_grad=True)
Epoch 5\20
before training
head0: Parameter containing:
tensor([[ 0.3832,  0.0470, -0.3103, -0.2590,  0.2161]], requires_grad=True)
head1: Parameter containing:
tensor([[-0.1457, -0.5382,  0.1972,  0.2674,  0.0464]], requires_grad=True)
head2: Parameter containing:
tensor([[ 0.0949, -0.0535,  0.3701,  0.4329, -0.2442]], requires_grad=True)
after training
head0: Parameter containing:
tensor([[ 0.4239,  0.0878, -0.2695, -0.2182,  0.1827]], requires_grad=True)
head1: Parameter containing:
tensor([[-0.1483, -0.5498,  0.1607,  0.2312,  0.0199]], requires_grad=True)
head2: Parameter containing:
tensor([[ 0.1373, -0.0112,  0.4120,  0.4748, -0.2135]], requires_grad=True)
Epoch 6\20
before training
head0: Parameter containing:
tensor([[ 0.4239,  0.0878, -0.2695, -0.2182,  0.1827]], requires_grad=True)
head1: Parameter containing:
tensor([[-0.1483, -0.5498,  0.1607,  0.2312,  0.0199]], requires_grad=True)
head2: Parameter containing:
tensor([[ 0.1373, -0.0112,  0.4120,  0.4748, -0.2135]], requires_grad=True)
after training
head0: Parameter containing:
tensor([[ 0.4664,  0.1295, -0.2279, -0.1765,  0.1470]], requires_grad=True)
head1: Parameter containing:
tensor([[-0.1323, -0.5451,  0.1262,  0.1972, -0.0182]], requires_grad=True)
head2: Parameter containing:
tensor([[ 0.1820,  0.0333,  0.4552,  0.5180, -0.2129]], requires_grad=True)
Epoch 7\20
before training
head0: Parameter containing:
tensor([[ 0.4664,  0.1295, -0.2279, -0.1765,  0.1470]], requires_grad=True)
head1: Parameter containing:
tensor([[-0.1323, -0.5451,  0.1262,  0.1972, -0.0182]], requires_grad=True)
head2: Parameter containing:
tensor([[ 0.1820,  0.0333,  0.4552,  0.5180, -0.2129]], requires_grad=True)
after training
head0: Parameter containing:
tensor([[ 0.5112,  0.1723, -0.1850, -0.1337,  0.1061]], requires_grad=True)
head1: Parameter containing:
tensor([[-0.0994, -0.5279,  0.0944,  0.1660, -0.0631]], requires_grad=True)
head2: Parameter containing:
tensor([[ 0.2284,  0.0787,  0.4991,  0.5620, -0.2304]], requires_grad=True)
Epoch 8\20
before training
head0: Parameter containing:
tensor([[ 0.5112,  0.1723, -0.1850, -0.1337,  0.1061]], requires_grad=True)
head1: Parameter containing:
tensor([[-0.0994, -0.5279,  0.0944,  0.1660, -0.0631]], requires_grad=True)
head2: Parameter containing:
tensor([[ 0.2284,  0.0787,  0.4991,  0.5620, -0.2304]], requires_grad=True)
after training
head0: Parameter containing:
tensor([[ 0.5575,  0.2156, -0.1417, -0.0905,  0.0612]], requires_grad=True)
head1: Parameter containing:
tensor([[-0.0543, -0.5016,  0.0667,  0.1392, -0.1130]], requires_grad=True)
head2: Parameter containing:
tensor([[ 0.2766,  0.1243,  0.5437,  0.6066, -0.2628]], requires_grad=True)
Epoch 9\20
before training
head0: Parameter containing:
tensor([[ 0.5575,  0.2156, -0.1417, -0.0905,  0.0612]], requires_grad=True)
head1: Parameter containing:
tensor([[-0.0543, -0.5016,  0.0667,  0.1392, -0.1130]], requires_grad=True)
head2: Parameter containing:
tensor([[ 0.2766,  0.1243,  0.5437,  0.6066, -0.2628]], requires_grad=True)
after training
head0: Parameter containing:
tensor([[ 0.6045,  0.2589, -0.0979, -0.0469,  0.0139]], requires_grad=True)
head1: Parameter containing:
tensor([[-0.0042, -0.4696,  0.0420,  0.1154, -0.1658]], requires_grad=True)
head2: Parameter containing:
tensor([[ 0.3270,  0.1693,  0.5893,  0.6520, -0.3069]], requires_grad=True)
Epoch 10\20
before training
head0: Parameter containing:
tensor([[ 0.6045,  0.2589, -0.0979, -0.0469,  0.0139]], requires_grad=True)
head1: Parameter containing:
tensor([[-0.0042, -0.4696,  0.0420,  0.1154, -0.1658]], requires_grad=True)
head2: Parameter containing:
tensor([[ 0.3270,  0.1693,  0.5893,  0.6520, -0.3069]], requires_grad=True)
after training
head0: Parameter containing:
tensor([[ 0.6530,  0.3014, -0.0535, -0.0027, -0.0360]], requires_grad=True)
head1: Parameter containing:
tensor([[ 0.0483, -0.4347,  0.0196,  0.0939, -0.2201]], requires_grad=True)
head2: Parameter containing:
tensor([[ 0.3782,  0.2138,  0.6353,  0.6979, -0.3559]], requires_grad=True)
Epoch 11\20
before training
head0: Parameter containing:
tensor([[ 0.6530,  0.3014, -0.0535, -0.0027, -0.0360]], requires_grad=True)
head1: Parameter containing:
tensor([[ 0.0483, -0.4347,  0.0196,  0.0939, -0.2201]], requires_grad=True)
head2: Parameter containing:
tensor([[ 0.3782,  0.2138,  0.6353,  0.6979, -0.3559]], requires_grad=True)
after training
head0: Parameter containing:
tensor([[ 0.7020,  0.3433, -0.0087,  0.0417, -0.0870]], requires_grad=True)
head1: Parameter containing:
tensor([[ 0.1023, -0.3992, -0.0006,  0.0746, -0.2753]], requires_grad=True)
head2: Parameter containing:
tensor([[ 0.4285,  0.2584,  0.6808,  0.7433, -0.4049]], requires_grad=True)
Epoch 12\20
before training
head0: Parameter containing:
tensor([[ 0.7020,  0.3433, -0.0087,  0.0417, -0.0870]], requires_grad=True)
head1: Parameter containing:
tensor([[ 0.1023, -0.3992, -0.0006,  0.0746, -0.2753]], requires_grad=True)
head2: Parameter containing:
tensor([[ 0.4285,  0.2584,  0.6808,  0.7433, -0.4049]], requires_grad=True)
after training
head0: Parameter containing:
tensor([[ 0.7487,  0.3860,  0.0350,  0.0852, -0.1340]], requires_grad=True)
head1: Parameter containing:
tensor([[ 0.1566, -0.3634, -0.0192,  0.0568, -0.3305]], requires_grad=True)
head2: Parameter containing:
tensor([[ 0.4757,  0.3040,  0.7248,  0.7872, -0.4475]], requires_grad=True)
Epoch 13\20
before training
head0: Parameter containing:
tensor([[ 0.7487,  0.3860,  0.0350,  0.0852, -0.1340]], requires_grad=True)
head1: Parameter containing:
tensor([[ 0.1566, -0.3634, -0.0192,  0.0568, -0.3305]], requires_grad=True)
head2: Parameter containing:
tensor([[ 0.4757,  0.3040,  0.7248,  0.7872, -0.4475]], requires_grad=True)
after training
head0: Parameter containing:
tensor([[ 0.7888,  0.4309,  0.0755,  0.1256, -0.1674]], requires_grad=True)
head1: Parameter containing:
tensor([[ 0.2072, -0.3226, -0.0404,  0.0365, -0.3812]], requires_grad=True)
head2: Parameter containing:
tensor([[ 0.5207,  0.3501,  0.7677,  0.8302, -0.4854]], requires_grad=True)
Epoch 14\20
before training
head0: Parameter containing:
tensor([[ 0.7888,  0.4309,  0.0755,  0.1256, -0.1674]], requires_grad=True)
head1: Parameter containing:
tensor([[ 0.2072, -0.3226, -0.0404,  0.0365, -0.3812]], requires_grad=True)
head2: Parameter containing:
tensor([[ 0.5207,  0.3501,  0.7677,  0.8302, -0.4854]], requires_grad=True)
after training
head0: Parameter containing:
tensor([[ 0.8239,  0.4771,  0.1138,  0.1640, -0.1903]], requires_grad=True)
head1: Parameter containing:
tensor([[ 0.2472, -0.2722, -0.0698,  0.0083, -0.4193]], requires_grad=True)
head2: Parameter containing:
tensor([[ 0.5588,  0.3983,  0.8075,  0.8701, -0.5043]], requires_grad=True)
Epoch 15\20
before training
head0: Parameter containing:
tensor([[ 0.8239,  0.4771,  0.1138,  0.1640, -0.1903]], requires_grad=True)
head1: Parameter containing:
tensor([[ 0.2472, -0.2722, -0.0698,  0.0083, -0.4193]], requires_grad=True)
head2: Parameter containing:
tensor([[ 0.5588,  0.3983,  0.8075,  0.8701, -0.5043]], requires_grad=True)
after training
head0: Parameter containing:
tensor([[ 0.8461,  0.5265,  0.1465,  0.1970, -0.1885]], requires_grad=True)
head1: Parameter containing:
tensor([[ 0.2792, -0.2168, -0.1048, -0.0256, -0.4479]], requires_grad=True)
head2: Parameter containing:
tensor([[ 0.5814,  0.4507,  0.8408,  0.9039, -0.4860]], requires_grad=True)
Epoch 16\20
before training
head0: Parameter containing:
tensor([[ 0.8461,  0.5265,  0.1465,  0.1970, -0.1885]], requires_grad=True)
head1: Parameter containing:
tensor([[ 0.2792, -0.2168, -0.1048, -0.0256, -0.4479]], requires_grad=True)
head2: Parameter containing:
tensor([[ 0.5814,  0.4507,  0.8408,  0.9039, -0.4860]], requires_grad=True)
after training
head0: Parameter containing:
tensor([[ 0.8422,  0.5812,  0.1676,  0.2187, -0.1528]], requires_grad=True)
head1: Parameter containing:
tensor([[ 0.2955, -0.1552, -0.1477, -0.0678, -0.4586]], requires_grad=True)
head2: Parameter containing:
tensor([[ 0.5922,  0.5051,  0.8694,  0.9332, -0.4480]], requires_grad=True)
Epoch 17\20
before training
head0: Parameter containing:
tensor([[ 0.8422,  0.5812,  0.1676,  0.2187, -0.1528]], requires_grad=True)
head1: Parameter containing:
tensor([[ 0.2955, -0.1552, -0.1477, -0.0678, -0.4586]], requires_grad=True)
head2: Parameter containing:
tensor([[ 0.5922,  0.5051,  0.8694,  0.9332, -0.4480]], requires_grad=True)
after training
head0: Parameter containing:
tensor([[ 0.8207,  0.6384,  0.1800,  0.2318, -0.1023]], requires_grad=True)
head1: Parameter containing:
tensor([[ 0.2841, -0.0870, -0.2009, -0.1209, -0.4405]], requires_grad=True)
head2: Parameter containing:
tensor([[ 0.5883,  0.5622,  0.8921,  0.9565, -0.3961]], requires_grad=True)
Epoch 18\20
before training
head0: Parameter containing:
tensor([[ 0.8207,  0.6384,  0.1800,  0.2318, -0.1023]], requires_grad=True)
head1: Parameter containing:
tensor([[ 0.2841, -0.0870, -0.2009, -0.1209, -0.4405]], requires_grad=True)
head2: Parameter containing:
tensor([[ 0.5883,  0.5622,  0.8921,  0.9565, -0.3961]], requires_grad=True)
after training
head0: Parameter containing:
tensor([[ 0.7848,  0.6978,  0.1836,  0.2357, -0.0436]], requires_grad=True)
head1: Parameter containing:
tensor([[ 0.2552, -0.0167, -0.2590, -0.1793, -0.4053]], requires_grad=True)
head2: Parameter containing:
tensor([[ 0.5645,  0.6231,  0.9061,  0.9703, -0.3335]], requires_grad=True)
Epoch 19\20
before training
head0: Parameter containing:
tensor([[ 0.7848,  0.6978,  0.1836,  0.2357, -0.0436]], requires_grad=True)
head1: Parameter containing:
tensor([[ 0.2552, -0.0167, -0.2590, -0.1793, -0.4053]], requires_grad=True)
head2: Parameter containing:
tensor([[ 0.5645,  0.6231,  0.9061,  0.9703, -0.3335]], requires_grad=True)
after training
head0: Parameter containing:
tensor([[0.7364, 0.7599, 0.1770, 0.2286, 0.0206]], requires_grad=True)
head1: Parameter containing:
tensor([[ 0.2154,  0.0541, -0.3194, -0.2401, -0.3599]], requires_grad=True)
head2: Parameter containing:
tensor([[ 0.5281,  0.6856,  0.9135,  0.9770, -0.2667]], requires_grad=True)
Epoch 20\20
before training
head0: Parameter containing:
tensor([[0.7364, 0.7599, 0.1770, 0.2286, 0.0206]], requires_grad=True)
head1: Parameter containing:
tensor([[ 0.2154,  0.0541, -0.3194, -0.2401, -0.3599]], requires_grad=True)
head2: Parameter containing:
tensor([[ 0.5281,  0.6856,  0.9135,  0.9770, -0.2667]], requires_grad=True)
after training
head0: Parameter containing:
tensor([[0.6816, 0.8228, 0.1633, 0.2136, 0.0867]], requires_grad=True)
head1: Parameter containing:
tensor([[ 0.1682,  0.1246, -0.3807, -0.3021, -0.3084]], requires_grad=True)
head2: Parameter containing:
tensor([[ 0.4844,  0.7481,  0.9163,  0.9788, -0.1989]], requires_grad=True)

I found the answer myself in this stack overflow thread. I still think it is a little bit unintuitive though.