Decision Trees
Contents
10.2. Decision Trees¶
Decision trees are a predictive modeling approach that uses probability trees to either predict a continous value or predict a classification that the data fits into.
Pros:
Easy to understand and visualize
Easy to figure out why the model is making a certain prediction
Doesn’t need as much data preparation as other prediction methods
Cons:
Sometimes decision trees can get too complex and overfit the data
Small variations in the data could cause different trees to be created which can drastically change the model’s output
10.2.1. Concepts¶
The goal of a decision tree algorithm to create a model that predicts the value of a target variable by learning simple decision rules inferred from the data features. A tree can be seen as a piecewise constant approximation.
10.2.1.1. Algorithm Formulation¶
At a given node, find the best split that minimizes some impurity or loss measure \(H\) after the split. Let \(Q_m\) be the data at node \(m\) with sample size \(n_m\). Let \(\theta\) be a candidate split (which may consists of a candidate threshold for candidate feature). Suppose that after the split, \(Q_{m,l}\) is the left node data with sample size \(n_{m,l}\) and \(Q_{m,r}\) is the right node data with sample size \(n_{m,r}\). The quality of the split is measured by
The algorithm set
Recursively find the best split for each child node.
Stopping: until a the maximum tree depth is reached or all node sample size is below a preset threshold.
Pruning: reduces the complexity of the final classifier, and hence improves predictive accuracy by the reduction of overfitting.
10.2.1.2. Metrics¶
See sklearn
documentation for details.
Classification
Gini
$\(H(Q_m) = \sum_{k=1} p_{mk} (1 - p_{mk})\)$Entropy $\(H(Q_m) = - \sum p_{mk} \log p_{mk}\)$
Misclassification $\(H(Q_m) = 1 - \max_k p_{mk}\)$
Regression
Mean squared error
Half Poisson deviance (for count targets)
Mean absolute error (slower than MSE; more robust)
10.2.1.3. Confusion Matrix¶
See sklearn
example for
details.
See definitions on Wiki.
A confusion matrix is a matrix layout of the results of a classification algorithm, where each row of the matrix represents the instances in an actual class while each column represents the instances in a predicted class, or vice versa.
True positive (TP)
False positive (FP)
True negative (TN)
False negative (FN)
Metrics for evaluating classification:
Precision: $\(\frac{\text{TP}}{\text{TP} + \text{FP}}\)$
Recall (sensitivity): $\(\frac{\textt{TP}}{\text{TP} + \text{FP}}\)$
F-beta score: $\((1 + \beta^2) \frac{1}{\frac{\beta^2}{\text{recall}} + \frac{1}{\text{recision}}}\)\( where \)\beta\( means that recall is considered \)\beta\( times as important as precision. When \)\beta = 1$, the two are considered equally important.
From Wiki:
In a classification task, a precision score of 1.0 for a class C means that every item labelled as belonging to class C does indeed belong to class C (but says nothing about the number of items from class C that were not labelled correctly) whereas a recall of 1.0 means that every item from class C was labelled as belonging to class C (but says nothing about how many items from other classes were incorrectly also labelled as belonging to class C).
10.2.2. Simple Classification Example¶
## configure the inline figures to of svg format
%config InlineBackend.figure_formats = ['svg']
from sklearn.datasets import load_iris
from sklearn import tree
iris = load_iris()
X, y = iris.data, iris.target
clf = tree.DecisionTreeClassifier()
clf = clf.fit(X, y)
tree.plot_tree(clf)
[Text(0.5, 0.9166666666666666, 'X[3] <= 0.8\ngini = 0.667\nsamples = 150\nvalue = [50, 50, 50]'),
Text(0.4230769230769231, 0.75, 'gini = 0.0\nsamples = 50\nvalue = [50, 0, 0]'),
Text(0.5769230769230769, 0.75, 'X[3] <= 1.75\ngini = 0.5\nsamples = 100\nvalue = [0, 50, 50]'),
Text(0.3076923076923077, 0.5833333333333334, 'X[2] <= 4.95\ngini = 0.168\nsamples = 54\nvalue = [0, 49, 5]'),
Text(0.15384615384615385, 0.4166666666666667, 'X[3] <= 1.65\ngini = 0.041\nsamples = 48\nvalue = [0, 47, 1]'),
Text(0.07692307692307693, 0.25, 'gini = 0.0\nsamples = 47\nvalue = [0, 47, 0]'),
Text(0.23076923076923078, 0.25, 'gini = 0.0\nsamples = 1\nvalue = [0, 0, 1]'),
Text(0.46153846153846156, 0.4166666666666667, 'X[3] <= 1.55\ngini = 0.444\nsamples = 6\nvalue = [0, 2, 4]'),
Text(0.38461538461538464, 0.25, 'gini = 0.0\nsamples = 3\nvalue = [0, 0, 3]'),
Text(0.5384615384615384, 0.25, 'X[0] <= 6.95\ngini = 0.444\nsamples = 3\nvalue = [0, 2, 1]'),
Text(0.46153846153846156, 0.08333333333333333, 'gini = 0.0\nsamples = 2\nvalue = [0, 2, 0]'),
Text(0.6153846153846154, 0.08333333333333333, 'gini = 0.0\nsamples = 1\nvalue = [0, 0, 1]'),
Text(0.8461538461538461, 0.5833333333333334, 'X[2] <= 4.85\ngini = 0.043\nsamples = 46\nvalue = [0, 1, 45]'),
Text(0.7692307692307693, 0.4166666666666667, 'X[0] <= 5.95\ngini = 0.444\nsamples = 3\nvalue = [0, 1, 2]'),
Text(0.6923076923076923, 0.25, 'gini = 0.0\nsamples = 1\nvalue = [0, 1, 0]'),
Text(0.8461538461538461, 0.25, 'gini = 0.0\nsamples = 2\nvalue = [0, 0, 2]'),
Text(0.9230769230769231, 0.4166666666666667, 'gini = 0.0\nsamples = 43\nvalue = [0, 0, 43]')]
10.2.3. Data Cleaning¶
Import data and add a binary column for if a person was injured or not
Can’t use strings in decision trees so to fix this we can change each borough to a number EX: Bronx = 1, Brooklyn = 2
import pandas as pd
import numpy as np
nyc_collisions = pd.read_csv("../data/nyc_mv_collisions_202201.csv")
nyc_collisions["time"] = [x.split(":")[0] for x in nyc_collisions["CRASH TIME"]]
nyc_collisions["time"] = [int(x) for x in nyc_collisions["time"]]
nyc_collisions["injury_binary"] = nyc_collisions["NUMBER OF PERSONS INJURED"].map(lambda x: 1 if x>0 else 0)
nyc_collisions["num_borough"] = nyc_collisions["BOROUGH"].map(lambda x: 1 if x=="BRONX" else 0)
nyc_collisions["num_borough"] = nyc_collisions["BOROUGH"].map(lambda x: 2 if x=="BROOKLYN" else 0)
nyc_collisions["num_borough"] = nyc_collisions["BOROUGH"].map(lambda x: 3 if x=="QUEENS" else 0)
nyc_collisions["num_borough"] = nyc_collisions["BOROUGH"].map(lambda x: 4 if x=="MANHATTAN" else 0)
nyc_collisions["num_borough"] = nyc_collisions["BOROUGH"].map(lambda x: 5 if x=="STATEN ISLAND" else 0)
nyc_collisions.rename(columns={"NUMBER OF PERSONS KILLED": "num_ppl_killed"}, inplace=True)
nyc_collisions.rename(columns={"NUMBER OF PERSONS INJURED": "num_ppl_injured"}, inplace=True)
nyc_collisions['injury_binary'].value_counts()
0 5311
1 2348
Name: injury_binary, dtype: int64
10.2.4. Classification Decision Tree Model¶
Select columns to use in our decision tree model.
Don’t want to use columns such as number of pedestrians injured because that heavily affects our target variable
We also split our dataset into a training set and a test set to avoid overfitting the data
We are trying to predict if someone was injured or not in a crash using all of the columns in feature_cols
from sklearn.model_selection import train_test_split
from sklearn import tree
from sklearn import metrics
from sklearn.metrics import classification_report, confusion_matrix
feature_cols = ['time', 'num_borough', 'NUMBER OF MOTORIST KILLED',
'NUMBER OF CYCLIST KILLED',
'NUMBER OF PEDESTRIANS KILLED']
x = nyc_collisions[feature_cols] # Features
y = nyc_collisions.injury_binary # Target variable
x_train, x_test, y_train, y_test = train_test_split(
x, y, test_size=0.2, random_state=12) # 80% training and 20% test
We then set up our decision tree and use it to predict on the test set. This example is a classification example so it will predict a 0 or 1 based on if someone was injured or not in the crash (1=injured)
We then can print out the accuracy of our model and a confusion matrix to see its predictions
clf = tree.DecisionTreeClassifier()
# Train Decision Tree Classifer
clf = clf.fit(x_train,y_train)
#Predict the response for test dataset
y_pred = clf.predict(x_test)
print("Accuracy:",metrics.accuracy_score(y_test, y_pred))
print(confusion_matrix(y_test, y_pred))
Accuracy: 0.6977806788511749
[[1064 3]
[ 460 5]]
We can plot this model using the sklearn plotting function
tree.plot_tree(clf)
[Text(0.6271067415730337, 0.9583333333333334, 'X[0] <= 16.5\ngini = 0.426\nsamples = 6127\nvalue = [4244, 1883]'),
Text(0.46207865168539325, 0.875, 'X[2] <= 0.5\ngini = 0.405\nsamples = 4226\nvalue = [3035, 1191]'),
Text(0.351123595505618, 0.7916666666666666, 'X[1] <= 2.5\ngini = 0.404\nsamples = 4222\nvalue = [3034, 1188]'),
Text(0.19662921348314608, 0.7083333333333334, 'X[0] <= 5.5\ngini = 0.403\nsamples = 4110\nvalue = [2962, 1148]'),
Text(0.1348314606741573, 0.625, 'X[0] <= 4.5\ngini = 0.386\nsamples = 940\nvalue = [694, 246]'),
Text(0.11235955056179775, 0.5416666666666666, 'X[0] <= 3.5\ngini = 0.391\nsamples = 792\nvalue = [581, 211]'),
Text(0.0898876404494382, 0.4583333333333333, 'X[0] <= 2.5\ngini = 0.385\nsamples = 641\nvalue = [474, 167]'),
Text(0.06741573033707865, 0.375, 'X[0] <= 1.5\ngini = 0.393\nsamples = 535\nvalue = [391, 144]'),
Text(0.0449438202247191, 0.2916666666666667, 'X[0] <= 0.5\ngini = 0.389\nsamples = 405\nvalue = [298, 107]'),
Text(0.02247191011235955, 0.20833333333333334, 'gini = 0.397\nsamples = 271\nvalue = [197, 74]'),
Text(0.06741573033707865, 0.20833333333333334, 'gini = 0.371\nsamples = 134\nvalue = [101, 33]'),
Text(0.0898876404494382, 0.2916666666666667, 'gini = 0.407\nsamples = 130\nvalue = [93, 37]'),
Text(0.11235955056179775, 0.375, 'gini = 0.34\nsamples = 106\nvalue = [83, 23]'),
Text(0.1348314606741573, 0.4583333333333333, 'gini = 0.413\nsamples = 151\nvalue = [107, 44]'),
Text(0.15730337078651685, 0.5416666666666666, 'gini = 0.361\nsamples = 148\nvalue = [113, 35]'),
Text(0.25842696629213485, 0.625, 'X[0] <= 8.5\ngini = 0.407\nsamples = 3170\nvalue = [2268, 902]'),
Text(0.20224719101123595, 0.5416666666666666, 'X[4] <= 0.5\ngini = 0.437\nsamples = 820\nvalue = [556, 264]'),
Text(0.1797752808988764, 0.4583333333333333, 'X[0] <= 7.5\ngini = 0.437\nsamples = 818\nvalue = [554, 264]'),
Text(0.15730337078651685, 0.375, 'X[0] <= 6.5\ngini = 0.431\nsamples = 461\nvalue = [316, 145]'),
Text(0.1348314606741573, 0.2916666666666667, 'gini = 0.437\nsamples = 223\nvalue = [151, 72]'),
Text(0.1797752808988764, 0.2916666666666667, 'gini = 0.425\nsamples = 238\nvalue = [165, 73]'),
Text(0.20224719101123595, 0.375, 'gini = 0.444\nsamples = 357\nvalue = [238, 119]'),
Text(0.2247191011235955, 0.4583333333333333, 'gini = 0.0\nsamples = 2\nvalue = [2, 0]'),
Text(0.3146067415730337, 0.5416666666666666, 'X[0] <= 10.5\ngini = 0.396\nsamples = 2350\nvalue = [1712, 638]'),
Text(0.2696629213483146, 0.4583333333333333, 'X[0] <= 9.5\ngini = 0.335\nsamples = 516\nvalue = [406, 110]'),
Text(0.24719101123595505, 0.375, 'gini = 0.335\nsamples = 263\nvalue = [207, 56]'),
Text(0.29213483146067415, 0.375, 'gini = 0.336\nsamples = 253\nvalue = [199, 54]'),
Text(0.3595505617977528, 0.4583333333333333, 'X[0] <= 15.5\ngini = 0.41\nsamples = 1834\nvalue = [1306, 528]'),
Text(0.33707865168539325, 0.375, 'X[0] <= 14.5\ngini = 0.416\nsamples = 1476\nvalue = [1040, 436]'),
Text(0.3146067415730337, 0.2916666666666667, 'X[0] <= 12.5\ngini = 0.409\nsamples = 1132\nvalue = [808, 324]'),
Text(0.2696629213483146, 0.20833333333333334, 'X[0] <= 11.5\ngini = 0.422\nsamples = 520\nvalue = [363, 157]'),
Text(0.24719101123595505, 0.125, 'gini = 0.404\nsamples = 256\nvalue = [184, 72]'),
Text(0.29213483146067415, 0.125, 'gini = 0.437\nsamples = 264\nvalue = [179, 85]'),
Text(0.3595505617977528, 0.20833333333333334, 'X[0] <= 13.5\ngini = 0.397\nsamples = 612\nvalue = [445, 167]'),
Text(0.33707865168539325, 0.125, 'gini = 0.379\nsamples = 256\nvalue = [191, 65]'),
Text(0.38202247191011235, 0.125, 'gini = 0.409\nsamples = 356\nvalue = [254, 102]'),
Text(0.3595505617977528, 0.2916666666666667, 'gini = 0.439\nsamples = 344\nvalue = [232, 112]'),
Text(0.38202247191011235, 0.375, 'gini = 0.382\nsamples = 358\nvalue = [266, 92]'),
Text(0.5056179775280899, 0.7083333333333334, 'X[0] <= 3.5\ngini = 0.459\nsamples = 112\nvalue = [72, 40]'),
Text(0.449438202247191, 0.625, 'X[0] <= 2.5\ngini = 0.5\nsamples = 8\nvalue = [4, 4]'),
Text(0.42696629213483145, 0.5416666666666666, 'X[0] <= 0.5\ngini = 0.444\nsamples = 6\nvalue = [4, 2]'),
Text(0.4044943820224719, 0.4583333333333333, 'gini = 0.5\nsamples = 2\nvalue = [1, 1]'),
Text(0.449438202247191, 0.4583333333333333, 'X[0] <= 1.5\ngini = 0.375\nsamples = 4\nvalue = [3, 1]'),
Text(0.42696629213483145, 0.375, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'),
Text(0.47191011235955055, 0.375, 'gini = 0.444\nsamples = 3\nvalue = [2, 1]'),
Text(0.47191011235955055, 0.5416666666666666, 'gini = 0.0\nsamples = 2\nvalue = [0, 2]'),
Text(0.5617977528089888, 0.625, 'X[0] <= 8.5\ngini = 0.453\nsamples = 104\nvalue = [68, 36]'),
Text(0.5168539325842697, 0.5416666666666666, 'X[0] <= 4.5\ngini = 0.405\nsamples = 39\nvalue = [28, 11]'),
Text(0.4943820224719101, 0.4583333333333333, 'gini = 0.0\nsamples = 2\nvalue = [2, 0]'),
Text(0.5393258426966292, 0.4583333333333333, 'X[0] <= 5.5\ngini = 0.418\nsamples = 37\nvalue = [26, 11]'),
Text(0.5168539325842697, 0.375, 'gini = 0.444\nsamples = 6\nvalue = [4, 2]'),
Text(0.5617977528089888, 0.375, 'X[0] <= 7.5\ngini = 0.412\nsamples = 31\nvalue = [22, 9]'),
Text(0.5393258426966292, 0.2916666666666667, 'X[0] <= 6.5\ngini = 0.408\nsamples = 14\nvalue = [10, 4]'),
Text(0.5168539325842697, 0.20833333333333334, 'gini = 0.408\nsamples = 7\nvalue = [5, 2]'),
Text(0.5617977528089888, 0.20833333333333334, 'gini = 0.408\nsamples = 7\nvalue = [5, 2]'),
Text(0.5842696629213483, 0.2916666666666667, 'gini = 0.415\nsamples = 17\nvalue = [12, 5]'),
Text(0.6067415730337079, 0.5416666666666666, 'X[0] <= 9.5\ngini = 0.473\nsamples = 65\nvalue = [40, 25]'),
Text(0.5842696629213483, 0.4583333333333333, 'gini = 0.5\nsamples = 8\nvalue = [4, 4]'),
Text(0.6292134831460674, 0.4583333333333333, 'X[0] <= 10.5\ngini = 0.465\nsamples = 57\nvalue = [36, 21]'),
Text(0.6067415730337079, 0.375, 'gini = 0.0\nsamples = 3\nvalue = [3, 0]'),
Text(0.651685393258427, 0.375, 'X[0] <= 11.5\ngini = 0.475\nsamples = 54\nvalue = [33, 21]'),
Text(0.6292134831460674, 0.2916666666666667, 'gini = 0.5\nsamples = 12\nvalue = [6, 6]'),
Text(0.6741573033707865, 0.2916666666666667, 'X[0] <= 14.5\ngini = 0.459\nsamples = 42\nvalue = [27, 15]'),
Text(0.6292134831460674, 0.20833333333333334, 'X[0] <= 13.5\ngini = 0.413\nsamples = 24\nvalue = [17, 7]'),
Text(0.6067415730337079, 0.125, 'X[0] <= 12.5\ngini = 0.444\nsamples = 15\nvalue = [10, 5]'),
Text(0.5842696629213483, 0.041666666666666664, 'gini = 0.444\nsamples = 9\nvalue = [6, 3]'),
Text(0.6292134831460674, 0.041666666666666664, 'gini = 0.444\nsamples = 6\nvalue = [4, 2]'),
Text(0.651685393258427, 0.125, 'gini = 0.346\nsamples = 9\nvalue = [7, 2]'),
Text(0.7191011235955056, 0.20833333333333334, 'X[0] <= 15.5\ngini = 0.494\nsamples = 18\nvalue = [10, 8]'),
Text(0.6966292134831461, 0.125, 'gini = 0.496\nsamples = 11\nvalue = [6, 5]'),
Text(0.7415730337078652, 0.125, 'gini = 0.49\nsamples = 7\nvalue = [4, 3]'),
Text(0.5730337078651685, 0.7916666666666666, 'X[0] <= 7.5\ngini = 0.375\nsamples = 4\nvalue = [1, 3]'),
Text(0.550561797752809, 0.7083333333333334, 'gini = 0.0\nsamples = 3\nvalue = [0, 3]'),
Text(0.5955056179775281, 0.7083333333333334, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'),
Text(0.7921348314606742, 0.875, 'X[0] <= 18.5\ngini = 0.463\nsamples = 1901\nvalue = [1209, 692]'),
Text(0.6853932584269663, 0.7916666666666666, 'X[0] <= 17.5\ngini = 0.478\nsamples = 722\nvalue = [437, 285]'),
Text(0.6404494382022472, 0.7083333333333334, 'X[1] <= 2.5\ngini = 0.459\nsamples = 378\nvalue = [243, 135]'),
Text(0.6179775280898876, 0.625, 'gini = 0.456\nsamples = 367\nvalue = [238, 129]'),
Text(0.6629213483146067, 0.625, 'gini = 0.496\nsamples = 11\nvalue = [5, 6]'),
Text(0.7303370786516854, 0.7083333333333334, 'X[1] <= 2.5\ngini = 0.492\nsamples = 344\nvalue = [194, 150]'),
Text(0.7078651685393258, 0.625, 'X[4] <= 0.5\ngini = 0.49\nsamples = 329\nvalue = [188, 141]'),
Text(0.6853932584269663, 0.5416666666666666, 'gini = 0.489\nsamples = 328\nvalue = [188, 140]'),
Text(0.7303370786516854, 0.5416666666666666, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'),
Text(0.7528089887640449, 0.625, 'gini = 0.48\nsamples = 15\nvalue = [6, 9]'),
Text(0.898876404494382, 0.7916666666666666, 'X[1] <= 2.5\ngini = 0.452\nsamples = 1179\nvalue = [772, 407]'),
Text(0.8426966292134831, 0.7083333333333334, 'X[4] <= 0.5\ngini = 0.456\nsamples = 1150\nvalue = [746, 404]'),
Text(0.8202247191011236, 0.625, 'X[0] <= 19.5\ngini = 0.456\nsamples = 1147\nvalue = [743, 404]'),
Text(0.7752808988764045, 0.5416666666666666, 'X[2] <= 0.5\ngini = 0.441\nsamples = 286\nvalue = [192, 94]'),
Text(0.7528089887640449, 0.4583333333333333, 'gini = 0.442\nsamples = 285\nvalue = [191, 94]'),
Text(0.797752808988764, 0.4583333333333333, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'),
Text(0.8651685393258427, 0.5416666666666666, 'X[0] <= 22.5\ngini = 0.461\nsamples = 861\nvalue = [551, 310]'),
Text(0.8426966292134831, 0.4583333333333333, 'X[0] <= 20.5\ngini = 0.466\nsamples = 666\nvalue = [420, 246]'),
Text(0.8202247191011236, 0.375, 'gini = 0.467\nsamples = 245\nvalue = [154, 91]'),
Text(0.8651685393258427, 0.375, 'X[0] <= 21.5\ngini = 0.465\nsamples = 421\nvalue = [266, 155]'),
Text(0.8426966292134831, 0.2916666666666667, 'gini = 0.465\nsamples = 207\nvalue = [131, 76]'),
Text(0.8876404494382022, 0.2916666666666667, 'gini = 0.466\nsamples = 214\nvalue = [135, 79]'),
Text(0.8876404494382022, 0.4583333333333333, 'gini = 0.441\nsamples = 195\nvalue = [131, 64]'),
Text(0.8651685393258427, 0.625, 'gini = 0.0\nsamples = 3\nvalue = [3, 0]'),
Text(0.9550561797752809, 0.7083333333333334, 'X[0] <= 20.5\ngini = 0.185\nsamples = 29\nvalue = [26, 3]'),
Text(0.9325842696629213, 0.625, 'X[0] <= 19.5\ngini = 0.337\nsamples = 14\nvalue = [11, 3]'),
Text(0.9101123595505618, 0.5416666666666666, 'gini = 0.198\nsamples = 9\nvalue = [8, 1]'),
Text(0.9550561797752809, 0.5416666666666666, 'gini = 0.48\nsamples = 5\nvalue = [3, 2]'),
Text(0.9775280898876404, 0.625, 'gini = 0.0\nsamples = 15\nvalue = [15, 0]')]
We can plot this same exact model using the graphviz package and we see it looks much better
import graphviz
dot_data = tree.export_graphviz(clf, out_file=None,
feature_names=feature_cols,
class_names=None,
filled=True, rounded=True,
special_characters=True)
# Draw graph
graph = graphviz.Source(dot_data, format="svg")
graph
We can run this exact same decision tree but try to optimize its performance by changing parameters such as max_depth to set how many layers our tree is able to go until
from sklearn.model_selection import train_test_split
from sklearn import tree
from sklearn import metrics
from sklearn.metrics import classification_report, confusion_matrix
feature_cols = ['time', 'num_borough',
'NUMBER OF MOTORIST KILLED',
'NUMBER OF CYCLIST KILLED',
'NUMBER OF PEDESTRIANS KILLED']
x = nyc_collisions[feature_cols] # Features
y = nyc_collisions.injury_binary # Target variable
x_train, x_test, y_train, y_test = train_test_split(
x, y, test_size=0.2, random_state=12) # 80% training and 20% test
clf = tree.DecisionTreeClassifier(max_depth=3)
# Train Decision Tree Classifer
clf = clf.fit(x_train,y_train)
#Predict the response for test dataset
y_pred = clf.predict(x_test)
print("Accuracy:",metrics.accuracy_score(y_test, y_pred))
print(confusion_matrix(y_test, y_pred))
Accuracy: 0.6971279373368147
[[1067 0]
[ 464 1]]
We can see this change visually by re-running our code to graph the tree
dot_data = tree.export_graphviz(clf, out_file=None,
feature_names=feature_cols,
class_names=None,
filled=True, rounded=True,
special_characters=True)
# Draw graph
graph = graphviz.Source(dot_data, format="svg")
graph
10.2.5. Regression Decision Tree Model¶
In this example instead of classying a target variable into 0 or 1 we will use a decision tree for regression
The set-up is exactly the same as classification where we select feature columns, a target variable and split the data into a training and test data set
from sklearn.model_selection import train_test_split
from sklearn import tree
from sklearn import metrics
feature_cols2 = ['time', 'num_borough',
'NUMBER OF MOTORIST KILLED',
'NUMBER OF CYCLIST KILLED',
'NUMBER OF PEDESTRIANS KILLED']
x = nyc_collisions[feature_cols2] # Features
y = nyc_collisions.num_ppl_injured #Target variable
x_train, x_test, y_train, y_test = train_test_split(
x, y, test_size=0.2, random_state=12) # 80% training and 20% test
In a classification problem we can use accuracy and a confusion matrix to gauge a model’s effectiveness but that is not the case in a regression problem.
Here we use mean absolute error to judge our model
To show what the model is doing, we create another variable for the absolute value of the difference between the actual and predicted values.
rgr = tree.DecisionTreeRegressor()
# Train Decision Tree Regressor
rgr = rgr.fit(x_train,y_train)
#Predict the response for test dataset
y_pred = rgr.predict(x_test)
df = pd.DataFrame({'Actual':y_test, 'Predicted':y_pred})
df["diff"] = abs(df["Actual"] - df["Predicted"])
df.sort_values(by = ['diff'], inplace = True)
print('Mean Absolute Error:',
metrics.mean_absolute_error(y_test, y_pred))
df.head()
Mean Absolute Error: 0.5601416714129356
Actual | Predicted | diff | |
---|---|---|---|
4880 | 0 | 0.0 | 0.0 |
4058 | 0 | 0.0 | 0.0 |
5954 | 0 | 0.0 | 0.0 |
2226 | 0 | 0.0 | 0.0 |
6532 | 0 | 0.0 | 0.0 |
Here we can sort by descending order to show which of our predictions were the worst.
df.sort_values(by=['diff'], ascending=False,inplace=True)
df.head()
Actual | Predicted | diff | |
---|---|---|---|
7258 | 5 | 0.512077 | 4.487923 |
5363 | 5 | 0.518293 | 4.481707 |
1901 | 5 | 1.000000 | 4.000000 |
3360 | 4 | 0.339844 | 3.660156 |
6313 | 4 | 0.373596 | 3.626404 |
Similarly to a classification problem we can still plot the resulting decision tree
dot_data = tree.export_graphviz(rgr, out_file=None,
feature_names=feature_cols2,
class_names=None,
filled=True, rounded=True,
special_characters=True)
# Draw graph
graph = graphviz.Source(dot_data, format="svg")
graph