返回介绍

交叉验证 2 Cross-validation

发布于 2025-05-02 13:36:27 字数 1878 浏览 0 评论 0 收藏 0

作者: Bhan 编辑: Morvan

sklearn.learning_curve 中的 learning curve 可以很直观的看出我们的 model 学习的进度,对比发现有没有 overfitting 的问题. 然后我们可以对我们的 model 进行调整,克服 overfitting 的问题。

例子 - 藉由学习曲线(Learning curve) 来检视过拟合(Overfitting) 的问题

加载对应模块:

from sklearn.learning_curve import learning_curve #学习曲线模块
from sklearn.datasets import load_digits #digits 数据集
from sklearn.svm import SVC #Support Vector Classifier
import matplotlib.pyplot as plt #可视化模块
import numpy as np

加载 digits 数据集,其包含的是手写体的数字,从 0 到 9。数据集总共有 1797 个样本,每个样本由 64 个特征组成, 分别为其手写体对应的 8×8 像素表示,每个特征取值 0~16。

digits = load_digits()
X = digits.data
y = digits.target

观察样本由小到大的学习曲线变化,采用 K 折交叉验证 cv=10 , 选择平均方差检视模型效能 scoring='mean_squared_error' , 样本由小到大分成 5 轮检视学习曲线 (10%, 25%, 50%, 75%, 100%) :

train_sizes, train_loss, test_loss = learning_curve(
    SVC(gamma=0.001), X, y, cv=10, scoring='mean_squared_error',
    train_sizes=[0.1, 0.25, 0.5, 0.75, 1])

#平均每一轮所得到的平均方差(共 5 轮,分别为样本 10%、25%、50%、75%、100%)
train_loss_mean = -np.mean(train_loss, axis=1)
test_loss_mean = -np.mean(test_loss, axis=1)

可视化图形:

plt.plot(train_sizes, train_loss_mean, 'o-', color="r",
         label="Training")
plt.plot(train_sizes, test_loss_mean, 'o-', color="g",
        label="Cross-validation")

plt.xlabel("Training examples")
plt.ylabel("Loss")
plt.legend(loc="best")
plt.show()

如果你觉得这篇文章或视频对你的学习很有帮助,请你也分享它,让它能再次帮助到更多的需要学习的人。

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。
列表为空,暂无数据
    我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。