[Code review] Combining BCE and MSE loss

Hi all — I’m new to deep learning and PyTorch and am playing with some small examples to learn the ropes. I just wanted to check that I’ve implemented something correctly.

I’m looking to use a CNN to detect whether there’s a quasi-diagonal line in an image of a distance matrix (1 = Yes, 0 = No), and if so where (t_start, t_end), like in the image below.

Here is the code I’ve written:

BCE      = torch.nn.BCELoss()
MSE      = torch.nn.MSELoss()

labels   = torch.tensor([
    # Class, t_start, t_end
    [0, 0, 0],
    [1, 0.25, 0.31],
    [0, 0, 0]
])

# Simulate output = model(distance_matrix)
output = torch.tensor([
    [0.1478, 0.1370, 0.2288],
    [0.7052, 0.3207, 0.5404],
    [0.4511, 0.2978, 0.5279]
])

# Binary cross-entropy error on class labels, mean-squared error on time points
# Use labels to zero-out irrelevant rows
loss = BCE(output[:,:1], labels[:,:1]) + MSE(labels[:,:1] * output[:,-2:], labels[:,-2:])

# etc ...
# loss.backward()
# optimizer.step()

I was wondering whether using the labels to zero out the rows of the output I wasn’t interested in (where class = 0) was the right way to go about this. Thanks!

Hey Nay!

I don’t really understand your overall use case – “test utterance,”
applying the model to a distance matrix – so I don’t have any
thoughts on your “real” problem

Three lower-level comments:

Your code for loss looks sensible and correct. (I’ll leave it up to
you as to whether at the higher level it does anything useful or
what you want it to do.)

As written, your loss doesn’t care at all about the values of
output[0, 1], output[0, 2], output[2, 1], and output[2, 1].
So, in principle, these values – and the weights that produce
them – could drift off and become large. My gut tells me that it
might be beneficial (although hardly essential) to regulate these
output values (that don’t logically matter).

Assuming that a “class-0” labels row – that is, a row i for which
labels[i, 0] is 0 – is all zeros – that is, for that i, labels[i, j]
is 0.0 for all j – then simply using:

MSE (output[:, -2:], labels[:, -2:])

(or MSE (output[:, 1:], labels[:, 1:]), see below), will
provide that regularization. (Weight decay or a weight-regularization
term in your loss would presumably provide equivalent benefit.)

As a minor stylistic comment: Your choice of slicing notation looks
rather perverse to me. For readability (at least for me), I would prefer,
e.g., output[:, 0] to extract the first column, and output[:, 1:] to
extract the remaining columns. output[:, :1] and output[:, 1:]
could also work, as you could argue that using :1 and 1: together
emphasizes that you are using all of the columns (but I still like 0
better for the first column).

Best.

K. Frank

Hi K. Frank!

Thanks so much for your response. Low level comments are super helpful because I’m at the stage where I’m mainly trying to discover/get used to what best practices are around for writing PyTorch code.

I don’t really understand your overall use case – “test utterance,” applying the model to a distance matrix – so I don’t have any thoughts on your “real” problem

Ah yes, sorry, I didn’t explain it at all. The general task is called ‘query-by-example spoken term detection’ (QbE-STD) where you search for a spoken query (e.g. Q = ‘coffee’) in a corpus of audio documents (D1 = ‘I had coffee today’, D2 = ‘Where is the car?’) and output how likely the query is to occur in each document. If you don’t have access to a speech-to-text system, you typically just try to match on the spectral features (e.g. MFCCs) extracted from the audio files.

So the distance matrix is for a query Q of length M with F features and document D of length N with F features is a matrix of size M x N (taking the standardised Euclidean distance between each pair of feature vectors in Q and D). If the query occurs in the document, you typically get a quasi-diagonal band showing high spectrotemporal correlation somewhere along the document. I’m slowly making my way through replicating this paper while also learning PyTorch and looking at the GitHub repo associated with the paper.

