I am working a model and it contains multiple instances of a sub network. I am having trouble making each instance of the sub-module learn independently. The gradients are averaged during back prop, making the resulting gradient useless. Below is simplified version of my model:
net_a = ModuleA() net_b = ModuleB() for batch_x, label in dataset: # iterate through dataset out_b = torch.zeros(10) output_arr =  for sub_sample in some_sampler(batch): # sampling data along feature dimensions out_a = net_a(sub_sample) # Gradient should update `net_a` per slice out_b = net_b(out_a, out_b) product = out_a.bmm(out_b) output_arr.append(product) output = torch.cat(output_arr) loss = F.binary_cross_entropy(output, label, reduction='mean') loss.backward()
This model is a bit similar to RNN but I cannot unfold it like a RNN. When back-propagating through the graph,
net_a should be receiving a gradient for each sub sample. I believe it has to do with model instances sharing parameters, but I cannot quite grasp how to separate the gradients?