Sure ! The code is quite long, so I won’t post the whole thing here :
- The main model class
class STGCN(nn.Module):
"""Spatio temporal graph convolutional network
"""
def __init__(self, graph, num_kpts=18, in_features=3, num_classes=20, num_people=2, use_attention=True, use_tem=True):
super(STGCN, self).__init__()
if isinstance(graph, tuple):
self.adj = torch.cat([get_normalized_adj(subgraph).unsqueeze(0) for subgraph in graph]).cuda()
else:
self.adj = get_normalized_adj(graph).cuda()
init_filters = 64
self.b1 = STGC_Block(self.adj, in_features, init_filters, num_kpts=num_kpts, use_attention=use_attention, use_tem=use_tem)
self.b2 = STGC_Block(self.adj, init_filters, init_filters, num_kpts=num_kpts, use_attention=use_attention, use_tem=use_tem)
self.b3 = STGC_Block(self.adj, init_filters, init_filters, num_kpts=num_kpts, use_attention=use_attention, use_tem=use_tem)
self.b4 = STGC_Block(self.adj, init_filters, init_filters * 2, num_kpts=num_kpts, use_attention=use_attention, use_tem=use_tem, stride=2)
self.b5 = STGC_Block(self.adj, init_filters * 2, init_filters * 2, num_kpts=num_kpts, use_attention=use_attention, use_tem=use_tem)
self.b6 = STGC_Block(self.adj, init_filters * 2, init_filters * 2, num_kpts=num_kpts, use_attention=use_attention, use_tem=use_tem)
self.b7 = STGC_Block(self.adj, init_filters * 2, init_filters * 4, num_kpts=num_kpts, use_attention=use_attention, use_tem=use_tem, stride=2)
self.b8 = STGC_Block(self.adj, init_filters * 4, init_filters * 4, num_kpts=num_kpts, use_attention=use_attention, use_tem=use_tem)
self.b9 = STGC_Block(self.adj, init_filters * 4, init_filters * 4, num_kpts=num_kpts, use_attention=use_attention, use_tem=use_tem)
self.bn = nn.BatchNorm1d(num_people * in_features * num_kpts)
self.classify = nn.Linear(256, num_classes)
# self.softmax = nn.Softmax(-1)
self.drop = nn.Dropout(0.3)
def forward(self, x):
batch_size, in_feats, seq_len, num_kpts, num_people = x.size()
x = x.permute(0, 4, 3, 1, 2).contiguous().view(batch_size, num_people * num_kpts * in_feats, seq_len)
x = self.bn(x)
x = x.view(batch_size, num_people, num_kpts, in_feats, seq_len)
x = x.permute(0, 1, 3, 4, 2).contiguous().view(batch_size * num_people, in_feats, seq_len, num_kpts)
x = self.b1(x)
x = self.b2(x)
x = self.b3(x)
x = self.b4(x)
x = self.b5(x)
x = self.b6(x)
x = self.b7(x)
x = self.b8(x)
x = self.b9(x)
x = x.view(batch_size, num_people, 256, -1) # shape: batch x peops x 256 x (seq * kpts)
x = x.mean(-1).mean(1)
x = self.drop(x)
return self.classify(x)
- The STGC blocks :
class STGC_Block(nn.Module):
"""ST-GCN like model basic block:
- A graph convolution
- Optional : an attention layer
- A temporal convolution
"""
def __init__(self, adj, in_features, out_features, num_kpts, stride=1, temporal_kernel=9, use_attention=True, use_tem=True):
super(STGC_Block, self).__init__()
self.adj = adj
self.graph_conv = AdaptiveGraphConv(adj, in_features, in_features, out_features, num_kpts)
self.temp_conv = TemporalConv(out_features, temporal_kernel=temporal_kernel, stride=stride)
# RESIDUAL CONNECTION
if in_features != out_features and stride == 1:
self.residual_connection = nn.Sequential(
nn.Conv2d(in_features, out_features, kernel_size=1, stride=stride),
nn.BatchNorm2d(out_features)
)
elif in_features == out_features and stride == 1:
self.residual_connection = lambda input_features: input_features
else:
self.residual_connection = TemporalConv(in_features, out_features, temporal_kernel, stride)
self.relu = nn.ReLU(inplace=True)
def forward(self, features):
x = self.graph_conv(features)
x = self.temp_conv(x)
x += self.residual_connection(features)
x = self.relu(x)
return x
The adaptive graph-conv is basically made of two torch.matmul and a nn.Conv2D(kernel=1), whereas the temporal-conv is only made of a conv2d layer (kernel = 9x1).
The input I used to measure the time performances is a torch.rand tensor with dimensions (batch_size, 3, 150, 18, 1)