ValueError: Using a target size (torch.Size([32, 1])) that is different to the input size (torch.Size([7200, 1])) is deprecated. Please ensure they have the same size

Starting Training Loop…
/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py:477: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
cpuset_checked))
/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:42: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.

ValueError Traceback (most recent call last)
in ()
36 # train with real
37 s_output, c_output = netD(real_image)
—> 38 s_errD_real = s_criterion(s_output, real_target) # realfake
39 c_errD_real = c_criterion(c_output, real_label) # class
40 errD_real = s_errD_real + c_errD_real

2 frames
/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py in binary_cross_entropy(input, target, weight, size_average, reduce, reduction)
2753 raise ValueError(
2754 "Using a target size ({}) that is different to the input size ({}) is deprecated. "
→ 2755 “Please ensure they have the same size.”.format(target.size(), input.size())
2756 )
2757

ValueError: Using a target size (torch.Size([32, 1])) that is different to the input size (torch.Size([7200, 1])) is deprecated. Please ensure they have the same size.

I am unable to solve the above problem.Would anyone kindly help me to solve it?

The error is raised, since s_output and real_target have a different shape, which is not supported.
I don’t know how the output is calculated or what the used batch size is supposed to be, so you would have to check where this size difference is coming from and what the expected shape would be.

PS: you can post code snippets by wrapping them into three backticks ```, which makes debugging easier. :wink:

Hi @ptrblck , i am having a similar problem i wonder if you could help me. It goes like this:
My target is a tensor torch.Size([731022, 1]) but my model returns torch.Size([1462044, 2])


And after i do this:
////////////////////////////////////////////////////////////////////////////////////////////////////////////
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(lstm1.parameters(), lr=learning_rate)

for epoch in range(num_epochs):
outputs = lstm1.forward(X_train_tensors_final) #forward pass
optimizer.zero_grad() #caluclate the gradient, manually setting to 0
loss = criterion(outputs, y_train_tensors)
loss.backward() #calculates the loss of the loss function
optimizer.step() #improve from loss, i.e backprop
if epoch % 1 == 0:
print(“Epoch: %d, loss: %1.5f” % (epoch, loss.item()))
///////////////////////////////////////////////////////////////////////////////////////////////
And the error i get is this:
UserWarning: Using a target size (torch.Size([731022, 1])) that is different to the input size (torch.Size([1462044, 2])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
return F.mse_loss(input, target, reduction=self.reduction)

Now i know its returning the output, hn and cn but what i dont understand is why is it returning 1462044 values which is double the values of the input. I hope you can help me, thanks!

I guess the issue is created by flattening hn, which has the shape [num_layers*num_directions, batch_size, hidden_size] as described in the docs, so you would need to check the shape before and after the flattening.
You might have missed my last sentence, so your code also cannot be debugged, since you posted a picture :wink:

1 Like

Hi @ptrblck thank you for answering, after looking at the docs its possible thats the problem but it also specifies “If the LSTM is bidirectional, num_directions should be 2, else it should be 1.”. My end goal is to make it bidirectional but right now i am trying without it being bidirectional so num_directions should be 1 i don’t really get where this variable comes from or where it is being defined. Also i’m sorry about the code format i thought it would be easier with the picture so i’ll just put it all down so you can get the hole picture.
/////////////////////////////////////////////
import torch
from torch import nn
from torch.autograd import Variable
import pandas as pd
from sklearn.preprocessing import StandardScaler, MinMaxScaler

#loading data
training_data = pd.read_excel(r’rawDATA_2outputs.xlsx’)
test_data = pd.read_excel(r’test_2outputs.xlsx’)
#shuffle data
training = training_data.sample(frac=1)
test = test_data.sample(frac=1)

#data preprocessing
X = training_data.iloc[:, :-1] #training data sem labels
y = training_data.iloc[:, 42:43] #labels da training_data
X_test = test_data.iloc[:, :-1] #test_data sem labels
y_test = test_data.iloc[:, 42:43] #labels da test_data

#scaling the data
mm = MinMaxScaler()
ss = StandardScaler()
X_ss = ss.fit_transform(X)
y_mm = mm.fit_transform(y)
X_test_ss = ss.fit_transform(X_test)
y_test_mm = ss.fit_transform(y_test)

#data as tensors
X_train_tensors = Variable(torch.Tensor(X_ss))
y_train_tensors = Variable(torch.Tensor(y_mm))
X_test_tensors = Variable(torch.Tensor(X_test_ss))
y_test_tensors = Variable(torch.Tensor(y_test_mm))

#reshape
X_train_tensors_final = torch.reshape(X_train_tensors, (X_train_tensors.shape[0], 1, X_train_tensors.shape[1]))
X_test_tensors_final = torch.reshape(X_test_tensors, (X_test_tensors.shape[0], 1, X_test_tensors.shape[1]))

#build the network
class LSTM1(nn.Module):
def init(self, num_classes, input_size, hidden_size, num_layers, seq_length):
super(LSTM1, self).init()
self.num_classes = num_classes # number of classes
self.num_layers = num_layers # number of layers
self.input_size = input_size # input size
self.hidden_size = hidden_size # hidden state
self.seq_length = seq_length # sequence length

    self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size,
                        num_layers=num_layers, batch_first=True)  # lstm
    self.fc_1 = nn.Linear(hidden_size, 128)  # fully connected 1
    self.fc = nn.Linear(128, num_classes)  # fully connected last layer

    self.relu = nn.ReLU()

def forward(self, x):
    h_0 = Variable(torch.zeros(self.num_layers, x.size(0), self.hidden_size))  # hidden state
    c_0 = Variable(torch.zeros(self.num_layers, x.size(0), self.hidden_size))  # internal state
    # Propagate input through LSTM
    output, (hn, cn) = self.lstm(x, (h_0, c_0))  # lstm with input, hidden, and internal state
    hn = hn.view(-1, self.hidden_size)  # reshaping the data for Dense layer next
    out = self.relu(hn)
    out = self.fc_1(out)  # first Dense
    out = self.relu(out)  # relu
    out = self.fc(out)  # Final Output
    return out

#hyperparameters
num_epochs = 1
learning_rate = 0.001
input_size = 42 #number of features
hidden_size = 128 #number of features in hidden state
num_layers = 2 #number of stacked lstm layers
num_classes = 2 #number of output classes

#aplicar a network e definir criterio de erro e optimizador
lstm1 = LSTM1(num_classes, input_size, hidden_size, num_layers, X_train_tensors_final.shape[1])
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(lstm1.parameters(), lr=learning_rate)

#train
for epoch in range(num_epochs):
outputs = lstm1.forward(X_train_tensors_final) # forward pass
optimizer.zero_grad() # caluclate the gradient, manually setting to 0

# obtain the loss function
loss = criterion(outputs, y_train_tensors)

loss.backward()  # calculates the loss of the loss function

optimizer.step()  # improve from loss, i.e backprop
if epoch % 1 == 0:
    print("Epoch: %d, loss: %1.5f" % (epoch, loss.item())) 

//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
I’ll also try to debug and see what happens, but anyway thank you for your time and help!

This code illustrates your usage and the wrong flattening:

# setup
input_size = 42 
hidden_size = 128 
num_layers = 2

lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size,
                        num_layers=num_layers, batch_first=True)

batch_size, seq_len = 3, 4
x = torch.randn(batch_size, seq_len, input_size)
output, (hn, cn) = lstm(x)
print(output.shape) # [batch_size, seq_len, num_directions*hidden_size]
> torch.Size([3, 4, 128])
print(hn.shape) # [num_layers*num_directions, batch_size, hidden_size]
> torch.Size([2, 3, 128])
print(cn.shape) # [num_layers*num_directions, batch_size, hidden_size]
> torch.Size([2, 3, 128])

# this is most likely not what you want!
hn = hn.view(-1, hidden_size)
print(hn.shape) # [num_layers*num_directions*batch_size, hidden_size]
> torch.Size([6, 128])

I guess you might want to use the output (or its last step) instead of flattening the batch dimension into the layer/direction dim of the hidden state.

1 Like

Thank you for your response.
I’m working with acgan to generate image from noise.I am trying to save the images of mine as per the size of input image(299x299).For giving image_size =64 it works perfectly but when it’s changed the training part gives error.

Could you change this line of code:

x = x.view(-1, self.ndf * 1)

to

x = x.view(x.size(0), -1)

and check, if you would get a shape mismatch?
The latter approach makes sure the batch dimension stays constant, while the latter one could push additional dimensions to the batch dim.

After changing

x = x.view(-1, self.ndf * 1)

to

x = x.view(x.size(0), -1)

it’s showing → “RuntimeError: mat1 dim 1 must match mat2 dim 0” error.

This new error points to the expected shape mismatch in the activation and the next layer (I assume it’s self.aux_linear and self.disc_linear).
You can print the shape of x after the proper flattening and set the in_features of both linear layers to this value.

why it’s not printing the shape of anything else?

``` def onehot_encode(label, device, n_class=n_class):
eye = torch.eye(n_class, device=device)
return eye[label].view(-1, n_class, 1, 1)

