Developing an intuition for how to replace RNNs with Self-Attention

Because self-attention is capable of learning long-term relationships in sequential data, I was wondering if there was an intuitive way in PyTorch to substitute attention mechanisms in place of recurrent networks. Suppose I devise a task in which a RNN must predict the future value of a time-series input (e.g., a 5 Hz sinusoid).

torch.manual_seed(0)
class Temp(torch.nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.layer1 = torch.nn.GRU(1, hidden_size, 2, batch_first=True)
        self.layer2 = torch.nn.Linear(hidden_size, 1)
        self.tanh = torch.nn.Tanh()
    def forward(self, x):
        x = self.layer1(x)[0]
        x = self.tanh(self.layer2(x))
        return x
    
net = Temp(4).cuda()
n_pred = 1  # learn to predict one timestep forward
x = torch.sin(torch.linspace(0, 2*np.pi*5, 200+n_pred))  # make a sine wave
y = x[n_pred:].unsqueeze(0).unsqueeze(-1).cuda()  # target, shape=(B, L, 1)
x = x[:-n_pred].unsqueeze(0).unsqueeze(-1).cuda()  # source, shape=(B, L, 1)
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(net.parameters(), lr=1e-2)

n_iterations = 300
loss_vals = []
for _ in range(n_iterations):
    y_hat = net(x)
    loss = criterion(y_hat, y)
    loss_vals.append(loss.item())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

# net is now fully trained (loss reaches about 0.002)

For the working snippet above, and in particular, by modifying the Temp model, how would one leverage nn.MultiheadAttention or nn.Transformer to accomplish the same task?

With this problem, the “input dimension” is simply 1 (a floating point y value on the sine wave), so I’m a bit confused how this input size can be split evenly over N heads in the case of multi-head attention; both nn.MultiheadAttention and nn.Transformer require the input dimension to be divisible by the number of heads. Also, when using an RNN, I can arbitrarily specify a latent dimension size of 4, but I’m not sure how that can also be done with nn.MultiheadAttention or nn.Transformer. Changing the values of “vdim” and “kdim” doesn’t make sense since usually people apply attention where the Query=Key=Value=the_input_sequence, so therefore “vdim” and “kdim” must also be 1.

All the examples I’ve seen on online blogs relate to neural translation, so the input sequence is made up of word embeddings; in this case, I want to understand how attention can be applied directly with real-valued time-series data. Here I set the sequence length to about 200 values, but once I get a working attention-based model, I want to try smaller or larger sequences (e.g., as low as 60 but maybe up to 300, let’s say).

A positional encoder actually adds a sinusoidal “bias” to the embeddings. But for all intents and purposes, embeddings are n number dims. For a sinusoidal wave, you only have 2 dims, x and y.

So in the simple case, you’re just feeding a sequence with 2 embeddings that are however many t-steps you want in length.

alpha=10
batch_size=32
seq_len=100
beta=torch.rand((batch_size, 1, 1))*6.283
t=torch.arange((1, seq_len, 1)).repeat(batch_size,1).unsqueeze(2)
emb1 = torch.cos(t/alpha+beta)
emb2 = torch.sin(t/alpha+beta)
data=torch.cat([emb1, emb2], dim=2)

Where alpha is some constant that establishes the frequency. Beta would be a random vector between 0 and 2π that offsets the beginning of the sinusoidal wave, for each sample in the batch. And then to create a sequence of, say, 100 length, you’d just have t be a range between 0 and 99.