AttnRNN: An Efficient Recurrent Neural Network Model Based on Attention

AttnRNN: An Efficient Recurrent Neural Network Model Based on Attention Mechanism

Abstract

This paper proposes a novel recurrent neural network architecture—AttnRNN—which innovatively combines attention mechanisms with gating mechanisms to address the information retention problem in traditional RNNs for long-sequence tasks. Experiments demonstrate that the model significantly outperforms standard RNN, GRU, and LSTM models on the Adding Problem, Copy Memory task, and IMDB text classification, while maintaining low parameter count and computational complexity.

1. Model Architecture

1.1 Core Components

class EfficientAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        # Project only Q-values, KV uses raw context
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        
    def forward(self, query, context):
        # Multi-head attention calculation (no KV projection)
        ...

class AttnRNNCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        # Context construction: Combination of h_prev and x
        self.context = torch.cat([h_exp, x_exp, h_exp+x_exp, h_exp*x_exp], dim=1)
        
        # Attention mechanism
        self.attn = EfficientAttention(hidden_size, num_heads)
        
        # Single-gate design
        self.update_gate = nn.Sequential(
            nn.Linear(2*hidden_size, hidden_size),
            nn.Sigmoid()
        )
        
        # Residual connection
        self.layer_norm = nn.LayerNorm(hidden_size)
        
    def forward(self, h_prev, x):
        # Attention calculation → Gated fusion → Residual update
        ...

1.2 Design Innovations

  • Attention Competition Mechanism: Direct competition for attention scores between historical information (h_prev) and new inputs (x) within a fixed-length context (4 tokens)
  • Residual Dominance Principle: Residual connections ensure historical information always dominates
  • Minimalist Gating Design: Single-gate structure (vs. GRU’s 3 gates/LSTM’s 4 gates) reduces parameters
  • No KV Projection: Attention module projects only Q-values, significantly lowering computation

2. Theoretical Analysis

2.1 Parameter Comparison

Model Parameter Formula (d:input dim, h:hidden dim) d=64, h=128 Example
Vanilla RNN dh + h² + 2h 24,832
GRU 3(dh + h² + h) 74,496
LSTM 4(dh + h² + h) 99,328
AttnRNN dh + 4h² + 6h 66,432

Parameter Distribution:

  1. Data alignment projection layer: dh
  2. Attention module: 3h² (q_proj + out_proj)
  3. Gating layer: h² + h
  4. Normalization layer: 2h

2.2 Time Complexity

Model Per-step Complexity Sequence Complexity Measured Time Ratio
RNN O(B(dh + h²)) O(TB(dh + h²)) 1x
GRU O(3B(dh + h²)) O(3TB(dh + h²)) 1x
AttnRNN O(B(3h² + dh)) O(TB(3h² + dh)) 2.5-3x

Note: Current implementation has Python-level loop overhead. CUDA optimization will significantly improve speed

3. Experimental Validation

3.1 Adding Problem

Seq Length RNN GRU LSTM AttnRNN
50 0.3411 0.0111 0.0666 0.0069
100 0.3392 0.0108 0.0083 0.0110
200 0.3244 0.0048 0.0231 0.0014
400 0.3458 0.0208 0.0427 0.0011

Metric: Mean Absolute Error (MAE), lower is better

3.2 Copy Memory Task

Seq Length RNN GRU LSTM AttnRNN
30 0.1380 0.6620 0.5480 0.8140
60 0.2065 0.4550 0.3955 0.5780
90 0.1917 0.3877 0.3413 0.4720
120 0.1643 0.3463 0.3008 0.4383

Metric: Copy Accuracy (Acc), higher is better

3.3 IMDB Text Classification

Model Params Peak Acc Final Acc
RNN 33,024 50.72% 50.72%
GRU 99,072 78.40% 77.03%
LSTM 132,096 77.44% 77.44%
AttnRNN 66,176 80.16% 80.16%
Transformer 1,186,048 81.44% 79.84%

Training Dynamics: AttnRNN reached 79.40% at epoch 3, significantly faster than baselines

4. Design Philosophy & Biological Inspiration

4.1 Core Innovation Mechanism

graph LR
A[New Input x] --> B(Attention Competition)
C[History State h_prev] --> B
B --> D{Information Importance Evaluation}
D -->|Important| E[Reinforce Preservation]
D -->|Unimportant| F[Selective Ignorance]
E --> G[Gated Fusion]
F --> G
G --> H[Residual Connection]
H --> I[Updated State h_new]

4.2 Biological Correspondence

  1. Attention Competition → Synaptic information filtering in neurons
  2. Gated Fusion → Memory consolidation in hippocampus
  3. Residual Connection → Default activity patterns in neural circuits

“In human cognition, new information must compete with existing memories to enter long-term storage” - Cognitive Psychology Principle

5. Performance Advantage Analysis

5.1 Long-Sequence Processing Capability

  • 0.0011 error on 400-step Adding Problem (1/20 of GRU)
  • 43.8% accuracy on 120-step Copy Task (26% higher than GRU)
  • Attributed to: Attention mechanism directly filters noisy inputs

5.2 Convergence Speed Advantage

Task AttnRNN Reaches Optimum Baseline Lag
Adding Problem 8-10 epochs GRU: 15+ epochs
Copy Memory Task 5-7 epochs LSTM: 12+ epochs
IMDB Classification 3 epochs (79.4%) GRU: 8 epochs

6. Conclusion & Future Work

6.1 Key Contributions

  1. Proposed Attention Competition Mechanism to solve RNN retention challenges
  2. Designed fixed-length context to avoid O(n²) computation
  3. Achieved optimal parameter/performance balance (11% fewer parameters than GRU, 2-30% performance gain)

6.2 Future Directions

  1. Hardware-Level Optimization: CUDA kernel implementation to eliminate Python loop bottleneck
  2. Hierarchical Extension: Multi-layer AttnRNN stacks for complex sequences
  3. Transformer Integration: Using AttnRNN as Transformer decoder units

“Sequence processing should progress stepwise like human thought, not through undifferentiated global processing” - Design Philosophy

Code Repository: https://github.com/liluoyi666/Important_Memories-AttnRNN

Key translation decisions made:

  1. Technical terms standardized (e.g., 门控机制 → gating mechanism)
  2. Mathematical notations preserved (O(n²), MAE, Acc)
  3. Biological terms accurately translated (海马体 → hippocampus)
  4. Emphasis markers retained (bold for key terms)
  5. Table formats and code blocks unchanged
  6. Mermaid diagram labels translated while maintaining structure
  7. Academic conventions maintained (e.g., “we propose” omitted in abstract per paper conventions)
  8. URL and GitHub link preserved exactly
  9. Metric units kept consistent (% for accuracy)
  10. Date format adapted to international standard (July 2, 2025)