10.2. Decision Trees

Decision trees are a predictive modeling approach that uses probability trees to either predict a continous value or predict a classification that the data fits into.

Pros:

  • Easy to understand and visualize

  • Easy to figure out why the model is making a certain prediction

  • Doesn’t need as much data preparation as other prediction methods

Cons:

  • Sometimes decision trees can get too complex and overfit the data

  • Small variations in the data could cause different trees to be created which can drastically change the model’s output

10.2.1. Concepts

The goal of a decision tree algorithm to create a model that predicts the value of a target variable by learning simple decision rules inferred from the data features. A tree can be seen as a piecewise constant approximation.

10.2.1.1. Algorithm Formulation

At a given node, find the best split that minimizes some impurity or loss measure \(H\) after the split. Let \(Q_m\) be the data at node \(m\) with sample size \(n_m\). Let \(\theta\) be a candidate split (which may consists of a candidate threshold for candidate feature). Suppose that after the split, \(Q_{m,l}\) is the left node data with sample size \(n_{m,l}\) and \(Q_{m,r}\) is the right node data with sample size \(n_{m,r}\). The quality of the split is measured by

\[\begin{equation*} G(Q_m, \theta) = \frac{n_{m, l}}{n_m} H(Q_{m, l}(\theta)) + \frac{n_{m,r}}{n_m} H(Q_{m, r}(\theta)). \end{equation*}\]

The algorithm set

\[\begin{equation*} \theta^* = \arg\min_{\theta} G(Q_m, \theta). \end{equation*}\]

Recursively find the best split for each child node.

  • Stopping: until a the maximum tree depth is reached or all node sample size is below a preset threshold.

  • Pruning: reduces the complexity of the final classifier, and hence improves predictive accuracy by the reduction of overfitting.

10.2.1.2. Metrics

See sklearn documentation for details.

  • Classification

    • Gini
      $\(H(Q_m) = \sum_{k=1} p_{mk} (1 - p_{mk})\)$

    • Entropy $\(H(Q_m) = - \sum p_{mk} \log p_{mk}\)$

    • Misclassification $\(H(Q_m) = 1 - \max_k p_{mk}\)$

  • Regression

    • Mean squared error

    • Half Poisson deviance (for count targets)

    • Mean absolute error (slower than MSE; more robust)

10.2.1.3. Confusion Matrix

See sklearn example for details.

See definitions on Wiki.

A confusion matrix is a matrix layout of the results of a classification algorithm, where each row of the matrix represents the instances in an actual class while each column represents the instances in a predicted class, or vice versa.

  • True positive (TP)

  • False positive (FP)

  • True negative (TN)

  • False negative (FN)

Metrics for evaluating classification:

  • Precision: $\(\frac{\text{TP}}{\text{TP} + \text{FP}}\)$

  • Recall (sensitivity): $\(\frac{\textt{TP}}{\text{TP} + \text{FP}}\)$

  • F-beta score: $\((1 + \beta^2) \frac{1}{\frac{\beta^2}{\text{recall}} + \frac{1}{\text{recision}}}\)\( where \)\beta\( means that recall is considered \)\beta\( times as important as precision. When \)\beta = 1$, the two are considered equally important.

From Wiki:

In a classification task, a precision score of 1.0 for a class C means that every item labelled as belonging to class C does indeed belong to class C (but says nothing about the number of items from class C that were not labelled correctly) whereas a recall of 1.0 means that every item from class C was labelled as belonging to class C (but says nothing about how many items from other classes were incorrectly also labelled as belonging to class C).

10.2.2. Simple Classification Example

## configure the inline figures to of svg format
%config InlineBackend.figure_formats = ['svg']

from sklearn.datasets import load_iris
from sklearn import tree
iris = load_iris()
X, y = iris.data, iris.target
clf = tree.DecisionTreeClassifier()
clf = clf.fit(X, y)
tree.plot_tree(clf)
[Text(0.5, 0.9166666666666666, 'X[3] <= 0.8\ngini = 0.667\nsamples = 150\nvalue = [50, 50, 50]'),
 Text(0.4230769230769231, 0.75, 'gini = 0.0\nsamples = 50\nvalue = [50, 0, 0]'),
 Text(0.5769230769230769, 0.75, 'X[3] <= 1.75\ngini = 0.5\nsamples = 100\nvalue = [0, 50, 50]'),
 Text(0.3076923076923077, 0.5833333333333334, 'X[2] <= 4.95\ngini = 0.168\nsamples = 54\nvalue = [0, 49, 5]'),
 Text(0.15384615384615385, 0.4166666666666667, 'X[3] <= 1.65\ngini = 0.041\nsamples = 48\nvalue = [0, 47, 1]'),
 Text(0.07692307692307693, 0.25, 'gini = 0.0\nsamples = 47\nvalue = [0, 47, 0]'),
 Text(0.23076923076923078, 0.25, 'gini = 0.0\nsamples = 1\nvalue = [0, 0, 1]'),
 Text(0.46153846153846156, 0.4166666666666667, 'X[3] <= 1.55\ngini = 0.444\nsamples = 6\nvalue = [0, 2, 4]'),
 Text(0.38461538461538464, 0.25, 'gini = 0.0\nsamples = 3\nvalue = [0, 0, 3]'),
 Text(0.5384615384615384, 0.25, 'X[0] <= 6.95\ngini = 0.444\nsamples = 3\nvalue = [0, 2, 1]'),
 Text(0.46153846153846156, 0.08333333333333333, 'gini = 0.0\nsamples = 2\nvalue = [0, 2, 0]'),
 Text(0.6153846153846154, 0.08333333333333333, 'gini = 0.0\nsamples = 1\nvalue = [0, 0, 1]'),
 Text(0.8461538461538461, 0.5833333333333334, 'X[2] <= 4.85\ngini = 0.043\nsamples = 46\nvalue = [0, 1, 45]'),
 Text(0.7692307692307693, 0.4166666666666667, 'X[0] <= 5.95\ngini = 0.444\nsamples = 3\nvalue = [0, 1, 2]'),
 Text(0.6923076923076923, 0.25, 'gini = 0.0\nsamples = 1\nvalue = [0, 1, 0]'),
 Text(0.8461538461538461, 0.25, 'gini = 0.0\nsamples = 2\nvalue = [0, 0, 2]'),
 Text(0.9230769230769231, 0.4166666666666667, 'gini = 0.0\nsamples = 43\nvalue = [0, 0, 43]')]
../_images/decisiontrees_1_1.svg

10.2.3. Data Cleaning

Import data and add a binary column for if a person was injured or not

Can’t use strings in decision trees so to fix this we can change each borough to a number EX: Bronx = 1, Brooklyn = 2

import pandas as pd
import numpy as np

nyc_collisions = pd.read_csv("../data/nyc_mv_collisions_202201.csv")

nyc_collisions["time"] = [x.split(":")[0] for x in nyc_collisions["CRASH TIME"]]
nyc_collisions["time"] = [int(x) for x in nyc_collisions["time"]]

nyc_collisions["injury_binary"] = nyc_collisions["NUMBER OF PERSONS INJURED"].map(lambda x: 1 if x>0 else 0)
nyc_collisions["num_borough"] = nyc_collisions["BOROUGH"].map(lambda x: 1 if x=="BRONX" else 0)
nyc_collisions["num_borough"] = nyc_collisions["BOROUGH"].map(lambda x: 2 if x=="BROOKLYN" else 0)
nyc_collisions["num_borough"] = nyc_collisions["BOROUGH"].map(lambda x: 3 if x=="QUEENS" else 0)
nyc_collisions["num_borough"] = nyc_collisions["BOROUGH"].map(lambda x: 4 if x=="MANHATTAN" else 0)
nyc_collisions["num_borough"] = nyc_collisions["BOROUGH"].map(lambda x: 5 if x=="STATEN ISLAND" else 0)

nyc_collisions.rename(columns={"NUMBER OF PERSONS KILLED": "num_ppl_killed"}, inplace=True)
nyc_collisions.rename(columns={"NUMBER OF PERSONS INJURED": "num_ppl_injured"}, inplace=True)



nyc_collisions['injury_binary'].value_counts()
0    5311
1    2348
Name: injury_binary, dtype: int64

10.2.4. Classification Decision Tree Model

Select columns to use in our decision tree model.