def concat_image_label(image, label, device, n_class=n_class):
B, C, H, W = image.shape
oh_label = onehot_encode(label, device=device)
oh_label = oh_label.expand(B, n_class, H, W)
return torch.cat((image, oh_label), dim=1)

def concat_noise_label(noise, label, device):
oh_label = onehot_encode(label, device=device)
return torch.cat((noise, oh_label), dim=1) ```

as well is this causing any issue???

I don’t know why the print isn’t working, as I cannot see the error message in your screenshot.
You could post an executable code snippet by wrapping in into three backticks ```, in case you get stuck.
This would allow us to copy-paste the code and debug it easily, which is not possible with unformatted source code or screenshots.

Thanks for the code.
The printing works as expected and shows that the x activation has a shape of torch.Size([32, 67275]) before being passed to self.aux_linear and self.disc_linear.
However, changing the in_features to this value won’t solve the issue, since you are using different input shapes for netD. While real_image has a shape of [batch_size, 3, 299, 299], fake_image is [batch_size, 3, 64, 64] (which would work for in_features=299 as was originally set), so you would have to resize the inputs or use an adaptive pooling layer internally and make sure the activation has the same size.

Thank you for your continuous support :blush:


I am have updated for my genrator and discriminator and as per the output you may see that both works properly but now i am getting " RuntimeError: CUDA out of memory. Tried to allocate 2.00 GiB (GPU 0; 14.76 GiB total capacity; 12.20 GiB already allocated; 523.75 MiB free; 13.21 GiB reserved in total by PyTorch) " error.It’s mentioned that i have taken image_size=256,ndf and ngf=256,nz=128,n_classes=120…I have reduced batch_size 32 to 8 but reducing batch size haven’t give any benefit.What this error is for?

