Loss.backward() failed at second round for customized loss function

I’m plugging a customized loss in model.forward(). The reason to do so is to address the imbalanced gpu usage issue. The customized loss class is called “SplitCrossEntropyLoss”. See selected code below:

class RNNModel(nn.Module):
        def init(…):
                super(RNNModel, self).init()
                from splitcross import SplitCrossEntropyLoss
                splits = [2800, 20000, 76000]
                self.criterion = SplitCrossEntropyLoss(ninp, splits=splits, verbose=False)
                … …
        def forward(…)
                … …
                result = output
                # calculate loss
                result = result.view(result.size(0)*result.size(1), -1)
                raw_loss = self.criterion(decoder_weight, decoder_bias, result, target)
                loss = raw_loss
                # activation regularization
                if args.alpha: loss = loss + sum(args.alpha * dropped_rnn_h.pow(2).mean() for dropped_rnn_h in outputs[-1:])
                # Temporal Activation Regularization (slowness)
                if args.beta: loss = loss + sum(args.beta * (rnn_h[1:] - rnn_h[:-1]).pow(2).mean() for rnn_h in raw_outputs[-1:])
                # expand loss to two dimensional space so it can be gathered via the second dimension
                loss = loss.unsqueeze(1)
                raw_loss = raw_loss.unsqueeze(1)
                if return_h:
                        return raw_loss, loss, hidden, raw_outputs, outputs
                return raw_loss, loss, hidden

Then, in my main.py, I collect the loss and use loss.mean().backward() to update parameters. The interesting thing is, I can successfully finish the first round loss.mean().backward() but failed the second round with error:

RuntimeError: invalid argument 3: Index tensor must have same dimensions as input tensor at
/pytorch/torch/lib/THC/generic/THCTensorScatterGather.cu:199

Can anyone help?
Thanks in advance!

The workaround is pull the criterion out of model.forward() and mount to multiple GPUs.

i have the same issue. You find the solution?

Yes. I pulled my customized loss function out of model.forward() and mount it to multiple GPUs. I still don’t know the root cause for previous failure, but this workaround works for me.