The expanded size of the tensor must match the existing size

RuntimeError: The expanded size of the tensor (19) must match the existing size (7484) at non-singleton dimension 0. Target sizes: [19]. Tensor sizes: [7484]

When I run the function "train_one_epoch (see below the code) I get the error above:

def **train_one_epoch**(epoch, model, train_loader, optimizer, loss_fn):
    # Enumerate over the data
    all_preds = []
    all_labels = []
    running_loss = 0.0
    step = 0
    for _, batch in enumerate(tqdm(train_loader)):
        # Use GPU
        batch.to(device)  
        # Reset gradients
        optimizer.zero_grad() 
        # Passing the node features and the connection info
        x= torch.from_numpy(batch[0].x)
        edge_attr= torch.from_numpy(batch[0].edge_attr)
        edge_index= torch.from_numpy(batch[0].edge_index)
        pred = model(x.to(torch.float32), 
                                edge_attr.to(torch.float32),
                                edge_index, 
                                batch.batch.to(torch.float32)) 
        # Calculating the loss and gradients
        loss = loss_fn(torch.squeeze(pred), batch.y)
        loss.backward()  
        optimizer.step()  
        # Update tracking
        running_loss += loss.item()
        step += 1
        all_preds.append(np.rint(torch.sigmoid(pred).cpu().detach().numpy()))
        all_labels.append(batch.y.cpu().detach().numpy())
    all_preds = np.concatenate(all_preds).ravel()
    all_labels = np.concatenate(all_labels).ravel()
    calculate_metrics(all_preds, all_labels, epoch, "train")
    return running_loss/step

I’m not sure if it’s related to the model architecture (see below the code):

class GNN(torch.nn.Module):
    def forward(self, x):
        pass
   
    def __init__(self, feature_size, model_params):
        super(GNN, self).__init__()
  
        embedding_size = model_params["model_embedding_size"][0]
        n_heads = model_params["model_attention_heads"][0]
        self.n_layers = model_params["model_layers"][0]
        dropout_rate = model_params["model_dropout_rate"][0]
        top_k_ratio = model_params["model_top_k_ratio"][0]
        self.top_k_every_n = model_params["model_top_k_every_n"][0]
        dense_neurons = model_params["model_dense_neurons"][0]
        edge_dim = model_params["model_edge_dim"][0]

        self.conv_layers = ModuleList([]) 
        self.transf_layers = ModuleList([])
        self.pooling_layers = ModuleList([])
        self.bn_layers = ModuleList([])

        # Initial Layers
        # Transformation layer
        self.conv1 = TransformerConv(feature_size, 
                                    embedding_size, 
                                    heads=n_heads, 
                                    dropout=dropout_rate,
                                    edge_dim=edge_dim,
                                    beta=True) 
        # Linear Layer
        self.transf1 = Linear(embedding_size*n_heads, embedding_size)
        # Batch Normalization Layer
        self.bn1 = BatchNorm1d(embedding_size)

        # Other layers
        for i in range(self.n_layers):
            self.conv_layers.append(TransformerConv(embedding_size, 
                                                    embedding_size, 
                                                    heads=n_heads, 
                                                    dropout=dropout_rate,
                                                    edge_dim=edge_dim,
                                                    beta=True))

            self.transf_layers.append(Linear(embedding_size*n_heads, embedding_size))
            self.bn_layers.append(BatchNorm1d(embedding_size))
            # TopKPooling layer
            if i % self.top_k_every_n == 0:
                self.pooling_layers.append(TopKPooling(embedding_size, ratio=top_k_ratio))
            

        #  Final linear layers
        self.linear1 = Linear(embedding_size*2, dense_neurons)
        self.linear2 = Linear(dense_neurons, int(dense_neurons/2))  
        self.linear3 = Linear(int(dense_neurons/2), 1)  

    def forward(self, x, edge_attr, edge_index, batch_index):
        # Initial transformation
        x = self.conv1(x, edge_index, edge_attr)
        x = torch.relu(self.transf1(x))
        x = self.bn1(x)

        # Holds the intermediate graph representations
        global_representation = []

        for i in range(self.n_layers):
            x = self.conv_layers[i](x, edge_index, edge_attr)
            x = torch.relu(self.transf_layers[i](x))
            x = self.bn_layers[i](x)
            # Always aggregate last layer
            if i % self.top_k_every_n == 0 or i == self.n_layers:
                x , edge_index, edge_attr, batch_index, _, _ = self.pooling_layers[int(i/self.top_k_every_n)](
                    x, edge_index, edge_attr, batch_index
                    )
                # Add current representation
                global_representation.append(torch.cat([gmp(x, batch_index), gap(x, batch_index)], dim=1))
    
        x = sum(global_representation)

        # Output block
        x = torch.relu(self.linear1(x))
        x = F.dropout(x, p=0.8, training=self.training)
        x = torch.relu(self.linear2(x))
        x = F.dropout(x, p=0.8, training=self.training)
        x = self.linear3(x)

        return x

Could someone point me to the right direction to fix this issue?

Based on the error message it seems an expand call is failing as seen in this code snippet:

a = torch.randn(7484, 1, 1)

# works as we are expanding singleton dimensions
b = a.expand(-1, 100, 200)
print(b.shape)
# torch.Size([7484, 100, 200])

# fails
b = a.expand(19, 100, 200)
# RuntimeError: The expanded size of the tensor (19) must match the existing size (7484) at non-singleton dimension 0.  Target sizes: [19, 100, 200].  Tensor sizes: [7484, 1, 1]

Note that I’ve added two additional dimensions to show how expand can be used.
I don’t know which line of code raises this error, but you should be able to check the stacktrace to narrow down the failing operation.

1 Like