Hi! I have a strange runtime problem when using torch v1.3.1 to run this code
import torch
import hiddenlayer
class LeNet5(torch.nn.Module):
def __init__(self):
super(LeNet5, self).__init__()
self.convnet = torch.nn.Sequential(
# Conv Block 1
torch.nn.Conv2d(
in_channels=1,
out_channels=6,
kernel_size=(5, 5),
stride=1,
bias=True
),
torch.nn.ReLU(),
torch.nn.MaxPool2d(
kernel_size=(2, 2),
stride=2
),
# Conv Block 2
torch.nn.Conv2d(
in_channels=6,
out_channels=16,
kernel_size=(5, 5),
stride=1,
bias=True
),
torch.nn.ReLU(),
torch.nn.MaxPool2d(
kernel_size=(2, 2),
stride=2
),
# Conv Block 3
torch.nn.Conv2d(
in_channels=16,
out_channels=120,
kernel_size=(5, 5),
stride=1,
bias=True
),
torch.nn.ReLU(),
torch.nn.MaxPool2d(
kernel_size=(2, 2),
stride=2
)
)
self.fcn = torch.nn.Sequential(
# Fully Connected Layer 1
torch.nn.Linear(
in_features=120,
out_features=84,
bias=True
),
torch.nn.ReLU(),
# Classifier Layer 2
torch.nn.Linear(
in_features=84,
out_features=10,
bias=True
),
torch.nn.Softmax(dim=1)
)
def forward(self, batch):
ret = self.convnet(batch)
ret = ret.view(batch.size(0), -1)
ret = self.fcn(ret)
return ret
hiddenlayer.build_graph(LeNet5(), torch.zeros(1, 1, 32, 32)) # 1 image, greyscaled (1-channeled), 32 over 32
The error itself:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-4-844794a5d10c> in <module>()
2
3
----> 4 hiddenlayer.build_graph(LeNet5(), torch.zeros(1, 1, 32, 32)) # 1 image, shape-1, because greyscaled, 32 over 32
15 frames
/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py in _max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode, return_indices)
486 stride = torch.jit.annotate(List[int], [])
487 return torch.max_pool2d(
--> 488 input, kernel_size, stride, padding, dilation, ceil_mode)
489
490 max_pool2d = boolean_dispatch(
RuntimeError: Given input size: (120x1x1). Calculated output size: (120x0x0). Output size is too small
However, when I run the same code with pytorch 1.1.0 - there is no error, the graph is generated with no problem.
Please, advise what can be the issue with v1.3.1 and how to solve it?