Tensor dtype for input batches to pytorch networks

When adding torchvision.transform.ToTensor() to our dataloaders transformation, ToTensor() sets the datatype of input images X to torch.float32 and labels y to torch.int64. Is there a specific reason for this?

float32 is the default dtype in PyTorch and floating point types are expected in training a neural network, which is why the integer image types are transformed to float32.

I don’t think that’s the case, as ToTensor() would expect an image-line input as seen here.

thank you for your answer.
So in this piece of code, after applying the transformations on the data set, we get the aforementioned types for x and y:

training_data = torchvision.datasets.CIFAR10(

test_data = torchvision.datasets.CIFAR10(

batch_size = 64

train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

for X, y in test_dataloader:

print(f"Shape of X [N, C, H, W]: {X.shape} {X.dtype}")
print(f"Shape of y: {y.shape} {y.dtype}")

output is:

Shape of X [N, C, H, W]: torch.Size([64, 3, 32, 32]) torch.float32
Shape of y: torch.Size([64]) torch.int64

so I concluded that theses stypes are set after applying ToTensor() to the dataset. is torch.int64 the default data type for labels? what happens if we change these types to torch.float64 and torch.int32?

The CIFAR10 dataset applies the transform only to the data samples as seen here and uses target_transform for the targets:

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

The targets are stored as integers in a list and the DataLoader’s collate_fn would stack them to tensors, which would then default to int64:

# torch.int64

nn.CrossEntropyLoss as well as nn.NLLLoss expect int64 tensors as the targets which is why it’s the default type (unless CrossEntropyLoss is used with probabilities in newer PyTorch releases).

While the data loading would work, your loss calculation might fail depending which loss you are using.

Thank you very much for the through answer.