Source Information¶


Created by:

Updated by: October 01, 2024


Goal¶

This notebook shows how to use scikit-learn machine learning package, using a classic decision tree example.

Decision trees¶

Decision trees are a supervised machine learning method for classification

  • Supervised - use labeled data for training
  • Classification - predict the category for a given input

To explore decision trees, we'll use the Python scikit-learn machine learning (sklearn) package and the famous iris dataset that was described in R.A. Fisher's classic 1936 paper “The Use of Multiple Measurements in Taxonomic Problems”. The goal is straightforward - given four measurements of a flower (sepal width, sepal length, petal width and petal length), predict which of three species (setosa, versicolor, virginica) the flower belongs to.

We'll dive right in first and then go back and examine the steps in more detail

Required Modules for the Jupyter Notebook¶

Before running the notebook, we need to load the following modules.

Module: tree, load_iris, graphviz, sklearn

In [1]:
!pip install --user scikit-learn graphviz
Requirement already satisfied: scikit-learn in /cm/shared/apps/spack/cpu/opt/spack/linux-centos8-zen/gcc-8.3.1/anaconda3-2020.11-da3i7hmt6bdqbmuzq6pyt7kbm47wyrjp/lib/python3.8/site-packages (0.23.2)
Requirement already satisfied: graphviz in /home/amehrotra1/.local/lib/python3.8/site-packages (0.20.3)
Requirement already satisfied: joblib>=0.11 in /cm/shared/apps/spack/cpu/opt/spack/linux-centos8-zen/gcc-8.3.1/anaconda3-2020.11-da3i7hmt6bdqbmuzq6pyt7kbm47wyrjp/lib/python3.8/site-packages (from scikit-learn) (0.17.0)
Requirement already satisfied: threadpoolctl>=2.0.0 in /cm/shared/apps/spack/cpu/opt/spack/linux-centos8-zen/gcc-8.3.1/anaconda3-2020.11-da3i7hmt6bdqbmuzq6pyt7kbm47wyrjp/lib/python3.8/site-packages (from scikit-learn) (2.1.0)
Requirement already satisfied: scipy>=0.19.1 in /cm/shared/apps/spack/cpu/opt/spack/linux-centos8-zen/gcc-8.3.1/anaconda3-2020.11-da3i7hmt6bdqbmuzq6pyt7kbm47wyrjp/lib/python3.8/site-packages (from scikit-learn) (1.5.2)
Requirement already satisfied: numpy>=1.13.3 in /cm/shared/apps/spack/cpu/opt/spack/linux-centos8-zen/gcc-8.3.1/anaconda3-2020.11-da3i7hmt6bdqbmuzq6pyt7kbm47wyrjp/lib/python3.8/site-packages (from scikit-learn) (1.19.2)
In [2]:
from sklearn import tree
from sklearn.datasets import load_iris
import graphviz 

Decision Tree Classifier with Iris Dataset Visualization using Graphviz¶

Import the data set and sklearn tree package, load the data, create the decision tree classifier and train the classifier

In [3]:
iris = load_iris()
clf = tree.DecisionTreeClassifier()
clf = clf.fit(iris.data, iris.target)

Import graphviz, create a DOT representation of the decision tree, render using Graphviz and display

In [4]:
dot_data = tree.export_graphviz(clf, out_file=None, 
                         feature_names=iris.feature_names,  
                         class_names=iris.target_names,  
                         filled=True, rounded=True,  
                         special_characters=True)  
graph = graphviz.Source(dot_data)  
graph 
Out[4]:
No description has been provided for this image

We start by importing sklearn's tree module. Since the entire sklearn package is very large, we imported just the functionality we needed.

We then import and load a data set. In this case, we used one of the data sets that is provided with sklearn, but more generally you may need to read in and process one or more files.

In [5]:
iris = load_iris()

Let's take a look at the type and content of iris.

In [6]:
type(iris)
Out[6]:
sklearn.utils.Bunch

