hi! i’m using captum with a transformer based protein language model in order to identify input (embeddings)-output correlations. i take inspiration from captum website tutorials (BERT model) but i’m not able to run last bunch of codes relate to captum.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as Data
import torch.nn.utils.rnn as rnn_utils
import os
import time
from sklearn.metrics import auc, roc_curve, average_precision_score, precision_recall_curve
from termcolor import colored
import pdb
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
def generate_data(file):
# Amino acid dictionary
aa_dict = {'A': 1, 'R': 2, 'N': 3, 'D': 4, 'C': 5, 'Q': 6, 'E': 7, 'G': 8, 'H': 9, 'I': 10,
'L': 11, 'K': 12, 'M': 13, 'F': 14, 'P': 15, 'O': 16, 'S': 17, 'U': 18, 'T': 19,
'W': 20, 'Y': 21, 'V': 22, 'X': 23}
# open csv file
with open(file, 'r') as inf:
lines = inf.read().splitlines()
pep_codes = []
labels = []
peps = []
for pep in lines: # for every row
pep, label = pep.split(",") # sequence and label split
peps.append(pep)
labels.append(int(label))
current_pep = []
for aa in pep:
current_pep.append(aa_dict[aa])
pep_codes.append(torch.tensor(current_pep))
data = rnn_utils.pad_sequence(pep_codes, batch_first=True) # Fill the sequence to the same length
return data, torch.tensor(labels)
data, label = generate_data("./SSP_dataset.csv")
train_data, train_label= data[:1894], label[:1894] #I primi 1894 sono usati per il training
test_data, test_label = data[1894:], label[1894:]
train_dataset = Data.TensorDataset(train_data, train_label)
test_dataset = Data.TensorDataset(test_data, test_label)
batch_size = 64
train_iter = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_iter = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
class xAInet(nn.Module):
def __init__(self):
super().__init__()
self.hidden_dim = 25
self.batch_size = 32
self.embedding_dim = 512
self.embedding_layer = nn.Embedding(24, self.embedding_dim, padding_idx=0)
self.encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=1)
self.gru = nn.GRU(self.embedding_dim, self.hidden_dim, num_layers=2,
bidirectional=True, dropout=.2)
self.block_seq = nn.Sequential(nn.Linear(15050, 2048),
nn.BatchNorm1d(2048),
nn.LeakyReLU(),
nn.Linear(2048, 1024),
nn.BatchNorm1d(1024),
nn.LeakyReLU(),
nn.Linear(1024, 256),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.Linear(256, 8),
nn.Linear(8, 2))
def forward(self, seq):
embeddings = self.embedding_layer(seq)
output = self.transformer_encoder(embeddings).permute(1, 0, 2)
output, hn = self.gru(output)
output = output.permute(1, 0, 2)
hn = hn.permute(1, 0, 2)
output = output.reshape(output.shape[0], -1)
hn = hn.reshape(output.shape[0], -1)
output = torch.cat([output, hn], 1)
output = self.block_seq(output)
return output, embeddings
def train_model(self, seq):
#with torch.no_grad():
output,_ = self.forward(seq)
return output
class ContrastiveLoss(torch.nn.Module):
def __init__(self, margin=2.0):
super(ContrastiveLoss, self).__init__()
self.margin = margin
def forward(self, output1, output2, label):
euclidean_distance = F.pairwise_distance(output1, output2)
loss_contrastive = torch.mean((1 - label) * torch.pow(euclidean_distance, 2) +
label * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))
return loss_contrastive
def collate(batch):
seq1_ls = []
seq2_ls = []
label1_ls = []
label2_ls = []
label_ls = []
batch_size = len(batch)
for i in range(int(batch_size / 2)):
seq1, label1= batch[i][0], batch[i][1]
seq2, label2= batch[i + int(batch_size / 2)][0], \
batch[i + int(batch_size / 2)][1], \
label1_ls.append(label1.unsqueeze(0))
label2_ls.append(label2.unsqueeze(0))
label = (label1 ^ label2)
seq1_ls.append(seq1.unsqueeze(0))
seq2_ls.append(seq2.unsqueeze(0))
label_ls.append(label.unsqueeze(0))
seq1 = torch.cat(seq1_ls).to(device)
seq2 = torch.cat(seq2_ls).to(device)
label = torch.cat(label_ls).to(device)
label1 = torch.cat(label1_ls).to(device)
label2 = torch.cat(label2_ls).to(device)
return seq1, seq2, label, label1, label2
train_iter_cont = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,
shuffle=True, collate_fn=collate)
device = torch.device("cuda", 0)
def evaluate(data_iter, net):
pred_prob = []
label_pred = []
label_real = []
for x, y in data_iter:
x, y = x.to(device), y.to(device)
outputs = net.train_model(x)
outputs_cpu = outputs.cpu()
y_cpu = y.cpu()
pred_prob_positive = outputs_cpu[:, 1]
pred_prob = pred_prob + pred_prob_positive.tolist()
label_pred = label_pred + outputs.argmax(dim=1).tolist()
label_real = label_real + y_cpu.tolist()
performance, roc_data, prc_data = caculate_metric(pred_prob, label_pred, label_real)
return performance, roc_data, prc_data
def caculate_metric(pred_prob, label_pred, label_real):
test_num = len(label_real)
tp = 0
tn = 0
fp = 0
fn = 0
for index in range(test_num):
if label_real[index] == 1:
if label_real[index] == label_pred[index]:
tp = tp + 1
else:
fn = fn + 1
else:
if label_real[index] == label_pred[index]:
tn = tn + 1
else:
fp = fp + 1
# Accuracy
ACC = float(tp + tn) / test_num
# Sensitivity
if tp + fn == 0:
Recall = Sensitivity = 0
else:
Recall = Sensitivity = float(tp) / (tp + fn)
# Specificity
if tn + fp == 0:
Specificity = 0
else:
Specificity = float(tn) / (tn + fp)
# MCC
if (tp + fp) * (tp + fn) * (tn + fp) * (tn + fn) == 0:
MCC = 0
else:
MCC = float(tp * tn - fp * fn) / (np.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn)))
# ROC and AUC
FPR, TPR, thresholds = roc_curve(label_real, pred_prob, pos_label=1)
AUC = auc(FPR, TPR)
# PRC and AP
precision, recall, thresholds = precision_recall_curve(label_real, pred_prob, pos_label=1)
AP = average_precision_score(label_real, pred_prob, average='macro', pos_label=1, sample_weight=None)
performance = [ACC, Sensitivity, Specificity, AUC, MCC]
roc_data = [FPR, TPR, AUC]
prc_data = [recall, precision, AP]
return performance, roc_data, prc_data
def to_log(log):
with open("./results/ExamPle_Log.log", "a+") as f:
f.write(log + '\n')
net = xAInet().to(device)
lr = 0.0001
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
criterion = ContrastiveLoss()
criterion_model = nn.CrossEntropyLoss(reduction='sum')
best_acc = 0
EPOCH = 5
for epoch in range(EPOCH):
loss_ls = []
loss1_ls = []
loss2_3_ls = []
t0 = time.time()
net.train()
for seq1, seq2, label, label1, label2 in train_iter_cont:
output1,emb = net(seq1)
output2,emb = net(seq2)
#pdb.set_trace()
output3 = net.train_model(seq1)
output4 = net.train_model(seq2)
loss1 = criterion(output1, output2, label)
loss2 = criterion_model(output3, label1)
loss3 = criterion_model(output4, label2)
loss = loss1 + loss2 + loss3
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_ls.append(loss.item())
loss1_ls.append(loss1.item())
loss2_3_ls.append((loss2 + loss3).item())
net.eval()
with torch.no_grad():
train_performance, train_roc_data, train_prc_data = evaluate(train_iter, net)
test_performance, test_roc_data, test_prc_data = evaluate(test_iter, net)
results = f"\nepoch: {epoch + 1}, loss: {np.mean(loss_ls):.5f}, loss1: {np.mean(loss1_ls):.5f}, loss2_3: {np.mean(loss2_3_ls):.5f}\n"
results += f'train_acc: {train_performance[0]:.4f}, time: {time.time() - t0:.2f}'
results += '\n' + '=' * 16 + ' Test Performance. Epoch[{}] '.format(epoch + 1) + '=' * 16 \
+ '\n[ACC,\tSE,\t\tSP,\t\tAUC,\tMCC]\n' + '{:.4f},\t{:.4f},\t{:.4f},\t{:.4f},\t{:.4f}'.format(
test_performance[0], test_performance[1], test_performance[2], test_performance[3],
test_performance[4]) + '\n' + '=' * 60
print(results)
# to_log(results)
test_acc = test_performance[0] # test_performance: [ACC, Sensitivity, Specificity, AUC, MCC]
if test_acc > best_acc:
best_acc = test_acc
best_performance = test_performance
filename = '{}, {}[{:.3f}].pt'.format('ExamPle' + ', epoch[{}]'.format(epoch + 1), 'ACC', best_acc)
save_path_pt = os.path.join('./Model', filename)
# torch.save(net.state_dict(), save_path_pt, _use_new_zipfile_serialization=False)
best_results = '\n' + '=' * 16 + colored(' Best Performance. Epoch[{}] ', 'red').format(epoch + 1) + '=' * 16 \
+ '\n[ACC,\tSE,\t\tSP,\t\tAUC,\tMCC]\n' + '{:.4f},\t{:.4f},\t{:.4f},\t{:.4f},\t{:.4f}'.format(
best_performance[0], best_performance[1], best_performance[2], best_performance[3],
best_performance[4]) + '\n' + '=' * 60
print(best_results)
best_ROC = test_roc_data
best_PRC = test_prc_data
def model_output(inputs):
inputs = inputs[0].unsqueeze(0)
out, embeddings = model(inputs)
# Apply softmax to convert prediction scores to probabilities
probabilities = torch.softmax(out, dim=1)
# Get the predicted classes by selecting the class with the highest probability
predicted_classes = torch.argmax(probabilities, dim=1)
return predicted_classes
import captum
from captum.attr import LayerIntegratedGradients
lig = LayerIntegratedGradients(model_output, model.embedding_layer)
def construct_input_and_baseline(text):
max_length = 512
#baseline_token_id = rnn_utils.pad_sequence()
input_ids = []
token_list = []
aa_dict = {'A': 1, 'R': 2, 'N': 3, 'D': 4, 'C': 5, 'Q': 6, 'E': 7, 'G': 8, 'H': 9, 'I': 10,
'L': 11, 'K': 12, 'M': 13, 'F': 14, 'P': 15, 'O': 16, 'S': 17, 'U': 18, 'T': 19,
'W': 20, 'Y': 21, 'V': 22, 'X': 23}
for char in text:
if char in aa_dict:
input_ids.append(aa_dict[char])
token_list.append(char)
baseline_token_id = 13
baseline_input_ids = [baseline_token_id] * len(input_ids)
input_ids_tensor = torch.tensor([input_ids], device='cpu')
baseline_input_ids_tensor = torch.tensor([baseline_input_ids], device='cpu')
return input_ids_tensor, baseline_input_ids_tensor, token_list
text = 'MSKSKMLVFKSKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKMSKSKMLVFKMSKSKMLVFKMSKSKMLVFKMSKSKMLVFK'
input_ids, baseline_input_ids, all_tokens = construct_input_and_baseline(text)
print(f'original text: {input_ids}')
print(f'baseline text: {baseline_input_ids}')
print(f'all tokens: {all_tokens}')
desired_length = 299
padded_sequences = [seq[:desired_length] if len(seq) >= desired_length else torch.cat((seq, torch.zeros(desired_length - len(seq)))) for seq in input_ids]
# Apply pad_sequence on the padded_sequences
#data = pad_sequence(padded_sequences, batch_first=True)
input_ids = rnn_utils.pad_sequence(padded_sequences, batch_first=True)
desired_length = 299
padded_sequences = [seq[:desired_length] if len(seq) >= desired_length else torch.cat((seq, torch.zeros(desired_length - len(seq)))) for seq in baseline_input_ids]
# Apply pad_sequence on the padded_sequences
#data = pad_sequence(padded_sequences, batch_first=True)
baseline_input_ids = rnn_utils.pad_sequence(padded_sequences, batch_first=True)
attributions, delta = lig.attribute(inputs= input_ids,
baselines= baseline_input_ids,
return_convergence_delta=True,
internal_batch_size=1
)
i got the error:
RuntimeError Traceback (most recent call last)
<ipython-input-42-e20407bb5215> in <cell line: 1>()
----> 1 ig.attribute(inputs=input_ids, baselines=baseline_input_ids, target=0)
9 frames
/usr/local/lib/python3.10/dist-packages/captum/log/__init__.py in wrapper(*args, **kwargs)
40 @wraps(func)
41 def wrapper(*args, **kwargs):
---> 42 return func(*args, **kwargs)
43
44 return wrapper
/usr/local/lib/python3.10/dist-packages/captum/attr/_core/integrated_gradients.py in attribute(self, inputs, baselines, target, additional_forward_args, n_steps, method, internal_batch_size, return_convergence_delta)
284 )
285 else:
--> 286 attributions = self._attribute(
287 inputs=inputs,
288 baselines=baselines,
/usr/local/lib/python3.10/dist-packages/captum/attr/_core/integrated_gradients.py in _attribute(self, inputs, baselines, target, additional_forward_args, n_steps, method, step_sizes_and_alphas)
349
350 # grads: dim -> (bsz * #steps x inputs[0].shape[1:], ...)
--> 351 grads = self.gradient_func(
352 forward_fn=self.forward_func,
353 inputs=scaled_features_tpl,
/usr/local/lib/python3.10/dist-packages/captum/_utils/gradient.py in compute_gradients(forward_fn, inputs, target_ind, additional_forward_args)
110 with torch.autograd.set_grad_enabled(True):
111 # runs forward pass
--> 112 outputs = _run_forward(forward_fn, inputs, target_ind, additional_forward_args)
113 assert outputs[0].numel() == 1, (
114 "Target not provided when necessary, cannot"
/usr/local/lib/python3.10/dist-packages/captum/_utils/common.py in _run_forward(forward_func, inputs, target, additional_forward_args)
480 additional_forward_args = _format_additional_forward_args(additional_forward_args)
481
--> 482 output = forward_func(
483 *(*inputs, *additional_forward_args)
484 if additional_forward_args is not None
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
<ipython-input-5-d2063671fd62> in forward(self, seq)
28
29 def forward(self, seq):
---> 30 embeddings = self.embedding_layer(seq)
31 output = self.transformer_encoder(embeddings).permute(1, 0, 2)
32 output, hn = self.gru(output)
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/sparse.py in forward(self, input)
160
161 def forward(self, input: Tensor) -> Tensor:
--> 162 return F.embedding(
163 input, self.weight, self.padding_idx, self.max_norm,
164 self.norm_type, self.scale_grad_by_freq, self.sparse)
/usr/local/lib/python3.10/dist-packages/torch/nn/functional.py in embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)
2208 # remove once script supports set_grad_enabled
2209 _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
-> 2210 return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
2211
2212
RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.FloatTensor instead (while checking arguments for embedding)
so i tired using
input_ids.to(torch.long)
but nothing… can you help me please?