Because self-attention is capable of learning long-term relationships in sequential data, I was wondering if there was an intuitive way in PyTorch to substitute attention mechanisms in place of recurrent networks. Suppose I devise a task in which a RNN must predict the future value of a time-series input (e.g., a 5 Hz sinusoid).
torch.manual_seed(0)
class Temp(torch.nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.layer1 = torch.nn.GRU(1, hidden_size, 2, batch_first=True)
self.layer2 = torch.nn.Linear(hidden_size, 1)
self.tanh = torch.nn.Tanh()
def forward(self, x):
x = self.layer1(x)[0]
x = self.tanh(self.layer2(x))
return x
net = Temp(4).cuda()
n_pred = 1 # learn to predict one timestep forward
x = torch.sin(torch.linspace(0, 2*np.pi*5, 200+n_pred)) # make a sine wave
y = x[n_pred:].unsqueeze(0).unsqueeze(-1).cuda() # target, shape=(B, L, 1)
x = x[:-n_pred].unsqueeze(0).unsqueeze(-1).cuda() # source, shape=(B, L, 1)
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(net.parameters(), lr=1e-2)
n_iterations = 300
loss_vals = []
for _ in range(n_iterations):
y_hat = net(x)
loss = criterion(y_hat, y)
loss_vals.append(loss.item())
optimizer.zero_grad()
loss.backward()
optimizer.step()
# net is now fully trained (loss reaches about 0.002)
For the working snippet above, and in particular, by modifying the Temp model, how would one leverage nn.MultiheadAttention
or nn.Transformer
to accomplish the same task?
With this problem, the “input dimension” is simply 1 (a floating point y value on the sine wave), so I’m a bit confused how this input size can be split evenly over N heads in the case of multi-head attention; both nn.MultiheadAttention
and nn.Transformer
require the input dimension to be divisible by the number of heads. Also, when using an RNN, I can arbitrarily specify a latent dimension size of 4, but I’m not sure how that can also be done with nn.MultiheadAttention
or nn.Transformer
. Changing the values of “vdim” and “kdim” doesn’t make sense since usually people apply attention where the Query=Key=Value=the_input_sequence, so therefore “vdim” and “kdim” must also be 1.
All the examples I’ve seen on online blogs relate to neural translation, so the input sequence is made up of word embeddings; in this case, I want to understand how attention can be applied directly with real-valued time-series data. Here I set the sequence length to about 200 values, but once I get a working attention-based model, I want to try smaller or larger sequences (e.g., as low as 60 but maybe up to 300, let’s say).