Understanding loss function code and cat a tensor with itself

I’m extending open source code based on a paper and seek help understanding code related to the loss (the author already has a backlog of questions and comments from me at this point). My hope is that someone with more expert knowledge on losses can understand the code, and then I can make it easier for the next person. My fork with the file in question is here: Domain-Agnostic-Sentence-Specificity-Prediction train.py
Link to paper: Domain Agnostic Real-Valued Specificity Prediction

This model involves both supervised training and then unsupervised training, with a teacher/student learning paradigm (the supervised model is the teacher and the student is learning on unlabeled data for a new domain). The supervised training portion uses labels with 2 classes. I’m extending it to 4 classes (there was even a parameter to support this, but it didn’t work). I’m getting an error at the function cat: ou=ou* torch.cat((a,a), 1)

Traceback (most recent call last):
File “train.py”, line 491, in
train_acc = trainepoch(epoch)
File “train.py”, line 394, in trainepoch
ou=ou* torch.cat((a,a), 1)
RuntimeError: The size of tensor a (4) must match the size of tensor b (2) at non-singleton dimension 1

My question is, how can I get past this error while also making sure the code does what it’s supposed to as this is within the training loop?

The variable names are hard to understand, so I’ve added comments with educated guesses on what they mean. All the comments other than “backward” are from myself.

        output = pdtb_net((s1_batch, s1_len),s1_batchf) # Supervised model
        output2 = pdtb_net((s1_batch2, s1_len2),s1_batchf2)
        outputu = pdtb_net((su_batch, su_len),su_batchf) # "u" for unsupervised or unlabeled
        outputu2 = pdtb_net2((su_batch2, su_len2),su_batchf2)
        if params.loss==0:
            pred = output.data.max(1)[1]
        else:
            pred=output.data[:,0]>0
        

        assert len(pred) == len(s1[stidx:stidx + params.batch_size])
        if params.loss==0: # This code is used because params.loss = 0
            ou = F.softmax(outputu, dim=1) # output unlabeled
            
            ou2 = F.softmax(outputu2, dim=1)
            sou = F.softmax(output, dim=1) # supervised model output
            
            sou2 = F.softmax(output2, dim=1)
 
            a,_=torch.max(ou,1)
            sa,_=torch.max(sou,1)

            a=(a.detach()>params.th).view(-1,1).float()
            sa=(sa.detach()>params.th).view(-1,1).float()
            ou=ou*  torch.cat((a,a), 1)
            ou2=ou2*  torch.cat((a,a), 1)
            sou=sou*  torch.cat((sa,sa), 1)
            sou2=sou2*  torch.cat((sa,sa), 1)
        
        else: # This code is not used but may be an alternative to consider
            ou=outputu[:,0]
            ou2=outputu2[:,0]
            a=(ou.detach()>params.th).view(-1,1).float()
            ou=ou*  a
            ou2=ou2* a

        ou2.require_grad=False
        sou2.require_grad=False
        loss2=( F.mse_loss(ou, ou2.detach(), size_average=False)+F.mse_loss(sou, sou2.detach(), size_average=False)) / params.n_classes/params.batch_size
        # loss
        if params.loss==0:
            tgt_batch=torch.cat([1.0-tgt_batch.view(-1,1),tgt_batch.view(-1,1) ],dim=1)
            oop=F.softmax(output, dim=1)
            oop2=F.softmax(outputu, dim=1)
            loss3=0
            if params.use_gpu:
                pppp=Variable(torch.FloatTensor([1/oop.size(0)]).cuda())
            else:
                pppp=Variable(torch.FloatTensor([1/oop.size(0)]))
            dmiu=torch.mean(oop2[:,1])
            dstd=torch.std(oop2[:,1])
            loss3=loss3+torch.abs(torch.mean(oop2[:,1])-params.klmiu)+torch.abs(torch.std(oop2[:,1])-params.klsig)
            
            kss=float(params.klsig)
            
            
            loss1 = loss_fn(oop, tgt_batch.float())
        else: # This code is not used
            loss1 = loss_fn(output[:,0], (tgt_batch*2-1).float())
        if epoch>=params.se_epoch_start: # I think SE is for self ensembling
            loss=loss1+params.c*loss2+params.c2*loss3
        else:
            loss=loss1+params.c2*loss3
        all_costs.append(loss.item())
        words_count += (s1_batch.nelement()) / params.word_emb_dim

        # backward
        optimizer.zero_grad()
        loss.backward()

Hi Josh!

What’s going on is that when you increase the number of classes to 4
from 2, the shapes of some tenors change. You need to account for this
or you will get these shape-mismatch errors.

In this particular instance ou has the length of its rows increase to 4
(from 2), but torch.cat ((a, a), 1) creates a tensor with row-lengths
of only 2, hence the mismatch. You could cat() four copies of a together,
but using cat() to “expand” a isn’t really the right approach.

This example, based on a simplified version of your code, illustrates the
cause of the error and the use of broadcasting to use a to mask ou.
Consider:

