Hi!
Could you tell me if the Optimizer.step() should be really the slowest in the train process?
If so - are there any techniques to improve the performance?
I profiled my code with line-profiler, train is performed on CPU.
Line # Hits Time Per Hit % Time Line Contents
==============================================================
43 @profile
44 def train_models(self, train_df: pd.DataFrame, lr=0.001, train_batch=32,
45 pretrained_weights='bert-base-uncased',
46 val_batch=256, patience=100, epochs=1000, verbosity=100, shuffle=True):
47
48 1 26.0 26.0 0.0 if shuffle:
49 1 24201.0 24201.0 0.0 train_df = train_df.sample(frac=1).reset_index(drop=True)
50
51 1 34376699.0 34376699.0 2.7 embedder = BERTModel(pretrained_weights) # load BERT embedder
52 1 5456.0 5456.0 0.0 print ('BERT loaded successfully!')
53
54 3 197.0 65.7 0.0 for topic in self.topics:
55 2 7287.0 3643.5 0.0 print("Preparing single model for '{}' topic".format(topic))
56
57 ######################################
58 2 66451.0 33225.5 0.0 X = train_df[train_df['Topic'] == topic]['Sentence']
59 2 51771.0 25885.5 0.0 Xs = train_df[train_df['Topic'] == topic]['Section']
60 2 47551.0 23775.5 0.0 labels = torch.tensor(train_df[train_df['Topic'] == topic]['IsCorrect'].values, dtype=torch.float)
61 2 6627.0 3313.5 0.0 print('Starting generate BERT embeddings...')
62 2 114700598.0 57350299.0 9.1 X_embs = embedder.predict(X, caching=True)
63 2 122060608.0 61030304.0 9.6 Xs_embs = embedder.predict(Xs, caching=True)
64 2 435.0 217.5 0.0 data = Data(X_embs, Xs_embs, labels)
65 2 23461.0 11730.5 0.0 print('Starting training...')
66 ##pos_weight calculation
67 2 2903.0 1451.5 0.0 pos = torch.sum(labels)
68 2 841.0 420.5 0.0 lbls = torch.tensor(labels.shape, dtype=torch.float)
69 2 1659.0 829.5 0.0 pos_weight = (lbls - pos) / pos
70
71 2 860.0 430.0 0.0 skf = StratifiedKFold(n_splits=self.folds, shuffle=True, random_state=2020)
72 2 58.0 29.0 0.0 fold = -1
73 8 120343.0 15042.9 0.0 for train_idx, valid_idx in skf.split(X=X_embs, y=labels):
74 6 250.0 41.7 0.0 fold += 1
75 6 22295.0 3715.8 0.0 print('Fold {}'.format(fold))
76
77 6 13344.0 2224.0 0.0 train_loader = DataLoader(Subset(data, train_idx), batch_size=train_batch)
78 6 23836.0 3972.7 0.0 validation_loader = DataLoader(Subset(data, valid_idx), batch_size=val_batch)
79
80 6 12078.0 2013.0 0.0 criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
81 6 40771.0 6795.2 0.0 optimizer = torch.optim.Adam(self.models['{}_{}'.format(topic, fold)].parameters(), lr=lr, weight_decay=0.001)
82 6 13631.0 2271.8 0.0 scheduler = CyclicLR(optimizer, base_lr=0.0001, max_lr=0.001, cycle_momentum=False)
83 6 465.0 77.5 0.0 acc_train = []
84 6 236.0 39.3 0.0 acc_test = []
85 6 207.0 34.5 0.0 losses_test = []
86 # best_auc = 0
87 6 239.0 39.8 0.0 best_loss = 99
88 6 205.0 34.2 0.0 aucs = []
89 6 293.0 48.8 0.0 recalls = []
90 6 185.0 30.8 0.0 lrs = []
91 6 189.0 31.5 0.0 best_epoch = 0
92
93 66 2834.0 42.9 0.0 for epoch in range(epochs):
94 60 2519.0 42.0 0.0 if epoch % verbosity == 1:
95 30 321601.0 10720.0 0.0 print('Epoch ', epoch, 'Test Loss ', losses_test[-1])
96 60 328282.0 5471.4 0.0 self.models['{}_{}'.format(topic, fold)].train()
97 60 7857.0 130.9 0.0 y_true = []
98 60 3559.0 59.3 0.0 y_score = []
99 3510 41746453.0 11893.6 3.3 for x, xs, y in train_loader:
100 3450 27656656.0 8016.4 2.2 optimizer.zero_grad()
101 3450 153222734.0 44412.4 12.1 yhat = self.models['{}_{}'.format(topic, fold)](x, xs)
102 3450 9724355.0 2818.7 0.8 loss = criterion(yhat, y)
103 3450 202206004.0 58610.4 16.0 loss.backward()
104 3450 508723569.0 147456.1 40.2 optimizer.step()
105 3450 1720162.0 498.6 0.1 yhat = torch.sigmoid(yhat)
106 3450 203794.0 59.1 0.0 y_true.append(y)
107 3450 341063.0 98.9 0.0 y_score.append(yhat.detach())
108 3450 4140073.0 1200.0 0.3 scheduler.step()
109 60 70071.0 1167.8 0.0 y_true = torch.cat(y_true)
110 60 51460.0 857.7 0.0 y_score = torch.cat(y_score)
111 60 751454.0 12524.2 0.1 acc_train.append(accuracy_score(y_true=y_true.numpy(), y_pred=np.rint(y_score.numpy())))
112
113 60 320556.0 5342.6 0.0 self.models['{}_{}'.format(topic, fold)].eval()
114 60 7136.0 118.9 0.0 y_true = []
115 60 3611.0 60.2 0.0 y_score = []
116 300 17642286.0 58807.6 1.4 for x, xs, y in validation_loader:
117 240 39880.0 166.2 0.0 with torch.no_grad():
118 240 18828388.0 78451.6 1.5 y_pred = self.models['{}_{}'.format(topic, fold)](x, xs)
119 240 18126.0 75.5 0.0 y_true.append(y)
120 240 11873.0 49.5 0.0 y_score.append(y_pred)
121 60 20548.0 342.5 0.0 y_true = torch.cat(y_true)
122 60 14666.0 244.4 0.0 y_score = torch.cat(y_score)
123 60 130068.0 2167.8 0.0 val_loss = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
124 60 161065.0 2684.4 0.0 losses_test.append(val_loss(y_score, y_true).detach().item())
125 60 22926.0 382.1 0.0 y_score = torch.sigmoid(y_score)
126 60 2390831.0 39847.2 0.2 aucs.append(roc_auc_score(y_true=y_true.numpy(), y_score=y_score.numpy()))
127 60 546383.0 9106.4 0.0 acc_test.append(accuracy_score(y_true=y_true.numpy(), y_pred=np.rint(y_score.numpy())))
128 60 2220304.0 37005.1 0.2 recalls.append(recall_score(y_true=y_true.numpy(), y_pred=np.rint(y_score.numpy())))
129 60 56561.0 942.7 0.0 lrs.append(optimizer.state_dict()['param_groups'][0]['lr'])
130
131
132
133
134 # if aucs[-1] > best_auc:
135 # best_epoch = epoch
136 # best_auc = aucs[-1]
137
138 60 3338.0 55.6 0.0 if losses_test[-1] < best_loss:
139 43 1715.0 39.9 0.0 best_epoch = epoch
140 43 1814.0 42.2 0.0 best_loss = losses_test[-1]
141
142 60 2775.0 46.2 0.0 if epoch - best_epoch > patience:
143 break
144
145 6 281.0 46.8 0.0 print('Model for topic {} branch {} successfully trained for {} epochs with best val_loss={}'.format(
146 6 35600.0 5933.3 0.0 topic, fold, best_epoch + 1, best_loss))
147 self.train_history[topic]['fold_{}'.format(fold)] = {
148 6 338.0 56.3 0.0 'best_epoch': best_epoch,
149 6 332.0 55.3 0.0 'best_AUC': aucs[best_epoch],
150 6 292.0 48.7 0.0 'best_accuracy': acc_test[best_epoch],
151 6 220.0 36.7 0.0 'best_recall': recalls[best_epoch],
152 6 783.0 130.5 0.0 'best_loss': losses_test[best_epoch]}
153
154 6 245.0 40.8 0.0 self.train_wide_history[topic]['fold_{}'.format(fold)] = {'test_accuracy': acc_test,
155 6 256.0 42.7 0.0 'test_loss': losses_test,
156 6 299.0 49.8 0.0 'test_AUC': aucs,
157 6 431.0 71.8 0.0 'learning rate': lrs}