Does casting one datatype to another break the computaton graph?

I have a simple NER model that uses a word2vec embeddings layer followed by a linear layer and a softmax function.

To compute the loss after the softmax function, I use the BCELoss function.

class SemanticCRF(nn.Module):

    def __init__(self):
        super(SemanticCRF, self).__init__()

        # log reg
        self.hidden2tag = nn.Linear(300, 2)

        # crf
        self.crf_layer = CRF(2, batch_first=True)

        # binary cross entropy loss
        self.criterion = nn.BCELoss()

    def forward(self, input_ids=None, attention_mask=None, labels=None):

        # log reg
        probablities_ = F.softmax ( self.hidden2tag( input_ids.float() ) )
        probablities__ = torch.argmax(probablities_, dim=2)
        probablities___ = torch.masked_select(probablities__, attention_mask)
        labels = torch.masked_select(labels, attention_mask)

        print('Logits: ', probablities__.dtype)
        print('Labels: ', labels.dtype)

        loss = self.criterion(probablities___.float() , labels.float())
        # print(loss)
        # return loss, torch.Tensor(probablities.float()), labels.float()
        return loss, probablities___, labels

However, this causes the following error. The model worked perfectly fine with a CRF layer that output loss which was backpropagated without error. The error comes up only with BCELoss. I am using BCELoss because there are only two entities in my NER task. With BCELoss, loss.requires_grad is set to False.

Traceback (most recent call last):
  File "", line 264, in <module>
    train(model, optimizer, scheduler, train_dataloader, development_dataloader, args)
  File "", line 128, in train
  File "/home//lib/python3.6/site-packages/torch/", line 166, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home//lib/python3.6/site-packages/torch/autograd/", line 99, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

Changing the dtype won’t break the computation graph as long as you are using floating point tensors.
However, this operation probablities__ = torch.argmax(probablities_, dim=2) will break it, as argmax is not differentiable. Casting the output back to .float() won’t restore the graph.

1 Like