CUDA: Out of memory error when using multi-gpu

I made another post here

Here is my encoder

class Encoder(nn.Module):
  def __init__(self, input_channels, args):
    super().__init__()
    self.feature_size = args.fMRI_feature_size
    self.hidden_size = self.feature_size
    self.downsample = not args.no_downsample
    self.input_channels = input_channels
    self.two_d = args.fMRI_twoD
    self.end_with_relu = args.end_with_relu
    self.args = args
    init_ = lambda m: init(m,
                           nn.init.orthogonal_,
                           lambda x: nn.init.constant_(x, 0),
                           nn.init.calculate_gain('relu'))
    self.flatten = Flatten()

    if self.two_d:
        self.final_conv_size = 128 * 24 * 30
        self.final_conv_shape = (128, 24, 30)
        self.main = nn.Sequential(
            init_(nn.Conv2d(self.input_channels, 32, (9,10), stride=1)),
            nn.ReLU(),
            init_(nn.Conv2d(32, 64, (9,10), stride=1)),
            nn.ReLU(),
            init_(nn.Conv2d(64, 128, (8,9), stride=1)),
            nn.ReLU(),
            init_(nn.Conv2d(128, 128, (7,8), stride=1)),
            nn.ReLU(),
            Flatten(),
            init_(nn.Linear(self.final_conv_size, self.feature_size))
            #nn.ReLU()
        )
    else:
        self.final_conv_size = 10 * 24 * 30 * 12
        self.final_conv_shape = (10, 24, 30, 12)
        self.main = nn.Sequential(
            init_(nn.Conv3d(self.input_channels, 3, (9, 10, 4), stride=(1, 1, 1))),
            nn.ReLU(),
            init_(nn.Conv3d(3, 5, (9, 10, 3), stride=(1, 1, 1))),
            nn.ReLU(),
            init_(nn.Conv3d(5, 8, (8, 9, 3), stride=(1, 1, 1))),
            nn.ReLU(),
            init_(nn.Conv3d(8, 10, (7, 8, 2), stride=(1, 1, 1))),
            nn.ReLU(),
            Flatten(),
            init_(nn.Linear(self.final_conv_size, self.feature_size)),

            #nn.ReLU()
        )
    self.train()

def forward(self, inputs, fmaps=False):
    f5 = self.main[:6](inputs)
    f7 = self.main[6:8](f5)
    out = self.main[8:](f7)
    if self.end_with_relu:
        assert self.args.method != "vae", "can't end with relu and use vae!"
        out = F.relu(out)
    if fmaps:
        return {
            'f5': f5.permute(0, 2, 3, 1),
            'f7': f7.permute(0, 2, 3, 1),
            'out': out
        }
    return out