本编博客继续分享简单的机器学习的R语言实现。

今天是关于简单的线性回归方程问题的优化问题

常用方法,我们会考虑随机梯度递降,好处是,我们不需要遍历数据集中的所有元素,这样可以大幅度的减少运算量。

具体的算法参考下面:

首先我们先定义我们需要的参数的Notation

上述算法中,为了避免过拟合,我们采用了L2的正则化,在更新步骤中,我们会发现,这个正则项目,对参数更新的影响

下面是代码部分:

## Load Library
library(ggplot2)
library(reshape2)
library(mvtnorm) ## Function for reading the data
read_data <- function(fname, sc) {
data <- read.csv(file=fname,head=TRUE,sep=",")
nr = dim(data)[1]
nc = dim(data)[2]
x = data[1:nr,1:(nc-1)]
y = data[1:nr,nc]
if (isTRUE(sc)) {
x = scale(x) ## Scale x
y = scale(y) ## Scale y
}
return (list("x" = x, "y" = y))
}

我们定义了一个读取数据的方程,这里,我们会把数据集给scale一下,可以保证进一步提高运算速度

## Matrix Product Function
predict_func <- function(Phi, w){
return(Phi%*%w)
} ## Function to compute the cost function
train_obj_func <- function (Phi, w, label, lambda){
# Cost funtion including the L2 norm regulaztion
return(.5 * mean((predict_func(Phi, w) - label)^2) + .5 * lambda * t(w) %*% w)
} ## Return the errors for each iteration
get_errors <- function(data, label, W) {
n_weights = dim(W)[1]
Phi <- cbind('X0' = 1, data)
errors = matrix(,nrow=n_weights, ncol=2)
for (tau in 1:n_weights) {
errors[tau,1] = tau
errors[tau,2] = train_obj_func(Phi, W[tau,],label, 0) ## Get the errors, set the lambda to 0
}
return(errors)
}

 同时,我们定义了计算矩阵乘法,计算目标函数以及求误差的方程。

sgd_train <- function(train_x, train_y, lambda, eta, epsilon, max_epoch) {

    ## Prepare the traindata
## Attach the 1 for X0
Phi <- as.matrix(cbind('X0'=1, train.data)) ## Calculate the max iteration time for the SGD
train_len = dim(train_x)[1]
tau_max = max_epoch * train_len W <- matrix(,nrow=tau_max, ncol=ncol(Phi))
set.seed(1234)
## Random Generate the start parameter
W[1,] <- runif(ncol(Phi)) tau = 1 # counter
## Create a dateframe to store the value of cost function for each iteration
obj_func_val <-matrix(,nrow=tau_max, ncol=1)
obj_func_val[tau,1] = train_obj_func(Phi, W[tau,],train_y, lambda) while (tau <= tau_max){ # check termination criteria
if (obj_func_val[tau,1]<=epsilon) {break} # shuffle data:
train_index <- sample(1:train_len, train_len, replace = FALSE) # loop over each datapoint
for (i in train_index) {
# increment the counter
tau <- tau + 1
if (tau > tau_max) {break} # make the weight update
y_pred <- predict_func(Phi[i,], W[tau-1,])
W[tau,] <- sgd_update_weight(W[tau-1,], Phi[i,], train_y[i], y_pred, lambda, eta) # keep track of the objective funtion
obj_func_val[tau,1] = train_obj_func(Phi, W[tau,],train_y, lambda)
}
}
# resulting values for the training objective function as well as the weights
return(list('vals'=obj_func_val,'W'=W))
} # updating the weight vector
sgd_update_weight <- function(W_prev, x, y_true, y_pred, lambda, eta) {
## Computer the Gradient
grad = - (y_true-y_pred) * x
## Here I just combine the regularisation term with prev w
return(W_prev*(1-eta * lambda) - eta * grad)
}

  根据上述我们写的计算更新目标函数和参数的方法,讲算法用R实现

下面是实验部分

## Load the train data and train label
train.data <- read_data('assignment1_datasets/Task1C_train.csv',TRUE)$x
train.label <- read_data('assignment1_datasets/Task1C_train.csv',TRUE)$y
## Load the testdata and test label
test.data <- read_data('assignment1_datasets/Task1C_test.csv',TRUE)$x
test.label <- read_data('assignment1_datasets/Task1C_test.csv',TRUE)$y # Set MAX EPOCH max_epoch = 18 ## Implement SGD with Ridge regression
options(warn=-1) ## Initilize
## Set the related parameters
epsilon = .001 ## Terimation threshold
eta = .01 ## Learning Rate
lambda= 0.5 ## Regularisation parmater ## Run SGD
## Cost function values
train_res2 = sgd_train(train.data, train.label, lambda, eta, epsilon, max_epoch)
## Calulate the errors
## To be mentioned here, we will only visulisation for the train error to check the converge result
errors2 = get_errors(train.data, train.label, train_res2$W)  

 接着,我们把SGD的error plot给绘制出来

## Visulastion for SGD
plot(train_res2$val, main="SGD", type="l", col="blue",
xlab="iteration", ylab="training objective function")

  

  

