Split batch un-evenly across multi-gpus


Is there a way to not split batches evenly between gpus ?

Let’s give an example. Imagine we have a dataset where each sample has one document with one or more query vectors and the position of the related answer.


(Doc1, (q11, p11), (q12, p12)),
(Doc2, (q21, p21), (q22, p22), (q23, p23), (q24, p24))
(Doc3, (q31, p31))
(Doc4, (q41, p41))

Our batches could be:

batch = {
    'doc':     [doc1_tokens, doc2_tokens, doc3_tokens, doc4_tokens],
    'query':   [q11, q12, q21, q22, q23, q24, q31, q41],
    'y_pos':   [p11, p12, p21, p22, p23, p24, p31, p41],

repetition_vector = [2, 4, 1, 1]

For efficiency, a first module could compute the representation of the documents (RNN) and then a second one could repeat the embedding and use the query with.

doc_embds = model1(batch['docs'])
doc_embds = torch.repeat_interleave(doc_embds, repetition_vector, dim=1)

y_hat_pos = model2(doc_embds, batch['query'])

loss = criterion(y_hat_pos, batch['y_pos'])

Let’s imagine we have 2 gpus. doc_embds is gonna be split evenly into [doc1, doc2] and [doc3, doc4]; it will generate final embeddings such as [doc1, doc1, doc2, doc2, doc2, doc2] and [doc3, doc4] while queries and y_pos gonna be [q11, q12, q21, q22] with [q23, q24, q31, q41] and [0, 6, 3, 7] and [4, 5, 5, 9].

In this case, tensors will not match on each GPUs are the query vectors have been split evenly (which shouldn’t be the case).

I don’t what solution for this problem is available:

  • Write a sub-class of DataParallel for this ?
  • Should I reorganize my batches in another way such that using padding (would like to avoid this approach as it takes GRAM) ?
  • Forget the repeat_interleave and re-do the computation for each doc_embs (heavy computation).

The gain of multi-gpu would quitte consequent for this use-case.

Thank you for your help !

I’m not sure I understand your use case completely, but this description sounds more like model parallel / sharding, where the model itself is split to multiple devices, not like data parallel.

I’m working with documents where given a query; I have to find the position of the “answer”. #queries varies across documents.

I have a document encoder and a query encoder. Normally, we could have

batch = {
    'doc':     [d1,  d1,  d2,  d2,  d2,  d2,  d3,  d4],
    'query':   [q11, q12, q21, q22, q23, q24, q31, q41],
    'y_pos':   [p11, p12, p21, p22, p23, p24, p31, p41],

As computing document embedding takes a while with RNN, I thought of doing:

batch = {
    'doc':     [d1,       d2,                 d3,  d4],
    'query':   [q11, q12, q21, q22, q23, q24, q31, q41],
    'y_pos':   [p11, p12, p21, p22, p23, p24, p31, p41],
    'nb_q':    [2,        4,                  1,   1  ],

So one model could compute each document embeddings with only 4 computations (d1, d2, d3,d4) and then use torch.repeat_interleave instead having 8 computations (2d1, 4d2, 1d3, 1d4).

Is it more clear ?

For your answer, I could split the model on 2 gpus but they aren’t big, both fit within a single gpu. Would be the gain worth ?