数据集简介
MNIST(Modified National Institute of Standards and Technology database)是机器学习领域最著名、最常用的入门级图像分类数据集之一。它包含了大量手写数字的灰度图片,任务目标是正确识别出每张图片中的数字(0-9)。
由于其规模适中、任务直观、识别难度不高,它常被用作验证新算法或教学演示的“标准测试平台”,被誉为机器学习领域的 “Hello World”。
数据内容与结构
- 类别:10类,分别对应数字 0 到 9。
- 图像格式:每张图片是固定大小的 28×28 像素 的灰度图像。
- 像素值:每个像素是一个0到255之间的整数,0代表白色(背景),255代表纯黑色(笔迹),中间值是不同程度的灰色。
- 数据量:数据集通常被分为两部分:
- 训练集:60,000 张 图片,用于训练模型。
- 测试集:10,000 张 图片,用于评估模型的最终性能。
数据示例
每个样本通常由一个特征向量(图像数据)和一个标签(真实数字)组成。
- 原始图像:可以看作一个 28行 x 28列 的矩阵。
- 常用扁平化处理:在输入到许多经典机器学习模型(如逻辑回归、全连接神经网络)时,会将这个矩阵“展平”成一个长度为 784 (28*28) 的一维向量。
- 标签:一个 0-9 的整数,或通常被处理成 One-Hot 编码(例如,数字5表示为
[0,0,0,0,0,1,0,0,0,0])。本例中,我们将在代码中把整数5做One-Hot编码处理
代码详解
导入数据集
本例中使用keras的数据集包(keras.datasets)中的mnist数据集的load_data()函数直接加载mnist数据,其中:x_train 为训练集手写图片;y_train 为训练集标签;x_test 为测试集手写图片;y_test 为测试集标签
import keras
from keras.datasets import mnist
################################
# 加载mnist数据集
# x_train 为训练集手写图片
# y_train 为训练集标签
# x_test 为测试集手写图片
# y_test 为测试集标签
################################
(x_train, y_train), (x_test, y_test) = mnist.load_data()
查看数据集中的数据结构
下面的代码显示训练集图片数据和标签数据的形状,显示训练集索引0(第一张手写图片)的图片和标签所示数字“5”