PyTorch ArrayRef invalid index problem with linear nn

I have a headache and following code, that doesn’t let me sleep at night:

  1 import torch
  2 import torch.nn as nn
  3 import torch.optim as optim
  4 import numpy as np
  6 samples = torch.linspace(0, 100,100) # GENERATE THE SET
  7 train_split = int(len(samples)*0.8)
  8 x_train, x_test = samples[:train_split], samples[train_split:]
  9 y_labels = 2*samples-4 # define the function
 10 y_labels += torch.tensor(np.random.normal(0, 5, len(samples))) # ADD NOISE
 13 class NeuralNetwork(nn.Module):
 14     def __init__(self):
 15         super().__init__()
 16         self.fc1 = nn.Linear(1, 1)
 17     def forward(self, x):
 18         return self.fc1(x)
 19 model = NeuralNetwork()
 20 loss_func = nn.MSELoss()
 21 optimizer = optim.Adam(model.parameters(), lr=0.001)
 22 num_epochs = 50
 23 for epoch in range(num_epochs):
 24     for inputs, labels in zip(x_train, y_labels[:train_split]):                                                                                                                                         
 25      
 26         y_pred = model(inputs) # File ".../torch/nn/modules/linear.py", line 116, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: ArrayRef: invalid index Index = 18446744073709551615; Length = 0
 27         loss = loss_func(y_pred, labels)
 28         optimizer.zero_grad()
 29         loss.backward()
 30         optimizer.step()
 31         print("Epoch %d  - loss: %.4f%" % (epoch, loss))

It is a simple one layer thing, and I’m doing it intiuitevly so no need to call me a goof. I run it on couple machines, no difference…

P.S. ANY SUGGESTIONS ON FURTHER IMPROVEMENTS IN THE PROCESS OF WRITING THOSE BEASTS WOULD BE APPRECIATED!!!

Unsqueeze a dimension of the input tensor before passing it to the model:

y_pred = model(inputs.unsqueeze(0))

and it should work.
This seems like a nasty bug and I will create a GitHub issue for it.

EDIT: #119161

I wrapped my data properly (…torch.Tensor([[x_train]])… ), and now it works stuck_out_tongue_winking_eye: but thank you!!! Also, I just wanted to express gratitude and my awe at how much you do for the Pytorch community by answering so many stupid and not-stupid questions. Thank you!

1 Like