I am working a model and it contains multiple instances of a sub network. I am having trouble making each instance of the sub-module learn independently. The gradients are averaged during back prop, making the resulting gradient useless. Below is simplified version of my model:
net_a = ModuleA()
net_b = ModuleB()
for batch_x, label in dataset: # iterate through dataset
out_b = torch.zeros(10)
output_arr = []
for sub_sample in some_sampler(batch): # sampling data along feature dimensions
out_a = net_a(sub_sample) # Gradient should update `net_a` per slice
out_b = net_b(out_a, out_b)
product = out_a.bmm(out_b)
output_arr.append(product)
output = torch.cat(output_arr)
loss = F.binary_cross_entropy(output, label, reduction='mean')
loss.backward()
This model is a bit similar to RNN but I cannot unfold it like a RNN. When back-propagating through the graph, net_a
should be receiving a gradient for each sub sample. I believe it has to do with model instances sharing parameters, but I cannot quite grasp how to separate the gradients?