The following CNN net works fine:
class myLeNet5(torch.nn.Module): # for CIFAR10 image classification
def __init__(self):
super(myLeNet5, self).__init__()
self.conv_unit=torch.nn.Sequential(
torch.nn.Conv2d(channel_size, 6, kernel_size=5, stride=1, padding=0),
torch.nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
torch.nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0),
torch.nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
)
self.fc_unit=torch.nn.Sequential(
torch.nn.Linear(16*5*5, 120),
torch.nn.ReLU(),
torch.nn.Linear(120, 84),
torch.nn.ReLU(),
torch.nn.Linear(84, 10)
)
def forward(self, x):
bsz=x.size(0)
x =self.conv_unit(x)
x =x.view(bsz, 16*5*5)
x =self.fc_unit(x)
return x
can I write it in this way:?
class myLeNet5(torch.nn.Module): # for CIFAR10 image classification
def __init__(self):
super(myLeNet5, self).__init__()
self.model=torch.nn.Sequential(
torch.nn.Conv2d(channel_size, 6, kernel_size=5, stride=1, padding=0),
torch.nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
torch.nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0),
torch.nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
# how to x.view() here?
torch.nn.Linear(16*5*5, 120),
torch.nn.ReLU(),
torch.nn.Linear(120, 84),
torch.nn.ReLU(),
torch.nn.Linear(84, 10)
)
def forward(self, x):
return self.model(x)
How to x.view() inside nn.Sequential()? Thanks.