I don’t get what you mean by executable code snippet so i’m posting the whole thing here…
import torch
import torch.nn as nn
class Block(nn.Module):
def __init__(self, channels, seq):
super(Block, self).__init__()
self.seq = seq
self.channels = channels
self.conv = nn.Conv2d(self.channels, self.channels, 3, padding = 1, stride = 1, groups = self.seq)
self.bn = nn.BatchNorm2d(self.channels)
self.relu = nn.ReLU()
def forward(self, tensor):
identity = tensor
tensor = self.conv(tensor)
tensor = self.bn(tensor)
tensor = self.relu(tensor)
tensor = self.conv(tensor)
tensor = self.bn(tensor)
tensor += identity
tensor = self.relu(tensor)
return
class Block_Temporal(nn.Module):
def __init__(self, channels, seq):
super(Block_Temporal, self).__init__()
self.sequence = seq
self.channels = channels
self.conv_std = nn.Conv2d(self.channels*2, self.channels,
kernel_size = 3, padding = 1, stride = 1)
self.conv_update = nn.Conv2d(self.channels*2, self.channels,
kernel_size = 2, stride = 1)
self.pad = nn.ZeroPad2d((0, 1, 0, 1))
self.sigmoid = nn.Sigmoid()
self.tanh = nn.Tanh()
self.bn = nn.BatchNorm2d(self.channels)
def forward(self, tensor):
tensor_seq = torch.split(tensor, self.channels, 1)
hidden_tensor = torch.zeros(tensor_seq[0].size())
for i in range(self.sequence):
x = torch.cat([tensor_seq[i], hidden_tensor], 1)
reset = self.sigmoid(self.conv_std(x))
x = self.pad(x)
update = self.sigmoid(self.conv_update(x))
cnd_memory = update * self.bn(
self.conv_std(torch.cat([tensor_seq[i], (reset * hidden_tensor)], 1)))
hidden_tensor = self.tanh(cnd_memory) + (hidden_tensor * (1 - update))
return hidden_tensor
class ResNet(nn.Module):
def __init__(self, Block, Block_Temporal, layers, img_channels, seq):
super(ResNet, self).__init__()
self.sequence = seq
self.in_channels = 64*self.sequence
self.temporal_channels = 64
self.conv_init = nn.Conv2d(img_channels*self.sequence, 64*self.sequence, kernel_size = 4,
stride = 2, padding = 1, groups = self.sequence)
self.bn = nn.BatchNorm2d(64*self.sequence)
self.relu = nn.ReLU()
self.maxpool = nn.MaxPool2d(3, stride = 2, padding = 1)
self.layer1 = self._make_layer(Block, layers[0], 1)
self.layer2 = self._make_layer(Block, layers[1], 2)
self.layer3 = self._make_layer(Block, layers[2], 4)
self.layer4 = self._make_layer(Block, layers[3], 8)
self.temporal1 = self._make_temporal_layer(Block_Temporal, 2)
self.temporal2 = self._make_temporal_layer(Block_Temporal, 4)
self.temporal3 = self._make_temporal_layer(Block_Temporal, 8)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.sigmoid = nn.Sigmoid()
self.gru = nn.GRU(input_size = 1024, hidden_size = 512, batch_first = True, bias = False)
def _make_layer(self, Block, num_blocks, step):
layers = []
channels = self.in_channels * step
for i in range(num_blocks):
layers.append(Block(channels, self.sequence))
return nn.Sequential(*layers, nn.Conv2d(channels, channels*2, 3, stride = 2, padding = 1, groups = self.sequence),
nn.BatchNorm2d(channels*2))
def _make_temporal_layer(self, BLock_Temporal, step):
return nn.Sequential(BLock_Temporal(self.temporal_channels*step, self.sequence))
def forward(self, x):
temporal_list = []
x = self.conv_init(x)
x = self.bn(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
temporal = self.temporal1(x)
temporal_list.append(temporal)
x = self.layer2(x)
temporal = self.temporal2(x)
temporal_list.append(temporal)
x = self.layer3(x)
temporal = self.temporal3(x)
temporal_list.append(temporal)
x = self.layer4(x)
x = self.avgpool(x)
x = x.squeeze()
x = torch.split(x, int(x.size(1)/self.sequence), 1)
x = torch.cat([x[i].unsqueeze(1) for i in range(len(x))], 1)
x = self.gru(x)
x = self.Sigmoid(x)
temporal_list.append(x)
return temporal_list
def ResNet34_Temporal(img_channels, seq):
return ResNet(Block, Block_Temporal, [3, 4, 6, 3], img_channels = img_channels, seq = seq)
def test():
net = ResNet34_Temporal(3, seq = 5)
x = torch.randn(2, 15, 128, 128)
y = net(x)
print(len(y))
test()
Here’s an image of the error itself