JAX study notes[15]

发布于:2025-07-07 ⋅ 阅读:(17) ⋅ 点赞:(0)

the symmetric difference of sets

the symmetric difference can be express as follows:
A Δ B = ( A \ B ) ∪ ( B \ A ) A \Delta B=(A\backslash B)\cup (B \backslash A) AΔB=(A\B)(B\A)
the symmetric difference of sets A and B mean the element belong to A or belong to B but not in their intersection.

import jax.numpy as jnp
from jax import jit

def symmetric_difference(a, b):
    """
    计算两个集合的对称差集
    
    参数:
        a, b: 两个一维JAX数组,表示输入集合
        
    返回:
        一维JAX数组,包含只在a或只在b中的元素
    """
    # 找出在a中但不在b中的元素
    a_only = jnp.setdiff1d(a, b)
    # 找出在b中但不在a中的元素
    b_only = jnp.setdiff1d(b, a)
    # 合并结果
    return jnp.concatenate([a_only, b_only])




# 创建两个集合
set_a = jnp.array([11, 22, 26, 41, 5])
set_b = jnp.array([41, 52, 26, 7, 8])

# 计算对称差集
result = symmetric_difference(set_a, set_b)
print(result)  
[ 5 11 22  7  8 52]

limit superior and limit inferior

在这里插入图片描述

  • In JAX (and NumPy), when you apply jnp.cumsum to a boolean array, the boolean values are automatically upcast to integers before the cumulative sum is computed.
    How It Works:

    True is treated as 1

    False is treated as 0

import jax.numpy as jnp

bool_arr = jnp.array([True, False, True, True, False])

result = jnp.cumsum(bool_arr)
print(result)  # Output: [1 1 2 3 3]

to hand the Multi-dimensional Arrays

bool_matrix = jnp.array([[True, False], [False, True]])

# Cumulative sum along axis=0 (rows)
print(jnp.cumsum(bool_matrix, axis=0))
# Output:
# [[1 0]
#  [1 1]]

# Cumulative sum along axis=1 (columns)
print(jnp.cumsum(bool_matrix, axis=1))
# Output:
# [[1 1]
#  [0 1]]

JAX accumulates along the downward direction by column when axis equals1 and along the rightward direction by row when it equals 0.

how to use jnp.all can be explained as follows:

matrix = jnp.array([[True, False], [True, True]])
print(jnp.all(matrix, axis=0))  # 沿列聚合 → [True, False]
print(jnp.all(matrix, axis=1))  # 沿行聚合 → [False, True]

to calculate the cumulative product of the array elements can use jnp.cumprod

import jax.numpy as jnp

arr = jnp.array([1, 2, 3, 4])
result = jnp.cumprod(arr)
print(result)  # 输出: [1, 2, 6, 24] (计算过程:1, 1×2=2, 2×3=6, 6×4=24)
matrix = jnp.array([[1, 2], [3, 4]])

# 沿 axis=0(行方向)
print(jnp.cumprod(matrix, axis=0))
# 输出: [[1, 2], [1×3=3, 2×4=8]]

# 沿 axis=1(列方向)
print(jnp.cumprod(matrix, axis=1))
# 输出: [[1, 1×2=2], [3, 3×4=12]]
import jax.numpy as jnp
from jax import vmap, jit
import jax


def sets_to_mask(sets, all_elements):
    """将集合列转换为布尔掩码矩阵"""
    def is_in_set(s):
        return jnp.isin(all_elements, s)
    return vmap(is_in_set)(jnp.stack(sets))



def safety_concat(arr):
    return jnp.unique(jnp.concatenate(arr))  


def limsup_jax(sets):
    """计算上限集:元素属于无限多个集合"""
    all_elements = safety_concat(sets)
    mask = sets_to_mask(sets, all_elements)

    def is_in_limsup(j):
        return jnp.all(jnp.cumsum(mask[:, j][::-1]) > 0)
    limsup_mask = vmap(is_in_limsup)(jnp.arange(mask.shape[1]))
    return all_elements[limsup_mask]
 

def liminf_jax(sets):
    """计算下限集:元素从某时刻开始永远属于集合"""
    all_elements = safety_concat(sets)
    mask = sets_to_mask(sets, all_elements)
    liminf_masks=jnp.zeros((0, mask.shape[1]))
    def is_in_liminf(i,j):
        return jnp.all(jnp.cumprod(mask[:, j][i::1]) > 0)
    for ii in jnp.arange(mask.shape[0]):
        liminf_mask = vmap(is_in_liminf,in_axes=(None, 0))(ii,jnp.arange(mask.shape[1]))
        liminf_masks=jnp.vstack([liminf_masks, liminf_mask])
    return all_elements[jnp.any(liminf_masks[:-1,:], axis=0)]



# 测试用例
sets = [
    jnp.array([1, 2 ]),  # A1
    jnp.array([2, 3]),  # A2
    jnp.array([1, 3]),  # A3
    jnp.array([ 2, 3]),  # A4
]

# 预期结果分析
print("所有元素:", jnp.unique(jnp.concatenate(sets)))  # [1 2 3]
print("掩码矩阵:\n", sets_to_mask(sets, jnp.unique(jnp.concatenate(sets))))

print("\nLimsup:", limsup_jax(sets))  # 正确输出应为 [2 3]

print("\nLiminf:", liminf_jax(sets))   # 正确输出应为 [ 3]
所有元素: [1 2 3]
掩码矩阵:
 [[ True  True False]
 [False  True  True]
 [ True False  True]
 [False  True  True]]

Limsup: [2 3]

Liminf: [3]

references

  1. 《实变函数论(周民强)》
  2. deepseek

网站公告

今日签到

点亮在社区的每一天
去签到