How to handle .detach() in opacus enabled model

The overall picture of my model is expressed in the following pseudo-code

(SimSiam Pseudocode, PyTorch-like
here: f indicates backbone + projection mlp)

for x in loader: # load a minibatch x with n samples
x1, x2 = aug(x), aug(x) # random augmentation
z1, z2 = f(x1), f(x2) # projections, NxD
L = D(z1, z2) # loss
L.backward() # back-propagate
optimizer.step() #update(f, z1) # SGD update

def D(p, z): # negative cosine similarity
z = z.detach() # stop gradient
p = normalize(p, dim=1) # l2-normalize
z = normalize(z, dim=1) # l2-normalize
return -(p*z).sum(dim=1).mean()

I want to incorporate opacus for performing differential privacy, I have discovered that the .detach() is causing problems with the opacus. I am getting this error: “ValueError: Per sample gradient is not initialized. Not updated in backward pass?”
When I do not use .detach(), the code runs well without the error. But it creates issues with accuracy and performance as in Siamese networks one of the projector outputs should remain out of the gradient flow to avoid collapse.

I have also discussed the problem in the opacus git-hub > issue. Any suggestions will be appreciated. Thanks in advance.