Can forward pass be broken into two (mutiple) steps

Can the full forward pass f_x1 in PyTorch be done in 2 steps:

  • Get kth hidden state representations h_x1
  • Do forward pass on h_x1 and get 12-kth layer hidden state representations f_x2

I expected f_x1 will be nearly equal to f_x2.

def forward(self, x_input_ids=None, x_seg_ids=None, x_atten_masks=None, inputs_embeds=None, k=0, get_hidden=False):
    
    if inputs_embeds is not None:
        # get hidden representations of (12-k)th layer
        outputs = self.bert(inputs_embeds=inputs_embeds, output_hidden_states=True)
        query   = outputs[2][12-k][:,0] # shape should be (batch_size, 768)
    
    else:
        # full forward pass
        outputs = self.bert(input_ids=x_input_ids, attention_mask=x_atten_masks, token_type_ids=x_seg_ids, 
                            output_hidden_states=True) # tuple of len 3
        query  = outputs[0][:,0]
        hidden = outputs[2] # returns a tuple of len 13. hidden[0] is embedding layer
    
    query   = self.dropout(query)
    linear  = self.relu(self.linear(query))
    out     = self.out(linear)
    if get_hidden:
        return out, hidden[k] # return kth hidden layer
    else:
        return out


criterion = nn.MSELoss(reduction='mean')
k = 3

# training
for i, batch in enumerate(train_loader):
    optimizer.zero_grad()
    sup_batch = [t.to(device) for t in batch]
    f_x1, h_x1 = model(*sup_batch[:3], get_hidden=True, k=k)
    f_x2 = model(inputs_embeds=h_x1, k=k)
    # Is f_x1 almost equal to f_x2 ??

    loss = criterion(f_x1, f_x2)
    loss.backward()
    nn.utils.clip_grad_norm_(model.parameters(), 1)
    optimizer.step()