Scala 中累加器的创建与使用格式详解

发布于:2025-05-13 ⋅ 阅读:(5) ⋅ 点赞:(0)

1. 内置累加器的创建与使用格式

1.1 创建内置累加器
// 通过 SparkContext 创建
val acc = sc.longAccumulator("累加器名称")   // Long 类型(默认初始值 0)
val accDouble = sc.doubleAccumulator("累加器名称") // Double 类型(初始值 0.0)
1.2 在任务中更新累加器
// 只能在行动操作(如 foreach、collect)中更新累加器
rdd.foreach { element =>
  if (满足条件) {
    acc.add(1)      // 累加整数
    accDouble.add(5.5) // 累加浮点数
  }
}
1.3 在 Driver 端读取结果
println(s"累加器结果: ${acc.value}")

2. 自定义累加器的创建与使用格式
2.1 定义自定义累加器类
import org.apache.spark.util.AccumulatorV2

// 定义输入类型和输出类型
class CustomAccumulator extends AccumulatorV2[输入类型, 输出类型] {

  private var _value: 输出类型 = 初始值

  // 判断累加器是否为空
  override def isZero: Boolean = _value == 初始值

  // 创建副本
  override def copy(): AccumulatorV2[输入类型, 输出类型] = {
    val newAcc = new CustomAccumulator
    newAcc._value = this._value
    newAcc
  }

  // 重置累加器
  override def reset(): Unit = {
    _value = 初始值
  }

  // 添加元素(Executor 调用)
  override def add(v: 输入类型): Unit = {
    // 自定义累加逻辑(如将 v 合并到 _value)
    _value += v
  }

  // 合并其他累加器的值(Driver 调用)
  override def merge(other: AccumulatorV2[输入类型, 输出类型]): Unit = {
    _value += other.value
  }

  // 获取最终结果
  override def value: 输出类型 = _value
}
2.2 注册并使用自定义累加器
// 创建实例并注册
val customAcc = new CustomAccumulator()
sc.register(customAcc, "自定义累加器名称(可选)")

// 在行动操作中更新
rdd.foreach { element =>
  customAcc.add(元素)
}

// 读取结果
println(s"自定义累加器结果: ${customAcc.value}")
3. 完整示例:统计单词长度分布
3.1 代码实现
import org.apache.spark.{SparkConf, SparkContext}

object WordLengthAccumulatorDemo {
  def main(args: Array[String]): Unit = {
    val conf = new SparkConf()
      .setAppName("WordLengthAccumulator")
      .setMaster("local[*]")
    val sc = new SparkContext(conf)

    // 创建内置累加器
    val shortWordAcc = sc.longAccumulator("ShortWords")  // 统计短单词(长度 <=3)
    val longWordAcc = sc.longAccumulator("LongWords")    // 统计长单词(长度 >3)

    // 读取数据并处理
    val textRDD = sc.textFile("hdfs://path/to/textfile.txt")
    textRDD.flatMap(_.split(" "))
      .foreach { word =>
        if (word.nonEmpty) {
          if (word.length <= 3) shortWordAcc.add(1)
          else longWordAcc.add(1)
        }
      }

    // 输出结果
    println(s"短单词数量: ${shortWordAcc.value}")
    println(s"长单词数量: ${longWordAcc.value}")

    sc.stop()
  }
}
3.2 输出示例
短单词数量: 120
长单词数量: 350

4. 关键注意事项

注意事项 正确做法
只能在行动操作中更新累加器 确保在 foreachcollect 等行动操作中调用 add(),而非 mapfilter 等转换操作。
避免多次计算 RDD 对 RDD 调用 persist() 或 cache(),防止重复计算导致累加器重复累加。
自定义累加器需注册 通过 sc.register() 注册自定义累加器,否则可能引发序列化错误。
合并逻辑必须幂等 确保 merge() 方法正确处理重复数据(如集合合并用 addAll)。

5. 自定义累加器示例:统计唯一单词

5.1 定义累加器
import org.apache.spark.util.AccumulatorV2
import scala.collection.mutable.HashSet

class UniqueWordsAccumulator extends AccumulatorV2[String, HashSet[String]] {

  private val _words = HashSet[String]()

  override def isZero: Boolean = _words.isEmpty

  override def copy(): AccumulatorV2[String, HashSet[String]] = {
    val newAcc = new UniqueWordsAccumulator
    newAcc._words ++= this._words
    newAcc
  }

  override def reset(): Unit = _words.clear()

  override def add(word: String): Unit = _words.add(word)

  override def merge(other: AccumulatorV2[String, HashSet[String]]): Unit = {
    _words ++= other.value
  }

  override def value: HashSet[String]] = _words
}
5.2 使用累加器
val uniqueWordsAcc = new UniqueWordsAccumulator()
sc.register(uniqueWordsAcc, "UniqueWords")

val wordsRDD = sc.parallelize(List("apple", "banana", "apple", "orange"))
wordsRDD.foreach(word => uniqueWordsAcc.add(word))

println(s"去重后的单词: ${uniqueWordsAcc.value.mkString(", ")}")
// 输出: apple, banana, orange

总结

  • 创建格式

    • 内置累加器:sc.longAccumulator("name")

    • 自定义累加器:继承 AccumulatorV2 并实现方法,然后注册 sc.register(acc)

  • 使用格式

    • 在行动操作中调用 add()

    • 通过 value 属性在 Driver 端读取结果

  • 核心原则
    只在行动操作中更新累加器,避免重复计算和序列化问题。


网站公告

今日签到

点亮在社区的每一天
去签到