Checkpointing a BERT model

Hello, I’m trying to understand and compute myself checkpointing for a BERT model that I have( because I cannot run the code on GPU and i already tried lowering batch size as much as possible but that didn’t helped). Here is the class model :

class BertClassifier(nn.Module):
 def __init__(self, freeze_bert=False):
  super(BertClassifier, self).__init__()
  self.bert = BertModel.from_pretrained('bert-base-multilingual-uncased')
  self.lstm = nn.LSTM(768, 256, batch_first=True, bidirectional=True)
  self.linear = nn.Linear(256*2 , 2)

  if freeze_bert:
   for param in self.bert.parameters():
    param.requires_grad = False

 def forward(self, input_ids, attention_mask):
  outputs = self.bert(input_ids=input_ids,attention_mask=attention_mask)
  sequence_output = outputs[0]
  sequence_output, _ = self.lstm(sequence_output)
  linear_output = self.linear(sequence_output[:, -1])

  return linear_output

And below is me trying to use the checkpointing from
(pytorch_memonger/Checkpointing_for_PyTorch_models.ipynb at master · prigoyal/pytorch_memonger · GitHub)

class BertClassifier(nn.Module):

 def __init__(self, freeze_bert=False):
  super(BertClassifier, self).__init__()

  self.bert = BertModel.from_pretrained('bert-base-multilingual-uncased')
  self.lstm = nn.LSTM(768, 256, batch_first=True, bidirectional=True)
  self.linear = nn.Linear(256*2 , 2)

  if freeze_bert:
   for param in self.bert.parameters():
    param.requires_grad = False

 def run_function(self, start, end):
  def custom_forward(*inputs):
   output, hidden = self.lstm(inputs[0][start:(end + 1)], (inputs[1], inputs[2]))
   return output, hidden[0], hidden[1]

  return custom_forward

 def forward(self, input_ids, attention_mask):
  outputs = self.bert(input_ids=input_ids,attention_mask=attention_mask)
  sequence_output = outputs[0]
  # checkpoint self.lstm() computation
  output = []
  segment_size = len(modules) // segments
  for start in range(0, segment_size * (segments - 1), segment_size):
   end = start + segment_size - 1

   out = checkpoint.checkpoint(self.run_function(start, end), sequence_output, hidden[0], hidden[1])
   output.append(out[0])
   hidden = (out[1], out[2])
  out = checkpoint.checkpoint(self.run_function(end + 1, len(modules) - 1), sequence_output, hidden[0], hidden[1])
  output.append(out[0])
  hidden = (out[1], out[2])

  output = torch.cat(output, 0)
  hidden = (out[1], out[2])
  linear_output = self.linear(sequence_output[:, -1])

  return linear_output

I have a few questions on the above :

  1. What are the segments and modules ? I saw the modules declared in the Checkpointing sequential models, but mine isn’t a sequential model, how can I declare the modules in this case? Also segments were declared as =2, what means that 2 ?
  2. Will the above checkpointing technique work in my case? If not, how can i properly compute this?
  3. The checkpointing will have to be done just in the declaration of the class like above or there will also have to be modification to the overall training code?
  4. The saving and loading of the state_dict will be the same? (eg. torch.save(bert_classifier.state_dict(), ‘finetuned_model.pt’) )
  5. Is it possible to train the model using an english corpus and then test it on another language? Or is it possible to make the model language independent?
    Thanks in advance!