"""
torch.float64对应torch.DoubleTensor
torch.float32对应torch.FloatTensor
将真实函数的数据点能够拟合成一个多项式
eg:y = 0.9 +0.5×x + 3×x*x + 2.4 ×x*x*x
"""
import torch from torch import nn def make_features(x):
x = x.unsqueeze(1)#在原来的基础上扩充了一维
return torch.cat([x ** i for i in range(1,4)], 1) def get_batch(batch_size=32): random = torch.randn(batch_size)
# print('random')
# print(random) #32个数 x = make_features(random)#进行维度扩充,扩充后32*1,又进行1,2,3次幂运算,拼接后32*3 '''Compute the actual results'''
y = f(x) # 32*3 *3*1
if torch.cuda.is_available():
return torch.autograd.Variable(x).cuda(), torch.autograd.Variable(y).cuda()
else:
return torch.autograd.Variable(x), torch.autograd.Variable(y) w_target = torch.FloatTensor([0.5,3,2.4]).unsqueeze(1)#三行一列
b_target = torch.FloatTensor([0.9]) def f(x):
return x.mm(w_target)+b_target[0] class poly_model(nn.Module):
def __init__(self):
super(poly_model, self).__init__()
self.poly = nn.Linear(3, 1)# 输入是3维,输出是1维 def forward(self, x):
out = self.poly(x)
return out if torch.cuda.is_available():
model = poly_model().cuda()
else:
model = poly_model() criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) epoch = 0
for epoch in range(20):
batch_x,batch_y = get_batch()#batch_x 和get_batch里面的x是一样的
output = model(batch_x)
loss = criterion(output,batch_y)
print_loss = loss
print(loss.item()) # 0.4版本之后使用loss.item()从标量中获得Python number
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('finished')

torch_02_多项式回归的更多相关文章

  1. R语言多项式回归

    含有x和y这两个变量的线性回归是所有回归分析中最常见的一种:而且,在描述它们关系的时候,也是最有效.最容易假设的一种模型.然而,有些时候,它的实际情况下某些潜在的关系是非常复杂的,不是二元分析所能解决 ...

  2. 机器学习:scipy和sklearn中普通最小二乘法与多项式回归的使用对

    相关内容连接: 机器学习:Python中如何使用最小二乘法(以下简称文一) 机器学习:形如抛物线的散点图在python和R中的非线性回归拟合方法(以下简称文二) 有些内容已经在上面两篇博文中提到了,所 ...

  3. Machine Learning--week2 多元线性回归、梯度下降改进、特征缩放、均值归一化、多项式回归、正规方程与设计矩阵

    对于multiple features 的问题(设有n个feature),hypothesis 应该改写成 \[ \mathit{h} _{\theta}(x) = \theta_{0} + \the ...

  4. 线性回归,多项式回归(P2)

    回归问题 回归问题包含有线性回归和多项式回归 简单来说,线性回归就是用多元一次方程拟合数据,多项式回归是用多元多次来拟合方程 在几何意义上看,线性回归拟合出的是直线,平面.多项式拟合出来的是曲线,曲面 ...

  5. python 机器学习多项式回归

    现实世界的曲线关系都是通过增加多项式实现的,现在解决多项式回归问题 住房价格样本 样本图像 import matplotlib.font_manager as fm import matplotlib ...

  6. 【机器学习】多项式回归sklearn实现

    [机器学习]多项式回归原理介绍 [机器学习]多项式回归python实现 [机器学习]多项式回归sklearn实现 使用sklearn框架实现多项式回归.使用框架更方便,可以少写很多代码. 使用一个简单 ...

  7. 【机器学习】多项式回归python实现

    [机器学习]多项式回归原理介绍 [机器学习]多项式回归python实现 [机器学习]多项式回归sklearn实现 使用python实现多项式回归,没有使用sklearn等机器学习框架,目的是帮助理解算 ...

  8. 机器学习:多项式回归(scikit-learn中的多项式回归和 Pipeline)

    一.scikit-learn 中的多项式回归 1)实例过程 模拟数据 import numpy as np import matplotlib.pyplot as plt x = np.random. ...

  9. Matlab多项式回归实现

    多项式回归也称多元非线性回归,是指包含两个以上变量的非线性回归模型.对于多元非线性回归模型求解的传统解决方案,仍然是想办法把它转化成标准的线性形式的多元回归模型来处理. 多元非线性回归分析方程 如果自 ...

随机推荐

  1. My time is limited

    Your time is limited, so don't waste it living someone else's life. Don't be trapped by dogma - whic ...

  2. 宣布Visual Studio Code Installer for Java

    自从第一个Java语言服务器在微软苏黎世办公室的一个小型会议室的黑客马拉松中开发已经差不多3年了,该会议室的人员来自Red Hat,IBM,Codenvy和Microsoft,后来成为Visual S ...

  3. [笔记] vs code 设置终端

    设置文件: setting.json 1 设置自定义终端 cmd "terminal.integrated.shell.windows": "C:\\WINDOWS\\S ...

  4. Java生鲜电商平台-定时器,定时任务quartz的设计与架构

    Java生鲜电商平台-定时器,定时任务quartz的设计与架构 说明:任何业务有时候需要系统在某个定点的时刻执行某些任务,比如:凌晨2点统计昨天的报表,早上6点抽取用户下单的佣金. 对于Java开源生 ...

  5. jieba分词原理-DAG(NO HMM)

    最近公司在做一个推荐系统,让我给论坛上的帖子找关键字,当时给我说让我用jieba分词,我周末回去看了看,感觉不错,还学习了一下具体的原理 首先,通过正则表达式,将文章内容切分,形成一个句子数组,这个比 ...

  6. 「白帽挖洞技能提升」ThinkPHP5 远程代码执行漏洞-动态分析

    ThinkPHP是为了简化企业级应用开发和敏捷WEB应用开发而诞生的,在保持出色的性能和至简代码的同时,也注重易用性.但是简洁易操作也会出现漏洞,之前ThinkPHP官方修复了一个严重的远程代码执行漏 ...

  7. 修改源代码时不需要重启tomcat服务器

    我们在写JSP + Servlet 的时修改了Java代码就要重新启动服务器.十分麻烦. 为了解决这个问题我们可以将服务器改成debug 模式.就是按调试状态这样修改Java代码就不用再重新启动服务器 ...

  8. 消息中间件Kafaka - PHP操作使用Kafka

    Centos版本:Centos6.4,PHP版本:PHP7. 在上一篇文章中使用IP为192.168.9.154的机器安装并开启了Kafka进行了简单测试,充当了Kafka服务器. 本篇文章新开启一台 ...

  9. F5部署SSL证书

    查找中间证书 为了保证可以兼容所有浏览器,我们必须在服务器上安装中间证书,请到 中间证书下载工具,输入您的Server.cer,然后下载中间证书,请将中间证书保存为Chain.cer. 证书文件的上传 ...

  10. 关于微信小程序中遇到的各种问题汇总(持续更新)

    1.关于 <input />标签容易忽略的问题: 使用<input />标签时容易忘记绑定bindblur()方法(输入框失去焦点时触发),因为用户用键盘输入时不一定会点击完成 ...