功能:通过样本进行训练,让线性单元自己找到(这就是所谓机器学习)工资计算的规律,然后用两组数据进行测试机器是否真的get到了其中的规律。

  原文链接在文尾,文章中的代码为了演示起见,仅根据工作年限来预测工资,参数是一维的,最后绘制的图也是平面图。本着学习的态度,我将代码改为能根据两个参数来预测工资,两个参数分别是工作年限和级别,并且用3D图绘制出拟合的效果。原作者的代码是适用于Python2.7的,我的代码适用于Python3,谨供参考。

  注意:绘图代码需要安装matplotlib。

代码:

 #!/usr/bin/env python
# -*- coding: UTF-8 -*- from Perceptron import Perceptron #定义激活函数f
f = lambda x: x class LinearUnit(Perceptron):
def __init__(self, input_num):
'''初始化线性单元,设置输入参数的个数'''
Perceptron.__init__(self, input_num, f) def get_training_dataset():
'''
捏造5个人的收入数据
'''
# 构建训练数据
# 输入向量列表,每一项的第一个是工作年限,第二个是级别
# 构造这些数据所用的公式是:工资=1000*年限 + 500*级别,看机器是否能猜出来
input_vecs = [[5,1], [3, 7], [8,2], [1.5,5], [10,6]]
# 期望的输出列表,月薪,注意要与输入一一对应。【注意! 我故意让结果不太准确,这也会导致预测的结果有偏差】
labels = [5200, 6700, 9300, 3500, 15500]
return input_vecs, labels def train_linear_unit():
'''
使用数据训练线性单元
'''
# 创建感知器,输入参数的特征数为2(工作年限,级别)
lu = LinearUnit(2)
# 训练,迭代10轮, 学习速率为0.005
input_vecs, labels = get_training_dataset()
lu.train(input_vecs, labels, 10, 0.005)
#返回训练好的线性单元
return lu def plot(linear_unit):
import numpy as np
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
input_vecs, labels = get_training_dataset()
fig = plt.figure()
ax = Axes3D(fig)
ax.scatter(list(map(lambda x: x[0], input_vecs)),
list(map(lambda x: x[1], input_vecs)),
labels) weights = linear_unit.weights
bias = linear_unit.bias
x = range(0,12,1) # work age
y = range(0,12,1) # level
x, y = np.meshgrid(x, y)
z = weights[0] * x + weights[1] * y + bias
ax.plot_surface(x, y, z, cmap=plt.cm.winter) plt.show() if __name__ == '__main__':
'''训练线性单元'''
linear_unit = train_linear_unit()
# 打印训练获得的权重
#print (linear_unit)
# 测试
print ('预测:')
print ('Work 3.4 years, level 2, monthly salary = %.2f' % linear_unit.predict([3.4,2]))
print ('Work 15 years, level 6, monthly salary = %.2f' % linear_unit.predict([15,6]))
plot(linear_unit)

为了代码的正常运行,你可能还需要下面这个感知机的类文件,另存为Perceptron.py(注意大小写),和上面的代码放在同一个目录下即可。

 #coding=utf-8

 from functools import reduce  # for py3

 class Perceptron(object):
def __init__(self, input_num, activator):
'''
初始化感知器,设置输入参数的个数,以及激活函数。
激活函数的类型为double -> double
'''
self.activator = activator
# 权重向量初始化为0
self.weights = [0.0 for _ in range(input_num)]
# 偏置项初始化为0
self.bias = 0.0
def __str__(self):
'''
打印学习到的权重、偏置项
'''
return 'weights\t:%s\nbias\t:%f\n' % (self.weights, self.bias) def predict(self, input_vec):
'''
输入向量,输出感知器的计算结果
'''
# 把input_vec[x1,x2,x3...]和weights[w1,w2,w3,...]打包在一起
# 变成[(x1,w1),(x2,w2),(x3,w3),...]
# 然后利用map函数计算[x1*w1, x2*w2, x3*w3]
# 最后利用reduce求和 #list1 = list(self.weights)
#print ("predict self.weights:", list1) return self.activator(
reduce(lambda a, b: a + b,
list(map(lambda tp: tp[0] * tp[1],
zip(input_vec, self.weights)))
, 0.0) + self.bias)
def train(self, input_vecs, labels, iteration, rate):
'''
输入训练数据:一组向量、与每个向量对应的label;以及训练轮数、学习率
'''
for i in range(iteration):
self._one_iteration(input_vecs, labels, rate) def _one_iteration(self, input_vecs, labels, rate):
'''
一次迭代,把所有的训练数据过一遍
'''
# 把输入和输出打包在一起,成为样本的列表[(input_vec, label), ...]
# 而每个训练样本是(input_vec, label)
samples = zip(input_vecs, labels)
# 对每个样本,按照感知器规则更新权重
for (input_vec, label) in samples:
# 计算感知器在当前权重下的输出
output = self.predict(input_vec)
# 更新权重
self._update_weights(input_vec, output, label, rate) def _update_weights(self, input_vec, output, label, rate):
'''
按照感知器规则更新权重
'''
# 把input_vec[x1,x2,x3,...]和weights[w1,w2,w3,...]打包在一起
# 变成[(x1,w1),(x2,w2),(x3,w3),...]
# 然后利用感知器规则更新权重
delta = label - output
self.weights = list(map( lambda tp: tp[1] + rate * delta * tp[0], zip(input_vec, self.weights)) ) # 更新bias
self.bias += rate * delta print("_update_weights() -------------")
print("label - output = delta:" ,label, output, delta)
print("weights ", self.weights)
print("bias", self.bias) def f(x):
'''
定义激活函数f
'''
return 1 if x > 0 else 0 def get_training_dataset():
'''
基于and真值表构建训练数据
'''
# 构建训练数据
# 输入向量列表
input_vecs = [[1,1], [0,0], [1,0], [0,1]]
# 期望的输出列表,注意要与输入一一对应
# [1,1] -> 1, [0,0] -> 0, [1,0] -> 0, [0,1] -> 0
labels = [1, 0, 0, 0]
return input_vecs, labels def train_and_perceptron():
'''
使用and真值表训练感知器
'''
# 创建感知器,输入参数个数为2(因为and是二元函数),激活函数为f
p = Perceptron(2, f)
# 训练,迭代10轮, 学习速率为0.1
input_vecs, labels = get_training_dataset()
p.train(input_vecs, labels, 10, 0.1)
#返回训练好的感知器
return p if __name__ == '__main__':
# 训练and感知器
and_perception = train_and_perceptron()
# 打印训练获得的权重 # 测试
print (and_perception)
print ('1 and 1 = %d' % and_perception.predict([1, 1]))
print ('0 and 0 = %d' % and_perception.predict([0, 0]))
print ('1 and 0 = %d' % and_perception.predict([1, 0]))
print ('0 and 1 = %d' % and_perception.predict([0, 1]))

正常运行的话,输出的预测结果是这样的:

预测:
Work 3.4 years, level 2, monthly salary = 5125.02
Work 15 years, level 6, monthly salary = 20815.01

由上可见,本例中两个输入一个输出的线性单元拟合出来的是一个平面(因为预设的工资公式是线性的)。在旋转一个角度后看的更清楚:

原文链接:

https://www.zybuluo.com/hanbingtao/note/448086

文章写的很好,代码也漂亮,墙裂推荐大家看看原文。

用线性单元(LinearUnit)实现工资预测的Python3代码的更多相关文章

  1. (2)Deep Learning之线性单元和梯度下降

    往期回顾 在上一篇文章中,我们已经学会了编写一个简单的感知器,并用它来实现一个线性分类器.你应该还记得用来训练感知器的『感知器规则』.然而,我们并没有关心这个规则是怎么得到的.本文通过介绍另外一种『感 ...

  2. 感知机和线性单元的C#版本

    本文的原版Python代码参考了以下文章: 零基础入门深度学习(1) - 感知器 零基础入门深度学习(2) - 线性单元和梯度下降 在机器学习如火如荼的时代,Python大行其道,几乎所有的机器学习的 ...

  3. ReLu(修正线性单元)、sigmoid和tahh的比较

    不多说,直接上干货! 最近,在看论文,提及到这个修正线性单元(Rectified linear unit,ReLU). Deep Sparse Rectifier Neural Networks Re ...

  4. 修正线性单元(Rectified linear unit,ReLU)

    修正线性单元(Rectified linear unit,ReLU) Rectified linear unit 在神经网络中,常用到的激活函数有sigmoid函数f(x)=11+exp(−x).双曲 ...

  5. 量化投资_MATLAB在时间序列建模预测及程序代码

    1 ARMA时间序列机器特性 下面介绍一种重要的平稳时间序列——ARMA时间序列. ARMA时间序列分为三种: AR模型,auto regressiv model MA模型,moving averag ...

  6. 修正剑桥模型预测-用python3.4

    下面是预测结果: #!/usr/bin/env python # -*- coding:utf-8 -*- # __author__ = "blzhu" ""& ...

  7. 基于深度学习方法的dota2游戏数据分析与胜率预测(python3.6+keras框架实现)

    很久以前就有想过使用深度学习模型来对dota2的对局数据进行建模分析,以便在英雄选择,出装方面有所指导,帮助自己提升天梯等级,但苦于找不到数据源,该计划搁置了很长时间.直到前些日子,看到社区有老哥提到 ...

  8. kaggle预测房价的代码步骤

    # -*- coding: utf-8 -*- """ Created on Sat Oct 20 14:03:05 2018 @author: 12958 " ...

  9. 用python实现MRO算法

    引子: 如图反映了python3中,几个类的继承关系和查找顺序.对于类A,其查找顺序为:A,B,E,C,F,D,G,(Object),这并不是一个简单的深度优先或广度优先的规律.那么这个顺序到底是如何 ...

随机推荐

  1. redis数据类型-列表类型

    列表类型 列表类型(list)可以存储一个有序的字符串列表,常用的操作是向列表两端添加元素,或者获得列表的某一个片段. 列表类型内部是使用双向链表(double linked list)实现的,所以向 ...

  2. mysql基本了解

    数据库--MySql 数据仓库.就与我们之前学过的纯文本,properties这些技术一样.用来保存数据.并提供对数据进行增删改查的操作.我们以后做项目时,项目中的数据都是保存在数据库中的.//--- ...

  3. 关于Scanner类

    Scanner类 1.常用的两个方法:         public int nextInt():获取一个int类型的值         public String nextLine():获取一个St ...

  4. python数据类型之元组、字典、集合

    python数据类型元组.字典.集合 元组 python的元组与列表类似,不同的是元组是不可变的数据类型.元组使用小括号,列表使用方括号.当元组里只有一个元素是必须要加逗号: >>> ...

  5. Ubuntu14.04 安装vmware虚拟机

    下载VMware 链接:VMware 14  密码:5okh 移动VMware14 到 /opt #mv VMware-Workstation-Full-14.0.0-6661328.x86_64.b ...

  6. ABP官方文档翻译 4.6 审计日志

    审计日志 介绍 关于IAuditingStore 配置 通过特性启用/禁用 注意事项 介绍 维基百科:“审计追踪(也称为审计日志)是与安全相关的按时间先后的记录.记录集合.记录的目的地和源,提供一系列 ...

  7. Maven文件配置

    Maven文件路径的配置 默认设置 修改之后的设置 Maven文件内容的配置 对于Maven 的 settings.xml 文件,需要注意. <mirror>镜像元素之间是互斥的,优先级是 ...

  8. 在CentOS 7中安装Jetty服务器

    Jetty 是一款纯Java的HTTP (Web) 服务器和Java Servlet容器. 通常在更大的网络框架中,Jetty经常用于设备间的通信,而其他Web服务器通常给"人类" ...

  9. Redis 实践3-操作

    string常用操作 set key1  aminglinux get key1   set key1  aming //一个key对应一个value,多次赋值,会覆盖前面的value setnx k ...

  10. css为第几个倍数元素添加样式

    //3n就是3的倍数都加这个样式*/.list li:nth-child(3n){ border-bottom:1px;}