dask.dataframe.shuffle.set_index中获取 divisions 的步骤分析

发布于:2025-09-07 ⋅ 阅读:(30) ⋅ 点赞:(0)

dask.dataframe.shuffle.set_index 中获取 divisions 的步骤分析

主要流程概述

set_index 函数中,当 divisions=None 时,系统需要通过分析数据来动态计算分区边界。这个过程分为以下几个关键步骤:

1. 初始检查和准备

if divisions is None:
    sizes = df.map_partitions(sizeof) if repartition else []
    divisions = index2._repartition_quantiles(npartitions, upsample=upsample)
    mins = index2.map_partitions(M.min)
    maxes = index2.map_partitions(M.max)
    divisions, sizes, mins, maxes = base.compute(divisions, sizes, mins, maxes)

步骤说明:

  • 计算每个分区的大小(如果启用重新分区)
  • 调用 _repartition_quantiles 计算近似分位数
  • 并行计算每个分区的最小值和最大值
  • 使用 base.compute 触发实际计算

2. 分位数计算过程 (_repartition_quantiles)

_repartition_quantiles 方法调用 partition_quantiles 函数,该函数执行以下步骤:

2.1 生成采样策略
def sample_percentiles(num_old, num_new, chunk_length, upsample=1.0, random_state=None):
    # 计算随机百分位比例
    random_percentage = 1 / (1 + (4 * num_new / num_old) ** 0.5)
    # 生成等间距和随机百分位
2.2 创建计算图
# 1. 数据类型信息
dtype_dsk = {(name0, 0): (dtype_info, df_keys[0])}

# 2. 每个分区的百分位摘要
val_dsk = {
    (name1, i): (percentiles_summary, key, df.npartitions, npartitions, upsample, state)
    for i, (state, key) in enumerate(zip(state_data, df_keys))
}

# 3. 合并和压缩摘要
merge_dsk = create_merge_tree(merge_and_compress_summaries, sorted(val_dsk), name2)

# 4. 最终处理
last_dsk = {
    (name3, 0): (pd.Series, (process_val_weights, merged_key, npartitions, (name0, 0)), qs, None, df.name)
}

3. 数据后处理

divisions = methods.tolist(divisions)
if type(sizes) is not list:
    sizes = methods.tolist(sizes)
mins = methods.tolist(mins)
maxes = methods.tolist(maxes)

4. 空数据检测和重新分区

empty_dataframe_detected = pd.isnull(divisions).all()
if repartition or empty_dataframe_detected:
    total = sum(sizes)
    npartitions = max(math.ceil(total / partition_size), 1)
    npartitions = min(npartitions, df.npartitions)
    # 插值生成新的分界点
    divisions = np.interp(
        x=np.linspace(0, n - 1, npartitions + 1),
        xp=np.linspace(0, n - 1, n),
        fp=divisions,
    ).tolist()

5. 数据类型特殊处理

if pd.api.types.is_categorical_dtype(index2.dtype):
    dtype = index2.dtype
    mins = pd.Categorical(mins, dtype=dtype).codes.tolist()
    maxes = pd.Categorical(maxes, dtype=dtype).codes.tolist()

6. 排序优化检查

if (mins == sorted(mins) and maxes == sorted(maxes) and 
    all(mx < mn for mx, mn in zip(maxes[:-1], mins[1:]))):
    divisions = mins + [maxes[-1]]
    result = set_sorted_index(df, index, drop=drop, divisions=divisions)
    return result.map_partitions(M.sort_index)

这个检查的作用:

  • 如果数据已经按索引排序,可以直接使用最小值和最大值作为分界点
  • 避免昂贵的shuffle操作

分位数计算详细过程

核心算法:percentiles_summary 函数
def percentiles_summary(df, num_old, num_new, upsample, state):
    """Summarize data using percentiles and derived weights."""
    # 1. 生成采样百分位
    qs = sample_percentiles(num_old, num_new, len(df), upsample, state)
    
    # 2. 计算百分位值
    vals = df.quantile(qs)
    
    # 3. 转换为权重
    return percentiles_to_weights(qs, vals, len(df))
