LSTM vs Transformer: The Evolution of Sequence Learning in AI
LSTMs ruled NLP for a decade. Transformers replaced them in three years. This is the technical story of why — and what each architecture actually computes.
Get more content like this on Telegram!
Daily AI tips, notes & resources — free
LSTM vs Transformer: The Evolution of Sequence Learning in AI
The story most people tell about transformers is wrong. It is not a story of a new idea replacing an old one. It is a story of a long-standing bottleneck — the inability to parallelize recurrent computation — finally being solved in a way that also happened to scale better.
LSTMs are not naive. They were the best tool available for sequence modeling from 1997 to roughly 2018. Understanding why they work, and where they break, is what makes the transformer's design choices legible.
The Fundamental Problem of Sequences
Text, audio, time series, and code are sequences. Their meaning depends on order. "The cat sat on the mat" means something different from "the mat sat on the cat."
Fully-connected networks are permutation-invariant — shuffle the inputs and you get the same output. CNNs handle local order through convolution but cannot capture long-range dependencies across hundreds of tokens. Sequences need architectures that explicitly model temporal or positional relationships.
The key challenges:
- Variable length inputs: sentences have different numbers of words
- Long-range dependencies: the pronoun "it" in a sentence might refer to a noun from 20 words earlier
- Sequential processing vs. parallelism: processing token by token is sequential; training data efficiently requires parallelism
LSTMs solve problems 1 and 2 brilliantly. They fail at problem 3, and that failure is what transformers exploit.
How LSTMs Actually Work
Hochreiter and Schmidhuber published LSTMs in 1997, solving the vanishing gradient problem in vanilla RNNs. The core innovation: a separate memory cell c_t that information can write to, read from, and erase — controlled by learned gates.
At each timestep t, an LSTM receives the current input x_t and the previous hidden state h_{t-1}, then computes:
f_t = σ(W_f · [h_{t-1}, x_t] + b_f) # forget gate: what to erase
i_t = σ(W_i · [h_{t-1}, x_t] + b_i) # input gate: what to write
g_t = tanh(W_g · [h_{t-1}, x_t] + b_g) # candidate values
o_t = σ(W_o · [h_{t-1}, x_t] + b_o) # output gate: what to expose
c_t = f_t ⊙ c_{t-1} + i_t ⊙ g_t # update cell state
h_t = o_t ⊙ tanh(c_t) # compute hidden state
Where σ is sigmoid (outputs near 0 or 1) and ⊙ is element-wise multiplication.
The genius is the cell state c_t. When the forget gate is near 1 and the input gate is near 0, the cell state flows unchanged: c_t ≈ c_{t-1}. Information can persist across hundreds of timesteps without vanishing. When the forget gate is near 0, the cell erases old information. The gates themselves are learned — the network figures out what to remember.
import torch
import torch.nn as nn
class LSTMClassifier(nn.Module):
"""
LSTM for text classification.
Input: tokenized sequences → Output: class probabilities
"""
def __init__(self, vocab_size, embed_dim, hidden_dim, num_classes, num_layers=2):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
# LSTM processes the sequence token by token
self.lstm = nn.LSTM(
input_size=embed_dim,
hidden_size=hidden_dim,
num_layers=num_layers,
batch_first=True, # input shape: [batch, seq_len, features]
dropout=0.3,
bidirectional=True # process forward and backward
)
# bidirectional doubles the hidden size
self.classifier = nn.Sequential(
nn.Dropout(0.3),
nn.Linear(hidden_dim * 2, num_classes)
)
def forward(self, x):
# x shape: [batch, seq_len]
embedded = self.embedding(x) # [batch, seq_len, embed_dim]
# output: [batch, seq_len, hidden*2]
# h_n: [num_layers*2, batch, hidden] — final hidden states
output, (h_n, c_n) = self.lstm(embedded)
# Use the last hidden state from both directions
# h_n[-2]: last forward layer, h_n[-1]: last backward layer
hidden = torch.cat([h_n[-2], h_n[-1]], dim=1)
return self.classifier(hidden)
The Bottleneck: Sequential Computation
Here is the problem. To compute h_t, you need h_{t-1}. To compute h_{t-1}, you need h_{t-2}. The entire sequence must be processed left-to-right.
On GPU hardware — which is built for parallel computation — this sequential dependency is devastating. A sequence of 512 tokens requires 512 sequential LSTM steps. All 512 cannot execute in parallel because each depends on the previous.
The practical consequence: training a large LSTM on a large dataset is painfully slow. The parallelism that makes CNNs train fast on image data is not available for recurrent models on sequence data.
There is a second problem: the information bottleneck. Everything the network needs to remember must be compressed into the hidden state h_t. For a long document, by the time the LSTM reaches the end, information from the beginning must have been compressed through hundreds of sequential transformations. The forget gate cannot perfectly preserve everything — information fades.
Attention: The Key Idea
Before the full transformer, attention mechanisms were added on top of encoder-decoder LSTMs for machine translation (Bahdanau et al., 2015). Instead of compressing the entire source sentence into a single vector, the decoder could look at all encoder hidden states:
# Bahdanau-style attention
def attention(decoder_hidden, encoder_outputs):
# decoder_hidden: [1, hidden]
# encoder_outputs: [seq_len, hidden]
# Compute alignment scores
scores = torch.tanh(
linear_decoder(decoder_hidden) +
linear_encoder(encoder_outputs)
)
scores = linear_v(scores).squeeze(-1) # [seq_len]
# Normalize to weights summing to 1
weights = torch.softmax(scores, dim=0) # [seq_len]
# Weighted sum of encoder outputs
context = (weights.unsqueeze(-1) * encoder_outputs).sum(0)
return context, weights
This improved translation quality substantially, especially for long sentences. The attention weights are also interpretable — you can visualize which source words the decoder attends to when generating each target word.
The natural question: if attention lets you access any position directly, why keep the LSTM at all?
"Attention Is All You Need"
Vaswani et al. (2017) answered that question. The original transformer paper showed that attention alone, without any recurrence, could match or beat the LSTM+attention state-of-the-art for machine translation. More importantly, it could be fully parallelized.
The key mechanism is scaled dot-product self-attention:
import torch
import torch.nn.functional as F
import math
def scaled_dot_product_attention(Q, K, V, mask=None):
"""
Q: [batch, heads, seq_len, d_k] — queries
K: [batch, heads, seq_len, d_k] — keys
V: [batch, heads, seq_len, d_v] — values
"""
d_k = Q.size(-1)
# Compute attention scores: how compatible is each query with each key?
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
# scores shape: [batch, heads, seq_len, seq_len]
# For causal (decoder) attention: mask future positions
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
# Convert to probabilities
weights = F.softmax(scores, dim=-1)
# Weighted sum of values
output = torch.matmul(weights, V)
# output shape: [batch, heads, seq_len, d_v]
return output, weights
In self-attention, the same sequence produces queries, keys, and values. Every token can attend to every other token in one shot — no sequential dependency. The entire operation is a batch of matrix multiplications, which GPUs execute in parallel.
Multi-Head Attention
A single attention head learns one type of relationship between tokens. Multi-head attention runs several heads in parallel, each with its own Q, K, V projections:
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
assert d_model % num_heads == 0
self.d_k = d_model // num_heads
self.num_heads = num_heads
# Single projection matrices for all heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def split_heads(self, x, batch_size):
# [batch, seq, d_model] → [batch, heads, seq, d_k]
x = x.view(batch_size, -1, self.num_heads, self.d_k)
return x.transpose(1, 2)
def forward(self, x, mask=None):
batch_size, seq_len, _ = x.shape
Q = self.split_heads(self.W_q(x), batch_size)
K = self.split_heads(self.W_k(x), batch_size)
V = self.split_heads(self.W_v(x), batch_size)
attn_output, _ = scaled_dot_product_attention(Q, K, V, mask)
# Recombine heads: [batch, heads, seq, d_k] → [batch, seq, d_model]
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, seq_len, -1)
return self.W_o(attn_output)
Different heads learn different relationships. In a trained model, one head might capture syntactic dependencies (subject-verb agreement), another semantic similarity, another coreference (pronouns pointing to nouns). The heads operate independently and their outputs are concatenated.
Positional Encoding: Giving Transformers a Sense of Order
Self-attention is permutation-equivariant — shuffle the tokens and the outputs shuffle correspondingly. A pure attention model has no sense of word order.
Positional encodings inject order information:
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000, dropout=0.1):
super().__init__()
self.dropout = nn.Dropout(dropout)
# Precompute sinusoidal encodings
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1).float()
div_term = torch.exp(
torch.arange(0, d_model, 2).float() *
(-math.log(10000.0) / d_model)
)
pe[:, 0::2] = torch.sin(position * div_term) # even dims: sin
pe[:, 1::2] = torch.cos(position * div_term) # odd dims: cos
self.register_buffer('pe', pe.unsqueeze(0)) # [1, max_len, d_model]
def forward(self, x):
# x: [batch, seq_len, d_model]
x = x + self.pe[:, :x.size(1)]
return self.dropout(x)
The sinusoidal encoding is elegant: the relative position of two tokens can be expressed as a linear function of the encodings, which helps the model generalize to sequence lengths not seen during training. Modern models (GPT-NeoX, LLaMA) use learned relative position encodings like RoPE instead.
Architecture Comparison
Head-to-Head Comparison
| Property | LSTM | Transformer |
|---|---|---|
| Sequential computation | Yes (token by token) | No (all tokens in parallel) |
| Long-range dependency | Limited by hidden state size | Direct attention across all positions |
| Memory scaling | O(n) — fixed hidden state | O(n²) — attention matrix |
| Training speed | Slow (sequential) | Fast (parallel on GPU) |
| Inference on long sequences | Fast (streaming) | Slow (grows with context) |
| Position information | Implicit (order of processing) | Explicit (positional encoding) |
| Interpretability | Hidden state is opaque | Attention weights are visualizable |
| Scale | Difficult to scale past ~500M params | Scales to hundreds of billions |
NLP Benchmark Evolution
These results on GLUE (General Language Understanding Evaluation) show how quickly transformers surpassed LSTMs:
| Model | Architecture | GLUE Score | Year |
|---|---|---|---|
| ELMo | Bidirectional LSTM | 68.7 | 2018 |
| GPT | Transformer decoder | 72.8 | 2018 |
| BERT-base | Transformer encoder | 79.6 | 2018 |
| BERT-large | Transformer encoder | 80.5 | 2018 |
| RoBERTa-large | Transformer encoder | 88.1 | 2019 |
| DeBERTa-XXL | Transformer encoder | 91.4 | 2021 |
| GPT-3 175B (few-shot) | Transformer decoder | ~88.0 | 2020 |
The jump from ELMo (the best LSTM-based model) to BERT happened in the same year, 2018. Within 12 months, BERT had exceeded the LSTM frontier by a margin that would have taken years to close incrementally.
When to Use Each Architecture Today
Use an LSTM when:
- You need streaming inference — process one token at a time as it arrives
- Deploying on hardware with strict memory limits (microcontrollers, edge devices)
- Working with very long sequences where O(n²) attention is computationally infeasible
- The dataset is small and you lack the compute to fine-tune a large transformer
Use a Transformer when:
- Building any modern NLP application — classification, generation, translation, QA
- You want to fine-tune a pre-trained model (BERT, T5, GPT family)
- Working with multimodal data — vision-language models are all transformer-based
- Scale matters — transformers scale better with both data and compute
For practical NLP projects, the advice is simple: start with a pre-trained transformer (BERT for classification, GPT-2 for generation). LSTMs are worth understanding for the theory, but in most production systems they have been replaced.
State Space Models: The Next Chapter
An interesting development in 2023-2024 is the emergence of state space models (SSMs), particularly Mamba (Gu & Dao, 2023). SSMs have LSTM-like sequential structure but with careful design that enables fast parallel training. They show competitive performance to transformers at lower computational cost on long sequences.
The field is moving fast. Understanding both LSTMs and transformers gives you the conceptual foundation to follow this evolution.
For deeper context on how transformers power modern language models, see the transformer architecture notes and LLM concepts. The LLM Learning section covers the application of transformers in large language models.
Test your understanding with the Deep Learning Quiz, and the ML Basics Quiz covers the foundations that both architectures build on.
The Machine Learning course has hands-on sequence modeling projects, and the Embeddings and Vector Database notes explain how transformer representations are used in retrieval systems.
💬 DiscussionPowered by GitHub Discussions
Frequently Asked Questions
AiTechWorlds Team
✓ Verified WriterThe AiTechWorlds team is passionate about AI, technology, and education. We create high-quality, research-backed content to help you learn, grow, and succeed in the modern digital world.
Related Articles
Convolutional Neural Networks (CNNs): How Image Recognition Works
CNNs learn to see by sharing weights across space. Here's the math behind convolution, pooling, and why ResNets can train 100+ layers without vanishing gradients.
Deep Learning Explained: Neural Networks from Zero to Understanding
Most tutorials teach you the API. This guide teaches you what's actually happening inside a neural network — forward pass, backprop, and why depth matters.
Building Your First Deep Learning Model with PyTorch: Practical Guide
Learn to build deep learning models with PyTorch from scratch. Covers tensors, neural networks, training loops, and your first image classifier — hands-on for real beginners.
Transfer Learning Explained: Fine-Tune Pre-Trained Models in 30 Minutes
Transfer learning lets you use ResNet, BERT, and ViT weights trained on millions of examples for your own dataset. Fine-tune in 30 minutes with real code and benchmark comparisons.