top of page

Decision Tree in Machine Learning

A decision tree is a flowchart-like structure in which each internal node represents a test on a feature (e.g. whether a coin flip comes up heads or tails) , each leaf node represents a class label (decision taken after computing all features) and branches represent conjunctions of features that lead to those class labels. The paths from root to leaf represent classification rules. Below diagram illustrate the basic flow of decision tree for decision making with labels (Rain(Yes), No Rain(No)).

Decision Tree for Rain Forecasting

Decision tree is one of the predictive modelling approaches used in statistics, data mining and machine learning.

Decision trees are constructed via an algorithmic approach that identifies ways to split a data set based on different conditions. It is one of the most widely used and practical methods for supervised learning. Decision Trees are a non-parametric supervised learning method used for both classification and regression tasks.

Tree models where the target variable can take a discrete set of values are called classification trees. Decision trees where the target variable can take continuous values (typically real numbers) are called regression trees. Classification And Regression Tree (CART) is general term for this.

Throughout this post i will try to explain using the examples.

Data Format

Data comes in records of forms.


The dependent variable, Y, is the target variable that we are trying to understand, classify or generalize. The vector x is composed of the features, x1, x2, x3 etc., that are used for that task.


training_data = [
                  ['Green', 3, 'Apple'],
                  ['Yellow', 3, 'Apple'],
                  ['Red', 1, 'Grape'],
                  ['Red', 1, 'Grape'],
                  ['Yellow', 3, 'Lemon'],
 # Header = ["Color", "diameter", "Label"]
 # The last column is the label.
 # The first two columns are features.

Approach to make decision tree

While making decision tree, at each node of tree we ask different type of questions. Based on the asked question we will calculate the information gain corresponding to it.

Information Gain

Information gain is used to decide which feature to split on at each step in building the tree. Simplicity is best, so we want to keep our tree small. To do so, at each step we should choose the split that results in the purest daughter nodes. A commonly used measure of purity is called information. For each node of the tree, the information value measures how much information a feature gives us about the class. The split with the highest information gain will be taken as the first split and the process will continue until all children nodes are pure, or until the information gain is 0.

Asking Question

class Question:
  """A Question is used to partition a dataset.  This class just records a 'column number' (e.g., 0 for Color) and a
  'column value' (e.g., Green). The 'match' method is used to compare
  the feature value in an example to the feature value stored in the
  question. See the demo below.

    def __init__(self, column, value):
      self.column = column
      self.value = value  

    def match(self, example):
      # Compare the feature value in an example to the
      # feature value in this question.
      val = example[self.column]
      if is_numeric(val):
          return val >= self.value
          return val == self.value  

    def __repr__(self):
      # This is just a helper method to print
      # the question in a readable format.
      condition = "=="
      if is_numeric(self.value):
          condition = ">="
      return "Is %s %s %s?" % (
          header[self.column], condition, str(self.value))

Lets try querying questions and its outputs.

Question(1, 3) ## Is diameter >= 3?
Question(0, "Green") ## Is color == Green?

Now we will try to Partition the dataset based on asked question. Data will be divided into two classes at each steps.

def partition(rows, question):
    """Partitions a dataset.    

    For each row in the dataset, check if it matches the question. If
    so, add it to 'true rows', otherwise, add it to 'false rows'.
    true_rows, false_rows = [], []
    for row in rows:
        if question.match(row):
    return true_rows, false_rows
   # Let's partition the training data based on whether rows are Red.
   true_rows, false_rows = partition(training_data, Question(0, 'Red'))
   # This will contain all the 'Red' rows.
   true_rows ## [['Red', 1, 'Grape'], ['Red', 1, 'Grape']]
   false_rows ## [['Green', 3, 'Apple'], ['Yellow', 3, 'Apple'], ['Yellow', 3, 'Lemon']]

Algorithm for constructing decision tree usually works top-down, by choosing a variable at each step that best splits the set of items. Different algorithms use different metrices for measuring best.

Gini Impurity

First let’s understand the meaning of Pure and Impure.


Pure means, in a selected sample of dataset all data belongs to same class (PURE).


Impure means, data is mixture of different classes.

Definition of Gini Impurity

Gini Impurity is a measurement of the likelihood of an incorrect classification of a new instance of a random variable, if that new instance were randomly classified according to the distribution of class labels from the data set.

If our dataset is Pure then likelihood of incorrect classification is 0. If our sample is mixture of different classes then likelihood of incorrect classification will be high.

Calculating Gini Impurity.

def gini(rows):
    """Calculate the Gini Impurity for a list of rows.

    There are a few different ways to do this, I thought this one was
    the most concise. See:
    counts = class_counts(rows)
    impurity = 1
    for lbl in counts:
        prob_of_lbl = counts[lbl] / float(len(rows))
        impurity -= prob_of_lbl**2
    return impurity


# Demo 1:
    # Let's look at some example to understand how Gini Impurity works.
    # First, we'll look at a dataset with no mixing.
    no_mixing = [['Apple'],
    # this will return 0
    gini(no_mixing) ## output=0
   ## Demo 2:
   # Now, we'll look at dataset with a 50:50 apples:oranges ratio
    some_mixing = [['Apple'],
    # this will return 0.5 - meaning, there's a 50% chance of misclassifying
    # a random example we draw from the dataset.
    gini(some_mixing) ##output=0.5
    ## Demo 3:
    # Now, we'll look at a dataset with many different labels
    lots_of_mixing = [['Apple'],
    # This will return 0.8
    gini(lots_of_mixing) ##output=0.8

Steps for Making decision tree

  • Get list of rows (dataset) which are taken into consideration for making decision tree (recursively at each nodes).

  • Calculate uncertanity of our dataset or Gini impurity or how much our data is mixed up etc.

  • Generate list of all question which needs to be asked at that node.

  • Partition rows into True rows and False rows based on each question asked.

  • Calculate information gain based on gini impurity and partition of data from previous step.

  • Update highest information gain based on each question asked.

  • Update best question based on information gain (higher information gain).

  • Divide the node on best question. Repeat again from step 1 again until we get pure node (leaf nodes).

Code for Above Steps

def find_best_split(rows):
    """Find the best question to ask by iterating over every feature / value
    and calculating the information gain."""
    best_gain = 0  # keep track of the best information gain
    best_question = None  # keep train of the feature / value that produced it
    current_uncertainty = gini(rows)
    n_features = len(rows[0]) - 1  # number of columns    for col in range(n_features):  # for each feature        values = set([row[col] for row in rows])  # unique values in the column        for val in values:  # for each value            question = Question(col, val)            # try splitting the dataset
            true_rows, false_rows = partition(rows, question)            # Skip this split if it doesn't divide the
            # dataset.
            if len(true_rows) == 0 or len(false_rows) == 0:
                continue            # Calculate the information gain from this split
            gain = info_gain(true_rows, false_rows, current_uncertainty)            # You actually can use '>' instead of '>=' here
            # but I wanted the tree to look a certain way for our
            # toy dataset.
            if gain >= best_gain:
                best_gain, best_question = gain, question    return best_gain, best_question
    # Demo:
    # Find the best question to ask first for our toy dataset.
    best_gain, best_question = find_best_split(training_data)
    ## output - Is diameter >= 3?

Now build the Decision tree based on step discussed above recursively at each node.

def build_tree(rows):
    """Builds the tree.    Rules of recursion: 1) Believe that it works. 2) Start by checking
    for the base case (no further information gain). 3) Prepare for
    giant stack traces.
    """    # Try partitioning the dataset on each of the unique attribute,
    # calculate the information gain,
    # and return the question that produces the highest gain.
    gain, question = find_best_split(rows)    # Base case: no further info gain
    # Since we can ask no further questions,
    # we'll return a leaf.
    if gain == 0:
        return Leaf(rows)    # If we reach here, we have found a useful feature / value
    # to partition on.
    true_rows, false_rows = partition(rows, question)    # Recursively build the true branch.
    true_branch = build_tree(true_rows)    # Recursively build the false branch.
    false_branch = build_tree(false_rows)    # Return a Question node.
    # This records the best feature / value to ask at this point,
    # as well as the branches to follow
    # dependingo on the answer.
    return Decision_Node(question, true_branch, false_branch)

Building Decision Tree

Let’s build decision tree based on training data.

training_data = [
                  ['Green', 3, 'Apple'],
                  ['Yellow', 3, 'Apple'],
                  ['Red', 1, 'Grape'],
                  ['Red', 1, 'Grape'],
                  ['Yellow', 3, 'Lemon'],
  # Header = ["Color", "diameter", "Label"]
  # The last column is the label.
  # The first two columns are features.
  my_tree = build_tree(training_data)


Is diameter >= 3?
  --> True:
    Is color == Yellow?
    --> True:
        Predict {'Lemon': 1, 'Apple': 1}
    --> False:
        Predict {'Apple': 1}
 --> False:
    Predict {'Grape': 2}

From above output we can see that at each steps data is divided into True and False rows. This process keep repeated until we reach leaf node where information gain is 0 and further split of data is not possible as nodes are Pure.

Advantage of Decision Tree

  • Easy to use and understand.

  • Can handle both categorical and numerical data.

  • Resistant to outliers, hence require little data preprocessing.

Disadvantage of Decision Tree

  • Prone to overfitting.

  • Require some kind of measurement as to how well they are doing.

  • Need to be careful with parameter tuning.

  • Can create biased learned trees if some classes dominate.

How to avoid overfitting the Decision tree model

Overfitting is one of the major problem for every model in machine learning. If model is overfitted it will poorly generalized to new samples. To avoid decision tree from overfitting we remove the branches that make use of features having low importance. This method is called as Pruning or post-pruning. This way we will reduce the complexity of tree, and hence imroves predictive accuracy by the reduction of overfitting.

Pruning should reduce the size of a learning tree without reducing predictive accuracy as measured by a cross-validation set. There are 2 major Pruning techniques.

  • Minimum Error: The tree is pruned back to the point where the cross-validated error is a minimum.

  • Smallest Tree: The tree is pruned back slightly further than the minimum error. Technically the pruning creates a decision tree with cross-validation error within 1 standard error of the minimum error.

Early Stop or Pre-pruning

An alternative method to prevent overfitting is to try and stop the tree-building process early, before it produces leaves with very small samples. This heuristic is known as early stopping but is also sometimes known as pre-pruning decision trees.

At each stage of splitting the tree, we check the cross-validation error. If the error does not decrease significantly enough then we stop. Early stopping may underfit by stopping too early. The current split may be of little benefit, but having made it, subsequent splits more significantly reduce the error.

Early stopping and pruning can be used together, separately, or not at all. Post pruning decision trees is more mathematically rigorous, finding a tree at least as good as early stopping. Early stopping is a quick fix heuristic. If used together with pruning, early stopping may save time. After all, why build a tree only to prune it back again?

Decision Tree in Real Life

1. Selecting a flight to travel

Suppose you need to select a flight for your next travel. How do we go about it? We check first if the flight is available on that day or not. If it is not available, we will look for some other date but if it is available then we look for may be the duration of the flight. If we want to have only direct flights then we look whether the price of that flight is in your pre-defined budget or not. If it is too expensive, we look at some other flights else we book it!

2. Handling late night cravings

Source: Google

There are many more application of decision tree in real life. You can check for more applications of decision tree.

Source: Towards Data Science - Prince Yadav

The Tech Platform



bottom of page