连续变量离散化(Scala版本)
这段代码是Spark MLLib中决策树算法中的代码片段,是将连续的变量离散化。输入参数为:
1、featureSamples:在原始版本中,某一个特征下的数据可能很多,为了避免对过多的数据进行离散化导致速度上太慢,于是对于大于10000条数据的特征进行了抽样,这里就是抽样结果的数据,是一个double类型的数组
2、metadata:元数据,就是描述特征情况的,包括是离散的还是抽样的等
3、featureIndex:特征索引,表明是第几个特征
大致思路如下:
首先需要统计出不同变量值的个数,以及总的不同特征值的数量。假设总共有N个不同的值,那么一般情况下是将N个不同的值按照从小到大排序好,总共需要N-1个点将其分开(例如将1.2、2.3、3.2切分成三个区间,只需要1.75和2.75两个数字即可,结果为(\infty,1.75]、(1.75,2.75]、(2.75,\infty])。
但是在此之前已经给该特征设置了划分的区间数,当期望分割数大于不重复数据点数量,那么直接按照不同变量个数来分即可。如果期望分割的区间数小于这个值,那么按照期望分割的数量来切分。其切分逻辑如下图所示:
i-1是上一个数据点的值,如1.2,previousCount表明到第i-1这个值为止的数据点总数(假设1.2这个数据点有4个,那么previousCount=4)。i是当前数据点(例如之前的2.3),currentCount是previousCount+count(i)的结果(假设2.3数据有3个,那么currentCount=7)。middlePoint是二者中间的位置(即1.75)。targetCount是我们离散化数据点的分割点,这个值是用数据点总数除以(期望分割数+1)得到的(也就是下面代码中的stride)。其中:
previousGap = |targetCount - previousCount|
currentGap = |targetCount - currentCount|
当targetCount离currentCount更近的时候,那么把middlePoint当做一个切分点,也就是上图红色区域(其实这个而红色区域也不太准确,但差不多这个意思,也就是要以count为坐标轴算)。所以整个逻辑就是一直循环所有的数据点,当targetCount落入红色区域之后,就把middlePoint作为切分点。
1、计算数据点间隔数:val stride: Double = numSamples.toDouble / (numSplits + 1) 它表明的是每个区间需要包含的数据点个数,注意这里是用数据点总数除以区间数,不是不同数据点除的。我们定义一个targetCount作为每个区间应该有的数据点的数量。即最开始是一个stride,然后是两个stride,一直到最后应该是numSamples。
2、比较当前数据点的数量和下一个数据点的数量与targetCount的距离,targetCount如果离当前数据点近,那么继续循环,知道离下一个数据点近的时候确定分割点位于这两个数据之间。
我们一起看一下:
/**
* Find splits for a continuous feature
* NOTE: Returned number of splits is set based on `featureSamples` and
* could be different from the specified `numSplits`.
* The `numSplits` attribute in the `DecisionTreeMetadata` class will be set accordingly.
*
* @param featureSamples feature values of each sample
* @param metadata decision tree metadata
* NOTE: `metadata.numbins` will be changed accordingly
* if there are not enough splits to be found
* @param featureIndex feature index to find splits
* @return array of split thresholds
*/
private[tree] def findSplitsForContinuousFeature(
featureSamples: Iterable[Double],
metadata: DecisionTreeMetadata,
featureIndex: Int): Array[Double] = {
require(metadata.isContinuous(featureIndex),
"findSplitsForContinuousFeature can only be used to find splits for a continuous feature.")
val splits: Array[Double] = if (featureSamples.isEmpty) { //如果数据点为空,返回空数组
Array.empty[Double]
} else {
val numSplits = metadata.numSplits(featureIndex) //期望分割数
// get count for each distinct value
// 这里是计算所有不同的数据点及其数量,最后按照数据点从小到大进行排序(注意,不是按照数据点的数量排序)
val (valueCountMap, numSamples) = featureSamples.foldLeft((Map.empty[Double, Int], 0)) {
case ((m, cnt), x) =>
(m + ((x, m.getOrElse(x, 0) + 1)), cnt + 1)
}
// sort distinct values
val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray
// 这里计算可能的分割数,如果分割数为0返回空数组
val possibleSplits = valueCounts.length - 1
if (possibleSplits == 0) {
// constant feature
Array.empty[Double]
} else if (possibleSplits <= numSplits) { //可能的分割数小于期望分割数,那么就按照可能分割数来,一个数据点落入一个区间
// if possible splits is not enough or just enough, just return all possible splits
(1 to possibleSplits)
.map(index => (valueCounts(index - 1)._1 + valueCounts(index)._1) / 2.0)
.toArray
} else {
// 如果可能分割数大于期望分割数,按照期望分割数来,这就表明有区间包含了至少两个数据点,切分逻辑按上面说的来。
// stride between splits
val stride: Double = numSamples.toDouble / (numSplits + 1)
logDebug("stride = " + stride)
// iterate `valueCount` to find splits,这个值是一个数组,返回的是切分点坐标
val splitsBuilder = mutable.ArrayBuilder.make[Double]
var index = 1
// currentCount: sum of counts of values that have been visited
// 第一个数据点的数量作为第一个previousCount
var currentCount = valueCounts(0)._2
// targetCount: target value for `currentCount`.
// If `currentCount` is closest value to `targetCount`,
// then current value is a split threshold.
// After finding a split threshold, `targetCount` is added by stride.
// 第一个目标数据量就是一个区间
var targetCount = stride
// 循环所有的数据点,找出符合上述要求的点作为切分点
while (index < valueCounts.length) {
val previousCount = currentCount
currentCount += valueCounts(index)._2
val previousGap = math.abs(previousCount - targetCount)
val currentGap = math.abs(currentCount - targetCount)
// If adding count of current value to currentCount
// makes the gap between currentCount and targetCount smaller,
// previous value is a split threshold.
// 如果targetCount离currentCount更近,那么就把当前两个点之间作为切分点
if (previousGap < currentGap) {
splitsBuilder += (valueCounts(index - 1)._1 + valueCounts(index)._1) / 2.0
targetCount += stride
}
index += 1
}
// 这里就是最终的切分点数组了
splitsBuilder.result()
}
}
splits
}
欢迎大家关注DataLearner官方微信,接受最新的AI技术推送
