AttnRNN: An Efficient Recurrent Neural Network Model Based on Attention Mechanism
Abstract
This paper proposes a novel recurrent neural network architecture—AttnRNN—which innovatively combines attention mechanisms with gating mechanisms to address the information retention problem in traditional RNNs for long-sequence tasks. Experiments demonstrate that the model significantly outperforms standard RNN, GRU, and LSTM models on the Adding Problem, Copy Memory task, and IMDB text classification, while maintaining low parameter count and computational complexity.
1. Model Architecture
1.1 Core Components
class EfficientAttention(nn.Module):
def __init__(self, embed_dim, num_heads):
# Project only Q-values, KV uses raw context
self.q_proj = nn.Linear(embed_dim, embed_dim)
def forward(self, query, context):
# Multi-head attention calculation (no KV projection)
...
class AttnRNNCell(nn.Module):
def __init__(self, input_size, hidden_size):
# Context construction: Combination of h_prev and x
self.context = torch.cat([h_exp, x_exp, h_exp+x_exp, h_exp*x_exp], dim=1)
# Attention mechanism
self.attn = EfficientAttention(hidden_size, num_heads)
# Single-gate design
self.update_gate = nn.Sequential(
nn.Linear(2*hidden_size, hidden_size),
nn.Sigmoid()
)
# Residual connection
self.layer_norm = nn.LayerNorm(hidden_size)
def forward(self, h_prev, x):
# Attention calculation → Gated fusion → Residual update
...
1.2 Design Innovations
- Attention Competition Mechanism: Direct competition for attention scores between historical information (h_prev) and new inputs (x) within a fixed-length context (4 tokens)
- Residual Dominance Principle: Residual connections ensure historical information always dominates
- Minimalist Gating Design: Single-gate structure (vs. GRU’s 3 gates/LSTM’s 4 gates) reduces parameters
- No KV Projection: Attention module projects only Q-values, significantly lowering computation
2. Theoretical Analysis
2.1 Parameter Comparison
Model | Parameter Formula (d:input dim, h:hidden dim) | d=64, h=128 Example |
---|---|---|
Vanilla RNN | dh + h² + 2h | 24,832 |
GRU | 3(dh + h² + h) | 74,496 |
LSTM | 4(dh + h² + h) | 99,328 |
AttnRNN | dh + 4h² + 6h | 66,432 |
Parameter Distribution:
- Data alignment projection layer: dh
- Attention module: 3h² (q_proj + out_proj)
- Gating layer: h² + h
- Normalization layer: 2h
2.2 Time Complexity
Model | Per-step Complexity | Sequence Complexity | Measured Time Ratio |
---|---|---|---|
RNN | O(B(dh + h²)) | O(TB(dh + h²)) | 1x |
GRU | O(3B(dh + h²)) | O(3TB(dh + h²)) | 1x |
AttnRNN | O(B(3h² + dh)) | O(TB(3h² + dh)) | 2.5-3x |
Note: Current implementation has Python-level loop overhead. CUDA optimization will significantly improve speed
3. Experimental Validation
3.1 Adding Problem
Seq Length | RNN | GRU | LSTM | AttnRNN |
---|---|---|---|---|
50 | 0.3411 | 0.0111 | 0.0666 | 0.0069 |
100 | 0.3392 | 0.0108 | 0.0083 | 0.0110 |
200 | 0.3244 | 0.0048 | 0.0231 | 0.0014 |
400 | 0.3458 | 0.0208 | 0.0427 | 0.0011 |
Metric: Mean Absolute Error (MAE), lower is better
3.2 Copy Memory Task
Seq Length | RNN | GRU | LSTM | AttnRNN |
---|---|---|---|---|
30 | 0.1380 | 0.6620 | 0.5480 | 0.8140 |
60 | 0.2065 | 0.4550 | 0.3955 | 0.5780 |
90 | 0.1917 | 0.3877 | 0.3413 | 0.4720 |
120 | 0.1643 | 0.3463 | 0.3008 | 0.4383 |
Metric: Copy Accuracy (Acc), higher is better
3.3 IMDB Text Classification
Model | Params | Peak Acc | Final Acc |
---|---|---|---|
RNN | 33,024 | 50.72% | 50.72% |
GRU | 99,072 | 78.40% | 77.03% |
LSTM | 132,096 | 77.44% | 77.44% |
AttnRNN | 66,176 | 80.16% | 80.16% |
Transformer | 1,186,048 | 81.44% | 79.84% |
Training Dynamics: AttnRNN reached 79.40% at epoch 3, significantly faster than baselines
4. Design Philosophy & Biological Inspiration
4.1 Core Innovation Mechanism
graph LR
A[New Input x] --> B(Attention Competition)
C[History State h_prev] --> B
B --> D{Information Importance Evaluation}
D -->|Important| E[Reinforce Preservation]
D -->|Unimportant| F[Selective Ignorance]
E --> G[Gated Fusion]
F --> G
G --> H[Residual Connection]
H --> I[Updated State h_new]
4.2 Biological Correspondence
- Attention Competition → Synaptic information filtering in neurons
- Gated Fusion → Memory consolidation in hippocampus
- Residual Connection → Default activity patterns in neural circuits
“In human cognition, new information must compete with existing memories to enter long-term storage” - Cognitive Psychology Principle
5. Performance Advantage Analysis
5.1 Long-Sequence Processing Capability
- 0.0011 error on 400-step Adding Problem (1/20 of GRU)
- 43.8% accuracy on 120-step Copy Task (26% higher than GRU)
- Attributed to: Attention mechanism directly filters noisy inputs
5.2 Convergence Speed Advantage
Task | AttnRNN Reaches Optimum | Baseline Lag |
---|---|---|
Adding Problem | 8-10 epochs | GRU: 15+ epochs |
Copy Memory Task | 5-7 epochs | LSTM: 12+ epochs |
IMDB Classification | 3 epochs (79.4%) | GRU: 8 epochs |
6. Conclusion & Future Work
6.1 Key Contributions
- Proposed Attention Competition Mechanism to solve RNN retention challenges
- Designed fixed-length context to avoid O(n²) computation
- Achieved optimal parameter/performance balance (11% fewer parameters than GRU, 2-30% performance gain)
6.2 Future Directions
- Hardware-Level Optimization: CUDA kernel implementation to eliminate Python loop bottleneck
- Hierarchical Extension: Multi-layer AttnRNN stacks for complex sequences
- Transformer Integration: Using AttnRNN as Transformer decoder units
“Sequence processing should progress stepwise like human thought, not through undifferentiated global processing” - Design Philosophy
Code Repository: https://github.com/liluoyi666/Important_Memories-AttnRNN
Key translation decisions made:
- Technical terms standardized (e.g., 门控机制 → gating mechanism)
- Mathematical notations preserved (O(n²), MAE, Acc)
- Biological terms accurately translated (海马体 → hippocampus)
- Emphasis markers retained (bold for key terms)
- Table formats and code blocks unchanged
- Mermaid diagram labels translated while maintaining structure
- Academic conventions maintained (e.g., “we propose” omitted in abstract per paper conventions)
- URL and GitHub link preserved exactly
- Metric units kept consistent (% for accuracy)
- Date format adapted to international standard (July 2, 2025)