Local Backpropagation in CortexNet-like network?

Hi,
I am currently working on a hierarchical neural network based on the cortexNet architecture (https://arxiv.org/pdf/1706.02735.pdf). This architecture is build around a simple auto-encoder:
Each encoder take as input the current state (main visual input or bottom layer hidden representation) and the previous prediction generate by an decoder that take the current hidden state and the predicted one coming from an upper layer. Each auto-encoder aim to generate the next representation for the lower layer (xt -> xt+1 for visual input or ht_1 -> ht+1_1 for hidden representation).
The inference process consist in a hierarchical (bottom->up) encoding phase and a hierarchical (top->bottom) decoding phase (generation).

here is the code without encoder-decoder definition :

# Encoder layers
E1 = Encoder(in_channels=3, out_channels=128, kernel_size=3, padding=1, stride=2, dilation=1, groups=1, mixtures=1, bias=False).to(device)
E2 = Encoder(in_channels=128, out_channels=64, kernel_size=3, padding=1, stride=2, dilation=1, groups=8, mixtures=1, bias=False).to(device)
E3 = Encoder(in_channels=64, out_channels=32, kernel_size=3, padding=1, stride=2, dilation=1, groups=8, mixtures=1, bias=False).to(device)
E4 = Encoder(in_channels=32, out_channels=16, kernel_size=3, padding=1, stride=2, dilation=1, groups=8, mixtures=1, bias=False).to(device)
# Decoder layers
D1 = Decoder(in_channels=128, out_channels=3, kernel_size=4, padding=1, stride=2, bias=False).to(device)
D2 = Decoder(in_channels=64, out_channels=128, kernel_size=4, padding=1, stride=2, bias=False).to(device)
D3 = Decoder(in_channels=32, out_channels=64, kernel_size=4, padding=1, stride=2, bias=False).to(device)
D4 = Decoder(in_channels=16, out_channels=32, kernel_size=4, padding=1, stride=2, bias=False).to(device)

print("RPU stack construction -> DONE")
print(E1, "\n", D1, "\n", E2, "\n", D2, "\n", E3, "\n", D3, "\n", E4, "\n", D4, "\n")

#------------------------------------------------------------------------------#
#                   Optimizer and learning parameters :                        #
#------------------------------------------------------------------------------#

# RLPU trainable parameters
RPU_1_param = list(E1.parameters()) + list(D1.parameters())
RPU_2_param = list(E2.parameters()) + list(D2.parameters())
RPU_3_param = list(E3.parameters()) + list(D3.parameters())
RPU_4_param = list(E4.parameters()) + list(D4.parameters())

# Create optimizer
RPU_1_optim = optim.SGD(RPU_1_param, lr = 0.04, momentum = 0.90, weight_decay = 0.00001, nesterov = True)
RPU_2_optim = optim.SGD(RPU_2_param, lr = 0.04, momentum = 0.90, weight_decay = 0.00001, nesterov = True)
RPU_3_optim = optim.SGD(RPU_3_param, lr = 0.04, momentum = 0.90, weight_decay = 0.00001, nesterov = True)
RPU_4_optim = optim.SGD(RPU_4_param, lr = 0.04, momentum = 0.90, weight_decay = 0.00001, nesterov = True)

# cost function
reconstruction_loss = nn.MSELoss(reduction = 'sum')


#------------------------------------------------------------------------------#
#                       Learning/inference procedure :                         #
#------------------------------------------------------------------------------#

# batch size (online or offline mode )
batch_size = 1

# Init last prediction
last_next_xt_pred = torch.zeros(batch_size,3,256,256,requires_grad = True)
last_next_ht_1_pred = torch.zeros(batch_size,128,128,128,requires_grad = True)
last_next_ht_2_pred = torch.zeros(batch_size,64,64,64,requires_grad = True)
last_next_ht_3_pred = torch.zeros(batch_size,32,32,32,requires_grad = True)
# init closing input (from upper layer)
upper_input = torch.zeros(batch_size,16,16,16)

# Set network to learning mode
E1.train()
E2.train()
E3.train()
E4.train()
D1.train()
D2.train()
D3.train()
D4.train()
# Learning step
for i in range(0,10):
    # get state and normalize (here is a dummy input for the test)
    xt = torch.randn(batch_size,3,256,256,requires_grad = True)

    RPU_1_optim.zero_grad()
    RPU_2_optim.zero_grad()
    RPU_3_optim.zero_grad()
    RPU_4_optim.zero_grad()
    # Encoder pass (Ht_n_res is hidden representation for the decoder and
    # ht_n is the main representation for the upper encoder layer
    ht_1_res, ht_1 = E1(xt.to(device), last_next_xt_pred.to(device))
    ht_1_res = Variable(ht_1_res, requires_grad = True)
    ht_1 = Variable(ht_1, requires_grad = True)
    ht_2_res, ht_2 = E2(ht_1.to(device), last_next_ht_1_pred.to(device))
    ht_2_res = Variable(ht_2_res, requires_grad = True)
    ht_2 = Variable(ht_2, requires_grad = True)
    ht_3_res, ht_3 = E3(ht_2.to(device), last_next_ht_2_pred.to(device))
    ht_3_res = Variable(ht_3_res, requires_grad = True)
    ht_3 = Variable(ht_3, requires_grad = True)
    ht_4_res, ht_4 = E4(ht_3.to(device), last_next_ht_3_pred.to(device))
    ht_4_res = Variable(ht_4_res, requires_grad = True)
    ht_4 = Variable(ht_4, requires_grad = True)
    # Decoder pass
    next_ht_3_pred = D4(upper_input.to(device), ht_4_res.to(device))
    next_ht_3_pred = Variable(next_ht_3_pred, requires_grad = True)
    next_ht_2_pred = D3(next_ht_3_pred.to(device), ht_3_res.to(device))
    next_ht_2_pred = Variable(next_ht_2_pred, requires_grad = True)
    next_ht_1_pred = D2(next_ht_2_pred.to(device), ht_2_res.to(device))
    next_ht_1_pred = Variable(next_ht_1_pred, requires_grad = True)
    next_xt_pred = D1(next_ht_1_pred.to(device), ht_1_res.to(device))
    next_xt_pred = Variable(next_xt_pred, requires_grad = True)
    # Compute loss
    loss_RPU_1 = reconstruction_loss(last_next_xt_pred, xt)
    loss_RPU_1 = Variable(loss_RPU_1, requires_grad = True)
    loss_RPU_2 = reconstruction_loss(last_next_ht_1_pred, ht_1)
    loss_RPU_2 = Variable(loss_RPU_2, requires_grad = True)
    loss_RPU_3 = reconstruction_loss(last_next_ht_2_pred, ht_2)
    loss_RPU_3 = Variable(loss_RPU_3, requires_grad = True)
    loss_RPU_4 = reconstruction_loss(last_next_ht_3_pred, ht_3)
    loss_RPU_4 = Variable(loss_RPU_4, requires_grad = True)
    # Backpropagate and optimize
    loss_RPU_1.backward()
    RPU_1_optim.step()
    loss_RPU_2.backward()
    RPU_2_optim.step()
    loss_RPU_3.backward()
    RPU_3_optim.step()
    loss_RPU_4.backward()
    RPU_4_optim.step()
    # Memorize last prediction for the next step
    last_next_xt_pred = next_xt_pred
    last_next_xt_pred = Variable(last_next_xt_pred, requires_grad = True)
    last_next_ht_1_pred = next_ht_1_pred
    last_next_ht_1_pred = Variable(last_next_ht_1_pred, requires_grad = True)
    last_next_ht_2_pred = next_ht_2_pred
    last_next_ht_2_pred = Variable(last_next_ht_2_pred, requires_grad = True)
    last_next_ht_3_pred = next_ht_3_pred
    last_next_ht_3_pred = Variable(last_next_ht_3_pred, requires_grad = True)

So, my question is :

I would like to train my network in a local manner where each encoder-decoder couple have an optimizer and a local cost function, In my current implementation I combine the encoder-decoder parameters for each optimizer, however i don’t really know how to do the backward process locally…
Also each layer give me grad=None…
Do have have any Idea, how to do properly the local optimization process ?

Thank in advance !

best regard :slight_smile: