How to Visualise a Decision Tree So Anyone Can Understand It

One of the biggest advantages of Decision Trees is that they are naturally explainable.




Unlike Neural Networks, which often operate as "black boxes," Decision Trees show exactly how a prediction is made. Every split, branch, and leaf can be inspected and explained.


This transparency makes Decision Trees popular in healthcare, finance, government, education, and any industry where stakeholders need to understand why a model made a particular decision.


In this tutorial, you will learn how to visualize a Decision Tree using Python and how to explain the resulting diagram in plain language.


Why Visualising a Decision Tree Matters

Imagine you have built a model that predicts whether a tumor is malignant or benign.

A stakeholder asks:

            "Why did the model predict this patient is high risk?"

With many machine learning models, answering that question is difficult.


With a Decision Tree, you can literally point to the path the model followed.

For example:

Radius > 15.3?
    Yes
        Texture > 20?
            Yes → Malignant
            No → Benign
    No → Benign

This makes the model much easier to trust and validate.


The Dataset

We will use the Breast Cancer dataset included with Scikit-Learn.


from sklearn.datasets import load_breast_cancer
import pandas as pd

data = load_breast_cancer()

X = pd.DataFrame(
    data.data,
    columns=data.feature_names
)

y = data.target


The target variable is:

  • 0 = Malignant

  • 1 = Benign


Train a Simple Decision Tree

Before visualizing the tree, let's train one.

from sklearn.tree import DecisionTreeClassifier

tree = DecisionTreeClassifier(
    max_depth=3,
    random_state=42
)

tree.fit(X, y)


Notice that we limit the depth to 3.

If the tree becomes too large, the visualization becomes difficult to interpret.


Visualise the Tree

Scikit-Learn includes a built-in visualization tool.

from sklearn.tree import plot_tree
import matplotlib.pyplot as plt

plt.figure(figsize=(18,10))

plot_tree(
    tree,
    feature_names=data.feature_names,
    class_names=data.target_names,
    filled=True,
    rounded=True
)

plt.show()




This produces a tree diagram where:

  • Each box is a decision node

  • Branches represent possible outcomes

  • Leaf nodes contain final predictions

  • Colors indicate the predicted class


Understanding a Tree Node

A typical node may look like this:

worst radius <= 16.8

gini = 0.47

samples = 569

value = [212, 357]

class = Benign


Let's decode each part.

1. Split Condition

worst radius <= 16.8

This is the rule used to divide observations.

If the condition is true:

  • Go left

If false:

  • Go right


2. Gini Impurity

gini = 0.47

Gini measures how mixed the classes are.

  • 0 = perfectly pure

  • Higher values = more mixed

Lower Gini values indicate better separation.


3. Samples

samples = 569

The number of observations reaching that node.


4. Value

value = [212, 357]

The number of observations in each class.

In this example:

  • 212 malignant

  • 357 benign


5. Class

class = Benign

The majority class prediction for that node.


Following a Prediction Path

Suppose a patient has:

worst radius = 18
worst texture = 24

The tree may follow this path:

worst radius <= 16.8?
      No

worst texture <= 20?
      No

Prediction:
Malignant


This path provides a complete explanation for the prediction.


Instead of saying:

"The algorithm predicts malignant."

You can say:

"The patient's radius and texture measurements placed them into a branch historically associated with malignant tumors."

That explanation is understandable to most stakeholders.


Exporting a High-Quality Tree Diagram

For reports and presentations, export the tree as an image.

plt.figure(figsize=(20,12))

plot_tree(
    tree,
    feature_names=data.feature_names,
    class_names=data.target_names,
    filled=True,
    rounded=True,
    fontsize=10
)

plt.savefig(
    "decision_tree.png",
    dpi=300,
    bbox_inches="tight"
)

plt.show()



The resulting image can be inserted into:

  • Research reports

  • Executive presentations

  • Regulatory documentation

  • Academic publications


Simplifying Large Trees

Real-world trees often contain hundreds of nodes.

Such trees become impossible to explain visually.



To keep trees understandable you can do the following:

1. Limit Depth

DecisionTreeClassifier(max_depth=3)

2. Increase Minimum Leaf Size

DecisionTreeClassifier(
    min_samples_leaf=20
)

3. Prune the Tree

DecisionTreeClassifier(
    ccp_alpha=0.01
)

Smaller trees are usually easier to interpret and often generalize better.


Feature Importance Visualisation

Sometimes stakeholders care more about important variables than the entire tree.

You can visualize feature importance.

import pandas as pd

importance = pd.DataFrame({
    "Feature": data.feature_names,
    "Importance": tree.feature_importances_
})

importance = importance.sort_values(
    by="Importance",
    ascending=False
)

print(importance.head(10))



This shows which variables influenced predictions the most.

For many business audiences, feature importance charts are easier to understand than a complete tree diagram.


Common Mistakes

Avoid these common visualization mistakes:

  • Visualizing extremely deep trees

  • Using tiny unreadable fonts

  • Including hundreds of nodes in presentations

  • Explaining Gini scores before explaining the decision path

  • Focusing on technical metrics instead of business logic


Remember that stakeholders usually care about why a prediction was made, not the mathematics behind the algorithm.



Decision Trees are one of the few machine learning models that can be fully visualized and explained.

By plotting the tree structure, following prediction paths, and highlighting key features, you can transform a machine learning model from a black box into a transparent decision-making tool.


When presenting a Decision Tree to non-technical audiences, focus on the sequence of decisions that lead to a prediction. 


If a stakeholder can follow the branches and understand the logic, then the visualization has done its job.


Build a Job‑Ready Portfolio in 16 Python Projects — Proven, Practical, and Profitable for $288.


How to Pay and Get Access to the 16 End to End Practical Python Projects




Comments

Popular posts from this blog

How to Filter Rows Using Boolean Indexing in Pandas (Afrobarometer Kenya Dataset)

How to Build a Pivot Table From Our World in Data Demographics

How to Decide Whether to Drop or Fill Missing Value