Hi
I am currently trying to train a very simple transformer based on this codebase with some minor changes:
The forward pass for my data seems to be very quick (0.01s) but backwards pass is externally slow (about 7s on average). This is very concerning, I expect the time to be 2, maybe 3 times the forward pass but not so long.
I had to change some parameters to fit the requirements:
Vocabulary size is 256
The sequence length is 1024
Batch size is set to 64
Heads and blocks are both to 4 to fit this to my GPU.
Feed forward layer size is reduce to 256
I am running this on Intel i7-7700K, 16GB of ram and Nvidia 3080Ti with on Windows.
Python is 3.10.7, pytorch 2.3, cuda 12.1
So perhaps the performance is fine and it is simply going to be that slow. But is there a way to check that the setup is working optimally? Can this train faster or is that pretty much 100% and I just should be patient?
Here is the code I run with random data to check the performance:
import torch
import torch.nn as nn
import numpy as np
from torch.nn import functional as F
from datetime import datetime
from tqdm import trange
# hyperparameters
batch_size = 64 # how many independent sequences will we process in parallel?
block_size = 1024 # what is the maximum context length for predictions?
max_iters = 500000
eval_interval = 1
learning_rate = 3e-4
eval_iters = 200
n_embd = 256
n_head = 4
n_layer = 4
dropout = 0.2
vocab_size = 256
# ------------
# data loading
def get_batch(split):
# generate a small batch of data of inputs x and targets y
data = train_data if split == 'train' else val_data
ix = torch.randint(len(data) - block_size, (batch_size,))
x = torch.stack([data[i:i+block_size] for i in ix])
y = torch.stack([data[i+1:i+block_size+1] for i in ix])
x, y = x.to(device), y.to(device)
return x, y
@torch.no_grad()
def estimate_loss():
out = {}
model.eval()
for split in ['train', 'val']:
losses = torch.zeros(eval_iters)
for k in range(eval_iters):
X, Y = get_batch(split)
logits, loss = model(X, Y)
losses[k] = loss.item()
out[split] = losses.mean()
model.train()
return out
class Tokenizer():
binRanges = []
binValues = []
def __init__(self, priceDiffRange, binsCount, scalingFactor):
oneSide = (int)(binsCount/2);
for iter in range(oneSide):
self.binRanges.append(priceDiffRange)
self.binRanges.append(-priceDiffRange)
priceDiffRange = priceDiffRange/ scalingFactor
self.binRanges.sort()
self.binValues.append(-priceDiffRange)
for iter in range(1, binsCount-1):
self.binValues.append((self.binRanges[iter] + self.binRanges[iter-1])/2)
self.binValues.append(priceDiffRange)
def tokenize(self, raw_data):
result = []
for val in raw_data:
token = -1;
for iter in range(vocab_size-1):
if(self.binRanges[iter]> val):
token = iter;
break
if token ==-1:
token = vocab_size-1
result.append(token)
return result
def detokenize(self, tokens):
result = []
for val in tokens:
if(val>vocab_size-1) :
result.append(self.binValues[vocab_size-1])
else:
result.append(self.binValues[val])
return result
class Head(nn.Module):
""" one head of self-attention """
def __init__(self, head_size):
super().__init__()
self.key = nn.Linear(n_embd, head_size, bias=False)
self.query = nn.Linear(n_embd, head_size, bias=False)
self.value = nn.Linear(n_embd, head_size, bias=False)
self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# input of size (batch, time-step, channels)
# output of size (batch, time-step, head size)
B,T,C = x.shape
k = self.key(x) # (B,T,hs)
q = self.query(x) # (B,T,hs)
# compute attention scores ("affinities")
wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 # (B, T, hs) @ (B, hs, T) -> (B, T, T)
wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
wei = F.softmax(wei, dim=-1) # (B, T, T)
wei = self.dropout(wei)
# perform the weighted aggregation of the values
v = self.value(x) # (B,T,hs)
out = wei @ v # (B, T, T) @ (B, T, hs) -> (B, T, hs)
return out
class MultiHeadAttention(nn.Module):
""" multiple heads of self-attention in parallel """
def __init__(self, num_heads, head_size):
super().__init__()
self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
self.proj = nn.Linear(head_size * num_heads, n_embd)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
out = torch.cat([h(x) for h in self.heads], dim=-1)
out = self.dropout(self.proj(out))
return out
class FeedFoward(nn.Module):
""" a simple linear layer followed by a non-linearity """
def __init__(self, n_embd):
super().__init__()
self.net = nn.Sequential(
nn.Linear(n_embd, 4 * n_embd),
nn.ReLU(),
nn.Linear(4 * n_embd, n_embd),
nn.Dropout(dropout),
)
def forward(self, x):
return self.net(x)
class Block(nn.Module):
""" Transformer block: communication followed by computation """
def __init__(self, n_embd, n_head):
# n_embd: embedding dimension, n_head: the number of heads we'd like
super().__init__()
head_size = n_embd // n_head
self.sa = MultiHeadAttention(n_head, head_size)
self.ffwd = FeedFoward(n_embd)
self.ln1 = nn.LayerNorm(n_embd)
self.ln2 = nn.LayerNorm(n_embd)
def forward(self, x):
x = x + self.sa(self.ln1(x))
x = x + self.ffwd(self.ln2(x))
return x
class GPTLanguageModel(nn.Module):
def __init__(self):
super().__init__()
# each token directly reads off the logits for the next token from a lookup table
self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
self.position_embedding_table = nn.Embedding(block_size, n_embd)
self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
self.ln_f = nn.LayerNorm(n_embd) # final layer norm
self.lm_head = nn.Linear(n_embd, vocab_size)
# better init, not covered in the original GPT video, but important, will cover in followup video
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, idx, targets=None):
B, T = idx.shape
# idx and targets are both (B,T) tensor of integers
tok_emb = self.token_embedding_table(idx) # (B,T,C)
pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
x = tok_emb + pos_emb # (B,T,C)
x = self.blocks(x) # (B,T,C)
x = self.ln_f(x) # (B,T,C)
logits = self.lm_head(x) # (B,T,vocab_size)
if targets is None:
loss = None
else:
B, T, C = logits.shape
logits = logits.view(B*T, C)
targets = targets.view(B*T)
loss = F.cross_entropy(logits, targets)
return logits, loss
def generate(self, idx, max_new_tokens):
# idx is (B, T) array of indices in the current context
for _ in range(max_new_tokens):
# crop idx to the last block_size tokens
idx_cond = idx[:, -block_size:]
# get the predictions
logits, loss = self(idx_cond)
# focus only on the last time step
logits = logits[:, -1, :] # becomes (B, C)
# apply softmax to get probabilities
probs = F.softmax(logits, dim=-1) # (B, C)
# sample from the distribution
idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
# append sampled index to the running sequence
idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
return idx
################################################################################
######## Main
################################################################################
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
torch.manual_seed(1983)
# random sample datasets (for test if we do not want to wait for tokenizer)
val_data = torch.tensor(np.random.rand(100000), dtype=torch.long)
train_data = torch.tensor(np.random.rand(100000), dtype=torch.long)
model = GPTLanguageModel()
m = model.to(device)
# print the number of parameters in the model
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')
# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
print(param_tensor, "\t", model.state_dict()[param_tensor].size())
lastLoss = float("inf")
with trange(max_iters, desc="Training Progress", unit="batches") as pbar:
for iter in pbar:
# every once in a while evaluate the loss on train and val sets
# if iter % eval_interval == 0 or iter == max_iters - 1:
# losses = estimate_loss()
# print(datetime.now().strftime("%H:%M:%S"))
# print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
# context = torch.zeros((1, 1), dtype=torch.long, device=device)
# print(decode(m.generate(context, max_new_tokens=500)[0].tolist()))
# sample a batch of data
xb, yb = get_batch('train')
# evaluate the loss
logits, loss = model(xb, yb)
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
pbar.set_postfix(loss=loss.item())
# generate from the model
# context = torch.zeros((1, 1), dtype=torch.long, device=device)
# print(decode(m.generate(context, max_new_tokens=500)[0].tolist()))
#open('more.txt', 'w').write(decode(m.generate(context, max_new_tokens=10000)[0].tolist()))