Top_p, top_class = ps.topk(1, dim=1)

Pls scroll towards 1/3 down from this page

for

equals = top_class == labels

should it be

equals = top_class == labels.view(-1,64)

a small test that I did

rather than

equals = top_class == labels.view(*top_class.shape)

*I understood the unpacking of top_class.shape is based on this

because the former has a shape of (64,64) based on the explanation of

equals will have shape (64, 64), try it yourself. What it’s doing is comparing the one element in each row of top_class with each element in labels which returns 64 True/False boolean values for each row.This is because of broadcasting.

Pls advise.
Cheers and Happy New Year.

I haven’t executed the notebook, but if top_class has the shape [64, 64], your view operation will return the same shape as the current implementation in the notebook.
Also, using the shape of top_class seems to take into consideration use cases with another number of classes.

Please correct me, if I misunderstood something.

Thanks for your reply.

Here top_class is a 2D tensor with shape (64, 1) while labels is 1D with shape (64) .

We need to have top_class with shape of (64, 1) and ‘labels’ (1,64) in order for it to invoke broadcasting. With this top_class’s shape would be (64, 64) . The 64 refers to the batch size in this scenaraio.

Thanks.

As far as I understand the notebook, broadcasting should not be used:

To get the equality to work out the way we want, top_class and labels must have the same shape.
If we do
equals = top_class == labels
equals will have shape (64, 64) , try it yourself. What it’s doing is comparing the one element in each row of top_class with each element in labels which returns 64 True/False boolean values for each row.

So this behavior is not desired, but instead each element of top_class should be compared to only the corresponding element in labels.

Thanks I finally understood.

Would it be the same if I use

equals = clas == lab.view(-1,1)

as the output is the same.

.view(-1, 1) should work in this case, but .view(*clas.shape) seems to be the safest way here. :wink: