I have I have a lot of graph, every node has a text, I want to use BERT to extract feature from text.
Every graph might have 1000 nodes, every node has 64 token length (token_ids)
because a graph may has many nodes, I split into 100 as batch size to get embedding from BERT , but when I got 5-th batch size embedding , cuda OOM happened
torch.Size([1366, 64])
torch.Size([100, 64])
torch.Size([100, 64])
torch.Size([100, 64])
torch.Size([100, 64])
torch.Size([100, 64])
torch.Size([100, 64])
Traceback (most recent call last):
File "/data1/liyushen/projects/assessment/gcn/embedding_graph_model/main.py", line 249, in <module>
main(sys.argv[1],sys.argv[2],sys.argv[3],\
File "/data1/liyushen/projects/assessment/gcn/embedding_graph_model/main.py", line 149, in main
logits = model(g,token_ids)
File "/data1/liyushen/miniconda3/envs/dgl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/data1/liyushen/projects/assessment/gcn/embedding_graph_model/codebert_gnn.py", line 24, in forward
output : BaseModelOutputWithPoolingAndCrossAttentions = self.encoder(split_token_ids)
File "/data1/liyushen/miniconda3/envs/dgl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/data1/liyushen/miniconda3/envs/dgl/lib/python3.9/site-packages/transformers/models/roberta/modeling_roberta.py", line 852, in forward
encoder_outputs = self.encoder(
File "/data1/liyushen/miniconda3/envs/dgl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/data1/liyushen/miniconda3/envs/dgl/lib/python3.9/site-packages/transformers/models/roberta/modeling_roberta.py", line 528, in forward
layer_outputs = layer_module(
File "/data1/liyushen/miniconda3/envs/dgl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/data1/liyushen/miniconda3/envs/dgl/lib/python3.9/site-packages/transformers/models/roberta/modeling_roberta.py", line 413, in forward
self_attention_outputs = self.attention(
File "/data1/liyushen/miniconda3/envs/dgl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/data1/liyushen/miniconda3/envs/dgl/lib/python3.9/site-packages/transformers/models/roberta/modeling_roberta.py", line 340, in forward
self_outputs = self.self(
File "/data1/liyushen/miniconda3/envs/dgl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/data1/liyushen/miniconda3/envs/dgl/lib/python3.9/site-packages/transformers/models/roberta/modeling_roberta.py", line 227, in forward
value_layer = self.transpose_for_scores(self.value(hidden_states))
File "/data1/liyushen/miniconda3/envs/dgl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/data1/liyushen/miniconda3/envs/dgl/lib/python3.9/site-packages/torch/nn/modules/linear.py", line 114, in forward
return F.linear(input, self.weight, self.bias)
RuntimeError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 23.70 GiB total capacity; 21.75 GiB already allocated; 2.88 MiB free; 22.61 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
I try using below code
from torch.nn import Module
import torch.nn as nn
import torch
import torch.functional as F
from transformers import RobertaModel, BertTokenizer
import dgl
from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
class CodeBertGNN(Module):
def __init__(self) -> None:
super(CodeBertGNN, self).__init__()
assert dgl.__version__ != ''
self.encoder = RobertaModel.from_pretrained('microsoft/codebert-base').to('cuda:0')
def forward(self, g: dgl.DGLGraph, token_ids: torch.Tensor):
num_node = token_ids.shape[0]
context_result = []
print(token_ids.shape)
for i in range(0,num_node,100):
split_token_ids = token_ids[i:i+100,:]
split_token_ids = split_token_ids.to('cuda:0')
print(split_token_ids.shape)
output : BaseModelOutputWithPoolingAndCrossAttentions = self.encoder(split_token_ids)
cls_context_embedding = output[0][:,0,:].to('cuda:1')
context_result.append(cls_context_embedding)
context_embedding = torch.cat(context_result)
print(context_embedding.shape)
pass