连续变量离散化(Scala版本)

标签:#scala##编程# 时间:2018/08/01 10:07:34 作者:小木

这段代码是Spark MLLib中决策树算法中的代码片段,是将连续的变量离散化。输入参数为:

1、featureSamples:在原始版本中,某一个特征下的数据可能很多,为了避免对过多的数据进行离散化导致速度上太慢,于是对于大于10000条数据的特征进行了抽样,这里就是抽样结果的数据,是一个double类型的数组
2、metadata:元数据,就是描述特征情况的,包括是离散的还是抽样的等
3、featureIndex:特征索引,表明是第几个特征

大致思路如下:
首先需要统计出不同变量值的个数,以及总的不同特征值的数量。假设总共有N个不同的值,那么一般情况下是将N个不同的值按照从小到大排序好,总共需要N-1个点将其分开(例如将1.2、2.3、3.2切分成三个区间,只需要1.75和2.75两个数字即可,结果为(\infty,1.75](1.75,2.75](2.75,\infty])。

但是在此之前已经给该特征设置了划分的区间数,当期望分割数大于不重复数据点数量,那么直接按照不同变量个数来分即可。如果期望分割的区间数小于这个值,那么按照期望分割的数量来切分。其切分逻辑如下图所示:

i-1是上一个数据点的值,如1.2,previousCount表明到第i-1这个值为止的数据点总数(假设1.2这个数据点有4个,那么previousCount=4)。i是当前数据点(例如之前的2.3),currentCount是previousCount+count(i)的结果(假设2.3数据有3个,那么currentCount=7)。middlePoint是二者中间的位置(即1.75)。targetCount是我们离散化数据点的分割点,这个值是用数据点总数除以(期望分割数+1)得到的(也就是下面代码中的stride)。其中:

previousGap = |targetCount - previousCount|
currentGap = |targetCount - currentCount|

当targetCount离currentCount更近的时候,那么把middlePoint当做一个切分点,也就是上图红色区域(其实这个而红色区域也不太准确,但差不多这个意思,也就是要以count为坐标轴算)。所以整个逻辑就是一直循环所有的数据点,当targetCount落入红色区域之后,就把middlePoint作为切分点。

1、计算数据点间隔数:val stride: Double = numSamples.toDouble / (numSplits + 1) 它表明的是每个区间需要包含的数据点个数,注意这里是用数据点总数除以区间数,不是不同数据点除的。我们定义一个targetCount作为每个区间应该有的数据点的数量。即最开始是一个stride,然后是两个stride,一直到最后应该是numSamples。
2、比较当前数据点的数量和下一个数据点的数量与targetCount的距离,targetCount如果离当前数据点近,那么继续循环,知道离下一个数据点近的时候确定分割点位于这两个数据之间。

我们一起看一下:

  1. /**
  2. * Find splits for a continuous feature
  3. * NOTE: Returned number of splits is set based on `featureSamples` and
  4. * could be different from the specified `numSplits`.
  5. * The `numSplits` attribute in the `DecisionTreeMetadata` class will be set accordingly.
  6. *
  7. * @param featureSamples feature values of each sample
  8. * @param metadata decision tree metadata
  9. * NOTE: `metadata.numbins` will be changed accordingly
  10. * if there are not enough splits to be found
  11. * @param featureIndex feature index to find splits
  12. * @return array of split thresholds
  13. */
  14. private[tree] def findSplitsForContinuousFeature(
  15. featureSamples: Iterable[Double],
  16. metadata: DecisionTreeMetadata,
  17. featureIndex: Int): Array[Double] = {
  18. require(metadata.isContinuous(featureIndex),
  19. "findSplitsForContinuousFeature can only be used to find splits for a continuous feature.")
  20. val splits: Array[Double] = if (featureSamples.isEmpty) { //如果数据点为空,返回空数组
  21. Array.empty[Double]
  22. } else {
  23. val numSplits = metadata.numSplits(featureIndex) //期望分割数
  24. // get count for each distinct value
  25. // 这里是计算所有不同的数据点及其数量,最后按照数据点从小到大进行排序(注意,不是按照数据点的数量排序)
  26. val (valueCountMap, numSamples) = featureSamples.foldLeft((Map.empty[Double, Int], 0)) {
  27. case ((m, cnt), x) =>
  28. (m + ((x, m.getOrElse(x, 0) + 1)), cnt + 1)
  29. }
  30. // sort distinct values
  31. val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray
  32. // 这里计算可能的分割数,如果分割数为0返回空数组
  33. val possibleSplits = valueCounts.length - 1
  34. if (possibleSplits == 0) {
  35. // constant feature
  36. Array.empty[Double]
  37. } else if (possibleSplits <= numSplits) { //可能的分割数小于期望分割数,那么就按照可能分割数来,一个数据点落入一个区间
  38. // if possible splits is not enough or just enough, just return all possible splits
  39. (1 to possibleSplits)
  40. .map(index => (valueCounts(index - 1)._1 + valueCounts(index)._1) / 2.0)
  41. .toArray
  42. } else {
  43. // 如果可能分割数大于期望分割数,按照期望分割数来,这就表明有区间包含了至少两个数据点,切分逻辑按上面说的来。
  44. // stride between splits
  45. val stride: Double = numSamples.toDouble / (numSplits + 1)
  46. logDebug("stride = " + stride)
  47. // iterate `valueCount` to find splits,这个值是一个数组,返回的是切分点坐标
  48. val splitsBuilder = mutable.ArrayBuilder.make[Double]
  49. var index = 1
  50. // currentCount: sum of counts of values that have been visited
  51. // 第一个数据点的数量作为第一个previousCount
  52. var currentCount = valueCounts(0)._2
  53. // targetCount: target value for `currentCount`.
  54. // If `currentCount` is closest value to `targetCount`,
  55. // then current value is a split threshold.
  56. // After finding a split threshold, `targetCount` is added by stride.
  57. // 第一个目标数据量就是一个区间
  58. var targetCount = stride
  59. // 循环所有的数据点,找出符合上述要求的点作为切分点
  60. while (index < valueCounts.length) {
  61. val previousCount = currentCount
  62. currentCount += valueCounts(index)._2
  63. val previousGap = math.abs(previousCount - targetCount)
  64. val currentGap = math.abs(currentCount - targetCount)
  65. // If adding count of current value to currentCount
  66. // makes the gap between currentCount and targetCount smaller,
  67. // previous value is a split threshold.
  68. // 如果targetCount离currentCount更近,那么就把当前两个点之间作为切分点
  69. if (previousGap < currentGap) {
  70. splitsBuilder += (valueCounts(index - 1)._1 + valueCounts(index)._1) / 2.0
  71. targetCount += stride
  72. }
  73. index += 1
  74. }
  75. // 这里就是最终的切分点数组了
  76. splitsBuilder.result()
  77. }
  78. }
  79. splits
  80. }
欢迎大家关注DataLearner官方微信,接受最新的AI技术推送
Back to Top