Can the full forward pass f_x1
in PyTorch be done in 2 steps:
- Get
k
th hidden state representationsh_x1
- Do forward pass on
h_x1
and get12-k
th layer hidden state representationsf_x2
I expected f_x1
will be nearly equal to f_x2
.
def forward(self, x_input_ids=None, x_seg_ids=None, x_atten_masks=None, inputs_embeds=None, k=0, get_hidden=False):
if inputs_embeds is not None:
# get hidden representations of (12-k)th layer
outputs = self.bert(inputs_embeds=inputs_embeds, output_hidden_states=True)
query = outputs[2][12-k][:,0] # shape should be (batch_size, 768)
else:
# full forward pass
outputs = self.bert(input_ids=x_input_ids, attention_mask=x_atten_masks, token_type_ids=x_seg_ids,
output_hidden_states=True) # tuple of len 3
query = outputs[0][:,0]
hidden = outputs[2] # returns a tuple of len 13. hidden[0] is embedding layer
query = self.dropout(query)
linear = self.relu(self.linear(query))
out = self.out(linear)
if get_hidden:
return out, hidden[k] # return kth hidden layer
else:
return out
criterion = nn.MSELoss(reduction='mean')
k = 3
# training
for i, batch in enumerate(train_loader):
optimizer.zero_grad()
sup_batch = [t.to(device) for t in batch]
f_x1, h_x1 = model(*sup_batch[:3], get_hidden=True, k=k)
f_x2 = model(inputs_embeds=h_x1, k=k)
# Is f_x1 almost equal to f_x2 ??
loss = criterion(f_x1, f_x2)
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), 1)
optimizer.step()