>>> import torch
>>> print (torch.__version__)
1.12.0
>>>
>>> _ = torch.random.manual_seed (2021)
>>>
>>> param = 0.65
>>>
>>> outputA = torch.randn (10, 2)   #  batch of ten two-class predictions
>>> outputB = torch.randn (10, 4)   #  batch of ten four-class predictions
>>> ouA = outputA.softmax (dim = 1)
>>> ouB = outputB.softmax (dim = 1)
>>>
>>> aA, _ = torch.max (ouA, 1)
>>> aB, _ = torch.max (ouB, 1)
>>> aA = aA > param
>>> aB = aB > param
>>>
>>> ouA.shape
torch.Size([10, 2])
>>> aA.shape    # want to use a to mask ou, but it's a different shape
torch.Size([10])
>>> ouB.shape
torch.Size([10, 4])
>>> aB.shape    # want to use a to mask ou, but it's a different shape
torch.Size([10])
>>>
>>> # round-about way to make shapes match
>>> aAV = aA.view (-1, 1).float()
>>> torch.cat ((aAV, aAV), 1).shape                     # works for two classes
torch.Size([10, 2])
>>> ouAMasked = ouA * torch.cat ((aAV, aAV), 1)
>>> aBV = aB.view (-1, 1).float()
>>> torch.cat ((aBV, aBV), 1).shape                     # won't work for four classes
torch.Size([10, 2])
>>> try:
...     ouBMasked = ouB * torch.cat ((aBV, aBV), 1)
... except:
...     print ('failed')
...
failed
>>> # could cat four copies of a together ...  but, yuck!
>>> torch.cat ((aBV, aBV, aBV, aBV), 1).shape
torch.Size([10, 4])
>>> ouBMasked = ouB * torch.cat ((aBV, aBV, aBV, aBV), 1)
>>>
>>> # simpler approach:
>>> #   use unsqueeze() to add trailing singleton dimensions
>>> #   then use broadcasting to "match" dimensions of ou and mask
>>> #   (also, pytorch will convert boolean to 0 and 1 in this case)
>>>
>>> ouAMaskedB = ouA * aA.unsqueeze (-1)
>>> ouBMaskedB = ouB * aB.unsqueeze (-1)
>>>
>>> torch.equal (ouAMaskedB, ouAMasked)
True
>>> torch.equal (ouBMaskedB, ouBMasked)
True

Best.

K. Frank

2 Likes

Thank you for the quick answer! Replacing both the “view” lines and the “cat” lines with “unsqueeze” lines resolved the issue and I was able to test successfully with 4 classes.

It turns out I ran into another similar error right after that, though I didn’t realize until now as it was caught. The error is in the last line in the following code:

            tgt_batch=torch.cat([1.0-tgt_batch.view(-1,1),tgt_batch.view(-1,1) ],dim=1)
            oop=F.softmax(output, dim=1)
            oop2=F.softmax(outputu, dim=1)
            loss3=0
            if params.use_gpu:
                pppp=Variable(torch.FloatTensor([1/oop.size(0)]).cuda())
            else:
                pppp=Variable(torch.FloatTensor([1/oop.size(0)]))
            dmiu=torch.mean(oop2[:,1])
            dstd=torch.std(oop2[:,1])
            loss3=loss3+torch.abs(torch.mean(oop2[:,1])-params.klmiu)+torch.abs(torch.std(oop2[:,1])-params.klsig)
            
            kss=float(params.klsig)
            
            loss1 = loss_fn(oop, tgt_batch.float())

Error:

Traceback (most recent call last):
File “train.py”, line 422, in trainepoch
loss1 = loss_fn(oop, tgt_batch.float())
File “/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py”, line 1130, in _call_impl
return forward_call(*input, **kwargs)
File “/usr/local/lib/python3.7/dist-packages/torch/nn/modules/loss.py”, line 613, in forward
return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)
File “/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py”, line 3076, in binary_cross_entropy
“Please ensure they have the same size.”.format(target.size(), input.size())
ValueError: Using a target size (torch.Size([32, 2])) that is different to the input size (torch.Size([32, 4])) is deprecated. Please ensure they have the same size.

Similar to the error in my first post, the problem is doing a view and then cat where we’re assuming only 2 clases:

tgt_batch=torch.cat([1.0-tgt_batch.view(-1,1),tgt_batch.view(-1,1) ],dim=1)

I’m not sure how to use unsqueeze to solve this, so I thought to write a loop that create a list the size of the number of classes and use that for cat. But I see that the first class here is opposite the second, with the “1.0-”. It appears that just copying this for class 3 and 4 will lead to unintended results.

Follow up:

  1. I used unsqueeze and it didn’t work because the other variable oop didn’t match the dimensions.
  2. I used a loop to dynamically create the list, and the code runs. Whether it does what it’s supposed to is another question, as no “1.0-” was used.

Code I’m using:

target_classes = [tgt_batch.view(-1, 1) for _ in range(config_nli_model[‘n_classes’])]
tgt_batch = torch.cat(target_classes, dim=1)

Training appears to continue successfully.

Hi Josh!

The error message indicates that you are using binary_cross_entropy
as your loss_fn. But you’ve increased your number of classes, so you
have a *multi-class" problem (rather than a binary problem), so you would
want to use something like CrossEntropyLoss as your loss criterion.

With CrossEntropyLoss, you would not pass the output of your model
through softmax() – your predictions would typically be the output of
your final Linear layer. Make sure you understand what shapes and
types CrossEntropyLoss expects for its input and target.

You really shouldn’t be using cat() to obtain the correct tensor shapes.
Make sure you understand the shape, type, and meaning of your labels
and the output of your model. Those will determine how they need to
be processed in order to be passed into CrossEntropyLoss (or similar
loss criterion).

Best.

K. Frank

Thanks! Good catch and my mistake not to pay attention to that. I tested with CrossEntropy instead of BCE and it ran, though not with good results yet. Still working on the rest of your answer.