Python 机器学习实战:鸢尾花分类项目详解

一、项目概述

本教程将带领你从零开始构建一个完整的机器学习项目 —— 鸢尾花数据集分类任务。我们将使用 Python 语言结合 scikit-learn 库,完成从数据加载、探索性分析、特征工程、模型构建、评估优化到最终部署的全流程实践。

通过这个项目,你将掌握:

机器学习基础工作流程数据预处理与可视化方法经典分类算法的应用与调优模型评估指标的理解与使用简单的模型部署方法

二、环境准备

2.1 安装必要库

python

运行



# 安装核心库
!pip install numpy pandas matplotlib seaborn scikit-learn joblib jupyter
 
# 导入所需库
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split, cross_val_score, GridSearchCV
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier
from sklearn.svm import SVC
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
import joblib
import warnings
warnings.filterwarnings('ignore')
 
# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False

三、数据加载与探索

3.1 数据加载

python

运行



# 加载鸢尾花数据集
iris = load_iris()
 
# 转换为DataFrame格式便于分析
df = pd.DataFrame(data=iris.data, columns=iris.feature_names)
df['target'] = iris.target
df['species'] = df['target'].map({0: 'setosa', 1: 'versicolor', 2: 'virginica'})
 
# 查看数据基本信息
print("数据集形状:", df.shape)
print("
数据集前5行:")
print(df.head())
print("
数据集基本统计信息:")
print(df.describe())
print("
目标变量分布:")
print(df['species'].value_counts())

3.2 数据可视化分析

3.2.1 特征分布直方图

python

运行



# 设置画布大小
plt.figure(figsize=(15, 10))
 
# 绘制每个特征的直方图
features = iris.feature_names
for i, feature in enumerate(features):
    plt.subplot(2, 2, i+1)
    for species in df['species'].unique():
        plt.hist(df[df['species']==species][feature], 
                 alpha=0.7, 
                 label=species,
                 bins=15)
    plt.xlabel(feature)
    plt.ylabel('频数')
    plt.title(f'{feature}分布')
    plt.legend()
 
plt.tight_layout()
plt.savefig('feature_distribution.png', dpi=300, bbox_inches='tight')
plt.show()
3.2.2 特征散点图矩阵

python

运行



# 绘制散点图矩阵
plt.figure(figsize=(12, 10))
sns.pairplot(df, hue='species', markers=['o', 's', 'D'])
plt.suptitle('鸢尾花特征散点图矩阵', y=1.02)
plt.savefig('pairplot.png', dpi=300, bbox_inches='tight')
plt.show()
3.2.3 特征相关性分析

python

运行



# 计算特征相关性
numeric_df = df.drop(['target', 'species'], axis=1)
correlation_matrix = numeric_df.corr()
 
# 绘制热力图
plt.figure(figsize=(10, 8))
sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', fmt='.2f', linewidths=0.5)
plt.title('特征相关性热力图')
plt.savefig('correlation_heatmap.png', dpi=300, bbox_inches='tight')
plt.show()
3.2.4 箱线图分析

python

运行



# 绘制箱线图
plt.figure(figsize=(15, 8))
for i, feature in enumerate(features):
    plt.subplot(1, 4, i+1)
    sns.boxplot(x='species', y=feature, data=df)
    plt.title(f'{feature}箱线图')
    plt.tight_layout()
 
plt.savefig('boxplot.png', dpi=300, bbox_inches='tight')
plt.show()

四、数据预处理

4.1 数据分割

python

运行



# 准备特征和目标变量
X = iris.data
y = iris.target
 
# 分割训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.3, random_state=42, stratify=y
)
 
print(f"训练集大小: {X_train.shape}")
print(f"测试集大小: {X_test.shape}")
print(f"训练集类别分布: {np.bincount(y_train)}")
print(f"测试集类别分布: {np.bincount(y_test)}")

4.2 特征标准化

python

运行



# 标准化特征
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)
 
# 查看标准化后的数据统计
print("训练集标准化后统计:")
print(f"均值: {np.mean(X_train_scaled, axis=0)}")
print(f"标准差: {np.std(X_train_scaled, axis=0)}")

五、模型构建与训练

5.1 多种模型比较

python

运行



# 定义要比较的模型
models = {
    'K近邻': KNeighborsClassifier(),
    '支持向量机': SVC(),
    '决策树': DecisionTreeClassifier(),
    '随机森林': RandomForestClassifier()
}
 
# 存储模型性能
results = {}
 
# 训练并评估每个模型
print("各模型交叉验证准确率:")
print("-" * 30)
for name, model in models.items():
    # 使用5折交叉验证
    cv_scores = cross_val_score(model, X_train_scaled, y_train, cv=5)
    results[name] = {
        'cv_mean': cv_scores.mean(),
        'cv_std': cv_scores.std(),
        'cv_scores': cv_scores
    }
    print(f"{name}: {cv_scores.mean():.4f} ± {cv_scores.std():.4f}")
 
# 可视化交叉验证结果
plt.figure(figsize=(10, 6))
model_names = list(results.keys())
means = [results[name]['cv_mean'] for name in model_names]
stds = [results[name]['cv_std'] for name in model_names]
 
plt.bar(model_names, means, yerr=stds, alpha=0.7, capsize=10)
plt.ylabel('准确率')
plt.title('各模型交叉验证性能比较')
plt.ylim(0.8, 1.0)
plt.grid(axis='y', linestyle='--', alpha=0.7)
 
# 添加数值标签
for i, (mean, std) in enumerate(zip(means, stds)):
    plt.text(i, mean + 0.01, f'{mean:.3f}', ha='center')
 
plt.savefig('model_comparison.png', dpi=300, bbox_inches='tight')
plt.show()

5.2 模型调优(以随机森林为例)

python

运行



# 定义参数网格
param_grid = {
    'n_estimators': [50, 100, 150, 200],
    'max_depth': [None, 5, 10, 15],
    'min_samples_split': [2, 5, 10],
    'min_samples_leaf': [1, 2, 4]
}
 
# 创建模型
rf = RandomForestClassifier(random_state=42)
 
# 网格搜索
grid_search = GridSearchCV(
    estimator=rf,
    param_grid=param_grid,
    cv=5,
    n_jobs=-1,
    verbose=1
)
 
# 训练模型
grid_search.fit(X_train_scaled, y_train)
 
# 输出最佳参数
print("最佳参数组合:")
print(grid_search.best_params_)
print(f"最佳交叉验证准确率: {grid_search.best_score_:.4f}")
 
# 获取最佳模型
best_rf = grid_search.best_estimator_

5.3 特征重要性分析

python

运行



# 分析特征重要性
feature_importance = best_rf.feature_importances_
feature_names = iris.feature_names
 
# 创建特征重要性DataFrame
importance_df = pd.DataFrame({
    '特征': feature_names,
    '重要性': feature_importance
}).sort_values('重要性', ascending=False)
 
print("特征重要性排序:")
print(importance_df)
 
# 可视化特征重要性
plt.figure(figsize=(10, 6))
sns.barplot(x='重要性', y='特征', data=importance_df)
plt.title('随机森林特征重要性')
plt.xlabel('重要性得分')
plt.tight_layout()
plt.savefig('feature_importance.png', dpi=300, bbox_inches='tight')
plt.show()

六、模型评估

6.1 模型预测

python

运行



# 使用最佳模型进行预测
y_pred = best_rf.predict(X_test_scaled)
y_pred_proba = best_rf.predict_proba(X_test_scaled)
 
# 计算准确率
accuracy = accuracy_score(y_test, y_pred)
print(f"测试集准确率: {accuracy:.4f}")

6.2 分类报告

python

运行



# 生成分类报告
print("分类报告:")
print("-" * 50)
print(classification_report(y_test, y_pred, target_names=iris.target_names))

6.3 混淆矩阵

python

运行



# 计算混淆矩阵
cm = confusion_matrix(y_test, y_pred)
 
# 可视化混淆矩阵
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=iris.target_names,
            yticklabels=iris.target_names)
