I have just started with pytorch. I cannot seem to be capable of importing my custom cat & dog image dataset 0f 3 * 64 * 64 size for my simple convolutional class.
Data importing
from torchvision import datasets,transforms
mean = 0.5
std_dev = 0.5
transformer = transforms.Compose([transforms.ToTensor(),
transforms.Normalize(
(mean,
mean,
mean),
(std_dev,
std_dev,
std_dev))])
train_data = datasets.ImageFolder(root="./dataset/training_set",
transform=transformer)
test_data = datasets.ImageFolder(root="./dataset/test_set",
transform=transformer)
Converting the data into iterable forms
from torch.utils import data
batch_size=64
trainset = data.DataLoader(dataset=train_data,
batch_size=batch_size,
shuffle=True)
testset = data.DataLoader(dataset=test_data,
batch_size=batch_size,
shuffle=False)
Architecture
import torch.nn as nn
from torch.nn import functional as F
class CNN(nn.Module):
def init(self):
super(CNN, self).init()
# Input size 3 * 64 * 64
self.c1 = nn.Conv2d(in_channels=3,
out_channels=10,
kernel_size=5,
stride=3,
padding=1)
# 10 * 21 * 21
self.cl2 = nn.Conv2d(in_channels=10,
out_channels=20,
kernel_size=5,
stride=2,
padding=1)
# 10 * 10 * 10
self.mp = nn.MaxPool2d(kernel_size=2)
# 10 * 5 * 5
self.fc1 = nn.Linear(in_features=1055,
out_features=50)
self.dp = nn.Dropout(p=0.5)
self.a = nn.ReLU()
self.fc2 = nn.Linear(in_features=50, out_features=1)
def forward(self, x):
pred = self.c1(x)
pred = F.relu(pred)
pred = self.c2(pred)
pred = F.dropout(pred)
pred = F.relu(pred)
pred = self.mp(pred)
pred = pred.view(-1, 1055)
pred = self.fc1(pred)
pred = self.dp(pred)
pred = self.a(pred)
pred = self.fc2(pred)
pred = F.sigmoid(pred)
return pred
I am getting the following error
RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 360 and 375 in dimension 2 at c:\programdata\miniconda3\conda-bld\pytorch_1524549877902\work\aten\src\th\generic/THTensorMath.c:3586
I am sorry as I could not convert the text into code format. I am fairly new to this forums platform as well as pytorch
hope this helps developers in the future.
Thanks in advance