I am missing a transformation in the torchvision datasets. I think it would be grate to have a method that could make a batchcolormap28x28 Mnist image (as example) to a 1*784 vector, to train fully connected network or whatever.
Maybe I do not understand your question.
what you want is some operation like np.reshape
? You can use view
operation.
examples from doc
>>> x = torch.randn(4, 4)
>>> x.size()
torch.Size([4, 4])
>>> y = x.view(16)
>>> y.size()
torch.Size([16])
>>> z = x.view(-1, 8) # the size -1 is inferred from other dimensions
>>> z.size()
torch.Size([2, 8])
You can do that in torch instead of torchvision using torch.view
as @paul_c suggested. The call would be x.view(x.size(0), -1)
Yes, for torch tensor is pretty clear.
I meant as part of http://pytorch.org/docs/master/torchvision/transforms.html torchvision.transform. So I could for example load mnist using the MNIST loader and automatically iterate over the train_loader such as each element would be batch_size * dimension instead of batch_size * feature_map * height * width.
It seems torchvision package is only provided for convolutional (and it make sense) but I think could be interesting to have this functionality if you want to train a fully connected for whatever reason.
It is my first day with pytorch, I should first learn it
I found another way. So, for anyone:
train_feat and train_labl are numpy arrays with mnist with the shape I need.
mnist_dataset_train = torch.utils.data.TensorDataset(torch.from_numpy(train_feat).cuda(), torch.from_numpy(train_labl).cuda())
train_loader = torch.utils.data.DataLoader(mnist_dataset_train,batch_size=100,shuffle=True)
and just
for batch in train_loader:
Here’s ReshapeTransform I wrote:
class ReshapeTransform:
def __init__(self, new_size):
self.new_size = new_size
def __call__(self, img):
return torch.reshape(img, self.new_size)
You can use this for MNIST like this:
mnist_transforms=[
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
mnist_transforms.append(ReshapeTransform((-1,)))]
train_ds = datasets.MNIST(train=True, download=True,
transform=transforms.Compose(mnist_transforms));