Hi, I’m working on a CRNN project forked from this repo. I felt like it was a bit messy so I decided to clean it. Here is the original code:
class CRNN0(nn.Module):
def __init__(self, img_channel, img_height, img_width, num_class,
map_to_seq_hidden=64, rnn_hidden=256, leaky_relu=False):
super(CRNN0, self).__init__()
self.cnn, (output_channel, output_height, output_width) = \
self._cnn_backbone(img_channel, img_height, img_width, leaky_relu)
self.map_to_seq = nn.Linear(output_channel * output_height, map_to_seq_hidden)
self.rnn1 = nn.LSTM(map_to_seq_hidden, rnn_hidden, bidirectional=True)
self.rnn2 = nn.LSTM(2 * rnn_hidden, rnn_hidden, bidirectional=True)
self.dense = nn.Linear(2 * rnn_hidden, num_class)
def _cnn_backbone(self, img_channel, img_height, img_width, leaky_relu):
assert img_height % 16 == 0
assert img_width % 4 == 0
channels = [img_channel, 64, 128, 256, 256, 512, 512, 512]
kernel_sizes = [3, 3, 3, 3, 3, 3, 2]
strides = [1, 1, 1, 1, 1, 1, 1]
paddings = [1, 1, 1, 1, 1, 1, 0]
cnn = nn.Sequential()
def conv_relu(i, batch_norm=False):
# shape of input: (batch, input_channel, height, width)
input_channel = channels[i]
output_channel = channels[i + 1]
cnn.add_module(
f'conv{i}',
nn.Conv2d(input_channel, output_channel, kernel_sizes[i], strides[i], paddings[i])
)
if batch_norm:
cnn.add_module(f'batchnorm{i}', nn.BatchNorm2d(output_channel))
relu = nn.LeakyReLU(0.2, inplace=True) if leaky_relu else nn.ReLU(inplace=True)
cnn.add_module(f'relu{i}', relu)
conv_relu(0)
cnn.add_module('pooling0', nn.MaxPool2d(kernel_size=2, stride=2))
conv_relu(1)
cnn.add_module('pooling1', nn.MaxPool2d(kernel_size=2, stride=2))
conv_relu(2)
conv_relu(3)
cnn.add_module(
'pooling2',
nn.MaxPool2d(kernel_size=(2, 1))
)
conv_relu(4, batch_norm=True)
conv_relu(5, batch_norm=True)
cnn.add_module(
'pooling3',
nn.MaxPool2d(kernel_size=(2, 1))
)
conv_relu(6)
output_channel, output_height, output_width = channels[-1], img_height // 16 - 1, img_width // 4 - 1
return cnn, (output_channel, output_height, output_width)
def forward(self, images):
conv = self.cnn(images)
batch, channel, height, width = conv.size()
conv = conv.view(batch, channel * height, width)
conv = conv.permute(2, 0, 1)
seq = self.map_to_seq(conv)
recurrent, _ = self.rnn1(seq)
recurrent, _ = self.rnn2(recurrent)
output = self.dense(recurrent)
return output
And here’s my code:
class ConvBlock(nn.Module):
def __init__(self, input_channel, output_channel, kernel_sizes, strides, paddings, batch_norm: bool = False):
super(ConvBlock, self).__init__()
self.do_batch_norm = batch_norm
self.conv = nn.Conv2d(input_channel, output_channel, kernel_sizes, strides, paddings)
self.bn = nn.BatchNorm2d(output_channel)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
if self.do_batch_norm:
x = self.bn(x)
x = self.relu(x)
return x
class CRNN1(nn.Module):
def __init__(self, img_channel, img_height, img_width, num_class, map_to_seq_hidden=64, rnn_hidden=256):
super(CRNN1, self).__init__()
# CNN block
self.cnn = nn.Sequential(
ConvBlock(img_channel, 64, 3, 1, 1),
nn.MaxPool2d(kernel_size=2, stride=2),
ConvBlock(64, 128, 3, 1, 1),
nn.MaxPool2d(kernel_size=2, stride=2),
ConvBlock(128, 256, 3, 1, 1),
ConvBlock(256, 256, 3, 1, 1),
nn.MaxPool2d(kernel_size=(2, 1)),
ConvBlock(256, 512, 3, 1, 1, batch_norm=True),
ConvBlock(512, 512, 3, 1, 1, batch_norm=True),
nn.MaxPool2d(kernel_size=(2, 1)),
ConvBlock(512, 512, 2, 1, 0)
)
# map CNN to sequence
self.map2seq = nn.Linear(512 * (img_height // 16 - 1), map_to_seq_hidden)
# RNN
self.rnn1 = nn.LSTM(map_to_seq_hidden, rnn_hidden, bidirectional=True)
self.rnn2 = nn.LSTM(2 * rnn_hidden, rnn_hidden, bidirectional=True)
# fully connected
self.dense = nn.Linear(2 * rnn_hidden, num_class)
def forward(self, x):
# CNN block
x = self.cnn(x)
# reformat array
batch, channel, height, width = x.size()
x = x.view(batch, channel * height, width)
x = x.permute(2, 0, 1)
x = self.map2seq(x)
x, _ = self.rnn1(x)
x, _ = self.rnn2(x)
x = self.dense(x)
return x
And here’s the code to test them out:
if __name__ == '__main__':
torch.manual_seed(1234)
crnn0 = CRNN0(1, 32, 100, 37)
crnn1 = CRNN1(1, 32, 100,37)
x = torch.ones(1, 1, 32, 100)
# printing results gives different outputs
print(crnn0(x))
print(crnn1(x))
I’ve done a lot of testing and debugging on this but still can’t figure it out. What am I missing here?