1.PyTorch基础实现代码

  1. import torch
  2. from torch.autograd import Variable
  3.  
  4. torch.manual_seed(2)
  5. x_data = Variable(torch.Tensor([[1.0], [2.0], [3.0], [4.0]]))
  6. y_data = Variable(torch.Tensor([[0.0], [0.0], [1.0], [1.0]]))
  7.  
  8. #初始化
  9. w = Variable(torch.Tensor([-1]), requires_grad=True)
  10. b = Variable(torch.Tensor([0]), requires_grad=True)
  11. epochs = 100
  12. costs = []
  13. lr = 0.1
  14. print("before training, predict of x = 1.5 is:")
  15. print("y_pred = ", float(w.data*1.5 + b.data > 0))
  16.  
  17. #模型训练
  18. for epoch in range(epochs):
  19. #计算梯度
  20. A = 1/(1+torch.exp(-(w*x_data+b))) #逻辑回归函数
  21. J = -torch.mean(y_data*torch.log(A) + (1-y_data)*torch.log(1-A)) #逻辑回归损失函数
  22. #J = -torch.mean(y_data*torch.log(A) + (1-y_data)*torch.log(1-A)) +alpha*w**2
  23. #基础类进行正则化,加上L2范数
  24. costs.append(J.data)
  25. J.backward() #自动反向传播
  26.  
  27. #参数更新
  28. w.data = w.data - lr*w.grad.data
  29. w.grad.data.zero_()
  30. b.data = b.data - lr*b.grad.data
  31. b.grad.data.zero_()
  32.  
  33. print("after training, predict of x = 1.5 is:")
  34. print("y_pred =", float(w.data*1.5+b.data > 0))
  35. print(w.data, b.data)

2.用PyTorch类实现Logistic regression,torch.nn.module写网络结构

  1. import torch
  2. from torch.autograd import Variable
  3.  
  4. x_data = Variable(torch.Tensor([[0.6], [1.0], [3.5], [4.0]]))
  5. y_data = Variable(torch.Tensor([[0.], [0.], [1.], [1.]]))
  6.  
  7. class Model(torch.nn.Module):
  8. def __init__(self):
  9. super(Model, self).__init__()
  10. self.linear = torch.nn.Linear(1, 1)
  11. self.sigmoid = torch.nn.Sigmoid() ###### **sigmoid**
  12.  
  13. def forward(self, x):
  14. y_pred = self.sigmoid(self.linear(x))
  15. return y_pred
  16.  
  17. model = Model()
  18.  
  19. criterion = torch.nn.BCELoss(size_average=True) #损失函数
  20. optimizer = torch.optim.SGD(model.parameters(), lr=0.01) # 随机梯度下降
  21.  
  22. for epoch in range(500):
  23. # Forward pass
  24. y_pred = model(x_data)
  25.  
  26. loss = criterion(y_pred, y_data)
  27. if epoch % 20 == 0:
  28. print(epoch, loss.item())
  29.  
  30. #梯度归零
  31. optimizer.zero_grad()
  32. # 反向传播
  33. loss.backward()
  34. # update weights
  35. optimizer.step()
  36.  
  37. hour_var = Variable(torch.Tensor([[0.5]]))
  38. print("predict (after training)", 0.5, model.forward(hour_var).data[0][0])
  39. hour_var = Variable(torch.Tensor([[7.0]]))
  40. print("predict (after training)", 7.0, model.forward(hour_var).data[0][0])

参考:https://blog.csdn.net/ZZQsAI/article/details/90216593

