When adding torchvision.transform.ToTensor() to our dataloaders transformation, ToTensor() sets the datatype of input images X to torch.float32 and labels y to torch.int64. Is there a specific reason for this?
float32
is the default dtype
in PyTorch and floating point types are expected in training a neural network, which is why the integer image types are transformed to float32
.
I don’t think that’s the case, as ToTensor()
would expect an image-line input as seen here.
thank you for your answer.
So in this piece of code, after applying the transformations on the data set, we get the aforementioned types for x and y:
training_data = torchvision.datasets.CIFAR10(
root=“data”,
train=True,
download=True,
transform=ToTensor()
)test_data = torchvision.datasets.CIFAR10(
root=“data”,
train=False,
download=True,
transform=ToTensor()
)
batch_size = 64
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)for X, y in test_dataloader:
print(f"Shape of X [N, C, H, W]: {X.shape} {X.dtype}") print(f"Shape of y: {y.shape} {y.dtype}") break
output is:
Shape of X [N, C, H, W]: torch.Size([64, 3, 32, 32]) torch.float32
Shape of y: torch.Size([64]) torch.int64
so I concluded that theses stypes are set after applying ToTensor() to the dataset. is torch.int64 the default data type for labels? what happens if we change these types to torch.float64 and torch.int32?
The CIFAR10
dataset applies the transform
only to the data samples as seen here and uses target_transform
for the targets:
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
The targets are stored as integers in a list
and the DataLoader
’s collate_fn
would stack them to tensors, which would then default to int64
:
torch.tensor(1).dtype
# torch.int64
nn.CrossEntropyLoss
as well as nn.NLLLoss
expect int64
tensors as the targets which is why it’s the default type (unless CrossEntropyLoss
is used with probabilities in newer PyTorch releases).
While the data loading would work, your loss calculation might fail depending which loss you are using.
Thank you very much for the through answer.