I am experimenting a lot with transformers lately and I encountered a weird problem I could not solve. Imagine the following “task”:
Input: A sequence of n times the number one, followed by m times the number zero where n + m is equal to a fixed number, e.g., 10. The output should be the sum of each digit.
Example: Input: 1111000000. Output: 4
This is an easy task and clearly no Machine Learning in general is required to solve this task. But still, I wanted to investigate how easy/fast transformer (only the encoder) models are able to learn this kind of function. For an MLP it seems like it is no problem at all, but I cannot even get close to the same performance using transformers. Here is an example implementation:
import random
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
class BlockData(Dataset):
def __init__(self):
super(BlockData, self).__init__()
self.amount = 100000
self.max_length = 10
def __len__(self):
return self.amount
def __getitem__(self, index):
x = random.randint(1, self.max_length)
item = torch.zeros(self.max_length)
item[:x] = 1
return item.long(), torch.tensor(x).long() - 1
class Model2(nn.Module):
def __init__(self):
super(Model2, self).__init__()
hidden = 32
encoder_layer = nn.TransformerEncoderLayer(d_model=hidden, nhead=4, batch_first=True)
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=12)
self.embedding = nn.Embedding(2, 32)
self.head = nn.Sequential(
nn.Linear(hidden * 10, 64),
nn.ReLU(),
nn.Linear(64, 16),
nn.ReLU(),
nn.Linear(16, 20),
)
def forward(self, x):
x = self.embedding(x)
x = self.transformer_encoder(x)
return self.head(x.reshape(x.shape[0], -1))
def main():
model = Model2()
# model = Model3()
optimizer = optim.Adam(model.parameters(), lr=.001)
data = BlockData()
data_laoder = DataLoader(data, batch_size=32)
loss_f = nn.CrossEntropyLoss()
test = 0
while True:
with tqdm(data_laoder) as pbar:
for x, y in pbar:
y_pred = model(x).squeeze(1)
print(y_pred.argmax(dim=1)[:8])
loss = loss_f(y_pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
pbar.set_description(f"{loss.item()}")
test += 1
if test % 1000 == 0:
print(y_pred)
if __name__ == "__main__":
main()
I wanted the model to overfit, but the loss (hyperparams. do not matter) stuck at 2.4. Are transformers even suitable to solve these kind of tasks? If so, what would be an intuitive explanation why? If they should be able to, are you able to get a transformer model to converge?