Pytorch线性规划模型 学习笔记(一)
Pytorch线性规划模型 学习笔记(一)
Pytorch视频学习资料参考:《PyTorch深度学习实践》完结合集
Pytorch搭建神经网络的四大部分
1. 准备数据 Prepare dataset
准备数据包括数据的读取加载并转换为torch框架下识别的tensor格式,注意数据的dtype为float32格式
2. 设计模型 Design model using class
网络的基本框架部分,包括自定义的网络layer结构,注意维度的变换要一致,另外,该类中还应包括forward部分
3. 构建损失和优化器 Construct loss and optimizer
根据处理的问题和模型设置合适的损失,或自己构建损失函数。优化器为梯度下降的解决方案,可选择合适的优化器进行梯度下降
4. 重复训练 Training cycle
重复训练部分可以后续设置batchsize的大小,按batch进行随机梯度下降(此代码中暂无设置),注意优化器的清零迭代操作
数据部分
X.csv,y.csv链接: https://pan.baidu.com/s/1dJD8zBewCS86fRgv0nL7kQ 密码: 0us0
下载后与程序放置在同一文件夹下
代码部分
# import
import torch
import numpy as np
## 1. prepare dataset
x_data = 'X.csv'
y_data = 'y.csv'
x = np.loadtxt(x_data, delimiter=' ', dtype=np.float32)
y = np.loadtxt(y_data, delimiter=' ', dtype=np.float32).reshape(-1, 1)
x = torch.from_numpy(x[:, :])
y = torch.from_numpy(y[:, :])
print(y.shape)
print(x.shape)
## 2. design model using class
class LinearModel(torch.nn.Module):
def __init__(self):
super(LinearModel, self).__init__()
self.linear1 = torch.nn.Linear(10, 6)
self.linear2 = torch.nn.Linear(6, 6)
self.linear3 = torch.nn.Linear(6, 1)
self.sigmoid = torch.nn.Sigmoid()
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
x = self.linear3(x)
x = self.sigmoid(x)
return x
model = LinearModel()
## 3. construct loss and optimizer
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr = 0.01)
## 4. training cycle
for epoch in range(500):
y_hat = model(x)
loss = criterion(y_hat, y)
print('epoch', epoch, loss.item())
optimizer.zero_grad()
loss.backward()
optimizer.step()
Pytorch线性规划模型 学习笔记(一)的更多相关文章
- 概率图模型学习笔记:HMM、MEMM、CRF
作者:Scofield链接:https://www.zhihu.com/question/35866596/answer/236886066来源:知乎著作权归作者所有.商业转载请联系作者获得授权,非商 ...
- NIO模型学习笔记
NIO模型学习笔记 简介 Non-blocking I/O 或New I/O 自JDK1.4开始使用 应用场景:高并发网络服务器支持 概念理解 模型:对事物共性的抽象 编程模型:对编程共性的抽象 BI ...
- LDA主题模型学习笔记5:C源代码理解
1.说明 本文对LDA原始论文的作者所提供的C代码中LDA的主要逻辑部分做凝视,原代码可在这里下载到:https://github.com/Blei-Lab/lda-c 这份代码实现论文<Lat ...
- GAN︱生成模型学习笔记(运行机制、NLP结合难点、应用案例、相关Paper)
我对GAN"生成对抗网络"(Generative Adversarial Networks)的看法: 前几天在公开课听了新加坡国立大学[机器学习与视觉实验室]负责人冯佳时博士在[硬 ...
- HMM模型学习笔记(前向算法实例)
HMM算法想必大家已经听说了好多次了,完全看公式一头雾水.但是HMM的基本理论其实很简单.因为HMM是马尔科夫链中的一种,只是它的状态不能直接被观察到,但是可以通过观察向量间接的反映出来,即每一个观察 ...
- Note | PyTorch官方教程学习笔记
目录 1. 快速入门PYTORCH 1.1. 什么是PyTorch 1.1.1. 基础概念 1.1.2. 与NumPy之间的桥梁 1.2. Autograd: Automatic Differenti ...
- 微软CodeDom模型学习笔记(全)
CodeDomProvider MSDN描述 CodeDomProvider可用于创建和检索代码生成器和代码编译器的实例.代码生成器可用于以特定的语言生成代码,而代码编译器可用于将代码编译为程序集. ...
- OSI七层模型学习笔记
1.简介 什么是OSI模型呢? OSI模型全名Open System InterConnect 即开放式系统互联,是国际标准化组织(ISO)提出的一个试图使各种计算机在世界范围内互连为网络的标准框架, ...
- 深度学习在美团点评推荐平台排序中的应用&& wide&&deep推荐系统模型--学习笔记
写在前面:据说下周就要xxxxxxxx, 吓得本宝宝赶紧找些广告的东西看看 gbdt+lr的模型之前是知道怎么搞的,dnn+lr的模型也是知道的,但是都没有试验过 深度学习在美团点评推荐平台排序中的运 ...
随机推荐
- 逆向 stdio.h 函数库 fopen 函数(调试版本)
0x01 fopen 函数 函数原型:FILE *fopen(const char *filename, const char *mode) 返回值为 FILE 类型 函数功能:使用给定的模式 mod ...
- 使用 WinAFL 图片解析软件进行模糊测试 - FreeImage 图片解析库
看雪链接:https://bbs.pediy.com/thread-255162.htm
- 【python】Leetcode每日一题-删除排序链表中的重复元素2
[python]Leetcode每日一题-删除排序链表中的重复元素2 [题目描述] 存在一个按升序排列的链表,给你这个链表的头节点 head ,请你删除链表中所有存在数字重复情况的节点,只保留原始链表 ...
- Linux(CentOS-8)安装MySQL8.0.11
CentOS安装MySQL8.0.11 总的思路就是:安装MySQL,编写配置文件,配置环境变量,成功开启服务,登陆并修改ROOT密码 开启远程访问的思路就是:授权用户所有IP都可以访问,系统的数据库 ...
- 关于Java处理串口二进制数据的问题 byte的范围 一个字节8bits
前置知识点 byte的范围[-128~127] 内存里表现为 0x00~0xFF 刚好是一个8bits的字节 问题 byte[] hexData = new byte[] {0x01, 0x03, 0 ...
- 容器进阶:OCI与容器运行时
Blog:博客园 个人 什么是容器运行时(Container Runtime) Kubernetes节点的底层由一个叫做容器运行时的软件进行支撑,它负责比如启停容器 这样的事情.最广为人知的容器运行时 ...
- QFNU-ACM 2020.04.05个人赛补题
A.CodeForces-124A (简单数学题) #include<cstdio> #include<algorithm> #include<iostream> ...
- 爬虫:获取动态加载数据(selenium)(某站)
如果网站数据是动态加载,需要不停往下拉进度条才能显示数据,用selenium模拟浏览器下拉进度条可以实现动态数据的抓取. 本文希望找到某乎某话题下讨论较多的问题,以此再寻找每一问题涉及的话题关键词(侵 ...
- 【BUAA软工】Beta阶段事后分析
设想与目标 我们的软件要解决什么问题?是否定义得很清楚?是否对典型用户和典型场景有清晰的描述? 解决的问题 总体解决的问题:新手编程者配置编程环境难.本地编写的代码跨设备同步难.本地ide安装使用过程 ...
- OO随笔之魔鬼的第一单元——多项式求导
OO是个借助Java交我们面向对象的课,可是萌新们总是喜欢带着面向过程的脑子去写求导,然后就是各种一面(main)到底.各种方法杂糅,然后就是被hack的很惨. 第一次作业:萌新入门面向对象 题目分析 ...