MLP实现minist数据集分类任务
1. 数据集
minist手写体数字数据集
2. 代码
- '''
- Description:
- Author: zhangyh
- Date: 2024-05-04 15:21:49
- LastEditTime: 2024-05-04 22:36:26
- LastEditors: zhangyh
- '''
- import numpy as np
- class MlpClassifier:
- def __init__(self, input_size, hidden_size1, hidden_size2, output_size, learning_rate=0.01):
- self.input_size = input_size
- self.hidden_size1 = hidden_size1
- self.hidden_size2 = hidden_size2
- self.output_size = output_size
- self.learning_rate = learning_rate
- self.W1 = np.random.randn(input_size, hidden_size1) * 0.01
- self.b1 = np.zeros((1, hidden_size1))
- self.W2 = np.random.randn(hidden_size1, hidden_size2) * 0.01
- self.b2 = np.zeros((1, hidden_size2))
- self.W3 = np.random.randn(hidden_size2, output_size) * 0.01
- self.b3 = np.zeros((1, output_size))
- def softmax(self, x):
- exps = np.exp(x - np.max(x, axis=1, keepdims=True))
- return exps / np.sum(exps, axis=1, keepdims=True)
- def relu(self, x):
- return np.maximum(x, 0)
- def relu_derivative(self, x):
- return np.where(x > 0, 1, 0)
- def cross_entropy_loss(self, y_true, y_pred):
- m = y_true.shape[0]
- return -np.sum(y_true * np.log(y_pred + 1e-8)) / m
- def forward(self, X):
- self.Z1 = np.dot(X, self.W1) + self.b1
- self.A1 = self.relu(self.Z1)
- self.Z2 = np.dot(self.A1, self.W2) + self.b2
- self.A2 = self.relu(self.Z2)
- self.Z3 = np.dot(self.A2, self.W3) + self.b3
- self.A3 = self.softmax(self.Z3)
- return self.A3
- def backward(self, X, y):
- m = X.shape[0]
- dZ3 = self.A3 - y
- dW3 = np.dot(self.A2.T, dZ3) / m
- db3 = np.sum(dZ3, axis=0, keepdims=True) / m
- dA2 = np.dot(dZ3, self.W3.T)
- dZ2 = dA2 * self.relu_derivative(self.Z2)
- dW2 = np.dot(self.A1.T, dZ2) / m
- db2 = np.sum(dZ2, axis=0, keepdims=True) / m
- dA1 = np.dot(dZ2, self.W2.T)
- dZ1 = dA1 * self.relu_derivative(self.Z1)
- dW1 = np.dot(X.T, dZ1) / m
- db1 = np.sum(dZ1, axis=0, keepdims=True) / m
- # Update weights and biases
- self.W3 -= self.learning_rate * dW3
- self.b3 -= self.learning_rate * db3
- self.W2 -= self.learning_rate * dW2
- self.b2 -= self.learning_rate * db2
- self.W1 -= self.learning_rate * dW1
- self.b1 -= self.learning_rate * db1
- # 计算精确度
- def accuracy(self, y_pred, y):
- predictions = np.argmax(y_pred, axis=1)
- correct_predictions = np.sum(predictions == np.argmax(y, axis=1))
- return correct_predictions / y.shape[0]
- def train(self, X, y, epochs=100, batch_size=64):
- print('Training...')
- m = X.shape[0]
- for epoch in range(epochs):
- for i in range(0, m, batch_size):
- X_batch = X[i:i+batch_size]
- y_batch = y[i:i+batch_size]
- # Forward propagation
- y_pred = self.forward(X_batch)
- # Backward propagation
- self.backward(X_batch, y_batch)
- if (epoch+1) % 10 == 0:
- loss = self.cross_entropy_loss(y, self.forward(X))
- acc = self.accuracy(y_pred, y_batch)
- print(f'Epoch {epoch+1}/{epochs}, Loss: {loss}, Training-Accuracy: {acc}')
- def test(self, X, y):
- print('Testing...')
- y_pred = self.forward(X)
- acc = self.accuracy(y_pred, y)
- return acc
- if __name__ == '__main__':
- import tensorflow as tf
- # 加载MNIST数据集
- (X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()
- # 将图像转换为向量形式
- X_train = X_train.reshape(X_train.shape[0], -1) / 255.0
- X_test = X_test.reshape(X_test.shape[0], -1) / 255.0
- # 将标签进行 one-hot 编码
- num_classes = 10
- y_train = tf.keras.utils.to_categorical(y_train, num_classes)
- y_test = tf.keras.utils.to_categorical(y_test, num_classes)
- # 打印转换后的结果
- # 训练集维度: (60000, 784) (60000, 10)
- # 测试集维度: (10000, 784) (10000, 10)
- model = MlpClassifier(784, 128, 128, 10)
- model.train(X_train, y_train)
- test_acc = model.test(X_test, y_test)
- print(f'Test-Accuracy: {test_acc}')
3. 运行结果
- Training...
- Epoch 10/100, Loss: 0.3617846299623725, Training-Accuracy: 0.9375
- Epoch 20/100, Loss: 0.1946690996652946, Training-Accuracy: 1.0
- Epoch 30/100, Loss: 0.13053815227522408, Training-Accuracy: 1.0
- Epoch 40/100, Loss: 0.09467908427578901, Training-Accuracy: 1.0
- Epoch 50/100, Loss: 0.07120217251250453, Training-Accuracy: 1.0
- Epoch 60/100, Loss: 0.055233734086591456, Training-Accuracy: 1.0
- Epoch 70/100, Loss: 0.04369171830999816, Training-Accuracy: 1.0
- Epoch 80/100, Loss: 0.03469674775956587, Training-Accuracy: 1.0
- Epoch 90/100, Loss: 0.027861857647949812, Training-Accuracy: 1.0
- Epoch 100/100, Loss: 0.0225212692988995, Training-Accuracy: 1.0
- Testing...
- Test-Accuracy: 0.9775
MLP实现minist数据集分类任务的更多相关文章
- 单向LSTM笔记, LSTM做minist数据集分类
单向LSTM笔记, LSTM做minist数据集分类 先介绍下torch.nn.LSTM()这个API 1.input_size: 每一个时步(time_step)输入到lstm单元的维度.(实际输入 ...
- 用CNN及MLP等方法识别minist数据集
用CNN及MLP等方法识别minist数据集 2017年02月13日 21:13:09 hnsywangxin 阅读数:1124更多 个人分类: 深度学习.keras.tensorflow.cnn ...
- 机器学习与Tensorflow(3)—— 机器学习及MNIST数据集分类优化
一.二次代价函数 1. 形式: 其中,C为代价函数,X表示样本,Y表示实际值,a表示输出值,n为样本总数 2. 利用梯度下降法调整权值参数大小,推导过程如下图所示: 根据结果可得,权重w和偏置b的梯度 ...
- Python实现鸢尾花数据集分类问题——基于skearn的NaiveBayes
Python实现鸢尾花数据集分类问题——基于skearn的NaiveBayes 代码如下: # !/usr/bin/env python # encoding: utf-8 __author__ = ...
- Python实现鸢尾花数据集分类问题——基于skearn的LogisticRegression
Python实现鸢尾花数据集分类问题——基于skearn的LogisticRegression 一. 逻辑回归 逻辑回归(Logistic Regression)是用于处理因变量为分类变量的回归问题, ...
- Python实现鸢尾花数据集分类问题——基于skearn的SVM
Python实现鸢尾花数据集分类问题——基于skearn的SVM 代码如下: # !/usr/bin/env python # encoding: utf-8 __author__ = 'Xiaoli ...
- BP算法在minist数据集上的简单实现
BP算法在minist上的简单实现 数据:http://yann.lecun.com/exdb/mnist/ 参考:blog,blog2,blog3,tensorflow 推导:http://www. ...
- TensorFlow笔记三:从Minist数据集出发 两种经典训练方法
Minist数据集:MNIST_data 包含四个数据文件 一.方法一:经典方法 tf.matmul(X,w)+b import tensorflow as tf import numpy as np ...
- 3.keras-简单实现Mnist数据集分类
keras-简单实现Mnist数据集分类 1.载入数据以及预处理 import numpy as np from keras.datasets import mnist from keras.util ...
- 6.keras-基于CNN网络的Mnist数据集分类
keras-基于CNN网络的Mnist数据集分类 1.数据的载入和预处理 import numpy as np from keras.datasets import mnist from keras. ...
随机推荐
- HarmonyOS网络管理开发—Socket连接
简介 Socket连接主要是通过Socket进行数据传输,支持TCP/UDP/TLS协议. 基本概念 ● Socket:套接字,就是对网络中不同主机上的应用进程之间进行双向通信的端点的抽象. ● ...
- IntelliJ IDEA 配置类注释模板
菜单栏依次点击 File > Settings 在弹出窗口中找到 Editor >File and Code Templates 在右侧中 Files 选项卡中找到 Class. 在右侧输 ...
- mongodb基础整理篇————设计[四]
前言 简单整理一下mongodb的设计. 正文 设计三步曲: 第一步:建立基础文档模型 例子: 1对1建模: 1 对多建模: N对N模型: 第二步: 根据读写况细化 遇到的问题: 解决: 查询连表: ...
- Greenplum Jdbc 调用 SETOF refcursor
最近公司需要用Greenplum,在调用 jdbc的时候遇到了一些问题.由于我们前提的业务都是使用 sqlserver,sqlserver的 procedure 在前端展示做数据源的时候才用的非常多, ...
- python flashtext字符串快速替换,自然语言处理加速
在自然语言处理当中,经常对数据集进行一些数据字符的替换,表情的替换,以便在tokenizer的时候不被识别成[unk],造成信息的缺失 常规方法使用python自带的replace方法实现,但数据量很 ...
- 我们为什么需要操作系统(Operating System)?
我们为什么需要操作系统(Operating System)? a) 从计算机体系的角度,OS向下统筹了所有硬件资源(1),向上为所有软件提供API调用(2),使得软件程序员不必知晓硬件的具体细节,实现 ...
- HL7消息类型
HL7消息有很多不同的类型,每种都有其自己的独特用途和消息内容.以下是常见的HL7消息类型的列表. Message Type Description HL7 ADT Admit, Discharge ...
- 力扣283(java)-移动零(简单)
题目: 给定一个数组 nums,编写一个函数将所有 0 移动到数组的末尾,同时保持非零元素的相对顺序. 请注意 ,必须在不复制数组的情况下原地对数组进行操作. 示例 1: 输入: nums = [0, ...
- 力扣636(java)-函数的独占时间(中等)
题目: 有一个 单线程 CPU 正在运行一个含有 n 道函数的程序.每道函数都有一个位于 0 和 n-1 之间的唯一标识符. 函数调用 存储在一个 调用栈 上 :当一个函数调用开始时,它的标识符将会 ...
- 从0到1使用Webpack5 + React + TS构建标准化应用
简介: 本篇文章主要讲解如何从一个空目录开始,建立起一个基于webpack + react + typescript的标准化前端应用. 作者 | 刘皇逊(恪语)来源 | 阿里开发者公众号 前言 本篇文 ...