权重计算:percentiles_to_weights 函数
def percentiles_to_weights(qs, vals, length):
    """Weigh percentile values by length and the difference between percentiles"""
    if length == 0:
        return ()
    diff = np.ediff1d(qs, 0.0, 0.0)
    weights = 0.5 * length * (diff[1:] + diff[:-1])
    return vals.tolist(), weights.tolist()

权重计算原理:

  • 每个百分位值的权重 = 0.5 × 分区长度 × (前一个百分位差 + 后一个百分位差)
  • 这样确保权重反映该值在数据分布中的重要性
合并和压缩:merge_and_compress_summaries 函数
def merge_and_compress_summaries(vals_and_weights):
    """Merge and sort percentile summaries that are already sorted."""
    vals, weights = zip(*vals_and_weights)
    vals = list(merge_sorted(*vals))
    weights = list(merge_sorted(*weights))
    
    # 压缩重复值
    compressed_vals = []
    compressed_weights = []
    # ... 压缩逻辑
    return compressed_vals, compressed_weights
最终处理:process_val_weights 函数

这个函数(我们之前分析过的)将合并后的值-权重对转换为最终的分区边界。

完整流程图和示例

让我创建一个详细的流程图来展示整个过程:

开始 set_index(df, index, divisions=None)
  ↓
检查 divisions 是否为 None
  ↓
是 → 开始动态计算分界点
  ↓
1. 并行计算:
   - 每个分区的大小 (sizes)
   - 每个分区的分位数 (divisions)
   - 每个分区的最小值 (mins)
   - 每个分区的最大值 (maxes)
  ↓
2. 触发计算:base.compute(divisions, sizes, mins, maxes)
  ↓
3. 分位数计算过程:
   ┌─────────────────────────────────────────┐
   │ 对每个分区执行 percentiles_summary:     │
   │ 1. 生成采样百分位 (等间距 + 随机)       │
   │ 2. 计算百分位值                        │
   │ 3. 转换为权重                          │
   └─────────────────────────────────────────┘
  ↓
4. 合并所有分区的摘要:
   ┌─────────────────────────────────────────┐
   │ merge_and_compress_summaries:           │
   │ 1. 合并排序的值和权重                   │
   │ 2. 压缩重复值                          │
   └─────────────────────────────────────────┘
  ↓
5. 最终处理:
   ┌─────────────────────────────────────────┐
   │ process_val_weights:                    │
   │ 1. 处理巨型值                          │
   │ 2. 计算最终分界点                      │
   │ 3. 处理数据类型                        │
   └─────────────────────────────────────────┘
  ↓
6. 后处理:
   - 转换为列表格式
   - 检查空数据
   - 处理重新分区
   - 特殊数据类型处理
  ↓
7. 排序优化检查:
   - 如果数据已排序,使用 min/max 作为分界点
   - 否则继续到 shuffle 阶段
  ↓
调用 set_partition 进行实际的数据重排
  ↓
结束

关键优化策略

  1. 采样策略:结合等间距和随机百分位,平衡计算效率和准确性
  2. 排序检测:如果数据已排序,避免昂贵的shuffle操作
  3. 数据类型感知:特别处理分类、时间等特殊数据类型
  4. 内存优化:通过压缩和合并减少内存使用
  5. 分布式计算:利用Dask的并行计算能力

性能考虑

  • 时间复杂度:O(n log n),主要由排序和分位数计算决定
  • 空间复杂度:O(n),存储采样数据和权重
  • 网络开销:需要收集所有分区的统计信息
  • 计算开销:需要两次数据遍历(统计 + shuffle)

总结

dask.dataframe.shuffle.set_index 中获取 divisions 的过程是一个复杂的分布式算法,主要包含以下步骤:

核心步骤

  1. 并行统计:计算每个分区的分位数、大小、最小值、最大值
  2. 分位数计算:使用采样策略生成代表性百分位
  3. 权重分配:根据数据分布为每个值分配权重
  4. 合并压缩:合并所有分区的统计信息并压缩重复值
  5. 分界点计算:使用 process_val_weights 计算最终分界点
  6. 优化检查:检测数据是否已排序,避免不必要的shuffle

