RNN generator weights not being updated

Hello there,
I’m currently working on a GAN that has 2 paired RNNs as generator (rnn and output in the following code).
My problem is that the weights of the generator are not being updated.
The following code is used to perform a sigle forward step of the 2 RNNs to generate a single observation.
The main idea is that I have the tensor y_pred_long that is filled step after step with the predictions of the 2 RNNs.
Then this structure (y_pred_long) is passed to a decoding function that retrieves the desidered input for the discriminator.
I not beeing able to understand whether/where the computational graph breaks down, any suggestion is very appreciated!
If anything from the following code is not clear and additional information is needed please let me know! (first time posting here)


def gen_rnn_epoch(epoch, rnn, output, device):

    rnn.hidden = rnn.init_hidden(batch_size=1)  # this is the source of noise for the GAN generator

    # empty tensor that will be filled at each step with predictions from both RNNs
    y_pred_long = torch.zeros(1, 9, 8 * 5 + 4).to(device)
    # the first input of RNN_1 is a token
    x_step = torch.ones(1, 1, 8 * 5 + 4).to(device)

    # sequential generation of RNN_1
    for i in range(9):

        h, node_prediction = rnn(x_step)
        node_prediction = F.gumbel_softmax(node_prediction, dim=2, hard=True)

        # init RNN_2: context init
        hidden_null = torch.zeros(rnn.num_layers - 1, h.size(0), h.size(2)).to(device)
        output.hidden = torch.cat((h.permute(1, 0, 2), hidden_null), dim=0)

        # reset token and fill with newest prediction
        x_step = torch.zeros(1, 1, 8 * 5 + 4).to(device)
        x_step[:, :, :4] = node_prediction

	    # the first input of RNN_2 is a token
        output_x_step = torch.ones(1, 1, 5).to(device)
	
		# sequential generation of RNN_2
        edge_rnn_step = 0
        for j in range(min(8, i + 1)):
            output_y_pred_step = output(output_x_step)
            output_x_step = F.gumbel_softmax(output_y_pred_step, dim=2, hard=True)
            x_step[:, :, 4 * j + 4 + j: 4 * (j + 1) + 4 + (j + 1)] = output_x_step
            output.hidden = output.hidden.to(device)
            edge_rnn_step = j

        y_pred_long[:, i:i + 1, :] = x_step

        # stop criteria (EOS met)
        node_to_break, edges_to_break = torch.split(x_step, [4, 40], dim=2)
        edges_to_break_temp = torch.reshape(edges_to_break, (edges_to_break.shape[0], 8, 5))
        edges_to_break_uptillnow = edges_to_break_temp[0, :edge_rnn_step + 1, :]
        break_ = True
        for row in edges_to_break_uptillnow:
            if torch.argmax(row).item() != 0:
                break_ = False
            if edge_rnn_step == 0:
                break_ = False

        if break_:
            break

        rnn.hidden = rnn.hidden.to(device)

    x, edg_idx, edg_attr = decode_adj_generation(y_pred_long[0], device)

    data = Data(x=x, edge_index=edg_idx, edge_attr=edg_attr)

    return data

def decode_adj_generation(encoded_adj, device):

    max_prev_TEST = encoded_adj.shape[0] - 1  
    node_x, edges = torch.split(encoded_adj, [4, 40], dim=1)
    edges = torch.reshape(edges, (encoded_adj.shape[0], max_prev_TEST, 5))

    # drop all the rows for which we didn't have sampled a node
    idxs_to_keep = []
    for row in range(node_x.shape[0]):
        if (torch.sum(node_x[row]).item()) != 0:
            idxs_to_keep.append(row)
        else:
            break

    idxs_to_keep = torch.LongTensor(idxs_to_keep).to(device)
    node_x = node_x.index_select(0, idxs_to_keep)
    edges = edges.index_select(0, idxs_to_keep)

    # un-FLIP the edges
    for i in range(edges.shape[0]):
        edges[i, :i + 1, :] = torch.flip(edges[i, :i + 1, :], [0]) 

    # below for edge index creation, this step CAN be NON-differentiable
    first_dim_idx = []
    second_dim_idx = []

    for i in range(edges.shape[0]):
        for j in range(i):
            if torch.sum(edges[i - 1, j, :]) != 0:
                if np.nonzero(edges[i - 1, j, :])[0].item() != 0:
                    first_dim_idx.append(i)
                    second_dim_idx.append(j)

    first_dim_idx = torch.tensor(torch.LongTensor(first_dim_idx)).to(device)
    second_dim_idx = torch.tensor(torch.LongTensor(second_dim_idx)).to(device)
    row1 = torch.unsqueeze(torch.cat((first_dim_idx, second_dim_idx), 0), 0)
    row2 = torch.unsqueeze(torch.cat((second_dim_idx, first_dim_idx), 0), 0)
    edge_idx = torch.cat((row1, row2), 0)
    half_edge_idx = torch.cat((torch.unsqueeze(second_dim_idx, 0), torch.unsqueeze(first_dim_idx - 1, 0)), 0)

    # construction of edge_attr matrix, this has to be differentiable
    edge_attr = None
    step = 0
    for i in range(first_dim_idx.shape[0]):
        if step == 0:
            edge_attr = torch.unsqueeze((edges[half_edge_idx[1][i].item(), half_edge_idx[0][i].item(), 1:]), 0)
            step += 1
        else:
            edge_attr = torch.cat(
                (edge_attr, torch.unsqueeze(edges[half_edge_idx[1][i].item(), half_edge_idx[0][i].item(), 1:], 0)), dim=0)

    if edge_attr is not None:
        edge_attr = torch.cat((edge_attr, edge_attr), 0)
    else:
        edge_attr = torch.empty(0, 4)

    return node_x.to(torch.float32).to(device), edge_idx.to(torch.long).to(device), edge_attr.to(torch.float32).to(
            device)