Hello all, I have a 5D tensor such as BxCxHxWxD
. I want to use nn.Linear
function to classify from C to C//2 channel. How should I do it?
This is my code
class 5D_Linear(nn.Module):
def __init__(self, channel):
super(5D_Linear, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool3d(1)
self.linear_5d = nn.Sequential(nn.Linear(channel, channel// 2),
nn.ReLU(inplace=True),
nn.Linear(channel // 2, channel))
def forward(self, x):
x = self.avg_pool(x)
B, C, D, H, W = x.size()
x_4d = x.view(B, C, D, -1)
x_4d = self.linear_5d(x_4d)
x_5d = x_4d.view(B, C, D, H, W)
print(x_5d.size())
return x_5d