from pyspark.sql import DataFrame
from pyspark.sql.functions import when, lit, col
from pyspark.errors import AnalysisException
def spark_merge_into(
target_df: DataFrame,
source_df: DataFrame,
merge_key: list,
update_rules: dict = None,
insert_columns: list = None,
delete_condition: str = None
) -> DataFrame:
"""
实现类似SQL MERGE INTO功能的PySpark函数
参数:
target_df: 目标DataFrame
source_df: 源DataFrame
merge_key: 用于匹配的键值列(列表)
update_rules: 更新规则字典(目标列: 源列表达式)
insert_columns: 插入操作使用的列列表
delete_condition: 删除条件表达式
返回合并后的DataFrame
"""
try:
# 参数校验
if not set(merge_key).issubset(target_df.columns) or not set(merge_key).issubset(source_df.columns):
raise ValueError("Merge key columns missing in source/target dataframe")
# 创建临时视图
target_alias = "target"
source_alias = "source"
target_df.createOrReplaceTempView(target_alias)
source_df.createOrReplaceTempView(source_alias)
# 生成匹配条件
join_cond = " AND ".join([f"{target_alias}.{k} = {source_alias}.{k}" for k in merge_key])
# 构建基础查询
query = f"""
SELECT
{target_alias}.*,
{source_alias}.*
FROM {target_alias}
FULL OUTER JOIN {source_alias}
ON {join_cond}
"""
merged_df = target_df.sql_ctx.sql(query)
# 处理更新逻辑
if update_rules:
update_exprs = [
when(col(f"{source_alias}.{merge_key[0]}").isNotNull(),
expr(f"coalesce({v}, {target_alias}.{k})")).otherwise(col(f"{target_alias}.{k}")).alias(k)
for k, v in update_rules.items()
]
other_cols = [col(f"{target_alias}.{c}") for c in target_df.columns if c not in update_rules]
merged_df = merged_df.select(*(update_exprs + other_cols))
# 处理插入逻辑
if insert_columns:
insert_cond = " AND ".join([f"{target_alias}.{k} IS NULL" for k in merge_key])
insert_query = f"""
SELECT {','.join(insert_columns)}
FROM {source_alias}
WHERE NOT EXISTS (
SELECT 1
FROM {target_alias}
WHERE {join_cond}
)
"""
insert_df = target_df.sql_ctx.sql(insert_query)
merged_df = merged_df.unionByName(insert_df)
# 处理删除逻辑
if delete_condition:
merged_df = merged_df.filter(f"NOT ({delete_condition})")
# 性能优化建议
return merged_df.cache().checkpoint(eager=True)
except AnalysisException as e:
print(f"SQL执行错误: {str(e)}")
raise
except Exception as e:
print(f"合并操作异常: {str(e)}")
raise
使用示例:
# 创建示例数据
target_data = [(1, "Alice", 30), (2, "Bob", 25)]
source_data = [(1, "Alice", 35), (3, "Charlie", 28)]
target_df = spark.createDataFrame(target_data, ["id", "name", "age"])
source_df = spark.createDataFrame(source_data, ["id", "name", "age"])
# 执行合并操作
merged_df = spark_merge_into(
target_df=target_df,
source_df=source_df,
merge_key=["id"],
update_rules={"age": "source.age"},
insert_columns=["id", "name", "age"]
)
merged_df.show()
实现特点:
- 全外连接策略:使用FULL OUTER JOIN确保获取所有数据变更可能性
- 条件更新机制:通过coalesce函数实现字段级更新控制
- 批量插入优化:使用NOT EXISTS子查询实现高效数据插入
- 检查点机制:通过cache()和checkpoint()优化迭代计算性能
- 异常处理:捕获Spark SQL执行异常和通用异常
- 类型安全校验:自动验证合并键的列存在性
性能优化建议:
- 对合并键列进行预排序:
source_df = source_df.sort(merge_key)
- 合理设置Spark分区数:
spark.conf.set("spark.sql.shuffle.partitions", "200")
- 对大数据集使用广播连接:
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "100mb")
- 启用AQE优化:
spark.conf.set("spark.sql.adaptive.enabled", "true")
注意事项:
- 需要根据数据特征调整合并策略(全连接/左连接)
- 大数据场景建议配合Delta Lake使用原子事务特性
- 建议对输入数据集进行预先去重处理
- 更新字段较多时建议使用Map类型参数简化操作