Don’t want to use columns such as number of pedestrians injured because that heavily affects our target variable

We also split our dataset into a training set and a test set to avoid overfitting the data

We are trying to predict if someone was injured or not in a crash using all of the columns in feature_cols

from sklearn.model_selection import train_test_split
from sklearn import tree
from sklearn import metrics
from sklearn.metrics import classification_report, confusion_matrix

feature_cols = ['time', 'num_borough', 'NUMBER OF MOTORIST KILLED',
                'NUMBER OF CYCLIST KILLED',
                'NUMBER OF PEDESTRIANS KILLED']
x = nyc_collisions[feature_cols] # Features
y = nyc_collisions.injury_binary # Target variable

x_train, x_test, y_train, y_test = train_test_split(
    x, y, test_size=0.2, random_state=12) # 80% training and 20% test

We then set up our decision tree and use it to predict on the test set. This example is a classification example so it will predict a 0 or 1 based on if someone was injured or not in the crash (1=injured)

We then can print out the accuracy of our model and a confusion matrix to see its predictions

clf = tree.DecisionTreeClassifier()

# Train Decision Tree Classifer
clf = clf.fit(x_train,y_train)

#Predict the response for test dataset
y_pred = clf.predict(x_test)

print("Accuracy:",metrics.accuracy_score(y_test, y_pred))

print(confusion_matrix(y_test, y_pred))
Accuracy: 0.6977806788511749
[[1064    3]
 [ 460    5]]

We can plot this model using the sklearn plotting function

