Hello, I’m trying to understand and compute myself checkpointing for a BERT model that I have( because I cannot run the code on GPU and i already tried lowering batch size as much as possible but that didn’t helped). Here is the class model :
class BertClassifier(nn.Module):
def __init__(self, freeze_bert=False):
super(BertClassifier, self).__init__()
self.bert = BertModel.from_pretrained('bert-base-multilingual-uncased')
self.lstm = nn.LSTM(768, 256, batch_first=True, bidirectional=True)
self.linear = nn.Linear(256*2 , 2)
if freeze_bert:
for param in self.bert.parameters():
param.requires_grad = False
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids,attention_mask=attention_mask)
sequence_output = outputs[0]
sequence_output, _ = self.lstm(sequence_output)
linear_output = self.linear(sequence_output[:, -1])
return linear_output
And below is me trying to use the checkpointing from
(pytorch_memonger/Checkpointing_for_PyTorch_models.ipynb at master · prigoyal/pytorch_memonger · GitHub)
class BertClassifier(nn.Module):
def __init__(self, freeze_bert=False):
super(BertClassifier, self).__init__()
self.bert = BertModel.from_pretrained('bert-base-multilingual-uncased')
self.lstm = nn.LSTM(768, 256, batch_first=True, bidirectional=True)
self.linear = nn.Linear(256*2 , 2)
if freeze_bert:
for param in self.bert.parameters():
param.requires_grad = False
def run_function(self, start, end):
def custom_forward(*inputs):
output, hidden = self.lstm(inputs[0][start:(end + 1)], (inputs[1], inputs[2]))
return output, hidden[0], hidden[1]
return custom_forward
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids,attention_mask=attention_mask)
sequence_output = outputs[0]
# checkpoint self.lstm() computation
output = []
segment_size = len(modules) // segments
for start in range(0, segment_size * (segments - 1), segment_size):
end = start + segment_size - 1
out = checkpoint.checkpoint(self.run_function(start, end), sequence_output, hidden[0], hidden[1])
output.append(out[0])
hidden = (out[1], out[2])
out = checkpoint.checkpoint(self.run_function(end + 1, len(modules) - 1), sequence_output, hidden[0], hidden[1])
output.append(out[0])
hidden = (out[1], out[2])
output = torch.cat(output, 0)
hidden = (out[1], out[2])
linear_output = self.linear(sequence_output[:, -1])
return linear_output
I have a few questions on the above :
- What are the
segments
andmodules
? I saw the modules declared in the Checkpointing sequential models, but mine isn’t a sequential model, how can I declare the modules in this case? Alsosegments
were declared as =2, what means that 2 ? - Will the above checkpointing technique work in my case? If not, how can i properly compute this?
- The checkpointing will have to be done just in the declaration of the class like above or there will also have to be modification to the overall training code?
- The saving and loading of the state_dict will be the same? (eg. torch.save(bert_classifier.state_dict(), ‘finetuned_model.pt’) )
- Is it possible to train the model using an english corpus and then test it on another language? Or is it possible to make the model language independent?
Thanks in advance!