Python SVM手写数字识别
Python SVM手写数字识别
Python 基于sklearn - svm实现MNIST手写数字识别
一、数据集:MNIST
数据地址:http://yann.lecun.com/exdb/mnist/
训练数据:MNIST中的60000张图像,0-9的手写数字
测试数据:MNIST中的10000张图像,0-9的手写数字
注意:训练和测试代码直接使用了ubyte格式数据,即只对原数据进行了解压,没有先转换为png/jpg,但也附上png数据转换代码。
数据格式转换:从ubyte转换到png格式,存储格式:mnist_train>label>.png,代码如下:
提示:PIL不再支持新版本,要额外安装Pillow库
import numpy as np
import structfrom PIL import Image
import osdata_file = 'train-images.idx3-ubyte'
# It's 47040016B, but we should set to 47040000B
data_file_size = 47040016
data_file_size = str(data_file_size - 16) + 'B'data_buf = open(data_file, 'rb').read()magic, numImages, numRows, numColumns = struct.unpack_from('>IIII', data_buf, 0)
datas = struct.unpack_from('>' + data_file_size, data_buf, struct.calcsize('>IIII'))
datas = np.array(datas).astype(np.uint8).reshape(numImages, 1, numRows, numColumns)label_file = 'train-labels.idx1-ubyte'# It's 60008B, but we should set to 60000B
label_file_size = 60008
label_file_size = str(label_file_size - 8) + 'B'label_buf = open(label_file, 'rb').read()magic, numLabels = struct.unpack_from('>II', label_buf, 0)
labels = struct.unpack_from('>' + label_file_size, label_buf, struct.calcsize('>II'))
labels = np.array(labels).astype(np.int64)datas_root = 'mnist_train'
if not os.path.exists(datas_root):os.mkdir(datas_root)for i in range(10):file_name = datas_root + os.sep + str(i)if not os.path.exists(file_name):os.mkdir(file_name)count = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
for ii in range(numLabels):img = Image.fromarray(datas[ii, 0, 0:28, 0:28])label = labels[ii]file_name = datas_root + os.sep + str(label) + os.sep + \str(label) + '_' + str(count[label]) + '.png'count[label] = count[label] + 1# file_name = datas_root + os.sep + str(label) + os.sep + \# 'mnist_train_' + str(ii) + '.png'img.save(file_name)data_file = 't10k-images.idx3-ubyte'
# It's 7840016B, but we should set to 7840000B
data_file_size = 7840016
data_file_size = str(data_file_size - 16) + 'B'data_buf = open(data_file, 'rb').read()magic, numImages, numRows, numColumns = struct.unpack_from('>IIII', data_buf, 0)
datas = struct.unpack_from('>' + data_file_size, data_buf, struct.calcsize('>IIII'))
datas = np.array(datas).astype(np.uint8).reshape(numImages, 1, numRows, numColumns)label_file = 't10k-labels.idx1-ubyte'# It's 10008B, but we should set to 10000B
label_file_size = 10008
label_file_size = str(label_file_size - 8) + 'B'label_buf = open(label_file, 'rb').read()magic, numLabels = struct.unpack_from('>II', label_buf, 0)
labels = struct.unpack_from('>' + label_file_size, label_buf, struct.calcsize('>II'))
labels = np.array(labels).astype(np.int64)datas_root = 'mnist_test'
if not os.path.exists(datas_root):os.mkdir(datas_root)for i in range(10):file_name = datas_root + os.sep + str(i)if not os.path.exists(file_name):os.mkdir(file_name)count = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
for ii in range(numLabels):img = Image.fromarray(datas[ii, 0, 0:28, 0:28])label = labels[ii]file_name = datas_root + os.sep + str(label) + os.sep + \str(label) + '_' + str(count[label]) + '.png'count[label] = count[label] + 1# file_name = datas_root + os.sep + str(label) + os.sep + \# 'mnist_test_' + str(ii) + '.png'img.save(file_name)
转换后的数据如下图
二、训练模型
import numpy as np
import struct
import pickle
from sklearn import svm
###用于做数据预处理
from sklearn import preprocessing##读取数据集
def load_mnist_train(labels_path, images_path):with open(labels_path, 'rb') as lbpath:magic, n = struct.unpack('>II', lbpath.read(8))labels = np.fromfile(lbpath, dtype=np.uint8)with open(images_path, 'rb') as imgpath:magic, num, rows, cols = struct.unpack('>IIII', imgpath.read(16))images = np.fromfile(imgpath, dtype=np.uint8).reshape(len(labels), 784)return images, labelsif __name__ == '__main__':##读取训练数据labels_path = "train-labels.idx1-ubyte"images_path = "train-images.idx3-ubyte"train_images, train_labels = load_mnist_train(labels_path, images_path)##标准化X = preprocessing.StandardScaler().fit_transform(train_images)X_train = X[0:60000]y_train = train_labels[0:60000]##定义并训练模型model_svc = svm.SVC()model_svc.fit(X_train, y_train)file = open("model.pickle", "wb")##保存模型pickle.dump(model_svc, file)file.close()
三、测试模型
import numpy as np
import struct
import pickle
###用于做数据预处理
from sklearn import preprocessingdef test(images_path, labels_path, modelPath):# 读取测试图像with open(labels_path, 'rb') as lbpath:magic, n = struct.unpack('>II', lbpath.read(8))test_labels = np.fromfile(lbpath, dtype=np.uint8)with open(images_path, 'rb') as imgpath:magic, num, rows, cols = struct.unpack('>IIII', imgpath.read(16))test_images = np.fromfile(imgpath, dtype=np.uint8).reshape(len(test_labels), 784)##读取模型file = open(modelPath, "rb")model_svc = pickle.load(file)file.close()##评分并预测x = preprocessing.StandardScaler().fit_transform(test_images)x_test = x[0:10000]y_test = test_labels[0:10000]num = model_svc.predict(x_test)for i in range(10000):print("Real:", y_test[i], "Predict:", num[i])print("Accuracy:", model_svc.score(x_test, y_test))return numif __name__ == '__main__':images_path = "t10k-images.idx3-ubyte"labels_path = "t10k-labels.idx1-ubyte"modelPath = "model.pickle"num = test(images_path, labels_path, modelPath)
四、参考资料
图片格式转换: MNIST数据集格式ubyte转png_haoji007的博客-CSDN博客_ubyte
模型训练及测试:图像处理基本库的学习笔记2--SVM,MATLAB,Tensorflow下分别对mnist数据集进行训练,并且进行预测 - 灰信网(软件开发博客聚合)
sklearn-svm模型参数设置:机器学习笔记(3)-sklearn支持向量机SVM - 简书
模型保存和调用: 基于sklearn的SVM模型保存与调用_hellosonny的博客-CSDN博客_svm保存模型
单个图片测试:基于svm机器学习的手写数字识别_Brinshy的博客-CSDN博客_基于svm的手写数字识别