why do I get this error, I pass in cifar10 images, [3, 32, 32] in size to this model
def conv_block(in_channels, out_channels, k):
# set_trace()
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, k, padding=0),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
nn.MaxPool2d(2)
)
from IPython.core.debugger import set_trace
class Top(nn.Module):
def __init__(self):
super().__init__()
self.encoder = conv_block(3, 16, 3)
self.lin = nn.Linear(20, 10)
self.childone = Second()
self.childtwo = Second()
def forward(self, x):
# set_trace()
a = self.childone(self.encoder(x))
b = self.childtwo(self.encoder(x))
# print('top', a.shape, b.shape)
out = torch.cat((a, b), dim=-1)
return self.lin(out)
class Second(nn.Module):
def __init__(self):
super().__init__()
self.encoder = conv_block(16, 32, 3)
self.lin = nn.Linear(20, 10)
self.childone = Middle()
self.childtwo = Middle()
def forward(self, x):
a = self.childone(self.encoder(x))
b = self.childtwo(self.encoder(x))
# print('middle', a.shape, b.shape)
out = torch.cat((a, b), dim=-1)
return self.lin(out)
class Middle(nn.Module):
def __init__(self):
super().__init__()
self.encoder = conv_block(32, 64, 1)
self.lin = nn.Linear(20, 10)
self.childone = Bottom()
self.childtwo = Bottom()
def forward(self, x):
a = self.childone(self.encoder(x))
b = self.childtwo(self.encoder(x))
# print('middle', a.shape, b.shape)
out = torch.cat((a, b), dim=-1)
return self.lin(out)
class Bottom(nn.Module):
def __init__(self):
super().__init__()
self.encoder = conv_block(64, 128, 1)
self.lin_one = nn.Linear(128, 10)
def forward(self, x):
# print('bottom', x.shape)
out = self.encoder(x)
return (self.lin_one(out.view(out.size(0), -1)))
model = Top()
# inp = [None, train_dataset[0][0]]
model.to('cuda')