

本文就拟牛顿法L-BFGS的由来做一个简要的回顾,然后就其在spark mllib中的实现进行源码走读。







  1. def runLBFGS(
  2. data: RDD[(Double, Vector)],
  3. gradient: Gradient,
  4. updater: Updater,
  5. numCorrections: Int,
  6. convergenceTol: Double,
  7. maxNumIterations: Int,
  8. regParam: Double,
  9. initialWeights: Vector): (Vector, Array[Double]) = {
  10. val lossHistory = new ArrayBuffer[Double](maxNumIterations)
  11. val numExamples = data.count()
  12. val costFun =
  13. new CostFun(data, gradient, updater, regParam, numExamples)
  14. val lbfgs = new BreezeLBFGS[BDV[Double]](maxNumIterations, numCorrections, convergenceTol)
  15. val states =
  16. lbfgs.iterations(new CachedDiffFunction(costFun), initialWeights.toBreeze.toDenseVector)
  17. /**
  18. * NOTE: lossSum and loss is computed using the weights from the previous iteration
  19. * and regVal is the regularization value computed in the previous iteration as well.
  20. */
  21. var state = states.next()
  22. while(states.hasNext) {
  23. lossHistory.append(state.value)
  24. state = states.next()
  25. }
  26. lossHistory.append(state.value)
  27. val weights = Vectors.fromBreeze(state.x)
  28. logInfo("LBFGS.runLBFGS finished. Last 10 losses %s".format(
  29. lossHistory.takeRight(10).mkString(", ")))
  30. (weights, lossHistory.toArray)
  31. }


  1. private class CostFun(
  2. data: RDD[(Double, Vector)],
  3. gradient: Gradient,
  4. updater: Updater,
  5. regParam: Double,
  6. numExamples: Long) extends DiffFunction[BDV[Double]] {
  7. private var i = 0
  8. override def calculate(weights: BDV[Double]) = {
  9. // Have a local copy to avoid the serialization of CostFun object which is not serializable.
  10. val localData = data
  11. val localGradient = gradient
  12. val (gradientSum, lossSum) = localData.aggregate((BDV.zeros[Double](weights.size), 0.0))(
  13. seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) =>
  14. val l = localGradient.compute(
  15. features, label, Vectors.fromBreeze(weights), Vectors.fromBreeze(grad))
  16. (grad, loss + l)
  17. },
  18. combOp = (c1, c2) => (c1, c2) match { case ((grad1, loss1), (grad2, loss2)) =>
  19. (grad1 += grad2, loss1 + loss2)
  20. })
  21. /**
  22. * regVal is sum of weight squares if it's L2 updater;
  23. * for other updater, the same logic is followed.
  24. */
  25. val regVal = updater.compute(
  26. Vectors.fromBreeze(weights),
  27. Vectors.dense(new Array[Double](weights.size)), 0, 1, regParam)._2
  28. val loss = lossSum / numExamples + regVal
  29. /**
  30. * It will return the gradient part of regularization using updater.
  31. *
  32. * Given the input parameters, the updater basically does the following,
  33. *
  34. * w' = w - thisIterStepSize * (gradient + regGradient(w))
  35. * Note that regGradient is function of w
  36. *
  37. * If we set gradient = 0, thisIterStepSize = 1, then
  38. *
  39. * regGradient(w) = w - w'
  40. *
  41. * TODO: We need to clean it up by separating the logic of regularization out
  42. * from updater to regularizer.
  43. */
  44. // The following gradientTotal is actually the regularization part of gradient.
  45. // Will add the gradientSum computed from the data with weights in the next step.
  46. val gradientTotal = weights - updater.compute(
  47. Vectors.fromBreeze(weights),
  48. Vectors.dense(new Array[Double](weights.size)), 1, 1, regParam)._1.toBreeze
  49. // gradientTotal = gradientSum / numExamples + gradientTotal
  50. axpy(1.0 / numExamples, gradientSum, gradientTotal)
  51. i += 1
  52. (loss, gradientTotal)
  53. }
  54. }
  55. }

