My aim is to get a linear layer with large output dimension. To achieve this I store the weights of the linear layer in an embedding layer. Further I need to forward and backward only on some connections of the fully connected layer(hence the “shortlist”). Since the output size is large, I divide the embedding layer onto 2 GPUs.
Relevant parts of the code:
class SparseLinear(nn.Module):
def __init__(self, num_labels, hidden_size, device_embeddings):
super(SparseLinear, self).__init__()
self.device_embeddings = device_embeddings
self.input_size = hidden_size
self.output_size = num_labels
self.weight = Parameter(torch.Tensor(self.output_size, self.input_size))
if bias:
self.bias = Parameter(torch.Tensor(self.output_size, 1))
else:
self.register_parameter('bias', None)
self.reset_parameters()
self.sparse = True # Required for optimizer
def forward(self, embed, shortlist):
short_weights = F.embedding(shortlist,
self.weight,
sparse=self.sparse)
out = torch.matmul(embed.unsqueeze(1), short_weights.permute(0, 2, 1))
short_bias = F.embedding(shortlist,
self.bias,
sparse=self.sparse)
out = out + short_bias.permute(0, 2, 1)
del short_weights
return out.squeeze()
class DividedLinear(DeepXMLBase):
def __init__(self, <params>):
# Say I have output size of 1000000, and I divide it into two 2 parts
self.label_partition_lengths = [(500000, "cuda:0"), (500000, "cuda:1")]
self.classifier = [SparseLinear(num_labels, 300, torch.device(device_name)) for
num_labels, device_name in self.label_partition_lengths]
<init other params>
def encode(self, batch_data):
return self.transform(batch_data["doc_embeddings"].to(self.device_embeddings)) # is some network to transform embeddings
def forward_with_error_calc(self, batch_data, criterion):
print("before", torch.cuda.memory_allocated(1) / (1024 * 1024 * 1024),
torch.cuda.memory_allocated(2) / (1024 * 1024 * 1024))
encoded = self.encode(batch_data)
device_embeddings = [torch.device(num_labels_device[1]) for num_labels_device in self.label_partition_lengths]
shortlists = [x.to(device_embeddings[i]) for i, x in enumerate(batch_data["shortlist"])]
encoded_replicate = [encoded.to(device_embeddings[i]) for i in range(len(device_embeddings))]
outputs = nn.parallel.parallel_apply(self.classifier, list(zip(encoded_replicate, shortlists)))
targets = [batch_data["shortlist_weights"][i].to(device_embeddings[i]) for i in range(len(device_embeddings))]
errors = nn.parallel.parallel_apply(nn.parallel.replicate(criterion, device_embeddings), list(zip(outputs, targets)))
errors_gather = nn.parallel.gather(errors, target_device=device_embeddings[0])
total_error = errors_gather.sum()
print("after", torch.cuda.memory_allocated(1) / (1024 * 1024 * 1024),
torch.cuda.memory_allocated(2) / (1024 * 1024 * 1024))
for output in outputs:
del output
for target in targets:
del target
for x in shortlists:
del x
for x in encoded_replicate:
del x
torch.cuda.empty_cache()
print("after del", torch.cuda.memory_allocated(1) / (1024 * 1024 * 1024),
torch.cuda.memory_allocated(2) / (1024 * 1024 * 1024))
return total_error
But the GPUs run out of memory after some batches. Particularly, I observe behavior like this:
before 5.049468994140625 5.049465179443359
after 5.1367316246032715 5.1367268562316895
after del 5.1367316246032715 5.1367268562316895
before 5.136678695678711 5.1366729736328125
after 5.223941326141357 5.223934650421143
after del 5.223941326141357 5.223934650421143
So these seems to be some leakage in the forward_with_error_calc
function, but I can’t figure out what it is. Can someone please help me in figuring this out? TIA.