Multiple submodule instances without merging gradients?

I am working a model and it contains multiple instances of a sub network. I am having trouble making each instance of the sub-module learn independently. The gradients are averaged during back prop, making the resulting gradient useless. Below is simplified version of my model:

net_a = ModuleA() 
net_b = ModuleB()

for batch_x, label in dataset: # iterate through dataset
    
    out_b = torch.zeros(10)
    output_arr = []
    for sub_sample in some_sampler(batch): # sampling data along feature dimensions
        out_a = net_a(sub_sample) # Gradient should update `net_a` per slice
        out_b = net_b(out_a, out_b)
        product = out_a.bmm(out_b)
        output_arr.append(product)
        
    output = torch.cat(output_arr)
    loss = F.binary_cross_entropy(output, label, reduction='mean')
    loss.backward()

This model is a bit similar to RNN but I cannot unfold it like a RNN. When back-propagating through the graph, net_a should be receiving a gradient for each sub sample. I believe it has to do with model instances sharing parameters, but I cannot quite grasp how to separate the gradients?