Task3.PyTorch实现Logistic regression的更多相关文章

  1. 逻辑回归 Logistic Regression

    逻辑回归(Logistic Regression)是广义线性回归的一种.逻辑回归是用来做分类任务的常用算法.分类任务的目标是找一个函数,把观测值匹配到相关的类和标签上.比如一个人有没有病,又因为噪声的 ...

  2. logistic regression与SVM

    Logistic模型和SVM都是用于二分类,现在大概说一下两者的区别 ① 寻找最优超平面的方法不同 形象点说,Logistic模型找的那个超平面,是尽量让所有点都远离它,而SVM寻找的那个超平面,是只 ...

  3. Logistic Regression - Formula Deduction

    Sigmoid Function \[ \sigma(z)=\frac{1}{1+e^{(-z)}} \] feature: axial symmetry: \[ \sigma(z)+ \sigma( ...

  4. SparkMLlib之 logistic regression源码分析

    最近在研究机器学习,使用的工具是spark,本文是针对spar最新的源码Spark1.6.0的MLlib中的logistic regression, linear regression进行源码分析,其 ...

  5. [OpenCV] Samples 06: [ML] logistic regression

    logistic regression,这个算法只能解决简单的线性二分类,在众多的机器学习分类算法中并不出众,但它能被改进为多分类,并换了另外一个名字softmax, 这可是深度学习中响当当的分类算法 ...

  6. Stanford机器学习笔记-2.Logistic Regression

    Content: 2 Logistic Regression. 2.1 Classification. 2.2 Hypothesis representation. 2.2.1 Interpretin ...

  7. Logistic Regression vs Decision Trees vs SVM: Part II

    This is the 2nd part of the series. Read the first part here: Logistic Regression Vs Decision Trees ...

  8. Logistic Regression Vs Decision Trees Vs SVM: Part I

    Classification is one of the major problems that we solve while working on standard business problem ...

  9. Logistic Regression逻辑回归

    参考自: http://blog.sina.com.cn/s/blog_74cf26810100ypzf.html http://blog.sina.com.cn/s/blog_64ecfc2f010 ...

随机推荐

  1. leetcode 374猜数字大小

    // Forward declaration of guess API. // @param num, your guess // @return -1 if my number is lower, ...

  2. JS 引擎

    最早的 JS 引擎是纯解释器,现代 JS 引擎已经使用 JIT(Just-in-time compilation:结合预编译(ahead-of-time compilation AOT)和解释器的优点 ...

  3. Java中使用MATLAB作图 .

    最近做一个项目,需要很多进行很多信号处理——小魏就是学软件的,对信号处理简直是个小白,最简单的实现就是傻瓜似的调用MATLAB的各种工具箱,达到目的就行. 同时,MATLAB是种解释性语言,执行效率比 ...

  4. 阶段1 语言基础+高级_1-3-Java语言高级_06-File类与IO流_01 File类_5_File类获取功能的方法

    获取的方法 GetAbsolutepath 传递一个相对路径进去,查看输出的结果 输出的还是绝对的路径 getPath 获取的就是构造方法中传递的路径,可以传递绝对路径也可以传递相对路径 实际上toS ...

  5. Spring学习01——HelloSpring

    这是一个spring入门demo: package com.su.test; public class HelloWorld { public void say(){ System.out.print ...

  6. 读取资源中的GIF文件相应像素宽高度

    代码参考了如下网页的实现: https://www.cnblogs.com/zy791976083/p/9921069.html 整理成一个函数: BOOL GetResGifSize(long nR ...

  7. SpringBoot错误经验

    1.在application.properties 添加 debug=true,可以看见项目的执行流程有助于调bug 2.如果错误显示端口号被占用 cmd 步骤1 查看端口号应用情况:netstat ...

  8. Cassandra commands

      Common commands:   describe keyspaces // 列出所有db use your_db; // 进去db describe tables; // 列出所有table ...

  9. Vue2+VueRouter2+webpack 构建项目实战(四)接通api,先渲染个列表

    Vue2+VueRouter2+webpack 构建项目实战(四)接通api,先渲染个列表:  Vue2+VueRouter2+Webpack+Axios 构建项目实战2017重制版(一)基础知识概述

  10. Java学习day6数组

    ---恢复内容开始--- Java数组 Java 语言中提供的数组是用来存储固定大小的同类型元素.你可以声明一个数组变量,如 numbers[100] 来代替直接声明 100 个独立变量 number ...