sklearn.utils.Bunch objects aggregate multiple arrays and strings. This is not necessary, but is a convenient way to package everything needed for a machine learning project. Let's now look into the contents of iris

In [7]:
iris
Out[7]:
{'data': array([[5.1, 3.5, 1.4, 0.2],
        [4.9, 3. , 1.4, 0.2],
        [4.7, 3.2, 1.3, 0.2],
        [4.6, 3.1, 1.5, 0.2],
        [5. , 3.6, 1.4, 0.2],
        [5.4, 3.9, 1.7, 0.4],
        [4.6, 3.4, 1.4, 0.3],
        [5. , 3.4, 1.5, 0.2],
        [4.4, 2.9, 1.4, 0.2],
        [4.9, 3.1, 1.5, 0.1],
        [5.4, 3.7, 1.5, 0.2],
        [4.8, 3.4, 1.6, 0.2],
        [4.8, 3. , 1.4, 0.1],
        [4.3, 3. , 1.1, 0.1],
        [5.8, 4. , 1.2, 0.2],
        [5.7, 4.4, 1.5, 0.4],
        [5.4, 3.9, 1.3, 0.4],
        [5.1, 3.5, 1.4, 0.3],
        [5.7, 3.8, 1.7, 0.3],
        [5.1, 3.8, 1.5, 0.3],
        [5.4, 3.4, 1.7, 0.2],
        [5.1, 3.7, 1.5, 0.4],
        [4.6, 3.6, 1. , 0.2],
        [5.1, 3.3, 1.7, 0.5],
        [4.8, 3.4, 1.9, 0.2],
        [5. , 3. , 1.6, 0.2],
        [5. , 3.4, 1.6, 0.4],
        [5.2, 3.5, 1.5, 0.2],
        [5.2, 3.4, 1.4, 0.2],
        [4.7, 3.2, 1.6, 0.2],
        [4.8, 3.1, 1.6, 0.2],
        [5.4, 3.4, 1.5, 0.4],
        [5.2, 4.1, 1.5, 0.1],
        [5.5, 4.2, 1.4, 0.2],
        [4.9, 3.1, 1.5, 0.2],
        [5. , 3.2, 1.2, 0.2],
        [5.5, 3.5, 1.3, 0.2],
        [4.9, 3.6, 1.4, 0.1],
        [4.4, 3. , 1.3, 0.2],
        [5.1, 3.4, 1.5, 0.2],
        [5. , 3.5, 1.3, 0.3],
        [4.5, 2.3, 1.3, 0.3],
        [4.4, 3.2, 1.3, 0.2],
        [5. , 3.5, 1.6, 0.6],
        [5.1, 3.8, 1.9, 0.4],
        [4.8, 3. , 1.4, 0.3],
        [5.1, 3.8, 1.6, 0.2],
        [4.6, 3.2, 1.4, 0.2],
        [5.3, 3.7, 1.5, 0.2],
        [5. , 3.3, 1.4, 0.2],
        [7. , 3.2, 4.7, 1.4],
        [6.4, 3.2, 4.5, 1.5],
        [6.9, 3.1, 4.9, 1.5],
        [5.5, 2.3, 4. , 1.3],
        [6.5, 2.8, 4.6, 1.5],
        [5.7, 2.8, 4.5, 1.3],
        [6.3, 3.3, 4.7, 1.6],
        [4.9, 2.4, 3.3, 1. ],
        [6.6, 2.9, 4.6, 1.3],
        [5.2, 2.7, 3.9, 1.4],
        [5. , 2. , 3.5, 1. ],
        [5.9, 3. , 4.2, 1.5],
        [6. , 2.2, 4. , 1. ],
        [6.1, 2.9, 4.7, 1.4],
        [5.6, 2.9, 3.6, 1.3],
        [6.7, 3.1, 4.4, 1.4],
        [5.6, 3. , 4.5, 1.5],
        [5.8, 2.7, 4.1, 1. ],
        [6.2, 2.2, 4.5, 1.5],
        [5.6, 2.5, 3.9, 1.1],
        [5.9, 3.2, 4.8, 1.8],
        [6.1, 2.8, 4. , 1.3],
        [6.3, 2.5, 4.9, 1.5],
        [6.1, 2.8, 4.7, 1.2],
        [6.4, 2.9, 4.3, 1.3],
        [6.6, 3. , 4.4, 1.4],
        [6.8, 2.8, 4.8, 1.4],
        [6.7, 3. , 5. , 1.7],
        [6. , 2.9, 4.5, 1.5],
        [5.7, 2.6, 3.5, 1. ],
        [5.5, 2.4, 3.8, 1.1],
        [5.5, 2.4, 3.7, 1. ],
        [5.8, 2.7, 3.9, 1.2],
        [6. , 2.7, 5.1, 1.6],
        [5.4, 3. , 4.5, 1.5],
        [6. , 3.4, 4.5, 1.6],
        [6.7, 3.1, 4.7, 1.5],
        [6.3, 2.3, 4.4, 1.3],
        [5.6, 3. , 4.1, 1.3],
        [5.5, 2.5, 4. , 1.3],
        [5.5, 2.6, 4.4, 1.2],
        [6.1, 3. , 4.6, 1.4],
        [5.8, 2.6, 4. , 1.2],
        [5. , 2.3, 3.3, 1. ],
        [5.6, 2.7, 4.2, 1.3],
        [5.7, 3. , 4.2, 1.2],
        [5.7, 2.9, 4.2, 1.3],
        [6.2, 2.9, 4.3, 1.3],
        [5.1, 2.5, 3. , 1.1],
        [5.7, 2.8, 4.1, 1.3],
        [6.3, 3.3, 6. , 2.5],
        [5.8, 2.7, 5.1, 1.9],
        [7.1, 3. , 5.9, 2.1],
        [6.3, 2.9, 5.6, 1.8],
        [6.5, 3. , 5.8, 2.2],
        [7.6, 3. , 6.6, 2.1],
        [4.9, 2.5, 4.5, 1.7],
        [7.3, 2.9, 6.3, 1.8],
        [6.7, 2.5, 5.8, 1.8],
        [7.2, 3.6, 6.1, 2.5],
        [6.5, 3.2, 5.1, 2. ],
        [6.4, 2.7, 5.3, 1.9],
        [6.8, 3. , 5.5, 2.1],
        [5.7, 2.5, 5. , 2. ],
        [5.8, 2.8, 5.1, 2.4],
        [6.4, 3.2, 5.3, 2.3],
        [6.5, 3. , 5.5, 1.8],
        [7.7, 3.8, 6.7, 2.2],
        [7.7, 2.6, 6.9, 2.3],
        [6. , 2.2, 5. , 1.5],
        [6.9, 3.2, 5.7, 2.3],
        [5.6, 2.8, 4.9, 2. ],
        [7.7, 2.8, 6.7, 2. ],
        [6.3, 2.7, 4.9, 1.8],
        [6.7, 3.3, 5.7, 2.1],
        [7.2, 3.2, 6. , 1.8],
        [6.2, 2.8, 4.8, 1.8],
        [6.1, 3. , 4.9, 1.8],
        [6.4, 2.8, 5.6, 2.1],
        [7.2, 3. , 5.8, 1.6],
        [7.4, 2.8, 6.1, 1.9],
        [7.9, 3.8, 6.4, 2. ],
        [6.4, 2.8, 5.6, 2.2],
        [6.3, 2.8, 5.1, 1.5],
        [6.1, 2.6, 5.6, 1.4],
        [7.7, 3. , 6.1, 2.3],
        [6.3, 3.4, 5.6, 2.4],
        [6.4, 3.1, 5.5, 1.8],
        [6. , 3. , 4.8, 1.8],
        [6.9, 3.1, 5.4, 2.1],
        [6.7, 3.1, 5.6, 2.4],
        [6.9, 3.1, 5.1, 2.3],
        [5.8, 2.7, 5.1, 1.9],
        [6.8, 3.2, 5.9, 2.3],
        [6.7, 3.3, 5.7, 2.5],
        [6.7, 3. , 5.2, 2.3],
        [6.3, 2.5, 5. , 1.9],
        [6.5, 3. , 5.2, 2. ],
        [6.2, 3.4, 5.4, 2.3],
        [5.9, 3. , 5.1, 1.8]]),
 'target': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]),
 'frame': None,
 'target_names': array(['setosa', 'versicolor', 'virginica'], dtype='<U10'),
 'DESCR': '.. _iris_dataset:\n\nIris plants dataset\n--------------------\n\n**Data Set Characteristics:**\n\n    :Number of Instances: 150 (50 in each of three classes)\n    :Number of Attributes: 4 numeric, predictive attributes and the class\n    :Attribute Information:\n        - sepal length in cm\n        - sepal width in cm\n        - petal length in cm\n        - petal width in cm\n        - class:\n                - Iris-Setosa\n                - Iris-Versicolour\n                - Iris-Virginica\n                \n    :Summary Statistics:\n\n    ============== ==== ==== ======= ===== ====================\n                    Min  Max   Mean    SD   Class Correlation\n    ============== ==== ==== ======= ===== ====================\n    sepal length:   4.3  7.9   5.84   0.83    0.7826\n    sepal width:    2.0  4.4   3.05   0.43   -0.4194\n    petal length:   1.0  6.9   3.76   1.76    0.9490  (high!)\n    petal width:    0.1  2.5   1.20   0.76    0.9565  (high!)\n    ============== ==== ==== ======= ===== ====================\n\n    :Missing Attribute Values: None\n    :Class Distribution: 33.3% for each of 3 classes.\n    :Creator: R.A. Fisher\n    :Donor: Michael Marshall (MARSHALL%PLU@io.arc.nasa.gov)\n    :Date: July, 1988\n\nThe famous Iris database, first used by Sir R.A. Fisher. The dataset is taken\nfrom Fisher\'s paper. Note that it\'s the same as in R, but not as in the UCI\nMachine Learning Repository, which has two wrong data points.\n\nThis is perhaps the best known database to be found in the\npattern recognition literature.  Fisher\'s paper is a classic in the field and\nis referenced frequently to this day.  (See Duda & Hart, for example.)  The\ndata set contains 3 classes of 50 instances each, where each class refers to a\ntype of iris plant.  One class is linearly separable from the other 2; the\nlatter are NOT linearly separable from each other.\n\n.. topic:: References\n\n   - Fisher, R.A. "The use of multiple measurements in taxonomic problems"\n     Annual Eugenics, 7, Part II, 179-188 (1936); also in "Contributions to\n     Mathematical Statistics" (John Wiley, NY, 1950).\n   - Duda, R.O., & Hart, P.E. (1973) Pattern Classification and Scene Analysis.\n     (Q327.D83) John Wiley & Sons.  ISBN 0-471-22361-1.  See page 218.\n   - Dasarathy, B.V. (1980) "Nosing Around the Neighborhood: A New System\n     Structure and Classification Rule for Recognition in Partially Exposed\n     Environments".  IEEE Transactions on Pattern Analysis and Machine\n     Intelligence, Vol. PAMI-2, No. 1, 67-71.\n   - Gates, G.W. (1972) "The Reduced Nearest Neighbor Rule".  IEEE Transactions\n     on Information Theory, May 1972, 431-433.\n   - See also: 1988 MLC Proceedings, 54-64.  Cheeseman et al"s AUTOCLASS II\n     conceptual clustering system finds 3 classes in the data.\n   - Many, many more ...',
 'feature_names': ['sepal length (cm)',
  'sepal width (cm)',
  'petal length (cm)',
  'petal width (cm)'],
 'filename': '/cm/shared/apps/spack/cpu/opt/spack/linux-centos8-zen/gcc-8.3.1/anaconda3-2020.11-da3i7hmt6bdqbmuzq6pyt7kbm47wyrjp/lib/python3.8/site-packages/sklearn/datasets/data/iris.csv'}

We're primarily concerned with the data and target arrays. The data array contains the measurements (sepal length, sepal width, petal length and petal width), with one row for each sample. The target array contains the data lables (0 = setosa, 1 = versicolor, 3 = virginica).

In [8]:
iris.data[0:5]
Out[8]:
array([[5.1, 3.5, 1.4, 0.2],
       [4.9, 3. , 1.4, 0.2],
       [4.7, 3.2, 1.3, 0.2],
       [4.6, 3.1, 1.5, 0.2],
       [5. , 3.6, 1.4, 0.2]])
In [9]:
iris.target[0:5]
Out[9]:
array([0, 0, 0, 0, 0])

The next two lines create a new classifier object and perform the fitting

In [10]:
clf = tree.DecisionTreeClassifier()
clf = clf.fit(iris.data, iris.target)

At this point, we're done with the hard work. We've trained the decision tree and can use it to predict the class from input data. Let's try it with a few flower measurements.

In [11]:
clf.predict([[ 5. ,  3.6,  1.4,  0.2],
             [ 5.9,  3. ,  5.1,  1.8],
             [ 6.7,  3. ,  5. ,  1.7]])
Out[11]:
array([0, 2, 1])

In the above example, we see that the predicted classes for the three flowers are 0, 2, and 1. We can get a little more detail with the predict_proba method. For each row of input measurements, it gives the probability that item will belong to a given class.

In [12]:
clf.predict_proba([[ 5. ,  3.6,  1.4,  0.2],
                   [ 5.9,  3. ,  5.1,  1.8],
                   [ 6.7,  3. ,  5. ,  1.7]])
Out[12]:
array([[1., 0., 0.],
       [0., 0., 1.],
       [0., 1., 0.]])

Sometimes we may want to inspect the decision tree. This is a two=step process. First we create a representation of the graph using the DOT language (http://www.graphviz.org/content/dot-language) and then we render an image of the graph using the graphviz Source method. To take a little of the mystery out of the process, we show a simple example below.

In [13]:
simple_dot_data = 'digraph Tree {\n0; \n1; \n2; 0 -> 1; 0 -> 2}'
In [14]:
graph = graphviz.Source(simple_dot_data)  
graph 
Out[14]:
No description has been provided for this image

Although we can generate the DOT description of a graph by hand, this is tedious and error prone for all but the simplest graphs. Fortunately, the sklearn tree module contains a method to automatically generate the graph from the classifier.

In [15]:
dot_data = tree.export_graphviz(clf, out_file=None)
graph = graphviz.Source(dot_data)  
graph 
Out[15]:
No description has been provided for this image

Although we only need to pass the classifier to the Source method, the resulting graph is much more useful if we label the features (sepal length, ...) and classes (setosa, ...)

In [16]:
dot_data = tree.export_graphviz(clf, out_file=None, 
                         feature_names=iris.feature_names,  
                         class_names=iris.target_names,  
                         filled=True, rounded=True,  
                         special_characters=True)  
graph = graphviz.Source(dot_data)  
graph 
Out[16]:
No description has been provided for this image

It's worth taking a quick look at the DOT representation for this more complex graph. Makes you fully appreciate that sklearn provides a method for generating this from the classifier.

In [17]:
dot_data
Out[17]:
'digraph Tree {\nnode [shape=box, style="filled, rounded", color="black", fontname=helvetica] ;\nedge [fontname=helvetica] ;\n0 [label=<petal width (cm) &le; 0.8<br/>gini = 0.667<br/>samples = 150<br/>value = [50, 50, 50]<br/>class = setosa>, fillcolor="#ffffff"] ;\n1 [label=<gini = 0.0<br/>samples = 50<br/>value = [50, 0, 0]<br/>class = setosa>, fillcolor="#e58139"] ;\n0 -> 1 [labeldistance=2.5, labelangle=45, headlabel="True"] ;\n2 [label=<petal width (cm) &le; 1.75<br/>gini = 0.5<br/>samples = 100<br/>value = [0, 50, 50]<br/>class = versicolor>, fillcolor="#ffffff"] ;\n0 -> 2 [labeldistance=2.5, labelangle=-45, headlabel="False"] ;\n3 [label=<petal length (cm) &le; 4.95<br/>gini = 0.168<br/>samples = 54<br/>value = [0, 49, 5]<br/>class = versicolor>, fillcolor="#4de88e"] ;\n2 -> 3 ;\n4 [label=<petal width (cm) &le; 1.65<br/>gini = 0.041<br/>samples = 48<br/>value = [0, 47, 1]<br/>class = versicolor>, fillcolor="#3de684"] ;\n3 -> 4 ;\n5 [label=<gini = 0.0<br/>samples = 47<br/>value = [0, 47, 0]<br/>class = versicolor>, fillcolor="#39e581"] ;\n4 -> 5 ;\n6 [label=<gini = 0.0<br/>samples = 1<br/>value = [0, 0, 1]<br/>class = virginica>, fillcolor="#8139e5"] ;\n4 -> 6 ;\n7 [label=<petal width (cm) &le; 1.55<br/>gini = 0.444<br/>samples = 6<br/>value = [0, 2, 4]<br/>class = virginica>, fillcolor="#c09cf2"] ;\n3 -> 7 ;\n8 [label=<gini = 0.0<br/>samples = 3<br/>value = [0, 0, 3]<br/>class = virginica>, fillcolor="#8139e5"] ;\n7 -> 8 ;\n9 [label=<petal length (cm) &le; 5.45<br/>gini = 0.444<br/>samples = 3<br/>value = [0, 2, 1]<br/>class = versicolor>, fillcolor="#9cf2c0"] ;\n7 -> 9 ;\n10 [label=<gini = 0.0<br/>samples = 2<br/>value = [0, 2, 0]<br/>class = versicolor>, fillcolor="#39e581"] ;\n9 -> 10 ;\n11 [label=<gini = 0.0<br/>samples = 1<br/>value = [0, 0, 1]<br/>class = virginica>, fillcolor="#8139e5"] ;\n9 -> 11 ;\n12 [label=<petal length (cm) &le; 4.85<br/>gini = 0.043<br/>samples = 46<br/>value = [0, 1, 45]<br/>class = virginica>, fillcolor="#843de6"] ;\n2 -> 12 ;\n13 [label=<sepal width (cm) &le; 3.1<br/>gini = 0.444<br/>samples = 3<br/>value = [0, 1, 2]<br/>class = virginica>, fillcolor="#c09cf2"] ;\n12 -> 13 ;\n14 [label=<gini = 0.0<br/>samples = 2<br/>value = [0, 0, 2]<br/>class = virginica>, fillcolor="#8139e5"] ;\n13 -> 14 ;\n15 [label=<gini = 0.0<br/>samples = 1<br/>value = [0, 1, 0]<br/>class = versicolor>, fillcolor="#39e581"] ;\n13 -> 15 ;\n16 [label=<gini = 0.0<br/>samples = 43<br/>value = [0, 0, 43]<br/>class = virginica>, fillcolor="#8139e5"] ;\n12 -> 16 ;\n}'

Submit Ticket¶

If you find anything that needs to be changed, edited, or if you would like to provide feedback or contribute to the notebook, please submit a ticket by contacting us at:

Email: consult@sdsc.edu

We appreciate your input and will review your suggestions promptly!

Submit Ticket¶

If you find anything that needs to be changed, edited, or if you would like to provide feedback or contribute to the notebook, please submit a ticket by contacting us at:

Email: consult@sdsc.edu

We appreciate your input and will review your suggestions promptly!