tree.plot_tree(clf)
[Text(0.6271067415730337, 0.9583333333333334, 'X[0] <= 16.5\ngini = 0.426\nsamples = 6127\nvalue = [4244, 1883]'),
 Text(0.46207865168539325, 0.875, 'X[2] <= 0.5\ngini = 0.405\nsamples = 4226\nvalue = [3035, 1191]'),
 Text(0.351123595505618, 0.7916666666666666, 'X[1] <= 2.5\ngini = 0.404\nsamples = 4222\nvalue = [3034, 1188]'),
 Text(0.19662921348314608, 0.7083333333333334, 'X[0] <= 5.5\ngini = 0.403\nsamples = 4110\nvalue = [2962, 1148]'),
 Text(0.1348314606741573, 0.625, 'X[0] <= 4.5\ngini = 0.386\nsamples = 940\nvalue = [694, 246]'),
 Text(0.11235955056179775, 0.5416666666666666, 'X[0] <= 3.5\ngini = 0.391\nsamples = 792\nvalue = [581, 211]'),
 Text(0.0898876404494382, 0.4583333333333333, 'X[0] <= 2.5\ngini = 0.385\nsamples = 641\nvalue = [474, 167]'),
 Text(0.06741573033707865, 0.375, 'X[0] <= 1.5\ngini = 0.393\nsamples = 535\nvalue = [391, 144]'),
 Text(0.0449438202247191, 0.2916666666666667, 'X[0] <= 0.5\ngini = 0.389\nsamples = 405\nvalue = [298, 107]'),
 Text(0.02247191011235955, 0.20833333333333334, 'gini = 0.397\nsamples = 271\nvalue = [197, 74]'),
 Text(0.06741573033707865, 0.20833333333333334, 'gini = 0.371\nsamples = 134\nvalue = [101, 33]'),
 Text(0.0898876404494382, 0.2916666666666667, 'gini = 0.407\nsamples = 130\nvalue = [93, 37]'),
 Text(0.11235955056179775, 0.375, 'gini = 0.34\nsamples = 106\nvalue = [83, 23]'),
 Text(0.1348314606741573, 0.4583333333333333, 'gini = 0.413\nsamples = 151\nvalue = [107, 44]'),
 Text(0.15730337078651685, 0.5416666666666666, 'gini = 0.361\nsamples = 148\nvalue = [113, 35]'),
 Text(0.25842696629213485, 0.625, 'X[0] <= 8.5\ngini = 0.407\nsamples = 3170\nvalue = [2268, 902]'),
 Text(0.20224719101123595, 0.5416666666666666, 'X[4] <= 0.5\ngini = 0.437\nsamples = 820\nvalue = [556, 264]'),
 Text(0.1797752808988764, 0.4583333333333333, 'X[0] <= 7.5\ngini = 0.437\nsamples = 818\nvalue = [554, 264]'),
 Text(0.15730337078651685, 0.375, 'X[0] <= 6.5\ngini = 0.431\nsamples = 461\nvalue = [316, 145]'),
 Text(0.1348314606741573, 0.2916666666666667, 'gini = 0.437\nsamples = 223\nvalue = [151, 72]'),
 Text(0.1797752808988764, 0.2916666666666667, 'gini = 0.425\nsamples = 238\nvalue = [165, 73]'),
 Text(0.20224719101123595, 0.375, 'gini = 0.444\nsamples = 357\nvalue = [238, 119]'),
 Text(0.2247191011235955, 0.4583333333333333, 'gini = 0.0\nsamples = 2\nvalue = [2, 0]'),
 Text(0.3146067415730337, 0.5416666666666666, 'X[0] <= 10.5\ngini = 0.396\nsamples = 2350\nvalue = [1712, 638]'),
 Text(0.2696629213483146, 0.4583333333333333, 'X[0] <= 9.5\ngini = 0.335\nsamples = 516\nvalue = [406, 110]'),
 Text(0.24719101123595505, 0.375, 'gini = 0.335\nsamples = 263\nvalue = [207, 56]'),
 Text(0.29213483146067415, 0.375, 'gini = 0.336\nsamples = 253\nvalue = [199, 54]'),
 Text(0.3595505617977528, 0.4583333333333333, 'X[0] <= 15.5\ngini = 0.41\nsamples = 1834\nvalue = [1306, 528]'),
 Text(0.33707865168539325, 0.375, 'X[0] <= 14.5\ngini = 0.416\nsamples = 1476\nvalue = [1040, 436]'),
 Text(0.3146067415730337, 0.2916666666666667, 'X[0] <= 12.5\ngini = 0.409\nsamples = 1132\nvalue = [808, 324]'),
 Text(0.2696629213483146, 0.20833333333333334, 'X[0] <= 11.5\ngini = 0.422\nsamples = 520\nvalue = [363, 157]'),
 Text(0.24719101123595505, 0.125, 'gini = 0.404\nsamples = 256\nvalue = [184, 72]'),
 Text(0.29213483146067415, 0.125, 'gini = 0.437\nsamples = 264\nvalue = [179, 85]'),
 Text(0.3595505617977528, 0.20833333333333334, 'X[0] <= 13.5\ngini = 0.397\nsamples = 612\nvalue = [445, 167]'),
 Text(0.33707865168539325, 0.125, 'gini = 0.379\nsamples = 256\nvalue = [191, 65]'),
 Text(0.38202247191011235, 0.125, 'gini = 0.409\nsamples = 356\nvalue = [254, 102]'),
 Text(0.3595505617977528, 0.2916666666666667, 'gini = 0.439\nsamples = 344\nvalue = [232, 112]'),
 Text(0.38202247191011235, 0.375, 'gini = 0.382\nsamples = 358\nvalue = [266, 92]'),
 Text(0.5056179775280899, 0.7083333333333334, 'X[0] <= 3.5\ngini = 0.459\nsamples = 112\nvalue = [72, 40]'),
 Text(0.449438202247191, 0.625, 'X[0] <= 2.5\ngini = 0.5\nsamples = 8\nvalue = [4, 4]'),
 Text(0.42696629213483145, 0.5416666666666666, 'X[0] <= 0.5\ngini = 0.444\nsamples = 6\nvalue = [4, 2]'),
 Text(0.4044943820224719, 0.4583333333333333, 'gini = 0.5\nsamples = 2\nvalue = [1, 1]'),
 Text(0.449438202247191, 0.4583333333333333, 'X[0] <= 1.5\ngini = 0.375\nsamples = 4\nvalue = [3, 1]'),
 Text(0.42696629213483145, 0.375, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'),
 Text(0.47191011235955055, 0.375, 'gini = 0.444\nsamples = 3\nvalue = [2, 1]'),
 Text(0.47191011235955055, 0.5416666666666666, 'gini = 0.0\nsamples = 2\nvalue = [0, 2]'),
 Text(0.5617977528089888, 0.625, 'X[0] <= 8.5\ngini = 0.453\nsamples = 104\nvalue = [68, 36]'),
 Text(0.5168539325842697, 0.5416666666666666, 'X[0] <= 4.5\ngini = 0.405\nsamples = 39\nvalue = [28, 11]'),
 Text(0.4943820224719101, 0.4583333333333333, 'gini = 0.0\nsamples = 2\nvalue = [2, 0]'),
 Text(0.5393258426966292, 0.4583333333333333, 'X[0] <= 5.5\ngini = 0.418\nsamples = 37\nvalue = [26, 11]'),
 Text(0.5168539325842697, 0.375, 'gini = 0.444\nsamples = 6\nvalue = [4, 2]'),
 Text(0.5617977528089888, 0.375, 'X[0] <= 7.5\ngini = 0.412\nsamples = 31\nvalue = [22, 9]'),
 Text(0.5393258426966292, 0.2916666666666667, 'X[0] <= 6.5\ngini = 0.408\nsamples = 14\nvalue = [10, 4]'),
 Text(0.5168539325842697, 0.20833333333333334, 'gini = 0.408\nsamples = 7\nvalue = [5, 2]'),
 Text(0.5617977528089888, 0.20833333333333334, 'gini = 0.408\nsamples = 7\nvalue = [5, 2]'),
 Text(0.5842696629213483, 0.2916666666666667, 'gini = 0.415\nsamples = 17\nvalue = [12, 5]'),
 Text(0.6067415730337079, 0.5416666666666666, 'X[0] <= 9.5\ngini = 0.473\nsamples = 65\nvalue = [40, 25]'),
 Text(0.5842696629213483, 0.4583333333333333, 'gini = 0.5\nsamples = 8\nvalue = [4, 4]'),
 Text(0.6292134831460674, 0.4583333333333333, 'X[0] <= 10.5\ngini = 0.465\nsamples = 57\nvalue = [36, 21]'),
 Text(0.6067415730337079, 0.375, 'gini = 0.0\nsamples = 3\nvalue = [3, 0]'),
 Text(0.651685393258427, 0.375, 'X[0] <= 11.5\ngini = 0.475\nsamples = 54\nvalue = [33, 21]'),
 Text(0.6292134831460674, 0.2916666666666667, 'gini = 0.5\nsamples = 12\nvalue = [6, 6]'),
 Text(0.6741573033707865, 0.2916666666666667, 'X[0] <= 14.5\ngini = 0.459\nsamples = 42\nvalue = [27, 15]'),
 Text(0.6292134831460674, 0.20833333333333334, 'X[0] <= 13.5\ngini = 0.413\nsamples = 24\nvalue = [17, 7]'),
 Text(0.6067415730337079, 0.125, 'X[0] <= 12.5\ngini = 0.444\nsamples = 15\nvalue = [10, 5]'),
 Text(0.5842696629213483, 0.041666666666666664, 'gini = 0.444\nsamples = 9\nvalue = [6, 3]'),
 Text(0.6292134831460674, 0.041666666666666664, 'gini = 0.444\nsamples = 6\nvalue = [4, 2]'),
 Text(0.651685393258427, 0.125, 'gini = 0.346\nsamples = 9\nvalue = [7, 2]'),
 Text(0.7191011235955056, 0.20833333333333334, 'X[0] <= 15.5\ngini = 0.494\nsamples = 18\nvalue = [10, 8]'),
 Text(0.6966292134831461, 0.125, 'gini = 0.496\nsamples = 11\nvalue = [6, 5]'),
 Text(0.7415730337078652, 0.125, 'gini = 0.49\nsamples = 7\nvalue = [4, 3]'),
 Text(0.5730337078651685, 0.7916666666666666, 'X[0] <= 7.5\ngini = 0.375\nsamples = 4\nvalue = [1, 3]'),
 Text(0.550561797752809, 0.7083333333333334, 'gini = 0.0\nsamples = 3\nvalue = [0, 3]'),
 Text(0.5955056179775281, 0.7083333333333334, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'),
 Text(0.7921348314606742, 0.875, 'X[0] <= 18.5\ngini = 0.463\nsamples = 1901\nvalue = [1209, 692]'),
 Text(0.6853932584269663, 0.7916666666666666, 'X[0] <= 17.5\ngini = 0.478\nsamples = 722\nvalue = [437, 285]'),
 Text(0.6404494382022472, 0.7083333333333334, 'X[1] <= 2.5\ngini = 0.459\nsamples = 378\nvalue = [243, 135]'),
 Text(0.6179775280898876, 0.625, 'gini = 0.456\nsamples = 367\nvalue = [238, 129]'),
 Text(0.6629213483146067, 0.625, 'gini = 0.496\nsamples = 11\nvalue = [5, 6]'),
 Text(0.7303370786516854, 0.7083333333333334, 'X[1] <= 2.5\ngini = 0.492\nsamples = 344\nvalue = [194, 150]'),
 Text(0.7078651685393258, 0.625, 'X[4] <= 0.5\ngini = 0.49\nsamples = 329\nvalue = [188, 141]'),
 Text(0.6853932584269663, 0.5416666666666666, 'gini = 0.489\nsamples = 328\nvalue = [188, 140]'),
 Text(0.7303370786516854, 0.5416666666666666, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'),
 Text(0.7528089887640449, 0.625, 'gini = 0.48\nsamples = 15\nvalue = [6, 9]'),
 Text(0.898876404494382, 0.7916666666666666, 'X[1] <= 2.5\ngini = 0.452\nsamples = 1179\nvalue = [772, 407]'),
 Text(0.8426966292134831, 0.7083333333333334, 'X[4] <= 0.5\ngini = 0.456\nsamples = 1150\nvalue = [746, 404]'),
 Text(0.8202247191011236, 0.625, 'X[0] <= 19.5\ngini = 0.456\nsamples = 1147\nvalue = [743, 404]'),
 Text(0.7752808988764045, 0.5416666666666666, 'X[2] <= 0.5\ngini = 0.441\nsamples = 286\nvalue = [192, 94]'),
 Text(0.7528089887640449, 0.4583333333333333, 'gini = 0.442\nsamples = 285\nvalue = [191, 94]'),
 Text(0.797752808988764, 0.4583333333333333, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'),
 Text(0.8651685393258427, 0.5416666666666666, 'X[0] <= 22.5\ngini = 0.461\nsamples = 861\nvalue = [551, 310]'),
 Text(0.8426966292134831, 0.4583333333333333, 'X[0] <= 20.5\ngini = 0.466\nsamples = 666\nvalue = [420, 246]'),
 Text(0.8202247191011236, 0.375, 'gini = 0.467\nsamples = 245\nvalue = [154, 91]'),
 Text(0.8651685393258427, 0.375, 'X[0] <= 21.5\ngini = 0.465\nsamples = 421\nvalue = [266, 155]'),
 Text(0.8426966292134831, 0.2916666666666667, 'gini = 0.465\nsamples = 207\nvalue = [131, 76]'),
 Text(0.8876404494382022, 0.2916666666666667, 'gini = 0.466\nsamples = 214\nvalue = [135, 79]'),
 Text(0.8876404494382022, 0.4583333333333333, 'gini = 0.441\nsamples = 195\nvalue = [131, 64]'),
 Text(0.8651685393258427, 0.625, 'gini = 0.0\nsamples = 3\nvalue = [3, 0]'),
 Text(0.9550561797752809, 0.7083333333333334, 'X[0] <= 20.5\ngini = 0.185\nsamples = 29\nvalue = [26, 3]'),
 Text(0.9325842696629213, 0.625, 'X[0] <= 19.5\ngini = 0.337\nsamples = 14\nvalue = [11, 3]'),
 Text(0.9101123595505618, 0.5416666666666666, 'gini = 0.198\nsamples = 9\nvalue = [8, 1]'),
 Text(0.9550561797752809, 0.5416666666666666, 'gini = 0.48\nsamples = 5\nvalue = [3, 2]'),
 Text(0.9775280898876404, 0.625, 'gini = 0.0\nsamples = 15\nvalue = [15, 0]')]
../_images/decisiontrees_9_1.svg

We can plot this same exact model using the graphviz package and we see it looks much better

import graphviz
dot_data = tree.export_graphviz(clf, out_file=None, 
                                feature_names=feature_cols,  
                                class_names=None,
                                filled=True, rounded=True,  
                      special_characters=True)

# Draw graph
graph = graphviz.Source(dot_data, format="svg") 
graph
../_images/decisiontrees_11_0.svg

We can run this exact same decision tree but try to optimize its performance by changing parameters such as max_depth to set how many layers our tree is able to go until

from sklearn.model_selection import train_test_split
from sklearn import tree
from sklearn import metrics
from sklearn.metrics import classification_report, confusion_matrix

feature_cols = ['time', 'num_borough', 
                'NUMBER OF MOTORIST KILLED',
                'NUMBER OF CYCLIST KILLED',
                'NUMBER OF PEDESTRIANS KILLED']
x = nyc_collisions[feature_cols] # Features
y = nyc_collisions.injury_binary # Target variable

x_train, x_test, y_train, y_test = train_test_split(
    x, y, test_size=0.2, random_state=12) # 80% training and 20% test

clf = tree.DecisionTreeClassifier(max_depth=3)

# Train Decision Tree Classifer
clf = clf.fit(x_train,y_train)

#Predict the response for test dataset
y_pred = clf.predict(x_test)

print("Accuracy:",metrics.accuracy_score(y_test, y_pred))

print(confusion_matrix(y_test, y_pred))
Accuracy: 0.6971279373368147
[[1067    0]
 [ 464    1]]

We can see this change visually by re-running our code to graph the tree

dot_data = tree.export_graphviz(clf, out_file=None, 
                                feature_names=feature_cols,  
                                class_names=None,
                                filled=True, rounded=True,  
                                special_characters=True)

# Draw graph
graph = graphviz.Source(dot_data, format="svg") 
graph
../_images/decisiontrees_15_0.svg

10.2.5. Regression Decision Tree Model

In this example instead of classying a target variable into 0 or 1 we will use a decision tree for regression

The set-up is exactly the same as classification where we select feature columns, a target variable and split the data into a training and test data set

from sklearn.model_selection import train_test_split 
from sklearn import tree
from sklearn import metrics

feature_cols2 = ['time', 'num_borough', 
                 'NUMBER OF MOTORIST KILLED',
                 'NUMBER OF CYCLIST KILLED',
                 'NUMBER OF PEDESTRIANS KILLED']
x = nyc_collisions[feature_cols2] # Features
y = nyc_collisions.num_ppl_injured  #Target variable

x_train, x_test, y_train, y_test = train_test_split(
    x, y, test_size=0.2, random_state=12) # 80% training and 20% test

In a classification problem we can use accuracy and a confusion matrix to gauge a model’s effectiveness but that is not the case in a regression problem.

Here we use mean absolute error to judge our model

To show what the model is doing, we create another variable for the absolute value of the difference between the actual and predicted values.

rgr = tree.DecisionTreeRegressor()

# Train Decision Tree Regressor
rgr = rgr.fit(x_train,y_train)

#Predict the response for test dataset
y_pred = rgr.predict(x_test)

df = pd.DataFrame({'Actual':y_test, 'Predicted':y_pred})
df["diff"] = abs(df["Actual"] - df["Predicted"])
df.sort_values(by = ['diff'], inplace = True)

print('Mean Absolute Error:', 
      metrics.mean_absolute_error(y_test, y_pred))

df.head()
Mean Absolute Error: 0.5601416714129356
Actual Predicted diff
4880 0 0.0 0.0
4058 0 0.0 0.0
5954 0 0.0 0.0
2226 0 0.0 0.0
6532 0 0.0 0.0

Here we can sort by descending order to show which of our predictions were the worst.

df.sort_values(by=['diff'], ascending=False,inplace=True)
df.head()
Actual Predicted diff
7258 5 0.512077 4.487923
5363 5 0.518293 4.481707
1901 5 1.000000 4.000000
3360 4 0.339844 3.660156
6313 4 0.373596 3.626404

Similarly to a classification problem we can still plot the resulting decision tree

dot_data = tree.export_graphviz(rgr, out_file=None, 
                                feature_names=feature_cols2,  
                                class_names=None,
                                filled=True, rounded=True,  
                      special_characters=True)

# Draw graph
graph = graphviz.Source(dot_data, format="svg") 
graph
../_images/decisiontrees_23_0.svg