Follow AiTechWorlds on LinkedIn for professional AI content!Follow Now →
20 minLesson 12 of 31
Core Algorithms

Decision Trees Explained

Decision Trees: Intuitive, Interpretable, Powerful

Decision trees are among the most interpretable machine learning models — you can literally draw them on a whiteboard and explain every prediction. They're also the foundation of some of the most powerful algorithms in ML: Random Forests and Gradient Boosting.

The Core Idea

A decision tree splits data by asking a series of yes/no questions, choosing the question at each step that best separates the classes or reduces prediction error.

Predicting loan default:

                [Income < $40k?]
               /                \
             Yes                 No
              │                  │
     [Has missed payment?]  [Debt ratio > 0.5?]
         /       \               /         \
       Yes        No           Yes           No
        │          │            │             │
     DEFAULT     SAFE       DEFAULT          SAFE

This is completely interpretable — you can follow the path for any individual and explain exactly why they were classified a certain way.

How a Tree Decides Where to Split

The tree chooses splits that maximize "purity" — the more uniform the resulting groups, the better.

Gini Impurity (Classification)

Gini = 1 - Σ(pᵢ²)

Where pᵢ is the proportion of class i in the node.

Pure node (all one class): Gini = 1 - 1² = 0
Maximally impure (50/50 split): Gini = 1 - (0.5² + 0.5²) = 0.5

Information Gain (Entropy)

Entropy = -Σ(pᵢ × log₂(pᵢ))

Pure node: Entropy = 0
50/50 split: Entropy = -(0.5×log₂(0.5) + 0.5×log₂(0.5)) = 1.0 bit

Mean Squared Error (Regression)

For regression trees, splits are chosen to minimize variance within each resulting group.

Building Your First Decision Tree

from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier, export_text, plot_tree
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
import matplotlib.pyplot as plt

# Load data
iris = load_iris()
X, y = iris.data, iris.target

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Train a decision tree
tree = DecisionTreeClassifier(
    max_depth=3,           # Limit depth to prevent overfitting
    min_samples_split=5,   # Minimum samples to split a node
    min_samples_leaf=2,    # Minimum samples at each leaf
    criterion='gini',      # 'gini' or 'entropy'
    random_state=42
)
tree.fit(X_train, y_train)

# Evaluate
y_pred = tree.predict(X_test)
print(f"Accuracy: {accuracy_score(y_test, y_pred):.3f}")
print(classification_report(y_test, y_pred, target_names=iris.target_names))

# Visualize the tree (text)
print(export_text(tree, feature_names=iris.feature_names))

# Visualize the tree (plot)
plt.figure(figsize=(20, 10))
plot_tree(tree, 
          feature_names=iris.feature_names,
          class_names=iris.target_names,
          filled=True, rounded=True)
plt.show()

Decision Trees for Regression

from sklearn.tree import DecisionTreeRegressor
from sklearn.metrics import mean_squared_error
import numpy as np

# Regression example
from sklearn.datasets import fetch_california_housing
housing = fetch_california_housing()
X, y = housing.data, housing.target

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

tree_reg = DecisionTreeRegressor(max_depth=5, random_state=42)
tree_reg.fit(X_train, y_train)

y_pred = tree_reg.predict(X_test)
rmse = np.sqrt(mean_squared_error(y_test, y_pred))
print(f"RMSE: {rmse:.3f}")

# Feature importance
for name, importance in zip(housing.feature_names, tree_reg.feature_importances_):
    print(f"  {name}: {importance:.3f}")

Controlling Overfitting: The Key Hyperparameters

Decision trees will overfit if left unconstrained — a tree with no limits will memorize the training data completely.

from sklearn.model_selection import validation_curve
import numpy as np
import matplotlib.pyplot as plt

# How max_depth affects train vs test performance
max_depths = range(1, 20)
train_scores, test_scores = [], []

for depth in max_depths:
    tree = DecisionTreeClassifier(max_depth=depth, random_state=42)
    tree.fit(X_train, y_train)
    train_scores.append(tree.score(X_train, y_train))
    test_scores.append(tree.score(X_test, y_test))

plt.figure(figsize=(10, 6))
plt.plot(max_depths, train_scores, label='Train')
plt.plot(max_depths, test_scores, label='Test')
plt.xlabel('max_depth')
plt.ylabel('Accuracy')
plt.title('Decision Tree: Depth vs Performance')
plt.legend()
plt.show()

Key hyperparameters:

ParameterEffectTypical Range
max_depthLimits tree height3–10
min_samples_splitMin samples to split a node5–20
min_samples_leafMin samples at leaf node1–10
max_featuresFeatures considered per split'sqrt', 'log2'
max_leaf_nodesMaximum number of leaves10–100

Pruning: Post-Training Simplification

Cost-complexity pruning removes branches that don't improve generalization enough.

# Find optimal alpha via cross-validation
from sklearn.model_selection import cross_val_score

path = DecisionTreeClassifier(random_state=42).cost_complexity_pruning_path(X_train, y_train)
ccp_alphas = path.ccp_alphas[:-1]  # Remove the root

cv_scores = []
for alpha in ccp_alphas:
    tree = DecisionTreeClassifier(random_state=42, ccp_alpha=alpha)
    scores = cross_val_score(tree, X_train, y_train, cv=5, scoring='accuracy')
    cv_scores.append(scores.mean())

# Best alpha
best_alpha = ccp_alphas[np.argmax(cv_scores)]
print(f"Best alpha: {best_alpha:.5f}")

# Train final tree with best alpha
final_tree = DecisionTreeClassifier(random_state=42, ccp_alpha=best_alpha)
final_tree.fit(X_train, y_train)

Feature Importance

Decision trees give you feature importance scores based on how much each feature reduces impurity across all splits.

importances = pd.Series(
    tree.feature_importances_,
    index=iris.feature_names
).sort_values(ascending=False)

print("Feature Importances:")
print(importances)

# Visualize
importances.plot(kind='bar')
plt.title('Feature Importance')
plt.tight_layout()
plt.show()

Note: Feature importance from a single tree can be unstable. Random Forest importances (averaged over many trees) are much more reliable.

When to Use Decision Trees

Good uses:

  • You need an explainable model (medical, legal, financial)
  • Data has non-linear relationships and interactions
  • Quick baseline before trying complex models
  • Feature selection (importance scores)

When to avoid:

  • You need the best possible accuracy (Random Forest or GBM will be better)
  • Data has many irrelevant features (trees are easily distracted)
  • You need probability calibration (tree probabilities are poorly calibrated)

The Complete Decision Tree Workflow

from sklearn.model_selection import GridSearchCV

param_grid = {
    'max_depth': [3, 5, 7, 10, None],
    'min_samples_split': [2, 5, 10, 20],
    'min_samples_leaf': [1, 2, 5],
    'criterion': ['gini', 'entropy']
}

grid_search = GridSearchCV(
    DecisionTreeClassifier(random_state=42),
    param_grid,
    cv=5,
    scoring='accuracy',
    n_jobs=-1
)
grid_search.fit(X_train, y_train)

print("Best params:", grid_search.best_params_)
print("Best CV score:", grid_search.best_score_)

best_tree = grid_search.best_estimator_
print("Test accuracy:", best_tree.score(X_test, y_test))

Decision trees are essential to understand not just for their own merits, but because they're the building blocks of Random Forests and Gradient Boosting — the workhorses of tabular ML.

Next lesson: Random Forests — combining hundreds of decision trees to build a much stronger model.

📱

Get this course's notes on Telegram!

Free cheat sheets, summaries & practice exercises

Get Notes Free →
!