Giter Club home page Giter Club logo

spark-ml-source-analysis's Introduction

              spark

spark机器学习算法研究和源码分析

  本项目对spark ml包中各种算法的原理加以介绍并且对算法的代码实现进行详细分析,旨在加深自己对机器学习算法的理解,熟悉这些算法的分布式实现方式。

本系列文章支持的spark版本

  本系列文章大部分的算法基于spark 1.6.1,少部分基于spark 2.x。

本系列的目录结构

  本系列目录如下:

说明

  本专题的大部分内容来自spark源码spark官方文档,并不用于商业用途。转载请注明本专题地址。 本专题引用他人的内容均列出了参考文献,如有侵权,请务必邮件通知作者。邮箱地址:[email protected]

  本专题的部分文章中用到了latex来写数学公式,可以在浏览器中安装MathJax插件用来展示这些公式。

  本人水平有限,分析中难免有错误和误解的地方,请大家不吝指教,万分感激。

License

  本文使用的许可见 LICENSE

spark-ml-source-analysis's People

Contributors

endymecy avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

spark-ml-source-analysis's Issues

GBDT 叶子节点特征id

您好,gbdt的mllib模型,可以通过GradientBoostedTreesModel的trees接口得到DecisionTreeModel,然后根据DecisionTreeModel的topNode接口得到的TopNode, 根据node的id接口可以得到叶子节点的id.
可是gbdt的ml模型的node接口没有id这一项, 这样怎么得到叶子节点的索引呢?请问您对这个熟悉吗?能帮帮我吗?相关代码如下:

 def getLeafNodes(node:Node):Array[Int] = {
    var treeLeafNodes = new Array[Int](0)
    if (node.isLeaf){
      treeLeafNodes = treeLeafNodes.:+(node.id)
    }else{
      treeLeafNodes = treeLeafNodes ++ getLeafNodes(node.leftNode.get)
      treeLeafNodes = treeLeafNodes ++ getLeafNodes(node.rightNode.get)
    }
    treeLeafNodes
  }


val numTrees = 100
    val boostingStrategy = BoostingStrategy.defaultParams("Classification")
    boostingStrategy.numIterations = 100 // Note: Use more iterations in practice.
    boostingStrategy.treeStrategy.numClasses = 2
    boostingStrategy.treeStrategy.maxDepth = 4
    boostingStrategy.learningRate = 0.01
    // Empty categoricalFeaturesInfo indicates all features are continuous.
    boostingStrategy.treeStrategy.categoricalFeaturesInfo = Map[Int, Int]()

    val gbdtModel = GradientBoostedTrees.train(data, boostingStrategy)

    println("gbdt train ok")
//    print(gbdtModel.toDebugString)


    //根据gbdt模型,拿到叶子节点特征
    val treeLeafArray = new Array[Array[Int]](numTrees)
    for(i<- 0.until(numTrees)){
      treeLeafArray(i) = getLeafNodes(gbdtModel.trees(i).topNode)
    }

Question about Random Forest

you've mentioned in this section https://github.com/endymecy/spark-ml-source-analysis/blob/master/%E5%88%86%E7%B1%BB%E5%92%8C%E5%9B%9E%E5%BD%92/%E7%BB%84%E5%90%88%E6%A0%91/%E9%9A%8F%E6%9C%BA%E6%A3%AE%E6%9E%97/random-forests.md that there seem to be two scoring methods: predictBySumming and predictByVoting.

But in ml package, I only find

override protected def predictRaw(features: Vector): Vector = { // TODO: When we add a generic Bagging class, handle transform there: SPARK-7128 // Classifies using majority votes. // Ignore the tree weights since all are 1.0 for now. val votes = Array.fill[Double](numClasses)(0.0) _trees.view.foreach { tree => val classCounts: Array[Double] = tree.rootNode.predictImpl(features).impurityStats.stats val total = classCounts.sum if (total != 0) { var i = 0 while (i < numClasses) { votes(i) += classCounts(i) / total i += 1 } } } Vectors.dense(votes) }

Does this means that ml random forest only support voting scoring?

Thank you in advance.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.