Hi,
I tried to use SDPA to calculate Attention part fast. I used time.time()
to record the time of training one batchsize (I used gradient accumlation). Here is how I record the time.
for epoch in range(config.epochs):
if os.path.exists(config.model_save_path) and config.resume:
losses = checkpoint['loss']
else:
losses = 0
t = time.time()
for idx, batch in enumerate(train_iter):
if config.ddp:
model.require_backward_grad_sync = (idx % config.gradient_accumulation_steps == config.gradient_accumulation_steps - 1)
# batch_size seq
b_token_ids = batch['input_ids'].t().to(config.device)
b_segs = batch['token_type_ids'].t().to(config.device)
b_mask = batch['attention_mask'].to(config.device)
b_mlm_label = batch['labels'].t().to(config.device)
with amp.autocast(device_type='cuda',dtype=torch.bfloat16): # for forward function! bfloat16
# with autocast(enabled=False):
loss, mlm_logits = model(input_ids=b_token_ids,
attention_mask=b_mask,
token_type_ids=b_segs,
masked_lm_labels=b_mlm_label,
next_sentence_labels=None)
loss = loss / config.gradient_accumulation_steps # scale the
scaler.scale(loss).backward()
if ((idx % config.gradient_accumulation_steps) == (config.gradient_accumulation_steps - 1)):
if config.grad_clip != 0.0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
scheduler.step()
print(f'Time for one step(one batchsize):{time.time() - t}')
t = time.time()
losses += (loss.item() * config.gradient_accumulation_steps)
mlm_acc, _, _ = accuracy(mlm_logits, b_mlm_label, config.pad_index)
....
And I change my attention to SDPA:
This is my version:
.....
if attn_mask is not None:
attn_output_weights += attn_mask # [batch_size * num_heads, tgt_len, src_len]
.....
attn_output_weights = F.softmax(attn_output_weights, dim=-1)
attn_output_weights = F.dropout(attn_output_weights, p=self.dropout, training=True)
attn_output = self.atten_qkv(v, attn_output_weights)
This is SDPA version:
....
attn_output = F.scaled_dot_product_attention(q, k, v-self.atten_qkv.t_min_1, dropout_p=self.dropout)
....
Things are werid because these two version comsume almost the same time. I am not sure whether I use F.scaled_dot_product_attention
in a right way. I tested on FP32, FP16, BF16, they all comsume almost the same time.