How to Tune a Decision Tree to Avoid Overfitting

Decision Trees are among the most intuitive machine learning algorithms. They mimic human decision-making by splitting data into branches based on feature values until a prediction is reached.




However, Decision Trees have a major weakness: overfitting.

A tree can become so specialized to the training data that it memorizes noise rather than learning useful patterns. 

When this happens, the model performs exceptionally well on training data but poorly on new, unseen data.

In this tutorial, you will learn how to tune a Decision Tree using Scikit-Learn to improve generalization and avoid overfitting.


What Does Overfitting Look Like?

Imagine you are predicting whether a citizen trusts their government using survey responses.

A fully grown Decision Tree might create hundreds of tiny branches that perfectly classify every training observation.

Training Accuracy: 99%

Test Accuracy: 62%


This large gap is a classic sign of overfitting.

The model has learned the training data too well and cannot generalize to new examples.


Why Decision Trees Overfit Easily

Decision Trees continue splitting data until stopping conditions are reached.


Without constraints, the algorithm can:

  • Create extremely deep trees

  • Build branches from very small groups of observations

  • Learn random fluctuations in the data

The result is a complex model that captures noise instead of meaningful patterns.


Parameters That Control Tree Complexity

The good news is that Scikit-Learn provides several parameters that help limit tree growth.


1. max_depth

This parameter controls how many levels the tree can grow.

A smaller depth creates a simpler model.

DecisionTreeClassifier(max_depth=5)

Typical values to test:

3, 5, 7, 10, 15

A shallower tree usually generalizes better.


2. min_samples_split

This specifies the minimum number of samples required before a node can split.

DecisionTreeClassifier(min_samples_split=20)

Higher values prevent the model from creating branches based on very small subsets of data.

Common values:

5, 10, 20, 50

3. min_samples_leaf

This determines the minimum number of observations allowed in a leaf node.

DecisionTreeClassifier(min_samples_leaf=10)

If a split would create a leaf with fewer than 10 observations, it is rejected.

This is one of the most effective ways to reduce overfitting.

Common values:

1, 5, 10, 20


4. max_leaf_nodes

Limits the total number of leaf nodes.

DecisionTreeClassifier(max_leaf_nodes=30)

The tree stops growing after reaching the specified number of leaves.

This keeps the model simple and interpretable.



5. ccp_alpha (Cost Complexity Pruning)

Pruning removes branches that provide little predictive value.

DecisionTreeClassifier(ccp_alpha=0.01)

Larger values produce smaller trees.

Pruning is often highly effective because it removes unnecessary complexity after the tree is built.



Practical Example

Let's use the famous Breast Cancer dataset included in Scikit-Learn.


from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score

# Load data
data = load_breast_cancer()

X = data.data
y = data.target

# Train-test split
X_train, X_test, y_train, y_test = train_test_split(
    X, y,
    test_size=0.2,
    random_state=42
)

# Baseline model
tree = DecisionTreeClassifier(random_state=42)

tree.fit(X_train, y_train)

train_acc = accuracy_score(
    y_train,
    tree.predict(X_train)
)

test_acc = accuracy_score(
    y_test,
    tree.predict(X_test)
)

print("Training Accuracy:", train_acc)
print("Test Accuracy:", test_acc)



As in the image above, you will often observe nearly perfect training accuracy but noticeably lower test accuracy.


Now tune the model:

tree = DecisionTreeClassifier(
    max_depth=5,
    min_samples_split=20,
    min_samples_leaf=10,
    random_state=42
)

tree.fit(X_train, y_train)

train_acc = accuracy_score(
    y_train,
    tree.predict(X_train)
)

test_acc = accuracy_score(
    y_test,
    tree.predict(X_test)
)

print("Training Accuracy:", train_acc)
print("Test Accuracy:", test_acc)



You may see training accuracy decrease slightly while test accuracy improves.

This is exactly what we want.


Using Grid Search to Find the Best Parameters

Instead of guessing parameter values, use Grid Search.

from sklearn.model_selection import GridSearchCV

param_grid = {
    'max_depth': [3, 5, 7, 10],
    'min_samples_split': [5, 10, 20],
    'min_samples_leaf': [1, 5, 10]
}

grid = GridSearchCV(
    DecisionTreeClassifier(random_state=42),
    param_grid,
    cv=5,
    scoring='accuracy'
)

grid.fit(X_train, y_train)

print(grid.best_params_)
print(grid.best_score_)



Grid Search evaluates many combinations and selects the one that performs best during cross-validation.



How to Tell If Tuning Worked

A tuned Decision Tree should show:

  • Smaller difference between training and test accuracy

  • Better cross-validation performance

  • Simpler tree structure

  • More reliable predictions on unseen data


The goal is not to maximize training accuracy.

The goal is to maximize performance on new data.


Decision Trees are powerful because they are easy to understand and explain. However, their flexibility makes them highly susceptible to overfitting.


The most important tuning parameters are:

  • max_depth

  • min_samples_split

  • min_samples_leaf

  • max_leaf_nodes

  • ccp_alpha


By controlling tree complexity and validating performance with cross-validation, you can build Decision Tree models that generalize well and make reliable predictions in real-world applications.


Build a Job‑Ready Portfolio in 16 Python Projects — Proven, Practical, and Profitable for $288.


How to Pay and Get Access to the 16 End to End Practical Python Projects




Comments

Popular posts from this blog

How to Filter Rows Using Boolean Indexing in Pandas (Afrobarometer Kenya Dataset)

How to Build a Pivot Table From Our World in Data Demographics

How to Decide Whether to Drop or Fill Missing Value