How to get around module having parameters that were not used in producing loss error?

There are 2 blocks (A and B) in my network architecture, and I defined an optimizer_A and loss_A to update module_A, and also defined optmizer_B and loss_B to update module_B.

I hope to disable any updating of module_B in the first 3 epochs, while allowing both modules to be updated after the 3rd epoch. What I am currently doing is:

forward(x):
   if self.train and self.epoch>3:
         x=x+module_B(x)
   prediction = module_A(x)
   return prediction
...

loss_A = calculate_loss_A(...)
if epoch>3:
     loss_B = calculate_loss_B(...)
...
if epoch>3:
    loss_A.backward(retain_graph=True)
    loss_B.backward()
else:
    loss_A.backward()
if epoch>3:
    optimizer_A.step()
    optimizer_B.step()
else:
    optimizer_A.step()

I think the above code was correct to illustrate what I wanted to do. Before adding the “epoch>3” condition, my code worked fine. but after adding it, I got this error message:
RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss. You can enable unused parameter detection by passing the keyword argument 'find_unused_parameters=True' to 'torch.nn.parallel.DistributedDataParallel', and by making sure all 'forward' function outputs participate in calculating loss. If you already have done the above, then the distributed data parallel module wasn't able to locate the output tensors in the return value of your module's 'forward' function.

Could someone provide some guidance which part is wrong in my correct implementation? How to get around this error?