Quickstart tutorial: __array__() takes 1 positional argument but 2 were given

I’ve tried the quickstart tutorial on two machines - one without a CUDA-compatible GPU running PyTorch 1.9.0-CPU, and the other with CUDA-compatible GPU running PyTorch 1.9.0+cu111. For the first one (with CPU), the tutorial works fine, however with the second one (with CUDA) it fails in the first section.

I copied and pasted the code in directly from the tutorial:

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda, Compose
import matplotlib.pyplot as plt

# Download training data from open datasets.
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)

# Download test data from open datasets.
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

batch_size = 64

# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

for X, y in test_dataloader:
    print("Shape of X [N, C, H, W]: ", X.shape)
    print("Shape of y: ", y.shape, y.dtype)
    break

Running this gives me the error:

TypeError: __array__() takes 1 positional argument but 2 were given

It’s the for loop that causes the error, it runs fine if that for loop is commented out. Looking on this forum and google, it seems a common reason is ommitting the init method from a function, but obviously I haven’t defined any functions yet in this tutorial.

Thanks for any help.

The issue might be related to PIL==8.3.0 as described here and @tom already provided a fix here.
If you are also using PIL==8.3.0, downgrade it to 8.2.0 for now. If that’s not the case, please let us know and we’ll take another look into your issue.

2 Likes

Thanks,this solution does work! :+1: :+1: :+1:

This worked thank you. Just to confirm, it was pillow that I had to downgrade to 8.2.0, which is a fork of PIL.