Spark源码分析之RDD下的RandomForest和DecisionTree

标签:#scala##spark##决策树##机器学习# 时间:2018/08/03 16:23:37 作者:小木

DecisionTreeMetadata解析

Spark的决策树算法与随机森林方法是结合在一起的,并没有单独实现。

首先我们看一下决策树算法的逻辑:

1、特征处理,主要是连续变量的离散化
2、选择一个特征作为根特征,根据这个特征的离散结果对数据集进行划分
3、判断上述数据划分结果是否满足停止条件(即每个划分的都属于同一个类型下的数据),如果满足,停止
3、如果不满足停止条件,则对于每个划分的数据集结果,选择一个新的特征对其继续划分,直到达到最大深度或者是划分的数据集都属于同一个类等。

接下来,我们看Spark的MLLib中决策树的实现。决策树的实现用到了很多的数据结构,可参考https://www.datalearner.com/blog/1051533093494419
首先,调用代码如下:

  1. package rdd.ml.classification.tree
  2. import org.apache.spark.mllib.tree.DecisionTree
  3. import org.apache.spark.mllib.tree.model.DecisionTreeModel
  4. import org.apache.spark.mllib.util.MLUtils
  5. import rdd.ml.DFUtils
  6. /**
  7. * 用于分类的决策树模型
  8. * Created by Du Fei on 2018/7/31.
  9. */
  10. object DecisionTreeForClassification {
  11. def main(args: Array[String]): Unit = {
  12. //这里封装了一个初始化Spark环境的包,可以按照正常的来
  13. val sc = DFUtils.getSparkContext()
  14. // Load and parse the data file. 使用自带的包载入LibSVM格式的数据
  15. val data = MLUtils.loadLibSVMFile(sc, "file:/d:/data/sample_libsvm_data.txt")
  16. // Split the data into training and test sets (30% held out for testing) 将数据集划分成训练集和测试集
  17. val splits = data.randomSplit(Array(0.7, 0.3))
  18. val (trainingData, testData) = (splits(0), splits(1))
  19. // Train a DecisionTree model. 决策树参数
  20. // Empty categoricalFeaturesInfo indicates all features are continuous.
  21. val numClasses = 2 //类标签数量,这是个二分类问题
  22. val categoricalFeaturesInfo = Map[Int, Int]()
  23. val impurity = "gini" //使用gini不纯度计算特征信息量
  24. val maxDepth = 5 //树的最大深度是5
  25. val maxBins = 32 //特征划分结果最多只有5个
  26. // 训练模型
  27. val model = DecisionTree.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo,
  28. impurity, maxDepth, maxBins)
  29. // Evaluate model on test instances and compute test error
  30. val labelAndPreds = testData.map { point =>
  31. val prediction = model.predict(point.features)
  32. (point.label, prediction)
  33. }
  34. val testErr = labelAndPreds.filter(r => r._1 != r._2).count().toDouble / testData.count()
  35. println(s"Test Error = $testErr")
  36. println(s"Learned classification tree model:\n ${model.toDebugString}")
  37. // Save and load model
  38. model.save(sc, "target/tmp/myDecisionTreeClassificationModel")
  39. val sameModel = DecisionTreeModel.load(sc, "target/tmp/myDecisionTreeClassificationModel")
  40. }
  41. }

由于在Spark的MLLib中,决策树核心代码是调用RandomForest完成的,意思是它将随机森林的数的数量设置为1,然后进行训练。源代码中经过层层调用最终到随机森林,为了方便看核心代码,我们把调用过程改写一下罗列如下:

  1. @Since("1.1.0")
  2. def trainClassifier(
  3. input: RDD[LabeledPoint],
  4. numClasses: Int,
  5. categoricalFeaturesInfo: Map[Int, Int],
  6. impurity: String,
  7. maxDepth: Int,
  8. maxBins: Int): DecisionTreeModel = {
  9. // 设置信息量计算方式
  10. val impurityType = Impurities.fromString(impurity)
  11. // 封装策略
  12. val strategy = new Strategy(algo, impurity, maxDepth, numClasses, maxBins,
  13. quantileCalculationStrategy, categoricalFeaturesInfo)
  14. // 调用随机森林算法训练
  15. val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = "all", seed = seed)
  16. val rfModel = rf.run(input)
  17. rfModel.trees(0) //将第一棵树返回即可得到决策树训练结果
  18. }

