Optimizer.step() the slowest

Hi!
Could you tell me if the Optimizer.step() should be really the slowest in the train process?
If so - are there any techniques to improve the performance?

I profiled my code with line-profiler, train is performed on CPU.

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    43                                               @profile
    44                                               def train_models(self, train_df: pd.DataFrame, lr=0.001, train_batch=32,
    45                                                                pretrained_weights='bert-base-uncased',
    46                                                                val_batch=256, patience=100, epochs=1000, verbosity=100, shuffle=True):
    47
    48         1         26.0     26.0      0.0          if shuffle:
    49         1      24201.0  24201.0      0.0              train_df = train_df.sample(frac=1).reset_index(drop=True)
    50
    51         1   34376699.0 34376699.0      2.7          embedder = BERTModel(pretrained_weights) # load BERT embedder
    52         1       5456.0   5456.0      0.0          print ('BERT loaded successfully!')
    53
    54         3        197.0     65.7      0.0          for topic in self.topics:
    55         2       7287.0   3643.5      0.0              print("Preparing single model for '{}' topic".format(topic))
    56
    57                                                       ######################################
    58         2      66451.0  33225.5      0.0              X = train_df[train_df['Topic'] == topic]['Sentence']
    59         2      51771.0  25885.5      0.0              Xs = train_df[train_df['Topic'] == topic]['Section']
    60         2      47551.0  23775.5      0.0              labels = torch.tensor(train_df[train_df['Topic'] == topic]['IsCorrect'].values, dtype=torch.float)
    61         2       6627.0   3313.5      0.0              print('Starting generate BERT embeddings...')
    62         2  114700598.0 57350299.0      9.1              X_embs = embedder.predict(X, caching=True)
    63         2  122060608.0 61030304.0      9.6              Xs_embs = embedder.predict(Xs, caching=True)
    64         2        435.0    217.5      0.0              data = Data(X_embs, Xs_embs, labels)
    65         2      23461.0  11730.5      0.0              print('Starting training...')
    66                                                       ##pos_weight calculation
    67         2       2903.0   1451.5      0.0              pos = torch.sum(labels)
    68         2        841.0    420.5      0.0              lbls = torch.tensor(labels.shape, dtype=torch.float)
    69         2       1659.0    829.5      0.0              pos_weight = (lbls - pos) / pos
    70
    71         2        860.0    430.0      0.0              skf = StratifiedKFold(n_splits=self.folds, shuffle=True, random_state=2020)
    72         2         58.0     29.0      0.0              fold = -1
    73         8     120343.0  15042.9      0.0              for train_idx, valid_idx in skf.split(X=X_embs, y=labels):
    74         6        250.0     41.7      0.0                  fold += 1
    75         6      22295.0   3715.8      0.0                  print('Fold {}'.format(fold))
    76
    77         6      13344.0   2224.0      0.0                  train_loader = DataLoader(Subset(data, train_idx), batch_size=train_batch)
    78         6      23836.0   3972.7      0.0                  validation_loader = DataLoader(Subset(data, valid_idx), batch_size=val_batch)
    79
    80         6      12078.0   2013.0      0.0                  criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    81         6      40771.0   6795.2      0.0                  optimizer = torch.optim.Adam(self.models['{}_{}'.format(topic, fold)].parameters(), lr=lr, weight_decay=0.001)
    82         6      13631.0   2271.8      0.0                  scheduler = CyclicLR(optimizer, base_lr=0.0001, max_lr=0.001, cycle_momentum=False)
    83         6        465.0     77.5      0.0                  acc_train = []
    84         6        236.0     39.3      0.0                  acc_test = []
    85         6        207.0     34.5      0.0                  losses_test = []
    86                                                           # best_auc = 0
    87         6        239.0     39.8      0.0                  best_loss = 99
    88         6        205.0     34.2      0.0                  aucs = []
    89         6        293.0     48.8      0.0                  recalls = []
    90         6        185.0     30.8      0.0                  lrs = []
    91         6        189.0     31.5      0.0                  best_epoch = 0
    92
    93        66       2834.0     42.9      0.0                  for epoch in range(epochs):
    94        60       2519.0     42.0      0.0                      if epoch % verbosity == 1:
    95        30     321601.0  10720.0      0.0                          print('Epoch ', epoch, 'Test Loss ', losses_test[-1])
    96        60     328282.0   5471.4      0.0                      self.models['{}_{}'.format(topic, fold)].train()
    97        60       7857.0    130.9      0.0                      y_true = []
    98        60       3559.0     59.3      0.0                      y_score = []
    99      3510   41746453.0  11893.6      3.3                      for x, xs, y in train_loader:
   100      3450   27656656.0   8016.4      2.2                          optimizer.zero_grad()
   101      3450  153222734.0  44412.4     12.1                          yhat = self.models['{}_{}'.format(topic, fold)](x, xs)
   102      3450    9724355.0   2818.7      0.8                          loss = criterion(yhat, y)
   103      3450  202206004.0  58610.4     16.0                          loss.backward()
   104      3450  508723569.0 147456.1     40.2                          optimizer.step()
   105      3450    1720162.0    498.6      0.1                          yhat = torch.sigmoid(yhat)
   106      3450     203794.0     59.1      0.0                          y_true.append(y)
   107      3450     341063.0     98.9      0.0                          y_score.append(yhat.detach())
   108      3450    4140073.0   1200.0      0.3                          scheduler.step()
   109        60      70071.0   1167.8      0.0                      y_true = torch.cat(y_true)
   110        60      51460.0    857.7      0.0                      y_score = torch.cat(y_score)
   111        60     751454.0  12524.2      0.1                      acc_train.append(accuracy_score(y_true=y_true.numpy(), y_pred=np.rint(y_score.numpy())))
   112
   113        60     320556.0   5342.6      0.0                      self.models['{}_{}'.format(topic, fold)].eval()
   114        60       7136.0    118.9      0.0                      y_true = []
   115        60       3611.0     60.2      0.0                      y_score = []
   116       300   17642286.0  58807.6      1.4                      for x, xs, y in validation_loader:
   117       240      39880.0    166.2      0.0                          with torch.no_grad():
   118       240   18828388.0  78451.6      1.5                              y_pred = self.models['{}_{}'.format(topic, fold)](x, xs)
   119       240      18126.0     75.5      0.0                          y_true.append(y)
   120       240      11873.0     49.5      0.0                          y_score.append(y_pred)
   121        60      20548.0    342.5      0.0                      y_true = torch.cat(y_true)
   122        60      14666.0    244.4      0.0                      y_score = torch.cat(y_score)
   123        60     130068.0   2167.8      0.0                      val_loss = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
   124        60     161065.0   2684.4      0.0                      losses_test.append(val_loss(y_score, y_true).detach().item())
   125        60      22926.0    382.1      0.0                      y_score = torch.sigmoid(y_score)
   126        60    2390831.0  39847.2      0.2                      aucs.append(roc_auc_score(y_true=y_true.numpy(), y_score=y_score.numpy()))
   127        60     546383.0   9106.4      0.0                      acc_test.append(accuracy_score(y_true=y_true.numpy(), y_pred=np.rint(y_score.numpy())))
   128        60    2220304.0  37005.1      0.2                      recalls.append(recall_score(y_true=y_true.numpy(), y_pred=np.rint(y_score.numpy())))
   129        60      56561.0    942.7      0.0                      lrs.append(optimizer.state_dict()['param_groups'][0]['lr'])
   130
   131
   132
   133
   134                                                               # if aucs[-1] > best_auc:
   135                                                               #     best_epoch = epoch
   136                                                               #     best_auc = aucs[-1]
   137
   138        60       3338.0     55.6      0.0                      if losses_test[-1] < best_loss:
   139        43       1715.0     39.9      0.0                          best_epoch = epoch
   140        43       1814.0     42.2      0.0                          best_loss = losses_test[-1]
   141
   142        60       2775.0     46.2      0.0                      if epoch - best_epoch > patience:
   143                                                                   break
   144
   145         6        281.0     46.8      0.0                  print('Model for topic {} branch {} successfully trained for {} epochs with best val_loss={}'.format(
   146         6      35600.0   5933.3      0.0                      topic, fold, best_epoch + 1, best_loss))
   147                                                           self.train_history[topic]['fold_{}'.format(fold)] = {
   148         6        338.0     56.3      0.0                                                                       'best_epoch': best_epoch,
   149         6        332.0     55.3      0.0                                                                       'best_AUC': aucs[best_epoch],
   150         6        292.0     48.7      0.0                                                                       'best_accuracy': acc_test[best_epoch],
   151         6        220.0     36.7      0.0                                                                       'best_recall': recalls[best_epoch],
   152         6        783.0    130.5      0.0                                                                       'best_loss': losses_test[best_epoch]}
   153
   154         6        245.0     40.8      0.0                  self.train_wide_history[topic]['fold_{}'.format(fold)] = {'test_accuracy': acc_test,
   155         6        256.0     42.7      0.0                                                                            'test_loss': losses_test,
   156         6        299.0     49.8      0.0                                                                            'test_AUC': aucs,
   157         6        431.0     71.8      0.0                                                                            'learning rate': lrs}

