Optimizing Layer Conductance

So I’m currently at odds with how to approach this. At first, I was only really using Layer Conductance on a document-by-document basis. However, I’m now trying to scale this a little bit as I’m using the attributions towards a paper. The set that I’m passing in for analysis of an already trained model, consists of 6658 documents. This is taking me, on average, 1 to 1.5 hours to process on a Titan RTX. So for around 120 models, it’s taking me just shy of about 3 days to process it if I split this across 4-6 Titan RTXs (I have GPUs available to me via my work cluster so it grabs how many are available).

The process I am currently following is the BERT SQUAD tutorial, slightly adapted for classification. I am using HuggingFace’s BERT Model as a base with dropout + linear layer on top for output.

I was wondering what optimization strategies were available to me, given my situation. I’ve checked the input tensors and they’re float32. I’ve noticed executing cpu().detach().numpy() slows things down a fair bit. Should I maybe be running this on a batch of documents at a time to avoid running the whole embedding process over and over?

I’ve attached my eval script for reference:

def run(docs):
    res = []
    for i, doc in enumerate(docs):
        print(f"Processing Doc #{i+1}/{len(docs)}")
        text_ids = tokenizer.encode(doc, add_special_tokens=False)
        input_ids = [tokenizer.cls_token_id] + text_ids + [tokenizer.sep_token_id]
        ref_input_ids = [tokenizer.cls_token_id] + [tokenizer.pad_token_id] * len(text_ids) + [tokenizer.sep_token_id]
        input_ids = torch.tensor([input_ids], device=DEVICE)
        ref_input_ids = torch.tensor([ref_input_ids], device=DEVICE)

        sep_id = len(text_ids)
        seq_len = input_ids.size(1)
        token_type_ids = torch.tensor(
            [[0 if i <= sep_id else 1 for i in range(seq_len)]],
        ref_token_type_ids = torch.zeros_like(token_type_ids, device=DEVICE)

        seq_length = input_ids.size(1)
        position_ids = torch.arange(seq_length, dtype=torch.long, device=DEVICE)
        position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
        ref_position_ids = torch.zeros(seq_length, dtype=torch.long, device=DEVICE)
        ref_position_ids = ref_position_ids.unsqueeze(0).expand_as(input_ids)
        attention_mask = torch.ones_like(input_ids)

        indices = input_ids[0].detach().tolist()
        all_tokens = tokenizer.convert_ids_to_tokens(indices)
        input_embeds = model.base.embeddings(input_ids, token_type_ids=token_type_ids, position_ids=position_ids)
        ref_input_embeds = model.base.embeddings(ref_input_ids, token_type_ids=token_type_ids, position_ids=position_ids)
        def forward(input_embeds, attention_mask):
            logits = model(inputs_embeds=input_embeds, attention_mask=attention_mask)
            return torch.softmax(logits, dim=1)
        def summarize_attributions(attributions):
            attributions = attributions.sum(dim=-1).squeeze(0)
            attributions = attributions / torch.norm(attributions)
            return attributions
        layer_attrs = []
        for i in range(model.base.config.num_hidden_layers):
            lc = LayerConductance(forward, model.base.encoder.layer[i])
            layer_attributions = lc.attribute(
        res.append ({
            "doc_no": i,
            "tokens": all_tokens,
            "conductance": layer_attrs
    return res

Hi @ajhepburn, if you run layer conductance on a batch of documents it can definitely speed up things. Only you want to make sure that you have sufficient memory and perhaps create sub-batches combing multiple documents if you can’t fit everything into a memory. Note that the internal requirements of layer conductance is n_steps * #examples. n_steps is defaulted to 50 and you can change it if necessary.

As I can see, you are treating each document as one example, is that right ? If your documents aren’t too large you can definitely try to combine them.
Another option is to use DataParallels and GPUs.
More info about Dataparallels and Distributed Dataparallels can be found here: