1. 1 #CS231n中线性、非线性分类器举例(Softmax)
  2. #注意其中反向传播的计算
  3.  
  4. # -*- coding: utf-8 -*-
  5. import numpy as np
  6. import matplotlib.pyplot as plt
  7. N = 100 # number of points per class
  8. D = 2 # dimensionality
  9. K = 3 # number of classes
  10. X = np.zeros((N*K,D)) # data matrix (each row = single example)
  11. y = np.zeros(N*K, dtype='uint8') # class labels
  12. for j in xrange(K):
  13. ix = range(N*j,N*(j+1))
  14. r = np.linspace(0.0,1,N) # radius
  15. t = np.linspace(j*4,(j+1)*4,N) + np.random.randn(N)*0.2 # theta
  16. X[ix] = np.c_[r*np.sin(t), r*np.cos(t)]
  17. y[ix] = j
  18. # lets visualize the data:
  19. plt.xlim([-1, 1])
  20. plt.ylim([-1, 1])
  21. plt.scatter(X[:, 0], X[:, 1], c=y, s=40, cmap=plt.cm.Spectral)
  22. plt.show()
  23.  
  24. # initialize parameters randomly
  25. # 线性分类器
  26. W = 0.01 * np.random.randn(D,K)
  27. b = np.zeros((1,K))
  28.  
  29. # some hyperparameters
  30. step_size = 1e-0
  31. reg = 1e-3 # regularization strength
  32.  
  33. # gradient descent loop
  34. num_examples = X.shape[0]
  35. for i in xrange(200):
  36.  
  37. # evaluate class scores, [N x K]
  38. scores = np.dot(X, W) + b
  39.  
  40. # compute the class probabilities
  41. exp_scores = np.exp(scores)
  42. probs = exp_scores / np.sum(exp_scores, axis=1, keepdims=True) # [N x K]
  43.  
  44. # compute the loss: average cross-entropy loss and regularization
  45. corect_logprobs = -np.log(probs[range(num_examples),y])
  46. data_loss = np.sum(corect_logprobs)/num_examples
  47. reg_loss = 0.5*reg*np.sum(W*W)
  48. loss = data_loss + reg_loss
  49. if i % 10 == 0:
  50. print "iteration %d: loss %f" % (i, loss)
  51.  
  52. # compute the gradient on scores
  53. dscores = probs
  54. dscores[range(num_examples),y] -= 1
  55. dscores /= num_examples
  56.  
  57. # backpropate the gradient to the parameters (W,b)
  58. dW = np.dot(X.T, dscores)
  59. db = np.sum(dscores, axis=0, keepdims=True)
  60.  
  61. dW += reg*W # regularization gradient
  62.  
  63. # perform a parameter update
  64. W += -step_size * dW
  65. b += -step_size * db
  66.  
  67. # evaluate training set accuracy
  68. scores = np.dot(X, W) + b
  69. predicted_class = np.argmax(scores, axis=1)
  70. print 'training accuracy: %.2f' % (np.mean(predicted_class == y))
  71.  
  72. # plot the resulting classifier
  73. h = 0.02
  74. x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
  75. y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
  76. xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
  77. np.arange(y_min, y_max, h))
  78. Z = np.dot(np.c_[xx.ravel(), yy.ravel()], W) + b
  79. Z = np.argmax(Z, axis=1)
  80. Z = Z.reshape(xx.shape)
  81. fig = plt.figure()
  82. plt.contourf(xx, yy, Z, cmap=plt.cm.Spectral, alpha=0.8)
  83. plt.scatter(X[:, 0], X[:, 1], c=y, s=40, cmap=plt.cm.Spectral)
  84. plt.xlim(xx.min(), xx.max())
  85. plt.ylim(yy.min(), yy.max())
  86.  
  87. ## initialize parameters randomly
  88. # 含一个隐层的非线性分类器 使用ReLU
  89. h = 100 # size of hidden layer
  90. W = 0.01 * np.random.randn(D,h)
  91. b = np.zeros((1,h))
  92. W2 = 0.01 * np.random.randn(h,K)
  93. b2 = np.zeros((1,K))
  94.  
  95. # some hyperparameters
  96. step_size = 1e-0
  97. reg = 1e-3 # regularization strength
  98.  
  99. # gradient descent loop
  100. num_examples = X.shape[0]
  101. for i in xrange(10000):
  102.  
  103. # evaluate class scores, [N x K]
  104. hidden_layer = np.maximum(0, np.dot(X, W) + b) # note, ReLU activation
  105. scores = np.dot(hidden_layer, W2) + b2
  106.  
  107. # compute the class probabilities
  108. exp_scores = np.exp(scores)
  109. probs = exp_scores / np.sum(exp_scores, axis=1, keepdims=True) # [N x K]
  110.  
  111. # compute the loss: average cross-entropy loss and regularization
  112. corect_logprobs = -np.log(probs[range(num_examples),y])
  113. data_loss = np.sum(corect_logprobs)/num_examples
  114. reg_loss = 0.5*reg*np.sum(W*W) + 0.5*reg*np.sum(W2*W2)
  115. loss = data_loss + reg_loss
  116. if i % 1000 == 0:
  117. print "iteration %d: loss %f" % (i, loss)
  118.  
  119. # compute the gradient on scores
  120. dscores = probs
  121. dscores[range(num_examples),y] -= 1
  122. dscores /= num_examples
  123.  
  124. # backpropate the gradient to the parameters
  125. # first backprop into parameters W2 and b2
  126. dW2 = np.dot(hidden_layer.T, dscores)
  127. db2 = np.sum(dscores, axis=0, keepdims=True)
  128. # next backprop into hidden layer
  129. dhidden = np.dot(dscores, W2.T)
  130. # backprop the ReLU non-linearity
  131. dhidden[hidden_layer <= 0] = 0
  132. # finally into W,b
  133. dW = np.dot(X.T, dhidden)
  134. db = np.sum(dhidden, axis=0, keepdims=True)
  135.  
  136. # add regularization gradient contribution
  137. dW2 += reg * W2
  138. dW += reg * W
  139.  
  140. # perform a parameter update
  141. W += -step_size * dW
  142. b += -step_size * db
  143. W2 += -step_size * dW2
  144. b2 += -step_size * db2
  145. # evaluate training set accuracy
  146. hidden_layer = np.maximum(0, np.dot(X, W) + b)
  147. scores = np.dot(hidden_layer, W2) + b2
  148. predicted_class = np.argmax(scores, axis=1)
  149. print 'training accuracy: %.2f' % (np.mean(predicted_class == y))
  150. # plot the resulting classifier
  151. h = 0.02
  152. x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
  153. y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
  154. xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
  155. np.arange(y_min, y_max, h))
  156. Z = np.dot(np.maximum(0, np.dot(np.c_[xx.ravel(), yy.ravel()], W) + b), W2) + b2
  157. Z = np.argmax(Z, axis=1)
  158. Z = Z.reshape(xx.shape)
  159. fig = plt.figure()
  160. plt.contourf(xx, yy, Z, cmap=plt.cm.Spectral, alpha=0.8)
  161. plt.scatter(X[:, 0], X[:, 1], c=y, s=40, cmap=plt.cm.Spectral)
  162. plt.xlim(xx.min(), xx.max())
  163. plt.ylim(yy.min(), yy.max())

运行结果

【Python 代码】CS231n中Softmax线性分类器、非线性分类器对比举例(含python绘图显示结果)的更多相关文章

  1. Python代码样例列表

    扫描左上角二维码,关注公众账号 数字货币量化投资,回复“1279”,获取以下600个Python经典例子源码 ├─algorithm│       Python用户推荐系统曼哈顿算法实现.py│    ...

  2. ROS系统python代码测试之rostest

    ROS系统中提供了测试框架,可以实现python/c++代码的单元测试,python和C++通过不同的方式实现, 之后的两篇文档分别详细介绍各自的实现步骤,以及测试结果和覆盖率的获取. ROS系统中p ...

  3. [转] Python 代码性能优化技巧

    选择了脚本语言就要忍受其速度,这句话在某种程度上说明了 python 作为脚本的一个不足之处,那就是执行效率和性能不够理想,特别是在 performance 较差的机器上,因此有必要进行一定的代码优化 ...

  4. Python代码性能优化技巧

    摘要:代码优化能够让程序运行更快,可以提高程序的执行效率等,对于一名软件开发人员来说,如何优化代码,从哪里入手进行优化?这些都是他们十分关心的问题.本文着重讲了如何优化Python代码,看完一定会让你 ...

  5. Python 代码性能优化技巧(转)

    原文:Python 代码性能优化技巧 Python 代码优化常见技巧 代码优化能够让程序运行更快,它是在不改变程序运行结果的情况下使得程序的运行效率更高,根据 80/20 原则,实现程序的重构.优化. ...

  6. Python 代码性能优化技巧

    选择了脚本语言就要忍受其速度,这句话在某种程度上说明了 python 作为脚本的一个不足之处,那就是执行效率和性能不够理想,特别是在 performance 较差的机器上,因此有必要进行一定的代码优化 ...

  7. 利用Python代码编写计算器小程序

    import tkinter import tkinter.messagebox import math class JSQ: def __init__(self): #创建主界面 self.root ...

  8. python 代码检测工具

    对于我这种习惯了 Java 这种编译型语言,在使用 Python 这种动态语言的时候,发现错误经常只能在执行的时候发现,总感觉有点不放心. 而且有一些错误由于隐藏的比较深,只有特定逻辑才会触发,往往导 ...

  9. 随机森林入门攻略(内含R、Python代码)

    随机森林入门攻略(内含R.Python代码) 简介 近年来,随机森林模型在界内的关注度与受欢迎程度有着显著的提升,这多半归功于它可以快速地被应用到几乎任何的数据科学问题中去,从而使人们能够高效快捷地获 ...

随机推荐

  1. HTML—链接

    怎么看都觉得链接太神奇了,尤其是创建电子邮件的链接,于是决定单独写一篇关于HTML链接的内容,同时加深记忆 一.首先,超链接可以是一个字,一个词,或者一组词,也可以是一幅图像,通过点击这些内容来跳转到 ...

  2. python代码工具小结

    目录: 1.with读.写文件 (1)with读文件 (2)with写文件 2.requests爬虫 (1)get请求 (2)post请求 1.with读.写文件 (1)with读文件 (2)with ...

  3. SqlServer中-char varchar nvarchar的区别

    说说nvarchar和varchar的区别:的区别: varchar:  可变长度,存储ANSI字符,根据数据长度自动变化. nvarchar: 可变长度,存储Unicode字符,根据数据长度自动变化 ...

  4. javascript_19-DOM初体验

    DOM DOM: 文档对象模型(Document Object Model),又称为文档树模型.是一套操作HTML和XML文档的API. DOM可以把HTML和XML描述为一个文档树.树上的每一个分支 ...

  5. CPNtools 模拟工具适合分析什么样的协议

    最近梳理和CPNtools和Scyther之间的性能和差别.方便后面整理使用 1.库所的托肯值是什么? 托肯值也叫作令牌, 即网络系统中的资源,托肯的数目值代表了网络赋予的资源大小.在一个活的网络系统 ...

  6. tengine编译安装及nginx高并发内核参数优化

    Tengine Tengine介绍 Tengine是由淘宝网发起的Web服务器项目.它在Nginx的基础上,针对大访问量网站的需求,添加了很多高级功能和特性. Tengine的性能和稳定性已经在大型的 ...

  7. VLAN实验2:配置Trunk接口

    实验环境公司规模较大200多个员工.内部网络是一个较大的局域网,有两台交换机S1和S2来负责员工网络的接入,接入交换机之间通过汇聚交换机S3相连.公司通过划分VLAN来隔离广播域,由于员工较多,同部门 ...

  8. Minio对象存储

    目录 Minio对象存储 1.概述 2.功能特性 3.2.多节点 3.3.分布式 4.分布式minio集群搭建 4.1.集群规划 4.3.编写集群启动脚本(所有节点) 4.4.编写服务脚本(所有节点) ...

  9. XSLT知识点【一】

    XSL 指扩展样式表语言(EXtensible Stylesheet Language). 它起始于 XSL,结束于 XSLT.XPath 以及 XSL-FO. 起始于 XSL------CSS = ...

  10. K3 Cloud的数据中心加载异常处理

    以前一直是财务维护的K3  Cloud突然说不能登录,用的SQL 2008的数据库,运维也搞不定,找帮忙,因为是部署在阿里云上,上去看看数据库,这个K3数据库占了600多G,想看看这个表结构,就是打不 ...