The error shows that your are running out of memory, so you would need to either reduce the batch size further or try to save memory by e.g. lowering the memory footprint in the overall training using smaller models or by trading compute for memory via torch.utils.checkpoint.

I used batch size 8.Would I reduce more?

Hi @ptrblck i wonder if you could clarify me on something. I’m running this bidirectional LSTM which i feed with dataloaders and loss is reducing but the accuracy remains the same. I feel like i’m doing something wrong on the preds variable when i call the softmax function. The code is as below.

class BiLSTM(nn.Module):

def __init__(self, input_dim, hidden_dim, layer_dim, output_dim):
    super().__init__()
    self.input_dim = input_dim
    self.hidden_dim = hidden_dim
    self.layer_dim = layer_dim
    self.lstm = nn.LSTM(input_dim, hidden_dim, layer_dim, dropout=0.2, bidirectional=True, batch_first=True) 
    self.fc_1 =  nn.Linear(hidden_dim*2, 128) #fully connected 1  
    self.fc = nn.Linear(hidden_dim, output_dim) #fully connected last layer
    self.relu = nn.ReLU()  #activation function
       
def forward(self,x):
    h_0 = torch.zeros(self.layer_dim*2, x.size(0), self.hidden_dim).requires_grad_() #hidden state
    c_0 = torch.zeros(self.layer_dim*2, x.size(0), self.hidden_dim).requires_grad_() #internal state
    # Propagate input through LSTM
    output, (hn, cn) = self.lstm(x, (h_0, c_0)) #lstm with input, hidden, and internal state
    output = output[:,-1,:]
    out = self.relu(output)
    out = self.fc_1(out) #first Dense
    out = self.relu(out) #relu
    out = self.fc(out) #Final Output
    #sigmoid
    out = torch.sigmoid(out)
    #out = F.log_softmax(out, dim=1).argmax(dim=1)
    
    return out

input_dim = 42
hidden_dim = 128
layer_dim = 2
output_dim = 1

lr = 0.001
n_epochs = 100
iterations_per_epoch = len(trn_ds)
best_acc = 0
patience, trials = 10, 0

model = BiLSTM(input_dim, hidden_dim, layer_dim, output_dim)
model = model.train()
criterion = nn.BCELoss()
opt = torch.optim.Adam(model.parameters(), lr=lr)

print(‘Start model training’)

for epoch in range(1, n_epochs + 1):

for i, (x_batch, y_batch) in enumerate(trn_dl):
    y_batch = y_batch.unsqueeze(1)
    y_batch = y_batch.float()
    out = model(x_batch)
    loss = criterion(out.float(), y_batch)
    #loss.requires_grad = True
    opt.zero_grad()
    loss.backward()
    opt.step()

model.eval()
correct, total = 0, 0
for x_val, y_val in val_dl:
    x_val, y_val = [t.cpu() for t in (x_val, y_val)]
    out = model(x_val)
    preds = F.log_softmax(out, dim=1).argmax(dim=1)
    total += y_val.size(0)
    correct += (preds == y_val).sum().item()

acc = correct / total

#if epoch % 5 == 0:
print(f’Epoch: {epoch:3d}. Loss: {loss.item():.4f}. Acc.: {acc:2.2%}’)

if acc > best_acc:
    trials = 0
    best_acc = acc
    torch.save(model.state_dict(), 'bestSimpleLSTM.pth')
    print(f'Epoch {epoch} best model saved with accuracy: {best_acc:2.2%}')
else:
    trials += 1
    if trials >= patience:
        print(f'Early stopping on epoch {epoch}')
        break

Thank you.

torch.argmax(output, dim=1) is usually used on outputs containing the logits (or probabilities) with nb_classes in dim1 (i.e. in the shape [batch_size, nb_classes]), which would be the case for a multi-class classification using nn.CrossEntropyLoss.
In your case you are using a single output unit, apply sigmoid on it, and use nn.BCELoss as the criterion, which points towards a binary classification (small advice: remove the sigmoid and use nn.BCEWithLogitsLoss for more numerical stability).
In that case you should use a threshold to get the predicated class labels, e.g. via:

preds = out > 0.0 # if out contains logits
preds = torch.sigmoid(out) > 0.5 # since out now contains probabilities
1 Like