1.PyTorch基础实现代码

 import torch
from torch.autograd import Variable torch.manual_seed(2)
x_data = Variable(torch.Tensor([[1.0], [2.0], [3.0], [4.0]]))
y_data = Variable(torch.Tensor([[0.0], [0.0], [1.0], [1.0]])) #初始化
w = Variable(torch.Tensor([-1]), requires_grad=True)
b = Variable(torch.Tensor([0]), requires_grad=True)
epochs = 100
costs = []
lr = 0.1
print("before training, predict of x = 1.5 is:")
print("y_pred = ", float(w.data*1.5 + b.data > 0)) #模型训练
for epoch in range(epochs):
#计算梯度
A = 1/(1+torch.exp(-(w*x_data+b))) #逻辑回归函数
J = -torch.mean(y_data*torch.log(A) + (1-y_data)*torch.log(1-A)) #逻辑回归损失函数
#J = -torch.mean(y_data*torch.log(A) + (1-y_data)*torch.log(1-A)) +alpha*w**2
#基础类进行正则化,加上L2范数
costs.append(J.data)
J.backward() #自动反向传播 #参数更新
w.data = w.data - lr*w.grad.data
w.grad.data.zero_()
b.data = b.data - lr*b.grad.data
b.grad.data.zero_() print("after training, predict of x = 1.5 is:")
print("y_pred =", float(w.data*1.5+b.data > 0))
print(w.data, b.data)

2.用PyTorch类实现Logistic regression,torch.nn.module写网络结构

 import torch
from torch.autograd import Variable x_data = Variable(torch.Tensor([[0.6], [1.0], [3.5], [4.0]]))
y_data = Variable(torch.Tensor([[0.], [0.], [1.], [1.]])) class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.linear = torch.nn.Linear(1, 1)
self.sigmoid = torch.nn.Sigmoid() ###### **sigmoid** def forward(self, x):
y_pred = self.sigmoid(self.linear(x))
return y_pred model = Model() criterion = torch.nn.BCELoss(size_average=True) #损失函数
optimizer = torch.optim.SGD(model.parameters(), lr=0.01) # 随机梯度下降 for epoch in range(500):
# Forward pass
y_pred = model(x_data) loss = criterion(y_pred, y_data)
if epoch % 20 == 0:
print(epoch, loss.item()) #梯度归零
optimizer.zero_grad()
# 反向传播
loss.backward()
# update weights
optimizer.step() hour_var = Variable(torch.Tensor([[0.5]]))
print("predict (after training)", 0.5, model.forward(hour_var).data[0][0])
hour_var = Variable(torch.Tensor([[7.0]]))
print("predict (after training)", 7.0, model.forward(hour_var).data[0][0])

参考:https://blog.csdn.net/ZZQsAI/article/details/90216593

Task3.PyTorch实现Logistic regression的更多相关文章

  1. 逻辑回归 Logistic Regression

    逻辑回归(Logistic Regression)是广义线性回归的一种.逻辑回归是用来做分类任务的常用算法.分类任务的目标是找一个函数,把观测值匹配到相关的类和标签上.比如一个人有没有病,又因为噪声的 ...

  2. logistic regression与SVM

    Logistic模型和SVM都是用于二分类,现在大概说一下两者的区别 ① 寻找最优超平面的方法不同 形象点说,Logistic模型找的那个超平面,是尽量让所有点都远离它,而SVM寻找的那个超平面,是只 ...

  3. Logistic Regression - Formula Deduction

    Sigmoid Function \[ \sigma(z)=\frac{1}{1+e^{(-z)}} \] feature: axial symmetry: \[ \sigma(z)+ \sigma( ...

  4. SparkMLlib之 logistic regression源码分析

    最近在研究机器学习,使用的工具是spark,本文是针对spar最新的源码Spark1.6.0的MLlib中的logistic regression, linear regression进行源码分析,其 ...

  5. [OpenCV] Samples 06: [ML] logistic regression

    logistic regression,这个算法只能解决简单的线性二分类,在众多的机器学习分类算法中并不出众,但它能被改进为多分类,并换了另外一个名字softmax, 这可是深度学习中响当当的分类算法 ...

  6. Stanford机器学习笔记-2.Logistic Regression

    Content: 2 Logistic Regression. 2.1 Classification. 2.2 Hypothesis representation. 2.2.1 Interpretin ...

  7. Logistic Regression vs Decision Trees vs SVM: Part II

    This is the 2nd part of the series. Read the first part here: Logistic Regression Vs Decision Trees ...

  8. Logistic Regression Vs Decision Trees Vs SVM: Part I

    Classification is one of the major problems that we solve while working on standard business problem ...

  9. Logistic Regression逻辑回归

    参考自: http://blog.sina.com.cn/s/blog_74cf26810100ypzf.html http://blog.sina.com.cn/s/blog_64ecfc2f010 ...

随机推荐

  1. AppStore IPv6-only 解决--看我就够了

    自2016年6月1日起,苹果要求所有提交App Store的iOS应用必须支持IPv6-only环境,背景也是众所周知的,IPv4地址已基本分配完毕,同时IPv6比IPv4也更加高效,向IPv6过渡是 ...

  2. Oracle-优化SQL语句

    建议不使用(*)来代替所有列名 用truncate代替delete 在SQL*Plus环境中直接使用truncate table即可:要在PL/SQL中使用,如: 创建一个存储过程,实现使用trunc ...

  3. 使用spring配置类代替xml配置文件注册bean类

    spring配置类,即在类上加@Configuration注解,使用这种配置类来注册bean,效果与xml文件是完全一样的,只是创建springIOC容器的方式不同: //通过xml文件创建sprin ...

  4. import * as 用法

  5. 阶段1 语言基础+高级_1-3-Java语言高级_04-集合_10 斗地主案例(双列)_1_斗地主案例的需求分析

    之前做的斗地主的版本,没有从小到大进行排序 一个存储牌的花色,一个存储牌的序号. 放牌的容器.使用Map 再创建一个集合进行洗牌. 调用shuffer方法洗牌.生成后就是随即的索引了.

  6. 腾讯重磅开源分布式NoSQL存储系统DCache

    当你在电商平台秒杀商品或者在社交网络刷热门话题的时候,可以很明显感受到当前网络数据流量的恐怖,几十万商品刚开抢,一秒都不到就售罄:哪个大明星出轨的消息一出现,瞬间阅读与转发次数可以达到上亿.作为终端用 ...

  7. 应用安全 - 工具使用 - Nmap

    TCP端口扫描类型 TCP connect扫描 三次握手完成/全连接/速度慢/易被检测到 TCP SYN扫描 半开扫描/发送SYN包启动TCP会话 TCP FIN扫描 半开扫描/发送SYN包启动TCP ...

  8. IDEA-包层级结构显示(三)

    IntelliJ IDEA包层级结构显示 如:A.B.C,在项目中希望以如下形式显示: A B C 效果: 再更换为A.B.C形式显示

  9. xmake-vscode插件开发过程记录

    最近打算给xmake写一些IDE和编辑器的集成插件,发现vscode的编辑器插件比较容易上手的,就先研究了下vscode的插件开发流程,并且完成了xmake-vscode插件的开发. 我们先来看几张最 ...

  10. MySQL-第十五篇使用连接池管理连接

    1.数据库连接池的解决方案是: 当应用程序启动时,系统主动建立足够的数据库连接,并将这些连接组成一个连接池.每次应用程序请求数据库连接时,无需重新打开连接,而是从连接池中取出已有的连接使用,使用完后不 ...