Adding weights to loss gives broadcast error

I have a model which trains fine when I use the BCELoss without weights, but when I add weights I get a broadcast error as follows

RuntimeError: output with shape [128, 1] doesn't match the broadcast shape [128, 2]

The code is quite straightforward

model = LinearRegression()
model.to(device)
lr = 1e-1
optimizer = optim.SGD(model.parameters(), lr=lr)
n_epochs = 50
loss_fn = nn.BCELoss(weight=torch.tensor([3, 7]))
model = train(n_epochs=n_epochs, model=model, train_loader=train_loader, optimizer=optimizer, loss_fn=loss_fn)

And the code for the trainer is

def train(model, train_loader, optimizer, loss_fn, n_epochs):
    training_losses = []

    for epoch in range(n_epochs):
        batch_losses = []
        for x_train, y_train in train_loader:
            if(str(type(loss_fn)).startswith("<class 'torch.nn.modules.loss")):
                y_true = y_train[:, 0]
                y_true = y_true.view(y_true.shape[0], 1)
            else:
                y_true = y_train
            model.train()   
            y_pred = model(x_train).view(y_true.shape[0], 1)
            loss = loss_fn(y_pred, y_true)
            loss.backward()    
            optimizer.step()
            optimizer.zero_grad()
            batch_losses.append(loss.item())
        training_loss = np.mean(batch_losses)
        training_losses.append(training_loss)

        print(f"[{epoch+1}] {training_loss:.3f}")
    return(model)

I have tried to go through the code for the loss function on git but I do not get what I can do to solve this. Sometimes I use custom losses and sometimes the pytorch loss, that is why the first few lines in the trainer look how they are. Also, I have made sure that y_train[:, 0] corresponds to the target. Thanks for you help!

Hi, if you clearly look at the documentation, the weight must be same size of the target or predictions.
In your case, the weight Tensor must be of size [128,1] and not [128,2].
You can use the weights the following way.

weight = torch.zeros_like(target)
weight[target==0] = 3
weight[target==1] = 7