I have this forward function and I use binary masks mask1 and mask2 but the module 1 get zero gradient how can I mask out1 so it still gets to be updated?

# x, z 2D matrices, mask1,mask2 are column vectors with matching height
def forward(self, x, z, mask2, mask1):
out1 = self.module1(x)
in_2 = z * mask1 + out1 *mask2
return self.module2 (in_2)