Torch.cat and backpropagation

Yes you should.
You should know this value based on the input to your network no? That is why I mentionned that you should extract one input from your Dataset to be able to compute these values when creating your network.

Oh OH, I need the whole batch to compute n1 and n2. It should be noted that my batch is the complete training set, not just some instances because of the problem I solve.

Maybe this is the case, the not so wrong code? Keep in mind that inputs are all instances of training

This code is correct, assuming that inputs is all the instances of the training, i.e. there is only 1 batch in trainloader?

for epoch in range(n_epochs):

for i, [inputs, labels] in enumerate(trainloader):
    inputs = inputs.to(device)
    labels = labels.to(device)

    # Forward + backward + optimize
    outputs = net(inputs.float())

    if epoch == 1:
        optimizer = torch.optim.Adam(net.parameters())

    optimizer.zero_grad()
    
    loss = criterion(outputs, labels.float())
    loss.backward()
    optimizer.step()