I made another post here
Here is my encoder
class Encoder(nn.Module):
def __init__(self, input_channels, args):
super().__init__()
self.feature_size = args.fMRI_feature_size
self.hidden_size = self.feature_size
self.downsample = not args.no_downsample
self.input_channels = input_channels
self.two_d = args.fMRI_twoD
self.end_with_relu = args.end_with_relu
self.args = args
init_ = lambda m: init(m,
nn.init.orthogonal_,
lambda x: nn.init.constant_(x, 0),
nn.init.calculate_gain('relu'))
self.flatten = Flatten()
if self.two_d:
self.final_conv_size = 128 * 24 * 30
self.final_conv_shape = (128, 24, 30)
self.main = nn.Sequential(
init_(nn.Conv2d(self.input_channels, 32, (9,10), stride=1)),
nn.ReLU(),
init_(nn.Conv2d(32, 64, (9,10), stride=1)),
nn.ReLU(),
init_(nn.Conv2d(64, 128, (8,9), stride=1)),
nn.ReLU(),
init_(nn.Conv2d(128, 128, (7,8), stride=1)),
nn.ReLU(),
Flatten(),
init_(nn.Linear(self.final_conv_size, self.feature_size))
#nn.ReLU()
)
else:
self.final_conv_size = 10 * 24 * 30 * 12
self.final_conv_shape = (10, 24, 30, 12)
self.main = nn.Sequential(
init_(nn.Conv3d(self.input_channels, 3, (9, 10, 4), stride=(1, 1, 1))),
nn.ReLU(),
init_(nn.Conv3d(3, 5, (9, 10, 3), stride=(1, 1, 1))),
nn.ReLU(),
init_(nn.Conv3d(5, 8, (8, 9, 3), stride=(1, 1, 1))),
nn.ReLU(),
init_(nn.Conv3d(8, 10, (7, 8, 2), stride=(1, 1, 1))),
nn.ReLU(),
Flatten(),
init_(nn.Linear(self.final_conv_size, self.feature_size)),
#nn.ReLU()
)
self.train()
def forward(self, inputs, fmaps=False):
f5 = self.main[:6](inputs)
f7 = self.main[6:8](f5)
out = self.main[8:](f7)
if self.end_with_relu:
assert self.args.method != "vae", "can't end with relu and use vae!"
out = F.relu(out)
if fmaps:
return {
'f5': f5.permute(0, 2, 3, 1),
'f7': f7.permute(0, 2, 3, 1),
'out': out
}
return out