plt.xlabel('预测标签')
plt.ylabel('真实标签')
plt.title('混淆矩阵')
plt.tight_layout()
plt.savefig('confusion_matrix.png', dpi=300, bbox_inches='tight')
plt.show()

6.4 概率预测分析

python

运行



# 分析预测概率
plt.figure(figsize=(12, 8))
 
# 为每个类别绘制概率分布
for i, species in enumerate(iris.target_names):
    plt.subplot(3, 1, i+1)
    # 获取真实为第i类的样本的预测概率
    mask = y_test == i
    probs = y_pred_proba[mask, i]
    
    plt.hist(probs, bins=10, alpha=0.7, color=f'C{i}')
    plt.axvline(x=0.5, color='red', linestyle='--')
    plt.title(f'{species}类别的预测概率分布')
    plt.xlabel('预测概率')
    plt.ylabel('频数')
    plt.xlim(0, 1)
    plt.grid(alpha=0.3)
 
plt.tight_layout()
plt.savefig('probability_distribution.png', dpi=300, bbox_inches='tight')
plt.show()

6.5 错误案例分析

python

运行



# 找出错误预测的样本
errors = y_test != y_pred
error_indices = np.where(errors)[0]
 
if len(error_indices) > 0:
    print(f"
错误预测样本数: {len(error_indices)}")
    print("
错误预测详情:")
    for idx in error_indices:
        true_label = iris.target_names[y_test[idx]]
        pred_label = iris.target_names[y_pred[idx]]
        prob = y_pred_proba[idx]
        print(f"样本索引 {idx}: 真实={true_label}, 预测={pred_label}, 概率={prob}")
        
        # 显示错误样本的特征值
        features = X_test_scaled[idx]
        feature_dict = dict(zip(iris.feature_names, features))
        print(f"特征值: {feature_dict}")
        print("-" * 30)
else:
    print("
模型在测试集上预测全部正确!")

七、模型部署

7.1 保存模型

python

运行



# 保存最佳模型和标准化器
joblib.dump(best_rf, 'iris_classifier_rf.pkl')
joblib.dump(scaler, 'iris_scaler.pkl')
 
print("模型和标准化器已保存!")

7.2 加载模型并预测

python

运行



# 加载保存的模型和标准化器
loaded_rf = joblib.load('iris_classifier_rf.pkl')
loaded_scaler = joblib.load('iris_scaler.pkl')
 
# 创建新样本进行预测
new_samples = np.array([
    [5.1, 3.5, 1.4, 0.2],  # setosa
    [6.2, 2.9, 4.3, 1.3],  # versicolor
    [7.3, 2.9, 6.3, 1.8]   # virginica
])
 
# 标准化新样本
new_samples_scaled = loaded_scaler.transform(new_samples)
 
# 预测
predictions = loaded_rf.predict(new_samples_scaled)
predictions_proba = loaded_rf.predict_proba(new_samples_scaled)
 
# 输出预测结果
print("新样本预测结果:")
for i, (sample, pred, proba) in enumerate(zip(new_samples, predictions, predictions_proba)):
    print(f"
样本 {i+1}: {sample}")
    print(f"预测类别: {iris.target_names[pred]} (代码: {pred})")
    print("类别概率:")
    for j, species in enumerate(iris.target_names):
        print(f"  {species}: {proba[j]:.4f}")

7.3 创建简单的预测函数

python

运行



def predict_iris_species(sepal_length, sepal_width, petal_length, petal_width):
    """
    预测鸢尾花种类
    
    参数:
    sepal_length: 花萼长度 (cm)
    sepal_width: 花萼宽度 (cm)
    petal_length: 花瓣长度 (cm)
    petal_width: 花瓣宽度 (cm)
    
    返回:
    预测结果字典
    """
    # 创建特征数组
    features = np.array([[sepal_length, sepal_width, petal_length, petal_width]])
    
    # 标准化
    features_scaled = loaded_scaler.transform(features)
    
    # 预测
    prediction = loaded_rf.predict(features_scaled)[0]
    probabilities = loaded_rf.predict_proba(features_scaled)[0]
    
    # 构建结果
    result = {
        'species': iris.target_names[prediction],
        'species_code': int(prediction),
        'probabilities': {
            iris.target_names[i]: float(probabilities[i]) 
            for i in range(len(iris.target_names))
        },
        'features': {
            'sepal_length': sepal_length,
            'sepal_width': sepal_width,
            'petal_length': petal_length,
            'petal_width': petal_width
        }
    }
    
    return result
 
# 测试预测函数
test_result = predict_iris_species(6.4, 2.8, 5.6, 2.2)
print("
预测函数测试结果:")
print(f"预测种类: {test_result['species']}")
print(f"置信度: {max(test_result['probabilities'].values()):.2%}")

八、项目总结与扩展

8.1 项目流程图



flowchart TD
    A[数据加载] --> B[探索性数据分析]
    B --> C[数据预处理]
    C --> D[特征工程]
    D --> E[模型选择]
    E --> F[模型训练]
    F --> G[模型评估]
    G --> H{模型是否达标?}
    H -- 否 --> I[参数调优]
    I --> F
    H -- 是 --> J[模型保存]
    J --> K[模型部署]
    K --> L[预测服务]
    
    B --> B1[数据可视化]
    B --> B2[统计分析]
    B --> B3[相关性分析]
    
    C --> C1[数据清洗]
    C --> C2[数据分割]
    C --> C3[特征标准化]
    
    E --> E1[K近邻]
    E --> E2[SVM]
    E --> E3[决策树]
    E --> E4[随机森林]
    
    G --> G1[准确率评估]
    G --> G2[混淆矩阵]
    G --> G3[分类报告]
    G --> G4[错误分析]

8.2 Prompt 示例

以下是可以用于向 AI 助手提问的 Prompt 示例:

Prompt 1: 数据探索

plaintext



我正在分析鸢尾花数据集,已经加载了数据并转换为DataFrame。请帮我生成代码来:
1. 检查数据是否有缺失值
2. 绘制每个特征的直方图,按类别着色
3. 计算特征之间的相关性并绘制热力图

Prompt 2: 模型构建

plaintext



我已经将鸢尾花数据集分割为训练集和测试集,并进行了标准化。请帮我:
1. 实现K近邻、SVM、决策树和随机森林模型
2. 使用5折交叉验证比较它们的性能
3. 绘制模型性能对比图

Prompt 3: 模型调优

plaintext



我的随机森林模型在鸢尾花分类任务上表现不错,但我想进一步优化。请提供:
1. 随机森林的网格搜索参数配置
2. 实现网格搜索的代码
3. 分析最佳参数和特征重要性的方法

Prompt 4: 模型评估

plaintext



我已经训练了一个随机森林模型,现在需要全面评估它的性能。请提供代码来:
1. 计算各种评估指标(准确率、精确率、召回率、F1分数)
2. 绘制混淆矩阵热力图
3. 分析预测概率分布
4. 找出并分析错误预测的样本

Prompt 5: 模型部署

plaintext



我已经训练好了鸢尾花分类模型,现在需要将其部署。请提供:
1. 保存和加载模型的代码
2. 创建预测函数的示例
3. 设计一个简单的API接口来提供预测服务

8.3 项目扩展方向

尝试更多算法

逻辑回归梯度提升(XGBoost, LightGBM)神经网络

高级特征工程

特征选择特征组合非线性特征转换

模型解释性

SHAP 值分析LIME 解释部分依赖图

部署优化

Flask/FastAPI 构建 API 服务Docker 容器化模型监控和更新

数据集扩展

添加噪声数据处理类别不平衡更大规模的数据集

九、常见问题解答

Q1: 为什么要对特征进行标准化?

A1: 许多机器学习算法(如 KNN、SVM、神经网络)对特征的尺度敏感。标准化可以使所有特征具有相同的尺度,避免某些特征因为数值范围大而主导模型训练。

Q2: 如何选择合适的模型?

A2: 通常需要尝试多种算法,通过交叉验证比较它们的性能。同时要考虑模型的复杂度、解释性和计算资源需求。对于小型数据集,简单模型往往表现更好且更稳定。

Q3: 过拟合如何处理?

A3: 处理过拟合的方法包括:增加数据量、正则化、简化模型、交叉验证、特征选择、集成方法等。在本项目中,我们通过网格搜索选择合适的参数来控制模型复杂度。

Q4: 特征重要性的意义是什么?

A4: 特征重要性告诉我们哪些特征对模型的预测贡献最大。这有助于理解数据、简化模型、发现特征之间的关系,并可能指导进一步的数据收集。

Q5: 如何评估分类模型的好坏?

A5: 除了准确率外,还需要关注精确率、召回率、F1 分数、混淆矩阵等指标。对于类别不平衡的问题,准确率可能会有误导性,需要综合考虑多个指标。

十、完整代码清单

python

运行



# 鸢尾花分类项目完整代码
 
# 1. 导入库
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split, cross_val_score, GridSearchCV
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier
from sklearn.svm import SVC
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
import joblib
import warnings
warnings.filterwarnings('ignore')
 
# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
 
# 2. 数据加载与探索
iris = load_iris()
df = pd.DataFrame(data=iris.data, columns=iris.feature_names)
df['target'] = iris.target
df['species'] = df['target'].map({0: 'setosa', 1: 'versicolor', 2: 'virginica'})
 
# 3. 数据预处理
X = iris.data
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42, stratify=y)
 
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)
 
# 4. 模型训练与调优
# 模型比较
models = {
    'K近邻': KNeighborsClassifier(),
    '支持向量机': SVC(),
    '决策树': DecisionTreeClassifier(),
    '随机森林': RandomForestClassifier()
}
 
# 网格搜索优化随机森林
param_grid = {
    'n_estimators': [50, 100, 150],
    'max_depth': [None, 5, 10],
    'min_samples_split': [2, 5],
    'min_samples_leaf': [1, 2]
}
 
rf = RandomForestClassifier(random_state=42)
grid_search = GridSearchCV(rf, param_grid, cv=5, n_jobs=-1)
grid_search.fit(X_train_scaled, y_train)
best_rf = grid_search.best_estimator_
 
# 5. 模型评估
y_pred = best_rf.predict(X_test_scaled)
print(f"准确率: {accuracy_score(y_test, y_pred):.4f}")
print(classification_report(y_test, y_pred, target_names=iris.target_names))
 
# 6. 模型保存与部署
joblib.dump(best_rf, 'iris_classifier_rf.pkl')
joblib.dump(scaler, 'iris_scaler.pkl')
 
# 预测函数
def predict_iris_species(sepal_length, sepal_width, petal_length, petal_width):
    features = np.array([[sepal_length, sepal_width, petal_length, petal_width]])
    features_scaled = scaler.transform(features)
    prediction = best_rf.predict(features_scaled)[0]
    probabilities = best_rf.predict_proba(features_scaled)[0]
    
    return {
        'species': iris.target_names[prediction],
        'probabilities': dict(zip(iris.target_names, probabilities))
    }
 
# 测试预测
print("
预测示例:")
print(predict_iris_species(6.4, 2.8, 5.6, 2.2))

总结

本教程详细介绍了一个完整的机器学习项目流程,从数据加载、探索性分析、预处理、模型构建、评估到部署的每一步都进行了详细讲解。通过鸢尾花分类这个经典案例,你不仅学习了具体的代码实现,更重要的是理解了机器学习项目的思维方式和工作流程。

这个项目虽然简单,但包含了机器学习实践的核心要素。掌握这些基础后,你可以将同样的方法论应用到更复杂的问题和更大规模的数据集上。机器学习是一个实践性很强的领域,不断练习和尝试是提升技能的关键。

© 版权声明
THE END
如果内容对您有所帮助,就支持一下吧!
点赞0 分享
评论 抢沙发

请登录后发表评论

    暂无评论内容