简单线性回归问题的优化(SGD)R语言
本编博客继续分享简单的机器学习的R语言实现。
今天是关于简单的线性回归方程问题的优化问题
常用方法,我们会考虑随机梯度递降,好处是,我们不需要遍历数据集中的所有元素,这样可以大幅度的减少运算量。
具体的算法参考下面:
首先我们先定义我们需要的参数的Notation
上述算法中,为了避免过拟合,我们采用了L2的正则化,在更新步骤中,我们会发现,这个正则项目,对参数更新的影响
下面是代码部分:
## Load Library
library(ggplot2)
library(reshape2)
library(mvtnorm) ## Function for reading the data
read_data <- function(fname, sc) {
data <- read.csv(file=fname,head=TRUE,sep=",")
nr = dim(data)[1]
nc = dim(data)[2]
x = data[1:nr,1:(nc-1)]
y = data[1:nr,nc]
if (isTRUE(sc)) {
x = scale(x) ## Scale x
y = scale(y) ## Scale y
}
return (list("x" = x, "y" = y))
}
我们定义了一个读取数据的方程,这里,我们会把数据集给scale一下,可以保证进一步提高运算速度
## Matrix Product Function
predict_func <- function(Phi, w){
return(Phi%*%w)
} ## Function to compute the cost function
train_obj_func <- function (Phi, w, label, lambda){
# Cost funtion including the L2 norm regulaztion
return(.5 * mean((predict_func(Phi, w) - label)^2) + .5 * lambda * t(w) %*% w)
} ## Return the errors for each iteration
get_errors <- function(data, label, W) {
n_weights = dim(W)[1]
Phi <- cbind('X0' = 1, data)
errors = matrix(,nrow=n_weights, ncol=2)
for (tau in 1:n_weights) {
errors[tau,1] = tau
errors[tau,2] = train_obj_func(Phi, W[tau,],label, 0) ## Get the errors, set the lambda to 0
}
return(errors)
}
同时,我们定义了计算矩阵乘法,计算目标函数以及求误差的方程。
sgd_train <- function(train_x, train_y, lambda, eta, epsilon, max_epoch) { ## Prepare the traindata
## Attach the 1 for X0
Phi <- as.matrix(cbind('X0'=1, train.data)) ## Calculate the max iteration time for the SGD
train_len = dim(train_x)[1]
tau_max = max_epoch * train_len W <- matrix(,nrow=tau_max, ncol=ncol(Phi))
set.seed(1234)
## Random Generate the start parameter
W[1,] <- runif(ncol(Phi)) tau = 1 # counter
## Create a dateframe to store the value of cost function for each iteration
obj_func_val <-matrix(,nrow=tau_max, ncol=1)
obj_func_val[tau,1] = train_obj_func(Phi, W[tau,],train_y, lambda) while (tau <= tau_max){ # check termination criteria
if (obj_func_val[tau,1]<=epsilon) {break} # shuffle data:
train_index <- sample(1:train_len, train_len, replace = FALSE) # loop over each datapoint
for (i in train_index) {
# increment the counter
tau <- tau + 1
if (tau > tau_max) {break} # make the weight update
y_pred <- predict_func(Phi[i,], W[tau-1,])
W[tau,] <- sgd_update_weight(W[tau-1,], Phi[i,], train_y[i], y_pred, lambda, eta) # keep track of the objective funtion
obj_func_val[tau,1] = train_obj_func(Phi, W[tau,],train_y, lambda)
}
}
# resulting values for the training objective function as well as the weights
return(list('vals'=obj_func_val,'W'=W))
} # updating the weight vector
sgd_update_weight <- function(W_prev, x, y_true, y_pred, lambda, eta) {
## Computer the Gradient
grad = - (y_true-y_pred) * x
## Here I just combine the regularisation term with prev w
return(W_prev*(1-eta * lambda) - eta * grad)
}
根据上述我们写的计算更新目标函数和参数的方法,讲算法用R实现
下面是实验部分
## Load the train data and train label
train.data <- read_data('assignment1_datasets/Task1C_train.csv',TRUE)$x
train.label <- read_data('assignment1_datasets/Task1C_train.csv',TRUE)$y
## Load the testdata and test label
test.data <- read_data('assignment1_datasets/Task1C_test.csv',TRUE)$x
test.label <- read_data('assignment1_datasets/Task1C_test.csv',TRUE)$y # Set MAX EPOCH max_epoch = 18 ## Implement SGD with Ridge regression
options(warn=-1) ## Initilize
## Set the related parameters
epsilon = .001 ## Terimation threshold
eta = .01 ## Learning Rate
lambda= 0.5 ## Regularisation parmater ## Run SGD
## Cost function values
train_res2 = sgd_train(train.data, train.label, lambda, eta, epsilon, max_epoch)
## Calulate the errors
## To be mentioned here, we will only visulisation for the train error to check the converge result
errors2 = get_errors(train.data, train.label, train_res2$W)
接着,我们把SGD的error plot给绘制出来
## Visulastion for SGD
plot(train_res2$val, main="SGD", type="l", col="blue",
xlab="iteration", ylab="training objective function")
这里我们的方程比较简单,可以看到,目标函数很快就收敛了。
简单线性回归问题的优化(SGD)R语言的更多相关文章
- 【数据分析】线性回归与逻辑回归(R语言实现)
文章来源:公众号-智能化IT系统. 回归模型有多种,一般在数据分析中用的比较常用的有线性回归和逻辑回归.其描述的是一组因变量和自变量之间的关系,通过特定的方程来模拟.这么做的目的也是为了预测,但有时也 ...
- 一个简单文本分类任务-EM算法-R语言
一.问题介绍 概率分布模型中,有时只含有可观测变量,如单硬币投掷模型,对于每个测试样例,硬币最终是正面还是反面是可以观测的.而有时还含有不可观测变量,如三硬币投掷模型.问题这样描述,首先投掷硬币A,如 ...
- R语言
什么是R语言编程? R语言是一种用于统计分析和为此目的创建图形的编程语言.不是数据类型,它具有用于计算的数据对象.它用于数据挖掘,回归分析,概率估计等领域,使用其中可用的许多软件包. R语言中的不同数 ...
- R语言:用简单的文本处理方法优化我们的读书体验
博客总目录:http://www.cnblogs.com/weibaar/p/4507801.html 前言 延续之前的用R语言读琅琊榜小说,继续讲一下利用R语言做一些简单的文本处理.分词的事情.其实 ...
- R语言-简单线性回归图-方法
目标:利用R语言统计描绘50组实验对比结果 第一步:导入.csv文件 X <- read.table("D:abc11.csv",header = TRUE, sep = & ...
- R 语言中的简单线性回归
... sessionInfo() # 查询版本及系统和库等信息 getwd() path <- "E:/RSpace/R_in_Action" setwd(path) rm ...
- R语言解读一元线性回归模型
转载自:http://blog.fens.me/r-linear-regression/ 前言 在我们的日常生活中,存在大量的具有相关性的事件,比如大气压和海拔高度,海拔越高大气压强越小:人的身高和体 ...
- R语言解读多元线性回归模型
转载:http://blog.fens.me/r-multi-linear-regression/ 前言 本文接上一篇R语言解读一元线性回归模型.在许多生活和工作的实际问题中,影响因变量的因素可能不止 ...
- 多元线性回归公式推导及R语言实现
多元线性回归 多元线性回归模型 实际中有很多问题是一个因变量与多个自变量成线性相关,我们可以用一个多元线性回归方程来表示. 为了方便计算,我们将上式写成矩阵形式: Y = XW 假设自变量维度为N W ...
随机推荐
- Windows 8.1 app 界面设计
大纲: Windows 应用商店应用 UI 详细信息 http://msdn.microsoft.com/zh-cn/library/windows/apps/xaml/dn263191.aspx 快 ...
- Viewer.js 是一款强大的 jQuery 图像浏览插件。
https://blog.csdn.net/qq_29132907/article/details/80136023 一.效果图 二.代码<!DOCTYPE html><html ...
- 【Linux】OpenSSL 安装
OpenSSL 简介 OpenSSL 是一个安全套接字层密码库,囊括主要的密码算法.常用的密钥和证书封装管理功能及SSL协议,并提供丰富的应用程序供测试或其它目的使用. OpenSSL 安装 环境:L ...
- 协程的NullReferenceException 错误
public void loadPic(string url) { WWW www = new WWW(url); StartCoroutine(WaitForRequest(www)); } IEn ...
- kbmmw 5.0 中的REST 服务
目前关于REST 服务的话题越来越热,kbmmw 在5.0 里面开始支持rest.今天我就试一下kbmmw 的 rest 服务.闲话少说,开始. 老规矩,放上两个kbmMWServer1和 kbmMW ...
- AOP (切点表达式讲解)
Spring EL表达式:: 1.execution 表达式 语法格式: execution(返回类型.包名.类名.方法名(参数表)) exection(*.com.xxx.AService.*(.. ...
- Web中的四大作用域对象
request:请求对象 类型:HttpServletRequest session:表示一次会话,可以处理一个用户多个页面之间的请求 application:标识web应用上下文,类型:Servle ...
- 2018.10.25 uoj#308. 【UNR #2】UOJ拯救计划(排列组合)
传送门 有一个显然的式子:Ans=∑A(n,i)∗用i种颜色的方案数Ans=\sum A(n,i)*用i种颜色的方案数Ans=∑A(n,i)∗用i种颜色的方案数 这个东西貌似是个NPCNPCNPC. ...
- poj-1328(贪心+思维)
题目链接:传送门 思路:找最少有几个点,先求出每个点的覆盖范围,就是一个点最大可以到达的地方是y相同的地方而且直线距离是d, 所以x范围在[x-sqrt(d*d-y*y),x+sqrt(d*d-y*y ...
- HTML5 通过 FileReader 实现文件上传
概述 在页面中上传时,之前一般都是需要使用form表单进行上传.html5 中提供了FileReader 可以将文件转换成Base64编码字符串,因此就可以直接使用 AJAX实现文件上传. 实现代码 ...