Hi ptrblck, Please refer to the below.
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.cnn = nn.Sequential(
# (51,51)
torch.nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=True),
torch.nn.BatchNorm2d(64),
torch.nn.ReLU(),
torch.nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=True),
torch.nn.BatchNorm2d(64),
torch.nn.ReLU(),
torch.nn.MaxPool2d(kernel_size=2, stride=2), # (25,25)
torch.nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True),
torch.nn.BatchNorm2d(128),
torch.nn.ReLU(),
torch.nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True),
torch.nn.BatchNorm2d(128),
torch.nn.ReLU(),
torch.nn.MaxPool2d(kernel_size=2, stride=2), # (12,12)
torch.nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=True),
torch.nn.BatchNorm2d(256),
torch.nn.ReLU(),
torch.nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),
torch.nn.BatchNorm2d(256),
torch.nn.ReLU(),
torch.nn.MaxPool2d(kernel_size=2, stride=2), # (6,6)
torch.nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=True),
torch.nn.BatchNorm2d(512),
torch.nn.ReLU(),
torch.nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),
torch.nn.BatchNorm2d(512),
torch.nn.ReLU(),
torch.nn.MaxPool2d(kernel_size=2, stride=2),
)
# (batch, 128,6,6)
self.linear = nn.Sequential(
torch.nn.Linear(512*3*3, 2048),
torch.nn.ReLU(),
torch.nn.Linear(2048, 1024),
torch.nn.ReLU(),
torch.nn.Linear(1024, 512),
torch.nn.ReLU(),
torch.nn.Linear(512, 256),
torch.nn.ReLU(),
torch.nn.Linear(256,3)
)
def forward(self, x):
out = self.cnn(x)
out = self.linear(out.view(-1, 128*6*6))
return out