Split batch un-evenly across multi-gpus

Hi,

Is there a way to not split batches evenly between gpus ?

Let’s give an example. Imagine we have a dataset where each sample has one document with one or more query vectors and the position of the related answer.

Example:

(Doc1, (q11, p11), (q12, p12)),
(Doc2, (q21, p21), (q22, p22), (q23, p23), (q24, p24))
(Doc3, (q31, p31))
(Doc4, (q41, p41))

Our batches could be:

batch = {
    'doc':     [doc1_tokens, doc2_tokens, doc3_tokens, doc4_tokens],
    'query':   [q11, q12, q21, q22, q23, q24, q31, q41],
    'y_pos':   [p11, p12, p21, p22, p23, p24, p31, p41],
}

repetition_vector = [2, 4, 1, 1]

For efficiency, a first module could compute the representation of the documents (RNN) and then a second one could repeat the embedding and use the query with.

doc_embds = model1(batch['docs'])
doc_embds = torch.repeat_interleave(doc_embds, repetition_vector, dim=1)

y_hat_pos = model2(doc_embds, batch['query'])

loss = criterion(y_hat_pos, batch['y_pos'])

PROBLEM
Let’s imagine we have 2 gpus. doc_embds is gonna be split evenly into [doc1, doc2] and [doc3, doc4]; it will generate final embeddings such as [doc1, doc1, doc2, doc2, doc2, doc2] and [doc3, doc4] while queries and y_pos gonna be [q11, q12, q21, q22] with [q23, q24, q31, q41] and [0, 6, 3, 7] and [4, 5, 5, 9].

In this case, tensors will not match on each GPUs are the query vectors have been split evenly (which shouldn’t be the case).

I don’t what solution for this problem is available:

  • Write a sub-class of DataParallel for this ?
  • Should I reorganize my batches in another way such that using padding (would like to avoid this approach as it takes GRAM) ?
  • Forget the repeat_interleave and re-do the computation for each doc_embs (heavy computation).

The gain of multi-gpu would quitte consequent for this use-case.

Thank you for your help !

I’m not sure I understand your use case completely, but this description sounds more like model parallel / sharding, where the model itself is split to multiple devices, not like data parallel.

I’m working with documents where given a query; I have to find the position of the “answer”. #queries varies across documents.

I have a document encoder and a query encoder. Normally, we could have

batch = {
    'doc':     [d1,  d1,  d2,  d2,  d2,  d2,  d3,  d4],
    'query':   [q11, q12, q21, q22, q23, q24, q31, q41],
    'y_pos':   [p11, p12, p21, p22, p23, p24, p31, p41],
}

As computing document embedding takes a while with RNN, I thought of doing:

batch = {
    'doc':     [d1,       d2,                 d3,  d4],
    'query':   [q11, q12, q21, q22, q23, q24, q31, q41],
    'y_pos':   [p11, p12, p21, p22, p23, p24, p31, p41],
    'nb_q':    [2,        4,                  1,   1  ],
}

So one model could compute each document embeddings with only 4 computations (d1, d2, d3,d4) and then use torch.repeat_interleave instead having 8 computations (2d1, 4d2, 1d3, 1d4).

Is it more clear ?

For your answer, I could split the model on 2 gpus but they aren’t big, both fit within a single gpu. Would be the gain worth ?