这里我们的方程比较简单,可以看到,目标函数很快就收敛了。

简单线性回归问题的优化(SGD)R语言的更多相关文章

  1. 【数据分析】线性回归与逻辑回归(R语言实现)

    文章来源:公众号-智能化IT系统. 回归模型有多种,一般在数据分析中用的比较常用的有线性回归和逻辑回归.其描述的是一组因变量和自变量之间的关系,通过特定的方程来模拟.这么做的目的也是为了预测,但有时也 ...

  2. 一个简单文本分类任务-EM算法-R语言

    一.问题介绍 概率分布模型中,有时只含有可观测变量,如单硬币投掷模型,对于每个测试样例,硬币最终是正面还是反面是可以观测的.而有时还含有不可观测变量,如三硬币投掷模型.问题这样描述,首先投掷硬币A,如 ...

  3. R语言

    什么是R语言编程? R语言是一种用于统计分析和为此目的创建图形的编程语言.不是数据类型,它具有用于计算的数据对象.它用于数据挖掘,回归分析,概率估计等领域,使用其中可用的许多软件包. R语言中的不同数 ...

  4. R语言:用简单的文本处理方法优化我们的读书体验

    博客总目录:http://www.cnblogs.com/weibaar/p/4507801.html 前言 延续之前的用R语言读琅琊榜小说,继续讲一下利用R语言做一些简单的文本处理.分词的事情.其实 ...

  5. R语言-简单线性回归图-方法

    目标:利用R语言统计描绘50组实验对比结果 第一步:导入.csv文件 X <- read.table("D:abc11.csv",header = TRUE, sep = & ...

  6. R 语言中的简单线性回归

    ... sessionInfo() # 查询版本及系统和库等信息 getwd() path <- "E:/RSpace/R_in_Action" setwd(path) rm ...

  7. R语言解读一元线性回归模型

    转载自:http://blog.fens.me/r-linear-regression/ 前言 在我们的日常生活中,存在大量的具有相关性的事件,比如大气压和海拔高度,海拔越高大气压强越小:人的身高和体 ...

  8. R语言解读多元线性回归模型

    转载:http://blog.fens.me/r-multi-linear-regression/ 前言 本文接上一篇R语言解读一元线性回归模型.在许多生活和工作的实际问题中,影响因变量的因素可能不止 ...

  9. 多元线性回归公式推导及R语言实现

    多元线性回归 多元线性回归模型 实际中有很多问题是一个因变量与多个自变量成线性相关,我们可以用一个多元线性回归方程来表示. 为了方便计算,我们将上式写成矩阵形式: Y = XW 假设自变量维度为N W ...

随机推荐

  1. 数的划分(NOIP2001&水题测试2017082401)

    题目链接:数的划分 这题直接搜索就行了.给代码,思路没什么好讲的,要讲的放在代码后面: #include<bits/stdc++.h> using namespace std; int d ...

  2. Anaconda 3中配置OpenCV

    平台:win10 x64+Anaconda 3(64-bit)+opencv_python-3.4.5+contrib-cp37-cp37m-win_amd64 一.OpenCV下载 Python环境 ...

  3. 2018.11.24 poj3693Maximum repetition substring(后缀数组)

    传送门 后缀数组好题. 考虑枚举循环节长度lenlenlen. 然后考虑枚举循环节的起点来更新答案. 但是直接枚举每次O(n)O(n)O(n). 考虑枚举len∗k+1len*k+1len∗k+1作为 ...

  4. [转]图解CSS的padding,margin,border属性(详细介绍及举例说明)

    图解CSS的padding,margin,border属性 W3C组织建议把所有网页上的对像都放在一个盒(box)中,设计师可以通过创建定义来控制这个盒的属性,这些对像包括段落.列表.标题.图片以及层 ...

  5. idea下启动tomcat时,打印的日志中文乱码

    idea2018.2+tomcat8+java8+win10 异常:将编码方式全都修改为UTF-8后,且tomcat的VM启动参数中配置了:-Dfile.encoding=UTF-8.导致控制台日志打 ...

  6. vue中的路由嵌套

    1定义组件 const Index = { template:` <div>首页</div> ` } const Order = { template:` <div> ...

  7. js判断软键盘是否开启弹出

    移动端关于页面布局,如果底部有position:fixed的盒子,又有input,当软键盘弹出收起都会影响页面布局.这时候Android可以监听resize事件,代码如下,而ios没有相关事件. va ...

  8. [置顶] AngularJS实战之路由ui-sref-active使用

    当我们使用angularjs的路由时,时常会出现一个需求,当选中菜单时把当前菜单的样式设置为选中状态(多数就是改变颜色) 接下来就看看Angular-UI-Router里的指令ui-sref-acti ...

  9. 回文(palindrome)

    如果一个字符串忽略标点符号.大小写和空格,正着读和反着读一模一样,那么这个字符串就是palindrome(回文).

  10. hdu 4915 括号匹配+巧模拟

    http://acm.hdu.edu.cn/showproblem.php?pid=4915 给定一个序列,由()?组成,其中?可以表示(或者),问说有一种.多种或者不存在匹配. 从左向右,优先填满n ...