接下来我们重点看随机森林中训练过程的代码:

  1. /**
  2. * Train a random forest.
  3. *
  4. * @param input Training data: RDD of `LabeledPoint`
  5. * @return an unweighted set of trees
  6. */
  7. def run(
  8. input: RDD[LabeledPoint],
  9. strategy: OldStrategy,
  10. numTrees: Int,
  11. featureSubsetStrategy: String,
  12. seed: Long,
  13. instr: Option[Instrumentation[_]],
  14. parentUID: Option[String] = None): Array[DecisionTreeModel] = {
  15. // timer是用来记录运行时间的,可以忽略
  16. val timer = new TimeTracker()
  17. timer.start("total")
  18. timer.start("init")
  19. // retag是给数据重新打标签,为了和Java相兼容而执行的,如果不加这个,在和java通信的时候可能会有问题
  20. val retaggedInput = input.retag(classOf[LabeledPoint])
  21. // 步骤一:构造决策树的元数据,这里面大多数和strategy都重复了,感觉实际上没必要搞这么复杂
  22. val metadata =
  23. DecisionTreeMetadata.buildMetadata(retaggedInput, strategy, numTrees, featureSubsetStrategy)
  24. // 监控
  25. instr match {
  26. case Some(instrumentation) =>
  27. instrumentation.logNumFeatures(metadata.numFeatures)
  28. instrumentation.logNumClasses(metadata.numClasses)
  29. case None =>
  30. logInfo("numFeatures: " + metadata.numFeatures)
  31. logInfo("numClasses: " + metadata.numClasses)
  32. }
  33. // Find the splits and the corresponding bins (interval between the splits) using a sample
  34. // of the input data.
  35. timer.start("findSplits")
  36. // 步骤二:对数据集的特征进行离散化(装箱),主要是针对连续属性来操作,还有无序分类属性以及有序分类属性三种)
  37. // 参考:https://www.datalearner.com/blog/1051533040913424
  38. val splits = findSplits(retaggedInput, metadata, seed)
  39. // 这几句都是日志类的,不用看
  40. timer.stop("findSplits")
  41. logDebug("numBins: feature: number of bins")
  42. logDebug(Range(0, metadata.numFeatures).map { featureIndex =>
  43. s"\t$featureIndex\t${metadata.numBins(featureIndex)}"
  44. }.mkString("\n"))
  45. // Bin feature values (TreePoint representation).
  46. // Cache input RDD for speedup during multiple passes.
  47. val treeInput = TreePoint.convertToTreeRDD(retaggedInput, splits, metadata)
  48. val withReplacement = numTrees > 1
  49. val baggedInput = BaggedPoint
  50. .convertToBaggedRDD(treeInput, strategy.subsamplingRate, numTrees, withReplacement, seed)
  51. .persist(StorageLevel.MEMORY_AND_DISK)
  52. // depth of the decision tree 这里有一个限制,就是树的深度不能超过30,可能是因为性能
  53. val maxDepth = strategy.maxDepth
  54. require(maxDepth <= 30,
  55. s"DecisionTree currently only supports maxDepth <= 30, but was given maxDepth = $maxDepth.")
  56. // Max memory usage for aggregates 内存使用限制
  57. // TODO: Calculate memory usage more precisely.
  58. val maxMemoryUsage: Long = strategy.maxMemoryInMB * 1024L * 1024L
  59. logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.")
  60. /*
  61. * The main idea here is to perform group-wise training of the decision tree nodes thus
  62. * reducing the passes over the data from (# nodes) to (# nodes / maxNumberOfNodesPerGroup).
  63. * Each data sample is handled by a particular node (or it reaches a leaf and is not used
  64. * in lower levels).
  65. */
  66. // Create an RDD of node Id cache.
  67. // At first, all the rows belong to the root nodes (node Id == 1).
  68. val nodeIdCache = if (strategy.useNodeIdCache) {
  69. Some(NodeIdCache.init(
  70. data = baggedInput,
  71. numTrees = numTrees,
  72. checkpointInterval = strategy.checkpointInterval,
  73. initVal = 1))
  74. } else {
  75. None
  76. }
  77. /*
  78. Stack of nodes to train: (treeIndex, node)
  79. The reason this is a stack is that we train many trees at once, but we want to focus on
  80. completing trees, rather than training all simultaneously. If we are splitting nodes from
  81. 1 tree, then the new nodes to split will be put at the top of this stack, so we will continue
  82. training the same tree in the next iteration. This focus allows us to send fewer trees to
  83. workers on each iteration; see topNodesForGroup below.
  84. */
  85. // 用一个可变的数组堆栈来存储需要训练的节点,里面放的是树的索引和节点LearningNode,由于这里说的是决策树,所以树只有一个
  86. val nodeStack = new mutable.ArrayStack[(Int, LearningNode)]
  87. val rng = new Random()
  88. rng.setSeed(seed)
  89. // Allocate and queue root nodes.
  90. // 把节点放到前面的数组堆栈中,这里只有一个节点
  91. val topNodes = Array.fill[LearningNode](numTrees)(LearningNode.emptyNode(nodeIndex = 1))
  92. Range(0, numTrees).foreach(treeIndex => nodeStack.push((treeIndex, topNodes(treeIndex))))
  93. timer.stop("init")
  94. // 步骤三: 循环训练所有的节点,由于是决策树,这里只会循环一次
  95. while (nodeStack.nonEmpty) {
  96. // Collect some nodes to split, and choose features for each node (if subsampling).
  97. // Each group of nodes may come from one or multiple trees, and at multiple levels.
  98. // 步骤3.1:选择某些节点进行划分,决策树没得节点可选。
  99. val (nodesForGroup, treeToNodeToIndexInfo) =
  100. RandomForest.selectNodesToSplit(nodeStack, maxMemoryUsage, metadata, rng)
  101. // Sanity check (should never occur):
  102. assert(nodesForGroup.nonEmpty,
  103. s"RandomForest selected empty nodesForGroup. Error for unknown reason.")
  104. // Only send trees to worker if they contain nodes being split this iteration.
  105. val topNodesForGroup: Map[Int, LearningNode] =
  106. nodesForGroup.keys.map(treeIdx => treeIdx -> topNodes(treeIdx)).toMap
  107. // Choose node splits, and enqueue new nodes as needed.
  108. timer.start("findBestSplits")
  109. // 为当前的树寻找最佳的划分
  110. RandomForest.findBestSplits(baggedInput, metadata, topNodesForGroup, nodesForGroup,
  111. treeToNodeToIndexInfo, splits, nodeStack, timer, nodeIdCache)
  112. timer.stop("findBestSplits")
  113. }
  114. baggedInput.unpersist()
  115. timer.stop("total")
  116. logInfo("Internal timing for DecisionTree:")
  117. logInfo(s"$timer")
  118. // Delete any remaining checkpoints used for node Id cache.
  119. if (nodeIdCache.nonEmpty) {
  120. try {
  121. nodeIdCache.get.deleteAllCheckpoints()
  122. } catch {
  123. case e: IOException =>
  124. logWarning(s"delete all checkpoints failed. Error reason: ${e.getMessage}")
  125. }
  126. }
  127. val numFeatures = metadata.numFeatures
  128. parentUID match {
  129. case Some(uid) =>
  130. if (strategy.algo == OldAlgo.Classification) {
  131. topNodes.map { rootNode =>
  132. new DecisionTreeClassificationModel(uid, rootNode.toNode, numFeatures,
  133. strategy.getNumClasses)
  134. }
  135. } else {
  136. topNodes.map { rootNode =>
  137. new DecisionTreeRegressionModel(uid, rootNode.toNode, numFeatures)
  138. }
  139. }
  140. case None =>
  141. if (strategy.algo == OldAlgo.Classification) {
  142. topNodes.map { rootNode =>
  143. new DecisionTreeClassificationModel(rootNode.toNode, numFeatures,
  144. strategy.getNumClasses)
  145. }
  146. } else {
  147. topNodes.map(rootNode => new DecisionTreeRegressionModel(rootNode.toNode, numFeatures))
  148. }
  149. }
  150. }

