LinerProgression
手动实现线性回归
点击查看代码
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import random
from torch.utils import data
构造一个人造数据集
点击查看代码
def synthetic_data(w, b, num_examples):
"""生成 y = Xw + b +噪声"""
x = torch.normal(0, 1, (num_examples, len(w))) # 均值为0,方差为1 的随机数,行数为num,列数为len(x)
y = torch.matmul(x, w) + b
y += torch.normal(0, 0.1, y.shape) # 随机噪音
return x, y.reshape(-1, 1) # 将y转换成一列
true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)
每次读取一个batch数据量
点击查看代码
def data_iter(batch_size, features, labels):
num_examples = len(features) # 样本数量
indices = list(range(num_examples)) # 生成一个下标列表
random.shuffle(indices) # 将列表中顺序打乱,否则就会有序提取不好,我们要随机取样本
for i in range(0, num_examples, batch_size): # 从0开始到num_examples结束,每次拿batch_size个数据
batch_indices = torch.tensor(indices[i:min(i + batch_size, num_examples)]) # 将拿出的下标拿出来,如果最后不够一个batchsize则拿到最后位置
yield features[batch_indices], labels[batch_indices] # 每次返回一个x,一个y直到完全返回
batch_size = 10
for x, y in data_iter(batch_size, features, labels):
print(x, '\n', y)
break
w = torch.normal(0, 0.01, size=(2, 1), requires_grad=True) # 生成一个均值为0方差为0.1 的两行一列的张量
b = torch.zeros(1, requires_grad=True) # 生成了一个0
定义模型
点击查看代码
def linreg(x, w, b):
return torch.matmul(x, w) + b
损失函数 均方误差
点击查看代码
def squared_loss(y_hat, y):
return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2
优化算法 小批量下降
点击查看代码
def sgd(params, lr, batch_size):
"""小批量下降"""
with torch.no_grad():
for param in params:
param -= lr * param.grad / batch_size
param.grad.zero_()
实现
点击查看代码
lr = 0.01
num_epochs = 5
net = linreg
loss = squared_loss
for epoch in range(num_epochs):
for x, y in data_iter(batch_size, features, labels):
l = loss(net(x, w, b), y) # x, y的小批量损失
l.sum().backward()
sgd([w, b], lr, batch_size)
with torch.no_grad():
train_l = loss(net(features, w, b), labels)
print(f'epoch {epoch + 1}, loss {float(train_l.mean()):f}')
print(f'w的估计误差:{true_w - w.reshape(true_w.shape)}')
print(f'b的估计误差:{true_b - b}')
给笔者点个赞呀!
随机推荐
- AT_arc113_c 题解
洛谷链接&Atcoder 链接 本篇题解为此题较简单做法及较少码量,并且码风优良,请放心阅读. 题目简述 现在有一个字符串 \(S\),每一次你可以选择一个 \(i(1 \le i \le | ...
- 使用uWSGI+nginx部署Django项目(Ubuntu)
对于uwsgi+nginx的部署方式,它的访问关系大概是: 1 the web client <-> the web server <-> the socket <-&g ...
- 配置Sprig security后Post请求无法使用
在学习过程中发现在配置完Spring security后,Post请求失效,无法增删改数据,这里可以通过在Spring Security 的Config类中增加 也可以自定义csrf,不过目前还不是很 ...
- PHP 高性能框架 Workerman 凭什么能硬刚 Swoole ?
大家好,我是码农先森. 一次偶然看到了国外某机构针对 PHP 周边生态框架及扩展的性能测试排行榜,看到 Workerman 竟遥遥领先 Swoole.在我们 PHP 程序员现有的认知里,Swoole ...
- mysql 忘记root密码怎么办?
忘记root可以跳过grant table来登录 1.打开命令行输入以下命令 mysqld -nt --grant-skip-tables 2.在打开一个新命令行,输入以下命令可以登录, mysql ...
- nginx的一些功能
一.四层(tcp/udp)代理 由于nginx默认是不支持四层代理的因此在安装的时候需要加上对应的模块with-stream ./configure --with-stream # 查看当前nginx ...
- web3 产品介绍: safe --多签钱包 多人审批更放心
Safe是一款由Gnosis团队开发的多签钱包,它提供了一种安全.灵活和易于使用的方式来管理加密资产.在本文中,我们将介绍Safe的主要特点以及如何使用Safe来保护您的数字资产. 一.Safe的特点 ...
- 【Axure RP】Axure RP 9 下载安装及汉化
本体及破解机下载: http://www.sd173.com/soft/7951.html 汉化补丁教程见: https://blog.csdn.net/weixin_74457789/article ...
- 【MySQL】Windows-5.7.30 解压版 下载安装
1.Download 下载 mysql官网: https://dev.mysql.com/ 找到download点击进入下载页面: https://dev.mysql.com/downloads/ 找 ...
- python科学计算:加速库numba —— 安装和试用
安装(anaconda环境下) conda install numba Demo代码: from numba import jit from numpy import arange import nu ...