from pyspark.sql import DataFrame
from pyspark.sql.functions import lit
from functools import wraps
def handle_spark_errors(func):
@wraps(func)
def wrapper(df, group_cols, agg_expr, *args, **kwargs):
try:
# 前置校验
if not isinstance(df, DataFrame):
raise ValueError("第一个参数必须是Spark DataFrame")
if not group_cols or len(group_cols) == 0:
raise ValueError("必须指定至少一个分组列")
missing_cols = [col for col in group_cols if col not in df.columns]
if missing_cols:
raise ValueError(f"列不存在: {missing_cols}")
return func(df, group_cols, agg_expr, *args, **kwargs)
except Exception as e:
# 记录日志或上报监控
print(f"Error in {func.__name__}: {str(e)}")
raise
return wrapper
@handle_spark_errors
def spark_rollup(df: DataFrame, group_cols: list, agg_expr: dict) -> DataFrame:
"""
PySpark实现SQL Server的WITH ROLLUP功能
示例:spark_rollup(df, ["year", "month"], {"sales": "sum"})
"""
return df.rollup(*group_cols).agg(agg_expr)
@handle_spark_errors
def spark_cube(df: DataFrame, group_cols: list, agg_expr: dict) -> DataFrame:
"""
PySpark实现SQL Server的WITH CUBE功能
示例:spark_cube(df, ["category", "color"], {"price": "avg"})
"""
return df.cube(*group_cols).agg(agg_expr)
实现要点说明:
- 核心机制:
- 利用PySpark原生的
rollup()
和cube()
方法实现多维聚合 - 底层采用Spark的列式存储和Catalyst优化器保障性能
- 支持多列组合:
group_cols
参数接受字符串列表
- 异常处理:
- 装饰器
handle_spark_errors
统一处理常见错误:- 输入数据类型校验(确保DataFrame对象)
- 空分组列检查
- 列名存在性校验
- 错误信息包含具体缺失的列名
- 异常捕获后重新抛出保持堆栈跟踪
- 性能优化:
- 避免数据倾斜:依赖Spark内置的Shuffle优化策略
- 谓词下推:自动应用Spark的优化规则(如ConstantFolding)
- 内存管理:利用Tungsten引擎的堆外内存管理
- 支持并行执行:多个cube/rollup操作可并行化
- 扩展功能:
- 支持多种聚合表达式:
# 标准写法 {"sales": "sum", "price": "avg"} # 带别名 {"discount": expr("avg(discount)").alias("avg_discount")}
- 自动处理NULL聚合值(对应SQL Server的超级聚合行)
- 使用示例:
# 汽车销售数据示例
data = [("Beijing", "Model3", 100),
("Shanghai", "ModelY", 200),
("Beijing", "ModelY", 150)]
df = spark.createDataFrame(data, ["city", "model", "sales"])
# ROLLUP查询
rollup_result = spark_rollup(df, ["city", "model"], {"sales": "sum"})
rollup_result.show()
# CUBE查询
cube_result = spark_cube(df, ["city", "model"], {"sales": "sum"})
cube_result.show()
- 执行计划优化:
- 自动合并相同分组:相同分组条件的操作会被Spark优化器合并
- 延迟计算:直到调用action操作时才触发实际计算
- 自适应查询:Spark 3.0+版本支持AQE动态优化
与SQL Server的差异处理:
- 空值处理:Spark使用
null
表示超级聚合行,SQL Server有GROUPING()
函数 - 结果排序:Spark默认不保证结果顺序,需显式调用
orderBy()
- 性能差异:Spark分布式计算更适合大数据量场景
注意事项:
- 建议在聚合前执行
.persist()
缓存输入数据(大数据量时) - 可通过
spark.sql.retainGroupColumns
控制是否保留分组列 - 使用
.cube()
时注意组合爆炸问题(2^n种组合) - 推荐配合
analyze
命令检查数据分布:df.groupBy("city").agg(count("*").alias("cnt")).show()