class Network(nn.Module):
def init(self):
super().init()
self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5)
self.fc1 = nn.Linear(in_features=12*4*4, out_features=120)
self.fc2 = nn.Linear(in_features=120, out_features=60)
self.out = nn.Linear(in_features=60, out_features=10)
def forward(self,t):
t = self.conv1(t)
t = F.relu(self.conv1(t))
t = F.max_pool2d(t, kernel_size=2, stride=2)
t = self.conv2(t)
t = F.relu(self.conv2(t))
t = F.max_pool2d(t, kernel_size=2, stride=2)
t = F.relu(self.fc1(t.reshape(-1, 12*4*4)))
t = F.relu(self.fc2(t))
t = self.out(t)
return t
network = Network()
sample = next(iter(train_set))
image, label = sample
image.shape
image.unsqueeze(0).shape
pred = network(image.unsqueeze(0))
donβt know why I am getting this error
RuntimeError Traceback (most recent call last)
in
----> 1 pred = network(image.unsqueeze(0))
~\Anaconda3\lib\site-packages\torch\nn\modules\module.py in call(self, *input, **kwargs)
530 result = self._slow_forward(*input, **kwargs)
531 else:
β> 532 result = self.forward(*input, **kwargs)
533 for hook in self._forward_hooks.values():
534 hook_result = hook(self, input, result)
in forward(self, t)
13 def forward(self,t):
14 t = self.conv1(t)
β> 15 t = F.relu(self.conv1(t))
16 t = F.max_pool2d(t, kernel_size=2, stride=2)
17
~\Anaconda3\lib\site-packages\torch\nn\modules\module.py in call(self, *input, **kwargs)
530 result = self._slow_forward(*input, **kwargs)
531 else:
β> 532 result = self.forward(*input, **kwargs)
533 for hook in self._forward_hooks.values():
534 hook_result = hook(self, input, result)
~\Anaconda3\lib\site-packages\torch\nn\modules\conv.py in forward(self, input)
343
344 def forward(self, input):
β> 345 return self.conv2d_forward(input, self.weight)
346
347 class Conv3d(_ConvNd):
~\Anaconda3\lib\site-packages\torch\nn\modules\conv.py in conv2d_forward(self, input, weight)
340 _pair(0), self.dilation, self.groups)
341 return F.conv2d(input, weight, self.bias, self.stride,
β> 342 self.padding, self.dilation, self.groups)
343
344 def forward(self, input):
RuntimeError: Given groups=1, weight of size 6 1 5 5, expected input[1, 6, 24, 24] to have 1 channels, but got 6 channels instead