How to properly apply causal mask for next char prediction in MLP

This is some noob NLP question that I’m trying to figure since I’m not an nlp expert.

Suppose we have some data X = [seq_len, batch_size] and corresponding labels Y = [seq_len, batch_size, vocab_size/num/classes].

And, now we want to train an MLP for next character prediction.

Question: How do you restrict the model from peaking at future tokens?

Using causal masking? If so where to you apply it on which layer or output?

All the things that I’ve tried with this approach of applying causal masking on the logits along the seq_len dimension make the model stagnate, i.e., not learning 52% train accuracy.

Without the mask the model works find on the data and reaches 99.99% train accuracy.

Usually the examples that I’ve seen most of them use X and the shifted version of Y = X' as labels to prevent the model from looking at future tokens but that’s not my case.

Any suggestions or MWE examples that could show how to appropriately achieve this are more than welcomed.

Thanks.

Hello,

Why Causal Masking Is Needed
Causal masking ensures that the model can only access information from past tokens and not future ones when making predictions. Without it, the model might “peek” at future tokens and lead to overly optimistic performance.

How to Apply Causal Masking in Your Case
Since you’re using an MLP (as opposed to RNNs or Transformers), the causal masking has to be carefully handled, particularly in the input data or logits.

Input Handling: Ensure that at every time step
𝑡
t, the model only sees tokens from
𝑡


𝑡
t

≤t. This means truncating or masking the input data to block future information.

Logit Masking: Apply a mask to the logits before calculating the loss. This prevents the model from optimizing based on information from future tokens.

Implementation Details
Masking the Logits (Causal Masking)
You can apply a causal mask to the logits (predictions) such that for a sequence
𝑋
X of length
𝐿
L, the logits at time step
𝑡
t can only depend on tokens from
1
,

,
𝑡
1,…,t. Here’s how:

python
Copy code
import torch
import torch.nn.functional as F

Example dimensions

seq_len, batch_size, vocab_size = 10, 2, 100
logits = torch.randn(seq_len, batch_size, vocab_size) # [seq_len, batch_size, vocab_size]

Causal mask: Upper triangular matrix

causal_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()

Apply the mask to logits (e.g., set future logits to a very negative value)

masked_logits = logits.clone()
for i in range(batch_size):
masked_logits[:, i, :] = logits[:, i, :].masked_fill(causal_mask, float(‘-inf’))

Loss calculation

targets = torch.randint(0, vocab_size, (seq_len, batch_size)) # Example targets
loss = F.cross_entropy(masked_logits.view(-1, vocab_size), targets.view(-1))
Shifted Input Labels
Alternatively, you can shift the input
𝑋
X to generate the target labels
𝑌
Y. For each
𝑋
X, the label
𝑌
Y at time
𝑡
t corresponds to
𝑋
𝑡
+
1
X
t+1

. This is simpler to implement and avoids causal masking:

python
Copy code

Shifted targets

shifted_targets = X[1:] # Skip the first token
inputs = X[:-1] # Exclude the last token
Healthcare management

Forward pass

logits = model(inputs) # Model outputs

Calculate loss

loss = F.cross_entropy(logits.view(-1, vocab_size), shifted_targets.view(-1))

Best Regards