DeepMask's PyTorch Implementation

I’m trying to implement Deepmask with Pytorch, so far I have defined the Joint Loss Function, and the model’s learn-able parameters.

I was working on the training phase, and as the paper says that training must be done in an alternative back-propagation fashion across the two branches, I have written the code for the same. But there is some problem with training, I tried to train the model with a Fake Dataset, for minibatches other than the first mini-batch the loss of the model is turning out to be nan.

Can somebody help me with this?

Here is the link to current version of code :