Pruning Strategies in Decision Trees

Pruning Strategies in Decision Trees

Decision trees are powerful tools for classification and regression tasks, but they are prone to overfitting, especially when they are deep and complex. Pruning is a technique used to reduce the size of a decision tree and improve its generalization by removing parts of the tree that do not provide significant power in predicting outcomes.

Introduction to Pruning

Pruning helps simplify the model, making it less complex and more interpretable. By cutting off nodes that contribute little to the predictive power of the model, we can reduce the variance without significantly increasing bias.

There are two main types of pruning strategies: 1. Pre-Pruning (also known as early stopping) 2. Post-Pruning

Pre-Pruning

Pre-pruning involves halting the growth of the tree before it reaches its full depth. This can be done by setting conditions such as: - Minimum number of samples required to split a node. - Maximum depth of the tree. - Minimum impurity decrease required for a split.

Example of Pre-Pruning in Python

`python from sklearn.tree import DecisionTreeClassifier from sklearn.datasets import load_iris from sklearn.model_selection import train_test_split

Load dataset

iris = load_iris() X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.2, random_state=42)

Create a Decision Tree Classifier with pre-pruning

clf = DecisionTreeClassifier(max_depth=3, min_samples_split=5) clf.fit(X_train, y_train)

Evaluate the model

accuracy = clf.score(X_test, y_test) print(f'Accuracy of pre-pruned tree: {accuracy}') `

In this example, the decision tree is restricted to a maximum depth of 3 and requires at least 5 samples to make a split. This simple setup helps reduce the likelihood of overfitting.

Post-Pruning

Post-pruning is performed after the tree has been fully grown. This approach involves the following steps: 1. First, grow the entire tree without restriction. 2. Evaluate each node and its subtree using a validation dataset. 3. Remove nodes that do not provide a significant improvement in predictive performance.

Example of Post-Pruning in Python

`python from sklearn.tree import DecisionTreeClassifier from sklearn.datasets import load_iris from sklearn.model_selection import train_test_split from sklearn.metrics import accuracy_score

Load dataset

iris = load_iris() X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.2, random_state=42)

Create a fully grown Decision Tree

clf_full = DecisionTreeClassifier() clf_full.fit(X_train, y_train)

Use cost complexity pruning (post-pruning)

path = clf_full.cost_complexity_pruning_path(X_train, y_train) ccp_alphas, impurities = path.ccp_alphas, path.impurities

Prune the tree for different values of ccp_alpha

clfs = [] for ccp_alpha in ccp_alphas: clf = DecisionTreeClassifier(ccp_alpha=ccp_alpha) clf.fit(X_train, y_train) clfs.append(clf)

Evaluate all pruned trees and choose the best one

accuracies = [accuracy_score(y_test, clf.predict(X_test)) for clf in clfs] optimal_alpha = ccp_alphas[accuracies.index(max(accuracies))] final_clf = DecisionTreeClassifier(ccp_alpha=optimal_alpha) final_clf.fit(X_train, y_train)

Final evaluation

final_accuracy = final_clf.score(X_test, y_test) print(f'Accuracy of post-pruned tree: {final_accuracy}') `

In this code example, we grow a full decision tree and then apply cost complexity pruning based on the alpha values derived from the training set. We evaluate each pruned version using the test set to find the one that maximizes accuracy.

Conclusion

Pruning is an essential part of managing decision trees and combating overfitting. By employing either pre-pruning or post-pruning strategies, we can create more robust models that generalize better to unseen data. Understanding when to apply these techniques is crucial for anyone looking to build effective machine learning models.

Quiz

Back to Course View Full Topic