PySpark实现LEFT OUTER APPLY、CROSS JOIN和CROSS APPLY的功能

发布于:2025-02-20 ⋅ 阅读:(32) ⋅ 点赞:(0)

一、LEFT OUTER APPLY模拟

通过UDF生成关联数据,利用explode_outer保留左表所有记录

from pyspark.sql import functions as F
from pyspark.sql.types import ArrayType, StructType
from pyspark.sql import DataFrame

def left_outer_apply(
    left_df: DataFrame,
    generator: callable,
    input_cols: list,
    output_schema: StructType,
    partition_num: int = 200
) -> DataFrame:
    """
    模拟LEFT OUTER APPLY操作
    :param left_df: 左表
    :param generator: 生成右表数据的函数,输入为元组对应input_cols值,输出为可迭代对象
    :param input_cols: 需要传入生成器的左表列名
    :param output_schema: 生成器返回数据的结构
    :param partition_num: shuffle分区数控制
    :return: 合并后的DataFrame
    """
    try:
        array_schema = ArrayType(output_schema)
        
        @F.udf(array_schema)
        def apply_udf(*args):
            try:
                return list(generator(*args))
            except Exception as e:
                # 异常时返回空数组保持左表记录
                return []
        
        return (left_df
                .withColumn('__temp_array', apply_udf(*[F.col(c) for c in input_cols]))
                .withColumn('__temp_exploded', F.explode_outer(F.col('__temp_array')))
                .select('*', *[F.col('__temp_exploded')[f].alias(f) 
                             for f in output_schema.fieldNames()])
                .drop('__temp_array', '__temp_exploded')
                .repartition(partition_num))
    
    except Exception as e:
        raise ValueError(f"LEFT OUTER APPLY执行失败: {str(e)}")

二、CROSS APPLY模拟

继承LEFT OUTER APPLY逻辑,通过过滤空结果实现INNER JOIN效果

def cross_apply(
    left_df: DataFrame,
    generator: callable,
    input_cols: list,
    output_schema: StructType,
    partition_num: int = 200
) -> DataFrame:
    """
    模拟CROSS APPLY操作
    :参数说明同left_outer_apply
    """
    try:
        array_schema = ArrayType(output_schema)
        
        @F.udf(array_schema)
        def apply_udf(*args):
            try:
                return list(generator(*args))
            except Exception as e:
                return []  # 空数组会被后续explode过滤
            
        return (left_df
                .withColumn('__temp_array', apply_udf(*[F.col(c) for c in input_cols]))
                .withColumn('__temp_exploded', F.explode(F.col('__temp_array')))
                .filter(F.size(F.col('__temp_array')) > 0)  # 过滤空结果
                .select('*', *[F.col('__temp_exploded')[f].alias(f) 
                             for f in output_schema.fieldNames()])
                .drop('__temp_array', '__temp_exploded')
                .repartition(partition_num))
    
    except Exception as e:
        raise ValueError(f"CROSS APPLY执行失败: {str(e)}")

三、CROSS JOIN增强版

封装原生crossJoin,增加列名冲突处理和性能优化

def cross_join(
    left_df: DataFrame,
    right_df: DataFrame,
    suffix: str = '_right',
    partition_num: int = 200
) -> DataFrame:
    """
    增强版CROSS JOIN
    :param left_df: 左表
    :param right_df: 右表
    :param suffix: 列名冲突时右表列后缀
    :param partition_num: 结果分区数控制
    :return: 笛卡尔积结果
    """
    try:
        # 列名冲突处理
        common_cols = set(left_df.columns) & set(right_df.columns)
        if common_cols:
            right_df = right_df.select(
                *[F.col(c).alias(f"{c}{suffix}") 
                  if c in common_cols else F.col(c) 
                  for c in right_df.columns]
            )
        
        # 添加并行度提示
        return (left_df
                .crossJoin(right_df.hint("broadcast"))
                if right_df.count() < 10000  # 小表自动广播
                else left_df.crossJoin(right_df))
                .repartition(partition_num)
    
    except Exception as e:
        raise ValueError(f"CROSS JOIN执行失败: {str(e)}")

四、使用示例

# 示例数据
orders = spark.createDataFrame([(1, "A"), (2, "B")], ["order_id", "order_name"])

# 定义生成器函数
def item_generator(order_id, order_name):
    if order_id == 1:
        return [{"item_id": i, "item_name": f"{order_name}{i}"} for i in range(2)]
    return []

# 输出结构定义
item_schema = StructType([
    StructField("item_id", IntegerType()),
    StructField("item_name", StringType())
])

# 执行LEFT OUTER APPLY
left_outer_result = left_outer_apply(
    orders, 
    item_generator, 
    ["order_id", "order_name"], 
    item_schema
)

# 执行CROSS APPLY 
cross_apply_result = cross_apply(
    orders,
    item_generator,
    ["order_id", "order_name"],
    item_schema
)

# 执行CROSS JOIN
products = spark.createDataFrame([(1, "X"), (2, "Y")], ["prod_id", "prod_name"])
cross_join_result = cross_join(orders, products)

五、性能优化措施

  1. UDF优化:使用pandas_udf替代普通UDF提升处理速度
  2. 分区控制:根据数据规模自动调整shuffle分区数
  3. 广播提示:CROSS JOIN时自动判断是否广播小表
  4. 异常隔离:单行数据处理异常不会影响整体任务
  5. 内存管理:通过repartition防止数据倾斜