Return vector but take loss function on element

Hi all,

I have a weird problem (so I guess). I want to train a model many times on different elements of the output. This works only if I output the respective element directly and it doesn’t work if I output a vector with all elements and calculate the loss only on one.
Lets say I have an output tensor r with size 1x4 and I’m interested in the loss on the last element r[3].

If I return from my forward path:

return r[3]

And then train:

y_hat = model(X)
loss = criterion(y_hat, y)

The predictions on the test set are perfectly fine. But if i return all elements:

return r

And train:

y_hat = model(X)
loss = criterion(y_hat[3], y)

In the ladder case the prediction is way off. In debugging mode, I see that the value which is passed to the loss function is in both cases exactly the same. Do I miss something here?

Thank you very much for your help!
Greetings

This shouldn’t make any difference as seen here:

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.fc1 = nn.Linear(10, 10, bias=False)
        self.fc2 = nn.Linear(10, 3, bias=False)
        
    def forward(self, x, return_slice=False):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        if return_slice:
            return x[:, -1]
        else:
            return x

model = MyModel()
x = torch.randn(10, 10)

# Complete batch approach
out = model(x)
out[:, -1].mean().backward()

for name, param in model.named_parameters():
    print(name, param.grad.abs().sum())

# Sliced approach
model.zero_grad()
out = model(x, return_slice=True)
out.mean().backward()

for name, param in model.named_parameters():
    print(name, param.grad.abs().sum())

You will get the same gradients for both methods.

How reproducible is this issue using your code?
I.e. are you always getting the good results in the former approach and bad ones in the latter?

Thank you for your help, it made me search at other places and in the end I found a mistake in the testing code. It all works fine now.