So I’m currently at odds with how to approach this. At first, I was only really using Layer Conductance on a document-by-document basis. However, I’m now trying to scale this a little bit as I’m using the attributions towards a paper. The set that I’m passing in for analysis of an already trained model, consists of 6658 documents. This is taking me, on average, 1 to 1.5 hours to process on a Titan RTX. So for around 120 models, it’s taking me just shy of about 3 days to process it if I split this across 4-6 Titan RTXs (I have GPUs available to me via my work cluster so it grabs how many are available).
The process I am currently following is the BERT SQUAD tutorial, slightly adapted for classification. I am using HuggingFace’s BERT Model as a base with dropout + linear layer on top for output.
I was wondering what optimization strategies were available to me, given my situation. I’ve checked the input tensors and they’re float32. I’ve noticed executing cpu().detach().numpy() slows things down a fair bit. Should I maybe be running this on a batch of documents at a time to avoid running the whole embedding process over and over?
I’ve attached my eval script for reference:
def run(docs):
res = []
for i, doc in enumerate(docs):
print(f"Processing Doc #{i+1}/{len(docs)}")
text_ids = tokenizer.encode(doc, add_special_tokens=False)
input_ids = [tokenizer.cls_token_id] + text_ids + [tokenizer.sep_token_id]
ref_input_ids = [tokenizer.cls_token_id] + [tokenizer.pad_token_id] * len(text_ids) + [tokenizer.sep_token_id]
input_ids = torch.tensor([input_ids], device=DEVICE)
ref_input_ids = torch.tensor([ref_input_ids], device=DEVICE)
sep_id = len(text_ids)
seq_len = input_ids.size(1)
token_type_ids = torch.tensor(
[[0 if i <= sep_id else 1 for i in range(seq_len)]],
device=DEVICE
)
ref_token_type_ids = torch.zeros_like(token_type_ids, device=DEVICE)
seq_length = input_ids.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long, device=DEVICE)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
ref_position_ids = torch.zeros(seq_length, dtype=torch.long, device=DEVICE)
ref_position_ids = ref_position_ids.unsqueeze(0).expand_as(input_ids)
attention_mask = torch.ones_like(input_ids)
indices = input_ids[0].detach().tolist()
all_tokens = tokenizer.convert_ids_to_tokens(indices)
input_embeds = model.base.embeddings(input_ids, token_type_ids=token_type_ids, position_ids=position_ids)
ref_input_embeds = model.base.embeddings(ref_input_ids, token_type_ids=token_type_ids, position_ids=position_ids)
def forward(input_embeds, attention_mask):
logits = model(inputs_embeds=input_embeds, attention_mask=attention_mask)
return torch.softmax(logits, dim=1)
def summarize_attributions(attributions):
attributions = attributions.sum(dim=-1).squeeze(0)
attributions = attributions / torch.norm(attributions)
return attributions
layer_attrs = []
for i in range(model.base.config.num_hidden_layers):
lc = LayerConductance(forward, model.base.encoder.layer[i])
layer_attributions = lc.attribute(
inputs=input_embeds,
baselines=ref_input_embeds,
additional_forward_args=(attention_mask),
target=1
)
layer_attrs.append(summarize_attributions(layer_attributions).numpy())
res.append ({
"doc_no": i,
"tokens": all_tokens,
"conductance": layer_attrs
})
torch.cuda.empty_cache()
return res