通过上面的代码我们可以看出,其实主要就分两块,第一块是特征处理,然后寻找最佳划分,这里多了很多步骤是因为是实现的决策树,因此需要构造节点,抽样特征或者数据等。接下来,我们看核心的树的学习实现代码:

  1. /**
  2. * Given a group of nodes, this finds the best split for each node.
  3. *
  4. * @param input Training data: RDD of [[TreePoint]]
  5. * @param metadata Learning and dataset metadata
  6. * @param topNodesForGroup For each tree in group, tree index -> root node.
  7. * Used for matching instances with nodes.
  8. * @param nodesForGroup Mapping: treeIndex --> nodes to be split in tree
  9. * @param treeToNodeToIndexInfo Mapping: treeIndex --> nodeIndex --> nodeIndexInfo,
  10. * where nodeIndexInfo stores the index in the group and the
  11. * feature subsets (if using feature subsets).
  12. * @param splits possible splits for all features, indexed (numFeatures)(numSplits)
  13. * @param nodeStack Queue of nodes to split, with values (treeIndex, node).
  14. * Updated with new non-leaf nodes which are created.
  15. * @param nodeIdCache Node Id cache containing an RDD of Array[Int] where
  16. * each value in the array is the data point's node Id
  17. * for a corresponding tree. This is used to prevent the need
  18. * to pass the entire tree to the executors during
  19. * the node stat aggregation phase.
  20. */
  21. private[tree] def findBestSplits(
  22. input: RDD[BaggedPoint[TreePoint]],
  23. metadata: DecisionTreeMetadata,
  24. topNodesForGroup: Map[Int, LearningNode],
  25. nodesForGroup: Map[Int, Array[LearningNode]],
  26. treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]],
  27. splits: Array[Array[Split]],
  28. nodeStack: mutable.ArrayStack[(Int, LearningNode)],
  29. timer: TimeTracker = new TimeTracker,
  30. nodeIdCache: Option[NodeIdCache] = None): Unit = {
  31. /*
  32. * The high-level descriptions of the best split optimizations are noted here.
  33. *
  34. * *Group-wise training*
  35. * We perform bin calculations for groups of nodes to reduce the number of
  36. * passes over the data. Each iteration requires more computation and storage,
  37. * but saves several iterations over the data.
  38. *
  39. * *Bin-wise computation*
  40. * We use a bin-wise best split computation strategy instead of a straightforward best split
  41. * computation strategy. Instead of analyzing each sample for contribution to the left/right
  42. * child node impurity of every split, we first categorize each feature of a sample into a
  43. * bin. We exploit this structure to calculate aggregates for bins and then use these aggregates
  44. * to calculate information gain for each split.
  45. *
  46. * *Aggregation over partitions*
  47. * Instead of performing a flatMap/reduceByKey operation, we exploit the fact that we know
  48. * the number of splits in advance. Thus, we store the aggregates (at the appropriate
  49. * indices) in a single array for all bins and rely upon the RDD aggregate method to
  50. * drastically reduce the communication overhead.
  51. */
  52. // numNodes: Number of nodes in this group
  53. val numNodes = nodesForGroup.values.map(_.length).sum
  54. logDebug("numNodes = " + numNodes)
  55. logDebug("numFeatures = " + metadata.numFeatures)
  56. logDebug("numClasses = " + metadata.numClasses)
  57. logDebug("isMulticlass = " + metadata.isMulticlass)
  58. logDebug("isMulticlassWithCategoricalFeatures = " +
  59. metadata.isMulticlassWithCategoricalFeatures)
  60. logDebug("using nodeIdCache = " + nodeIdCache.nonEmpty.toString)
  61. /**
  62. * Performs a sequential aggregation over a partition for a particular tree and node.
  63. *
  64. * For each feature, the aggregate sufficient statistics are updated for the relevant
  65. * bins.
  66. *
  67. * @param treeIndex Index of the tree that we want to perform aggregation for.
  68. * @param nodeInfo The node info for the tree node.
  69. * @param agg Array storing aggregate calculation, with a set of sufficient statistics
  70. * for each (node, feature, bin).
  71. * @param baggedPoint Data point being aggregated.
  72. */
  73. def nodeBinSeqOp(
  74. treeIndex: Int,
  75. nodeInfo: NodeIndexInfo,
  76. agg: Array[DTStatsAggregator],
  77. baggedPoint: BaggedPoint[TreePoint]): Unit = {
  78. if (nodeInfo != null) {
  79. val aggNodeIndex = nodeInfo.nodeIndexInGroup
  80. val featuresForNode = nodeInfo.featureSubset
  81. val instanceWeight = baggedPoint.subsampleWeights(treeIndex)
  82. if (metadata.unorderedFeatures.isEmpty) {
  83. orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode)
  84. } else {
  85. mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, splits,
  86. metadata.unorderedFeatures, instanceWeight, featuresForNode)
  87. }
  88. agg(aggNodeIndex).updateParent(baggedPoint.datum.label, instanceWeight)
  89. }
  90. }
  91. /**
  92. * Performs a sequential aggregation over a partition.
  93. *
  94. * Each data point contributes to one node. For each feature,
  95. * the aggregate sufficient statistics are updated for the relevant bins.
  96. *
  97. * @param agg Array storing aggregate calculation, with a set of sufficient statistics for
  98. * each (node, feature, bin).
  99. * @param baggedPoint Data point being aggregated.
  100. * @return agg
  101. */
  102. def binSeqOp(
  103. agg: Array[DTStatsAggregator],
  104. baggedPoint: BaggedPoint[TreePoint]): Array[DTStatsAggregator] = {
  105. treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
  106. val nodeIndex =
  107. topNodesForGroup(treeIndex).predictImpl(baggedPoint.datum.binnedFeatures, splits)
  108. nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint)
  109. }
  110. agg
  111. }
  112. /**
  113. * Do the same thing as binSeqOp, but with nodeIdCache.
  114. */
  115. def binSeqOpWithNodeIdCache(
  116. agg: Array[DTStatsAggregator],
  117. dataPoint: (BaggedPoint[TreePoint], Array[Int])): Array[DTStatsAggregator] = {
  118. treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
  119. val baggedPoint = dataPoint._1
  120. val nodeIdCache = dataPoint._2
  121. val nodeIndex = nodeIdCache(treeIndex)
  122. nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint)
  123. }
  124. agg
  125. }
  126. /**
  127. * Get node index in group --> features indices map,
  128. * which is a short cut to find feature indices for a node given node index in group.
  129. */
  130. def getNodeToFeatures(
  131. treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]]): Option[Map[Int, Array[Int]]] = {
  132. if (!metadata.subsamplingFeatures) {
  133. None
  134. } else {
  135. val mutableNodeToFeatures = new mutable.HashMap[Int, Array[Int]]()
  136. treeToNodeToIndexInfo.values.foreach { nodeIdToNodeInfo =>
  137. nodeIdToNodeInfo.values.foreach { nodeIndexInfo =>
  138. assert(nodeIndexInfo.featureSubset.isDefined)
  139. mutableNodeToFeatures(nodeIndexInfo.nodeIndexInGroup) = nodeIndexInfo.featureSubset.get
  140. }
  141. }
  142. Some(mutableNodeToFeatures.toMap)
  143. }
  144. }
  145. // array of nodes to train indexed by node index in group
  146. val nodes = new Array[LearningNode](numNodes)
  147. nodesForGroup.foreach { case (treeIndex, nodesForTree) =>
  148. nodesForTree.foreach { node =>
  149. nodes(treeToNodeToIndexInfo(treeIndex)(node.id).nodeIndexInGroup) = node
  150. }
  151. }
  152. // Calculate best splits for all nodes in the group
  153. timer.start("chooseSplits")
  154. // In each partition, iterate all instances and compute aggregate stats for each node,
  155. // yield a (nodeIndex, nodeAggregateStats) pair for each node.
  156. // After a `reduceByKey` operation,
  157. // stats of a node will be shuffled to a particular partition and be combined together,
  158. // then best splits for nodes are found there.
  159. // Finally, only best Splits for nodes are collected to driver to construct decision tree.
  160. val nodeToFeatures = getNodeToFeatures(treeToNodeToIndexInfo)
  161. val nodeToFeaturesBc = input.sparkContext.broadcast(nodeToFeatures)
  162. val partitionAggregates: RDD[(Int, DTStatsAggregator)] = if (nodeIdCache.nonEmpty) {
  163. input.zip(nodeIdCache.get.nodeIdsForInstances).mapPartitions { points =>
  164. // Construct a nodeStatsAggregators array to hold node aggregate stats,
  165. // each node will have a nodeStatsAggregator
  166. val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex =>
  167. val featuresForNode = nodeToFeaturesBc.value.map { nodeToFeatures =>
  168. nodeToFeatures(nodeIndex)
  169. }
  170. new DTStatsAggregator(metadata, featuresForNode)
  171. }
  172. // iterator all instances in current partition and update aggregate stats
  173. points.foreach(binSeqOpWithNodeIdCache(nodeStatsAggregators, _))
  174. // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
  175. // which can be combined with other partition using `reduceByKey`
  176. nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator
  177. }
  178. } else {
  179. input.mapPartitions { points =>
  180. // Construct a nodeStatsAggregators array to hold node aggregate stats,
  181. // each node will have a nodeStatsAggregator
  182. val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex =>
  183. val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
  184. Some(nodeToFeatures(nodeIndex))
  185. }
  186. new DTStatsAggregator(metadata, featuresForNode)
  187. }
  188. // iterator all instances in current partition and update aggregate stats
  189. points.foreach(binSeqOp(nodeStatsAggregators, _))
  190. // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
  191. // which can be combined with other partition using `reduceByKey`
  192. nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator
  193. }
  194. }
  195. val nodeToBestSplits = partitionAggregates.reduceByKey((a, b) => a.merge(b)).map {
  196. case (nodeIndex, aggStats) =>
  197. val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
  198. Some(nodeToFeatures(nodeIndex))
  199. }
  200. // find best split for each node
  201. val (split: Split, stats: ImpurityStats) =
  202. binsToBestSplit(aggStats, splits, featuresForNode, nodes(nodeIndex))
  203. (nodeIndex, (split, stats))
  204. }.collectAsMap()
  205. timer.stop("chooseSplits")
  206. val nodeIdUpdaters = if (nodeIdCache.nonEmpty) {
  207. Array.fill[mutable.Map[Int, NodeIndexUpdater]](
  208. metadata.numTrees)(mutable.Map[Int, NodeIndexUpdater]())
  209. } else {
  210. null
  211. }
  212. // Iterate over all nodes in this group.
  213. nodesForGroup.foreach { case (treeIndex, nodesForTree) =>
  214. nodesForTree.foreach { node =>
  215. val nodeIndex = node.id
  216. val nodeInfo = treeToNodeToIndexInfo(treeIndex)(nodeIndex)
  217. val aggNodeIndex = nodeInfo.nodeIndexInGroup
  218. val (split: Split, stats: ImpurityStats) =
  219. nodeToBestSplits(aggNodeIndex)
  220. logDebug("best split = " + split)
  221. // Extract info for this node. Create children if not leaf.
  222. val isLeaf =
  223. (stats.gain <= 0) || (LearningNode.indexToLevel(nodeIndex) == metadata.maxDepth)
  224. node.isLeaf = isLeaf
  225. node.stats = stats
  226. logDebug("Node = " + node)
  227. if (!isLeaf) {
  228. node.split = Some(split)
  229. val childIsLeaf = (LearningNode.indexToLevel(nodeIndex) + 1) == metadata.maxDepth
  230. val leftChildIsLeaf = childIsLeaf || (stats.leftImpurity == 0.0)
  231. val rightChildIsLeaf = childIsLeaf || (stats.rightImpurity == 0.0)
  232. node.leftChild = Some(LearningNode(LearningNode.leftChildIndex(nodeIndex),
  233. leftChildIsLeaf, ImpurityStats.getEmptyImpurityStats(stats.leftImpurityCalculator)))
  234. node.rightChild = Some(LearningNode(LearningNode.rightChildIndex(nodeIndex),
  235. rightChildIsLeaf, ImpurityStats.getEmptyImpurityStats(stats.rightImpurityCalculator)))
  236. if (nodeIdCache.nonEmpty) {
  237. val nodeIndexUpdater = NodeIndexUpdater(
  238. split = split,
  239. nodeIndex = nodeIndex)
  240. nodeIdUpdaters(treeIndex).put(nodeIndex, nodeIndexUpdater)
  241. }
  242. // enqueue left child and right child if they are not leaves
  243. if (!leftChildIsLeaf) {
  244. nodeStack.push((treeIndex, node.leftChild.get))
  245. }
  246. if (!rightChildIsLeaf) {
  247. nodeStack.push((treeIndex, node.rightChild.get))
  248. }
  249. logDebug("leftChildIndex = " + node.leftChild.get.id +
  250. ", impurity = " + stats.leftImpurity)
  251. logDebug("rightChildIndex = " + node.rightChild.get.id +
  252. ", impurity = " + stats.rightImpurity)
  253. }
  254. }
  255. }
  256. if (nodeIdCache.nonEmpty) {
  257. // Update the cache if needed.
  258. nodeIdCache.get.updateNodeIndices(input, nodeIdUpdaters, splits)
  259. }
  260. }

对节点的分类是采用逐层分裂的方式。

欢迎大家关注DataLearner官方微信,接受最新的AI技术推送
Back to Top