Task3.PyTorch实现Logistic regression
1.PyTorch基础实现代码
- import torch
- from torch.autograd import Variable
- torch.manual_seed(2)
- x_data = Variable(torch.Tensor([[1.0], [2.0], [3.0], [4.0]]))
- y_data = Variable(torch.Tensor([[0.0], [0.0], [1.0], [1.0]]))
- #初始化
- w = Variable(torch.Tensor([-1]), requires_grad=True)
- b = Variable(torch.Tensor([0]), requires_grad=True)
- epochs = 100
- costs = []
- lr = 0.1
- print("before training, predict of x = 1.5 is:")
- print("y_pred = ", float(w.data*1.5 + b.data > 0))
- #模型训练
- for epoch in range(epochs):
- #计算梯度
- A = 1/(1+torch.exp(-(w*x_data+b))) #逻辑回归函数
- J = -torch.mean(y_data*torch.log(A) + (1-y_data)*torch.log(1-A)) #逻辑回归损失函数
- #J = -torch.mean(y_data*torch.log(A) + (1-y_data)*torch.log(1-A)) +alpha*w**2
- #基础类进行正则化,加上L2范数
- costs.append(J.data)
- J.backward() #自动反向传播
- #参数更新
- w.data = w.data - lr*w.grad.data
- w.grad.data.zero_()
- b.data = b.data - lr*b.grad.data
- b.grad.data.zero_()
- print("after training, predict of x = 1.5 is:")
- print("y_pred =", float(w.data*1.5+b.data > 0))
- print(w.data, b.data)
2.用PyTorch类实现Logistic regression,torch.nn.module写网络结构
- import torch
- from torch.autograd import Variable
- x_data = Variable(torch.Tensor([[0.6], [1.0], [3.5], [4.0]]))
- y_data = Variable(torch.Tensor([[0.], [0.], [1.], [1.]]))
- class Model(torch.nn.Module):
- def __init__(self):
- super(Model, self).__init__()
- self.linear = torch.nn.Linear(1, 1)
- self.sigmoid = torch.nn.Sigmoid() ###### **sigmoid**
- def forward(self, x):
- y_pred = self.sigmoid(self.linear(x))
- return y_pred
- model = Model()
- criterion = torch.nn.BCELoss(size_average=True) #损失函数
- optimizer = torch.optim.SGD(model.parameters(), lr=0.01) # 随机梯度下降
- for epoch in range(500):
- # Forward pass
- y_pred = model(x_data)
- loss = criterion(y_pred, y_data)
- if epoch % 20 == 0:
- print(epoch, loss.item())
- #梯度归零
- optimizer.zero_grad()
- # 反向传播
- loss.backward()
- # update weights
- optimizer.step()
- hour_var = Variable(torch.Tensor([[0.5]]))
- print("predict (after training)", 0.5, model.forward(hour_var).data[0][0])
- hour_var = Variable(torch.Tensor([[7.0]]))
- print("predict (after training)", 7.0, model.forward(hour_var).data[0][0])
参考:https://blog.csdn.net/ZZQsAI/article/details/90216593
Task3.PyTorch实现Logistic regression的更多相关文章
- 逻辑回归 Logistic Regression
逻辑回归(Logistic Regression)是广义线性回归的一种.逻辑回归是用来做分类任务的常用算法.分类任务的目标是找一个函数,把观测值匹配到相关的类和标签上.比如一个人有没有病,又因为噪声的 ...
- logistic regression与SVM
Logistic模型和SVM都是用于二分类,现在大概说一下两者的区别 ① 寻找最优超平面的方法不同 形象点说,Logistic模型找的那个超平面,是尽量让所有点都远离它,而SVM寻找的那个超平面,是只 ...
- Logistic Regression - Formula Deduction
Sigmoid Function \[ \sigma(z)=\frac{1}{1+e^{(-z)}} \] feature: axial symmetry: \[ \sigma(z)+ \sigma( ...
- SparkMLlib之 logistic regression源码分析
最近在研究机器学习,使用的工具是spark,本文是针对spar最新的源码Spark1.6.0的MLlib中的logistic regression, linear regression进行源码分析,其 ...
- [OpenCV] Samples 06: [ML] logistic regression
logistic regression,这个算法只能解决简单的线性二分类,在众多的机器学习分类算法中并不出众,但它能被改进为多分类,并换了另外一个名字softmax, 这可是深度学习中响当当的分类算法 ...
- Stanford机器学习笔记-2.Logistic Regression
Content: 2 Logistic Regression. 2.1 Classification. 2.2 Hypothesis representation. 2.2.1 Interpretin ...
- Logistic Regression vs Decision Trees vs SVM: Part II
This is the 2nd part of the series. Read the first part here: Logistic Regression Vs Decision Trees ...
- Logistic Regression Vs Decision Trees Vs SVM: Part I
Classification is one of the major problems that we solve while working on standard business problem ...
- Logistic Regression逻辑回归
参考自: http://blog.sina.com.cn/s/blog_74cf26810100ypzf.html http://blog.sina.com.cn/s/blog_64ecfc2f010 ...
随机推荐
- leetcode 374猜数字大小
// Forward declaration of guess API. // @param num, your guess // @return -1 if my number is lower, ...
- JS 引擎
最早的 JS 引擎是纯解释器,现代 JS 引擎已经使用 JIT(Just-in-time compilation:结合预编译(ahead-of-time compilation AOT)和解释器的优点 ...
- Java中使用MATLAB作图 .
最近做一个项目,需要很多进行很多信号处理——小魏就是学软件的,对信号处理简直是个小白,最简单的实现就是傻瓜似的调用MATLAB的各种工具箱,达到目的就行. 同时,MATLAB是种解释性语言,执行效率比 ...
- 阶段1 语言基础+高级_1-3-Java语言高级_06-File类与IO流_01 File类_5_File类获取功能的方法
获取的方法 GetAbsolutepath 传递一个相对路径进去,查看输出的结果 输出的还是绝对的路径 getPath 获取的就是构造方法中传递的路径,可以传递绝对路径也可以传递相对路径 实际上toS ...
- Spring学习01——HelloSpring
这是一个spring入门demo: package com.su.test; public class HelloWorld { public void say(){ System.out.print ...
- 读取资源中的GIF文件相应像素宽高度
代码参考了如下网页的实现: https://www.cnblogs.com/zy791976083/p/9921069.html 整理成一个函数: BOOL GetResGifSize(long nR ...
- SpringBoot错误经验
1.在application.properties 添加 debug=true,可以看见项目的执行流程有助于调bug 2.如果错误显示端口号被占用 cmd 步骤1 查看端口号应用情况:netstat ...
- Cassandra commands
Common commands: describe keyspaces // 列出所有db use your_db; // 进去db describe tables; // 列出所有table ...
- Vue2+VueRouter2+webpack 构建项目实战(四)接通api,先渲染个列表
Vue2+VueRouter2+webpack 构建项目实战(四)接通api,先渲染个列表: Vue2+VueRouter2+Webpack+Axios 构建项目实战2017重制版(一)基础知识概述
- Java学习day6数组
---恢复内容开始--- Java数组 Java 语言中提供的数组是用来存储固定大小的同类型元素.你可以声明一个数组变量,如 numbers[100] 来代替直接声明 100 个独立变量 number ...