You won’t be able to update the parameters via optM.step()
since it would also update the shared parameters used to calculate mid
and would invalidate the stored intermediate forward activations which are needed to compute the gradients of all parameters from ld.backward()
.