关键特点

  • 分布式设计:充分利用Dask的并行计算能力
  • 智能采样:结合等间距和随机采样策略
  • 类型感知:特别处理不同数据类型
  • 性能优化:检测已排序数据,避免重复计算
  • 内存高效:通过压缩和合并减少内存使用

这个算法是Dask DataFrame实现高效分布式排序和分区的核心,通过巧妙的采样和合并策略,在保证准确性的同时实现了良好的性能。

自己实现

import numpy as np
import pandas as pd

# 1️⃣ 采样百分位
def sample_percentiles(num_old, num_new, chunk_length, upsample=1.0, random_state=None):
    """简单版本:等间距百分位"""
    return np.linspace(0, 1, num_new + 1)


# 2️⃣ 计算百分位摘要(值+权重)
def percentiles_summary(series, num_old, num_new):
    qs = sample_percentiles(num_old, num_new, len(series))
    vals = series.quantile(qs).to_numpy()
    diff = np.ediff1d(qs, 0.0, 0.0)
    weights = 0.5 * len(series) * (diff[1:] + diff[:-1])
    return vals.tolist(), weights.tolist()


# 3️⃣ 合并多个分区的摘要
def merge_and_compress_summaries(summaries):
    all_vals = []
    all_weights = []
    for vals, weights in summaries:
        all_vals.extend(vals)
        all_weights.extend(weights)
    # 按值排序
    order = np.argsort(all_vals)
    vals = np.array(all_vals)[order]
    weights = np.array(all_weights)[order]

    # 压缩重复值
    compressed_vals = []
    compressed_weights = []
    last_val = None
    for v, w in zip(vals, weights):
        if last_val is not None and v == last_val:
            compressed_weights[-1] += w
        else:
            compressed_vals.append(v)
            compressed_weights.append(w)
            last_val = v
    return np.array(compressed_vals), np.array(compressed_weights)


# 4️⃣ 最终处理:计算分界点
def process_val_weights(vals, weights, npartitions):
    if len(vals) == 0:
        return np.array([])

    if len(vals) == npartitions + 1:
        return vals

    elif len(vals) < npartitions + 1:
        q_weights = np.cumsum(weights)
        q_target = np.linspace(q_weights[0], q_weights[-1], npartitions + 1)
        return np.interp(q_target, q_weights, vals)

    else:
        target_weight = weights.sum() / npartitions
        jumbo_mask = weights >= target_weight
        jumbo_vals = vals[jumbo_mask]

        trimmed_vals = vals[~jumbo_mask]
        trimmed_weights = weights[~jumbo_mask]
        trimmed_npartitions = npartitions - len(jumbo_vals)

        q_weights = np.cumsum(trimmed_weights)
        q_target = np.linspace(0, q_weights[-1], trimmed_npartitions + 1)

        left = np.searchsorted(q_weights, q_target, side="left")
        right = np.searchsorted(q_weights, q_target, side="right") - 1
        lower = np.minimum(left, right)
        trimmed = trimmed_vals[lower]

        rv = np.concatenate([trimmed, jumbo_vals])
        rv.sort()
        return rv


# 5️⃣ 模拟 set_index 中 divisions 的获取
def simulate_set_index(df, column, npartitions):
    num_old = len(df)
    # 假设原始有分区(这里手动切分成2块模拟)
    partitions = np.array_split(df[column], 2)

    summaries = [percentiles_summary(p, num_old, npartitions) for p in partitions]

    vals, weights = merge_and_compress_summaries(summaries)

    divisions = process_val_weights(vals, weights, npartitions)
    return divisions


# ========== DEMO 使用 ==========
df = pd.DataFrame({"x": np.random.randint(0, 100, size=50)})

divs = simulate_set_index(df, "x", npartitions=4)

print("原始数据示例:\n", df.head())
print("\n计算得到的 divisions:", divs)


网站公告

今日签到

点亮在社区的每一天
去签到