决策树相关案例
全流程
以下是一个更复杂、全流程的决策树和随机森林示例,不仅包括模型训练和预测,还涵盖了数据预处理、超参数调优以及模型评估的可视化。我们依旧使用鸢尾花数据集,并额外引入 GridSearchCV 进行超参数调优,使用 matplotlib 进行简单的可视化。
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split, GridSearchCV, cross_val_score
from sklearn.preprocessing import StandardScaler
from sklearn.tree import DecisionTreeClassifier, export_graphviz
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, roc_curve, auc
from sklearn.externals.six import StringIO
from IPython.display import Image
from graphviz import Source
import pydotplus
# 1. 加载鸢尾花数据集
iris = load_iris()
X = pd.DataFrame(iris.data, columns=iris.feature_names)
y = pd.Series(iris.target)
# 2. 数据预处理
# 2.1 特征标准化
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
# 3. 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.3, random_state=42)
# 4. 决策树模型
# 4.1 定义超参数搜索空间
dtc_param_grid = {
'max_depth': [3, 5, 7, 10],
'min_samples_split': [2, 5, 10],
'min_samples_leaf': [1, 2, 4]
}
# 4.2 使用GridSearchCV进行超参数调优
dtc_grid_search = GridSearchCV(DecisionTreeClassifier(random_state=42), dtc_param_grid, cv=5)
dtc_grid_search.fit(X_train, y_train)
# 4.3 输出最佳超参数
print("决策树最佳超参数:", dtc_grid_search.best_params_)
# 4.4 使用最佳超参数构建决策树模型
dtc_best = dtc_grid_search.best_estimator_
# 4.5 预测并评估
y_pred_dtc = dtc_best.predict(X_test)
dtc_accuracy = accuracy_score(y_test, y_pred_dtc)
print(f"决策树模型的准确率: {dtc_accuracy}")
print("决策树分类报告:\n", classification_report(y_test, y_pred_dtc))
print("决策树混淆矩阵:\n", confusion_matrix(y_test, y_pred_dtc))
# 4.6 可视化决策树(需要graphviz工具支持)
dot_data = StringIO()
export_graphviz(dtc_best, out_file=dot_data,
filled=True, rounded=True,
special_characters=True, feature_names=iris.feature_names, class_names=iris.target_names)
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
Image(graph.create_png())
# 5. 随机森林模型
# 5.1 定义超参数搜索空间
rfc_param_grid = {
'n_estimators': [50, 100, 200],
'max_depth': [3, 5, 7, 10],
'min_samples_split': [2, 5, 10],
'min_samples_leaf': [1, 2, 4]
}
# 5.2 使用GridSearchCV进行超参数调优
rfc_grid_search = GridSearchCV(RandomForestClassifier(random_state=42), rfc_param_grid, cv=5)
rfc_grid_search.fit(X_train, y_train)
# 5.3 输出最佳超参数
print("随机森林最佳超参数:", rfc_grid_search.best_params_)
# 5.4 使用最佳超参数构建随机森林模型
rfc_best = rfc_grid_search.best_estimator_
# 5.5 预测并评估
y_pred_rfc = rfc_best.predict(X_test)
rfc_accuracy = accuracy_score(y_test, y_pred_rfc)
print(f"随机森林模型的准确率: {rfc_accuracy}")
print("随机森林分类报告:\n", classification_report(y_test, y_pred_rfc))
print("随机森林混淆矩阵:\n", confusion_matrix(y_test, y_pred_rfc))
# 5.6 绘制ROC曲线(以二分类为例,这里简单取其中一类演示)
fpr_rfc, tpr_rfc, thresholds_rfc = roc_curve(y_test == 0, rfc_best.predict_proba(X_test)[:, 0])
roc_auc_rfc = auc(fpr_rfc, tpr_rfc)
plt.figure()
plt.plot(fpr_rfc, tpr_rfc, label='Random Forest (area = %0.2f)' % roc_auc_rfc)
plt.plot([0, 1], [0, 1], 'k--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver operating characteristic example')
plt.legend(loc="lower right")
plt.show()
代码解释:
1. 数据加载与预处理:加载鸢尾花数据集,将其转换为 DataFrame 和 Series 形式,并对特征进行标准化处理。
2. 数据划分:将数据集划分为训练集和测试集。
3. 决策树模型:定义超参数搜索空间,使用 GridSearchCV 进行超参数调优,得到最佳超参数后构建决策树模型,进行预测和评估,并可视化决策树。
4. 随机森林模型:类似地,定义随机森林的超参数搜索空间,进行超参数调优,构建模型,预测评估,并绘制ROC曲线进行可视化。