Decision Trees Explained
Decision Trees: Intuitive, Interpretable, Powerful
Decision trees are among the most interpretable machine learning models — you can literally draw them on a whiteboard and explain every prediction. They're also the foundation of some of the most powerful algorithms in ML: Random Forests and Gradient Boosting.
The Core Idea
A decision tree splits data by asking a series of yes/no questions, choosing the question at each step that best separates the classes or reduces prediction error.
Predicting loan default:
[Income < $40k?]
/ \
Yes No
│ │
[Has missed payment?] [Debt ratio > 0.5?]
/ \ / \
Yes No Yes No
│ │ │ │
DEFAULT SAFE DEFAULT SAFE
This is completely interpretable — you can follow the path for any individual and explain exactly why they were classified a certain way.
How a Tree Decides Where to Split
The tree chooses splits that maximize "purity" — the more uniform the resulting groups, the better.
Gini Impurity (Classification)
Gini = 1 - Σ(pᵢ²)
Where pᵢ is the proportion of class i in the node.
Pure node (all one class): Gini = 1 - 1² = 0
Maximally impure (50/50 split): Gini = 1 - (0.5² + 0.5²) = 0.5
Information Gain (Entropy)
Entropy = -Σ(pᵢ × log₂(pᵢ))
Pure node: Entropy = 0
50/50 split: Entropy = -(0.5×log₂(0.5) + 0.5×log₂(0.5)) = 1.0 bit
Mean Squared Error (Regression)
For regression trees, splits are chosen to minimize variance within each resulting group.
Building Your First Decision Tree
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier, export_text, plot_tree
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
import matplotlib.pyplot as plt
# Load data
iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# Train a decision tree
tree = DecisionTreeClassifier(
max_depth=3, # Limit depth to prevent overfitting
min_samples_split=5, # Minimum samples to split a node
min_samples_leaf=2, # Minimum samples at each leaf
criterion='gini', # 'gini' or 'entropy'
random_state=42
)
tree.fit(X_train, y_train)
# Evaluate
y_pred = tree.predict(X_test)
print(f"Accuracy: {accuracy_score(y_test, y_pred):.3f}")
print(classification_report(y_test, y_pred, target_names=iris.target_names))
# Visualize the tree (text)
print(export_text(tree, feature_names=iris.feature_names))
# Visualize the tree (plot)
plt.figure(figsize=(20, 10))
plot_tree(tree,
feature_names=iris.feature_names,
class_names=iris.target_names,
filled=True, rounded=True)
plt.show()
Decision Trees for Regression
from sklearn.tree import DecisionTreeRegressor
from sklearn.metrics import mean_squared_error
import numpy as np
# Regression example
from sklearn.datasets import fetch_california_housing
housing = fetch_california_housing()
X, y = housing.data, housing.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
tree_reg = DecisionTreeRegressor(max_depth=5, random_state=42)
tree_reg.fit(X_train, y_train)
y_pred = tree_reg.predict(X_test)
rmse = np.sqrt(mean_squared_error(y_test, y_pred))
print(f"RMSE: {rmse:.3f}")
# Feature importance
for name, importance in zip(housing.feature_names, tree_reg.feature_importances_):
print(f" {name}: {importance:.3f}")
Controlling Overfitting: The Key Hyperparameters
Decision trees will overfit if left unconstrained — a tree with no limits will memorize the training data completely.
from sklearn.model_selection import validation_curve
import numpy as np
import matplotlib.pyplot as plt
# How max_depth affects train vs test performance
max_depths = range(1, 20)
train_scores, test_scores = [], []
for depth in max_depths:
tree = DecisionTreeClassifier(max_depth=depth, random_state=42)
tree.fit(X_train, y_train)
train_scores.append(tree.score(X_train, y_train))
test_scores.append(tree.score(X_test, y_test))
plt.figure(figsize=(10, 6))
plt.plot(max_depths, train_scores, label='Train')
plt.plot(max_depths, test_scores, label='Test')
plt.xlabel('max_depth')
plt.ylabel('Accuracy')
plt.title('Decision Tree: Depth vs Performance')
plt.legend()
plt.show()
Key hyperparameters:
| Parameter | Effect | Typical Range |
|---|---|---|
max_depth | Limits tree height | 3–10 |
min_samples_split | Min samples to split a node | 5–20 |
min_samples_leaf | Min samples at leaf node | 1–10 |
max_features | Features considered per split | 'sqrt', 'log2' |
max_leaf_nodes | Maximum number of leaves | 10–100 |
Pruning: Post-Training Simplification
Cost-complexity pruning removes branches that don't improve generalization enough.
# Find optimal alpha via cross-validation
from sklearn.model_selection import cross_val_score
path = DecisionTreeClassifier(random_state=42).cost_complexity_pruning_path(X_train, y_train)
ccp_alphas = path.ccp_alphas[:-1] # Remove the root
cv_scores = []
for alpha in ccp_alphas:
tree = DecisionTreeClassifier(random_state=42, ccp_alpha=alpha)
scores = cross_val_score(tree, X_train, y_train, cv=5, scoring='accuracy')
cv_scores.append(scores.mean())
# Best alpha
best_alpha = ccp_alphas[np.argmax(cv_scores)]
print(f"Best alpha: {best_alpha:.5f}")
# Train final tree with best alpha
final_tree = DecisionTreeClassifier(random_state=42, ccp_alpha=best_alpha)
final_tree.fit(X_train, y_train)
Feature Importance
Decision trees give you feature importance scores based on how much each feature reduces impurity across all splits.
importances = pd.Series(
tree.feature_importances_,
index=iris.feature_names
).sort_values(ascending=False)
print("Feature Importances:")
print(importances)
# Visualize
importances.plot(kind='bar')
plt.title('Feature Importance')
plt.tight_layout()
plt.show()
Note: Feature importance from a single tree can be unstable. Random Forest importances (averaged over many trees) are much more reliable.
When to Use Decision Trees
Good uses:
- You need an explainable model (medical, legal, financial)
- Data has non-linear relationships and interactions
- Quick baseline before trying complex models
- Feature selection (importance scores)
When to avoid:
- You need the best possible accuracy (Random Forest or GBM will be better)
- Data has many irrelevant features (trees are easily distracted)
- You need probability calibration (tree probabilities are poorly calibrated)
The Complete Decision Tree Workflow
from sklearn.model_selection import GridSearchCV
param_grid = {
'max_depth': [3, 5, 7, 10, None],
'min_samples_split': [2, 5, 10, 20],
'min_samples_leaf': [1, 2, 5],
'criterion': ['gini', 'entropy']
}
grid_search = GridSearchCV(
DecisionTreeClassifier(random_state=42),
param_grid,
cv=5,
scoring='accuracy',
n_jobs=-1
)
grid_search.fit(X_train, y_train)
print("Best params:", grid_search.best_params_)
print("Best CV score:", grid_search.best_score_)
best_tree = grid_search.best_estimator_
print("Test accuracy:", best_tree.score(X_test, y_test))
Decision trees are essential to understand not just for their own merits, but because they're the building blocks of Random Forests and Gradient Boosting — the workhorses of tabular ML.
Next lesson: Random Forests — combining hundreds of decision trees to build a much stronger model.
Get this course's notes on Telegram!
Free cheat sheets, summaries & practice exercises