It will depend on the size of your model as well and the optimizer.
But Adam does a few bookeeping so if your model is very small I guess it’s possible.

I wouldn’t say it is very small…
It has 2 topics x 2 branches (for text & section) x 3 folds (trained by cross-validation), every branch has the same sequence of linear layers with dropouts:
__init__(self, emb_in=768, h0=512, h1=256, h2=128, h3=64, out=32)
then concatenation and averaging…
In total 6.8M parameters

Hi,

Thanks for the details.
I think these layers are relatively small and will run very quickly especially because the input/output size are fairly small.
But Adam has to handle all the 6.8M parameters so it’s not surprising to me to see it being so high in runtime.

This is unusual for neural nets because things like convolution or recurrent structure make the forward computation much more expensive compared to the number of parameters.

Hello Alban!

Just for my education:

Would I be correct in deducing that in this situation a single SGD
optimizer step would run much faster than the Adam step?

(I’m leaving aside the question of whether Adam is a better
optimizer that would train the network faster.)

Thanks.

K. Frank

In this case, if the above analysis that Adam leads to a lot of computation compared to the forward, then I would expect so yes. Note that momentum or weight decay would slow it down of course.

@siarblack if you can try that, that would be a good way to confirm our hypothesis :slight_smile:

I’ve already tried SGD - it was 3-4 times faster, but convergence was much slower.
And I’ve noticed, that my train time increased when I have added weight decay.

But what I am wondering - why the same model architecture on the same data with same parameters, but on TF2 trains much faster (3-4 times) than my on PyTorch.

I’ve already tried SGD - it was 3-4 times faster, but convergence was much slower.

Sounds good!

But what I am wondering - why the same model architecture on the same data with same parameters, but on TF2 trains much faster (3-4 times) than my on PyTorch.

Do you mean the overall runtime?
Or is the model more or less the same and the optimizer is mostly responsible for this?
It is possible that some perf improvements can be done on our Adam implementation.

I am talking about overall time first of all. It is easy for me to profile Torch train procedure and understand duration of every step, but in TF there is just model.fit(epochs=100) in similar train method, which consumes ~3 times less time than my Torch model. Definitely I can profile .fit method, but as I see - optimizer consumes 40% and it looks too much from the first sight.

Indeed, tensorflow has fused kernels for optimizers, unlike pytorch, so step() is a few times slower than it could be (for adaptive optimizers, as they contain chains of math ops).

1 Like

Some people are working towards that: https://github.com/pytorch/pytorch/pull/41554 Not sure how much it will help in this particular case though.

Nice. Funny thing is, it should be possible to write such optimizers just with JIT, assuming that jit fuser is reliable nowadays…