While the approach in that paper is interesting, it occurred to me while learning about object detection in my deep learning class that I could try to extend this to the QbE-STD. My (naive) thought on this was to change the final layer from nn.Linear(60, 1) to nn.Linear(60, 3) (full model here) and make the network estimate, if a query occurs in the document, then at which time points (where t_start and t_end are proportion to the document, so in [0, 1]). Happy to hear any thoughts/cautions about this if you have any.

Assuming that a “class-0” labels row – that is, a row i for which labels[i, 0] is 0 – is all zeros – that is, for that i, labels[i, j] is 0.0 for all j – then simply using: MSE (output[:, -2:], labels[:, -2:]) (or MSE (output[:, 1:], labels[:, 1:]), see below), will provide that regularization. (Weight decay or a weight-regularization term in your loss would presumably provide equivalent benefit.)

Ah, I see. I don’t think I’ve grasped the entirety of what the loss function is doing then. I guess when I was playing around I was only looking at the output of MSE(), which is torch.Size([]). I’m guessing since MSE and BCE are instantiated as classes, e.g. MSE = torch.nn.MSELoss(), there are some internals in the object that keep track of the pairwise differences? So leaving the outputs as-is and doing the MSE will encourage the network to predict t_begin = 0 and t_end = 0 when class = 0, therefore constraining the weights that are associated with the 2nd and 3rd outputs when something like loss.backwards() is called? Please correct me if I’m wrong anywhere.

I was also wondering whether using 0 for values I wasn’t interested in was a good approach. One issue I see is confounding for t_start because for many cases where class = 1, the t_start is also 0, e.g. for ‘coffee’ in ‘Coffee’s the best!’ the t_start is 0 since the query ‘coffee’ is at the start of the document.

As a minor stylistic comment: Your choice of slicing notation looks rather perverse to me. For readability (at least for me), I would prefer, e.g., output[:, 0] to extract the first column, and output[:, 1:] to extract the remaining columns. output[:, :1] and output[:, 1:] could also work, as you could argue that using :1 and 1: together emphasizes that you are using all of the columns (but I still like 0 better for the first column).

Ah yes, this is the result of me trying various things in the console till the slice gave me what I wanted. Something I haven’t fully grasped is how : interacts with indices. I see suffixing the : to an index 1:, e.g. output[:, 1:], means return columns from 1 onwards. But when : is prefixed, as in output[:, :1], then the index starts from 1 instead of 0? So output[:, :2] returns the 1st and 2nd columns (not columns 0, 1, 2). For small label/output matrices, I think I might stick to explicit indexing with output[:, [0]] for the first column and use output[:, [0, 1]] for the 2nd and 3rd columns (coming from an R background).

Thanks!
Nay

Hey Nay!

Just one clarifying comment on this point:

In general, torch.nn.SomePytorchLoss (the class version) and
torch.nn.functional.some_pytorch_loss() (the function version)
are equivalent. Instantiating the class version gives an instance of
a function object (that is, an object instance that has a “call” method
(instance.(args ...)) that does essentially the same thing as the
function version.

The class version doesn’t really have any “smart,” loss-specific
internals. You can typically pass the constructor some “control”
parameters such as class weights or reduction = 'sum'. These
are stored internally to the specific instance constructed. But the
instance doesn’t remember anything from call to call about the actual
loss values calculated.

Best.

K. Frank

The class version doesn’t really have any “smart,” loss-specific
internals.

Ah, I see. Hm.

As written, your loss doesn’t care at all about the values of
output[0, 1] , output[0, 2] , output[2, 1] , and output[2, 1] .
So, in principle, these values – and the weights that produce
them – could drift off and become large. My gut tells me that it
might be beneficial (although hardly essential) to regulate these
output values (that don’t logically matter).

So including output[0, 1] , output[0, 2] , output[2, 1] , and output[2, 1] in the loss calculation even though they don’t matter has a regulating effect so that the network finds some weights that continually predict [0, 0] for [t_start, t_end] when class = 0, and penalises any changes that move away from that? I guess my original line of thinking was that having the network predict [t_start, t_end] for class = 0 would ‘distract’ from learning good weights for when I do care about [t_start, t_end] when class = 1. But I guess that’s theory vs. practice. Thanks for your help!