I’m tackling the following problem – there are buckets of long text documents, each bucket contains 1…N documents, each bucket has 1 target label. First, I pretrain encoder – RoBERTa model with a masked language model objective on my corpus of texts (python code). Then I prepare buckets by splitting documents into chunks of length 512. Each document contains 1…M chunks. Usually, there are many of them.
I can’t process a batch of buckets with 1 GPU, so I go one bucket at a time. At first, I retrieve chunk embeddings to get an averaged document embedding. Then I average document embeddings to get a bucket embedding. Then I pass bucket embedding to a classification head to get logits by classes.
for bucket in train_dataset: document_embeddings =  for document in bucket: chunk_embeddings =  for chunk in document: chunk_embedding = encoder(**chunk) # (1) chunk_embeddings.append(chunk_embedding) document_embedding = torch.cat(chunk_embeddings, dim=0).mean(dim=1).unsqueeze(0) document_embeddings.append(document_embedding) bucket_embedding = torch.cat(document_embeddings, dim=0).mean(dim=1) logits = classification_head(bucket_embedding)
The problem is that I can’t get over the OOM error, because there are too many chunks that I pass through the pretrained encoder. This can be avoided by detaching retrieved embeddings (1) from the computational graph, but then optimization becomes tricky.
If I detach (1) embeddings in this algorithm, only the classification head will get trained. I’ve tried to use a shallow transformer encoder block to put averaged documents embeddings to get a bucket embedding for further classification. This seems to work only if the chunk encoder is well pretrained.
I’ve also tried to accumulate gradients after each chunk embedding retrieval (1) –– this reduces bucket classification to chunks average classification. Moreover, it doesn’t seem fair because the original setup calculates loss on averaged embeddings –– and it makes sense because documents have weighted contexts, so we often can’t say much about the whole bucket by looking at one document.
What else can I do? Is there a proper way to get all embeddings without losing their gradients?