Rust从入门到精通之进阶篇:14.并发编程

发布于:2025-03-26 ⋅ 阅读:(32) ⋅ 点赞:(0)

并发编程

并发编程允许程序同时执行多个独立的任务,充分利用现代多核处理器的性能。Rust 提供了强大的并发原语,同时通过类型系统和所有权规则在编译时防止数据竞争和其他常见的并发错误。在本章中,我们将探索 Rust 的并发编程模型。

线程基础

创建线程

Rust 标准库提供了 std::thread 模块,用于创建和管理线程:

use std::thread;
use std::time::Duration;

fn main() {
    // 创建一个新线程
    let handle = thread::spawn(|| {
        for i in 1..10 {
            println!("在新线程中: {}", i);
            thread::sleep(Duration::from_millis(1));
        }
    });
    
    // 主线程中的代码
    for i in 1..5 {
        println!("在主线程中: {}", i);
        thread::sleep(Duration::from_millis(1));
    }
    
    // 等待新线程完成
    handle.join().unwrap();
}

thread::spawn 函数接受一个闭包,该闭包包含要在新线程中执行的代码,并返回一个 JoinHandle。调用 join 方法会阻塞当前线程,直到新线程完成。

线程与所有权

当我们将闭包传递给 thread::spawn 时,Rust 需要知道闭包将在哪个线程中运行以及它将使用哪些数据。闭包默认会捕获其环境中的变量,但在线程间传递数据时,我们需要考虑所有权问题。

use std::thread;

fn main() {
    let v = vec![1, 2, 3];
    
    // 错误:Rust 无法确定 v 的生命周期
    // let handle = thread::spawn(|| {
    //     println!("这是向量: {:?}", v);
    // });
    
    // 使用 move 关键字转移所有权
    let handle = thread::spawn(move || {
        println!("这是向量: {:?}", v);
    });
    
    // 错误:v 的所有权已经转移到新线程
    // println!("向量: {:?}", v);
    
    handle.join().unwrap();
}

使用 move 关键字可以强制闭包获取其使用的值的所有权,而不是借用它们。这对于确保数据在线程运行期间有效非常重要。

消息传递

一种处理并发的流行方法是消息传递,其中线程或执行者通过发送消息进行通信。Rust 的标准库提供了通道(channel)实现,这是一种实现消息传递并发的方式。

创建通道

use std::sync::mpsc;
use std::thread;

fn main() {
    // 创建一个通道
    let (tx, rx) = mpsc::channel();
    
    // 在新线程中发送消息
    thread::spawn(move || {
        let val = String::from("你好");
        tx.send(val).unwrap();
        // 错误:val 的所有权已转移
        // println!("val 是 {}", val);
    });
    
    // 在主线程中接收消息
    let received = rx.recv().unwrap();
    println!("收到: {}", received);
}

mpsc 代表"多生产者,单消费者"(multiple producers, single consumer)。通道有两部分:发送端(tx)和接收端(rx)。

发送多个值

use std::sync::mpsc;
use std::thread;
use std::time::Duration;

fn main() {
    let (tx, rx) = mpsc::channel();
    
    thread::spawn(move || {
        let vals = vec![
            String::from("你好"),
            String::from("来自"),
            String::from("线程"),
        ];
        
        for val in vals {
            tx.send(val).unwrap();
            thread::sleep(Duration::from_secs(1));
        }
    });
    
    for received in rx {
        println!("收到: {}", received);
    }
}

多个生产者

通过克隆发送者,我们可以有多个生产者:

use std::sync::mpsc;
use std::thread;
use std::time::Duration;

fn main() {
    let (tx, rx) = mpsc::channel();
    
    let tx1 = tx.clone();
    thread::spawn(move || {
        let vals = vec![
            String::from("你好"),
            String::from("来自"),
            String::from("线程1"),
        ];
        
        for val in vals {
            tx1.send(val).unwrap();
            thread::sleep(Duration::from_secs(1));
        }
    });
    
    thread::spawn(move || {
        let vals = vec![
            String::from("更多"),
            String::from("来自"),
            String::from("线程2"),
        ];
        
        for val in vals {
            tx.send(val).unwrap();
            thread::sleep(Duration::from_secs(1));
        }
    });
    
    for received in rx {
        println!("收到: {}", received);
    }
}

共享状态

另一种处理并发的方法是允许多个线程访问同一块数据。这种方法需要特别小心,以避免数据竞争。

互斥锁(Mutex)

互斥锁(Mutex,mutual exclusion)确保在任何时刻只有一个线程可以访问数据:

use std::sync::Mutex;

fn main() {
    let m = Mutex::new(5);
    
    {
        let mut num = m.lock().unwrap();
        *num = 6;
    } // 锁在这里被释放
    
    println!("m = {:?}", m);
}

在线程间共享 Mutex

use std::sync::{Arc, Mutex};
use std::thread;

fn main() {
    let counter = Arc::new(Mutex::new(0));
    let mut handles = vec![];
    
    for _ in 0..10 {
        let counter = Arc::clone(&counter);
        let handle = thread::spawn(move || {
            let mut num = counter.lock().unwrap();
            *num += 1;
        });
        handles.push(handle);
    }
    
    for handle in handles {
        handle.join().unwrap();
    }
    
    println!("结果: {}", *counter.lock().unwrap());
}

Arc<T> 是一个原子引用计数类型,允许在线程间安全地共享所有权。它类似于 Rc<T>,但可以在并发环境中使用。

死锁

使用互斥锁时需要注意死锁问题,当两个线程各自持有一个锁并尝试获取对方的锁时,就会发生死锁:

use std::sync::{Mutex, MutexGuard};
use std::thread;
use std::time::Duration;

fn main() {
    let mutex_a = Mutex::new(5);
    let mutex_b = Mutex::new(5);
    
    let thread_a = thread::spawn(move || {
        // 线程 A 先锁定 mutex_a
        let mut a: MutexGuard<i32> = mutex_a.lock().unwrap();
        println!("线程 A 获取了 mutex_a");
        
        // 睡眠一段时间,让线程 B 有机会锁定 mutex_b
        thread::sleep(Duration::from_millis(100));
        
        // 线程 A 尝试锁定 mutex_b
        println!("线程 A 尝试获取 mutex_b");
        let mut b: MutexGuard<i32> = mutex_b.lock().unwrap();
        
        *a += *b;
    });
    
    let thread_b = thread::spawn(move || {
        // 线程 B 先锁定 mutex_b
        let mut b: MutexGuard<i32> = mutex_b.lock().unwrap();
        println!("线程 B 获取了 mutex_b");
        
        // 睡眠一段时间,让线程 A 有机会锁定 mutex_a
        thread::sleep(Duration::from_millis(100));
        
        // 线程 B 尝试锁定 mutex_a
        println!("线程 B 尝试获取 mutex_a");
        let mut a: MutexGuard<i32> = mutex_a.lock().unwrap();
        
        *b += *a;
    });
    
    // 注意:这个例子会导致死锁!
}

读写锁(RwLock)

读写锁允许多个读取器或一个写入器访问数据:

use std::sync::RwLock;

fn main() {
    let lock = RwLock::new(5);
    
    // 多个读取器可以同时访问数据
    {
        let r1 = lock.read().unwrap();
        let r2 = lock.read().unwrap();
        println!("读取器: {} {}", r1, r2);
    } // 读锁在这里被释放
    
    // 只能有一个写入器
    {
        let mut w = lock.write().unwrap();
        *w += 1;
        println!("写入后: {}", *w);
    } // 写锁在这里被释放
}

原子类型

对于简单的计数器和标志,可以使用原子类型,它们提供了无锁的线程安全操作:

use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::thread;

fn main() {
    let counter = Arc::new(AtomicUsize::new(0));
    let mut handles = vec![];
    
    for _ in 0..10 {
        let counter = Arc::clone(&counter);
        let handle = thread::spawn(move || {
            for _ in 0..1000 {
                counter.fetch_add(1, Ordering::SeqCst);
            }
        });
        handles.push(handle);
    }
    
    for handle in handles {
        handle.join().unwrap();
    }
    
    println!("结果: {}", counter.load(Ordering::SeqCst));
}

Ordering 参数指定内存顺序约束:

  • Relaxed:最宽松的顺序,只保证原子性
  • Release:写操作使用,确保之前的操作不会被重排到此操作之后
  • Acquire:读操作使用,确保之后的操作不会被重排到此操作之前
  • AcqRel:结合了 AcquireRelease
  • SeqCst:最严格的顺序,提供全序关系

条件变量

条件变量允许线程等待某个条件变为真:

use std::sync::{Arc, Mutex, Condvar};
use std::thread;

fn main() {
    let pair = Arc::new((Mutex::new(false), Condvar::new()));
    let pair2 = Arc::clone(&pair);
    
    thread::spawn(move || {
        let (lock, cvar) = &*pair2;
        let mut started = lock.lock().unwrap();
        println!("改变条件变量之前");
        *started = true;
        // 通知等待的线程
        cvar.notify_one();
        println!("条件变量已改变");
    });
    
    let (lock, cvar) = &*pair;
    let mut started = lock.lock().unwrap();
    // 等待条件变为真
    while !*started {
        started = cvar.wait(started).unwrap();
    }
    
    println!("条件已满足,主线程继续执行");
}

屏障(Barrier)

屏障确保多个线程在某一点同步:

use std::sync::{Arc, Barrier};
use std::thread;

fn main() {
    let mut handles = Vec::with_capacity(10);
    let barrier = Arc::new(Barrier::new(10));
    
    for i in 0..10 {
        let b = Arc::clone(&barrier);
        handles.push(thread::spawn(move || {
            println!("线程 {} 开始工作", i);
            // 模拟工作
            thread::sleep(std::time::Duration::from_millis(i * 100));
            println!("线程 {} 到达屏障", i);
            
            // 等待所有线程到达屏障
            b.wait();
            
            println!("线程 {} 继续执行", i);
        }));
    }
    
    for handle in handles {
        handle.join().unwrap();
    }
}

线程池

创建线程有开销,线程池可以重用线程,提高性能:

use std::sync::{mpsc, Arc, Mutex};
use std::thread;

struct ThreadPool {
    workers: Vec<Worker>,
    sender: mpsc::Sender<Job>,
}

type Job = Box<dyn FnOnce() + Send + 'static>;

impl ThreadPool {
    fn new(size: usize) -> ThreadPool {
        assert!(size > 0);
        
        let (sender, receiver) = mpsc::channel();
        let receiver = Arc::new(Mutex::new(receiver));
        
        let mut workers = Vec::with_capacity(size);
        
        for id in 0..size {
            workers.push(Worker::new(id, Arc::clone(&receiver)));
        }
        
        ThreadPool { workers, sender }
    }
    
    fn execute<F>(&self, f: F)
    where
        F: FnOnce() + Send + 'static,
    {
        let job = Box::new(f);
        self.sender.send(job).unwrap();
    }
}

struct Worker {
    id: usize,
    thread: thread::JoinHandle<()>,
}

impl Worker {
    fn new(id: usize, receiver: Arc<Mutex<mpsc::Receiver<Job>>>) -> Worker {
        let thread = thread::spawn(move || loop {
            let job = receiver.lock().unwrap().recv().unwrap();
            println!("工作线程 {} 获得了一个任务", id);
            job();
        });
        
        Worker { id, thread }
    }
}

fn main() {
    let pool = ThreadPool::new(4);
    
    for i in 0..8 {
        pool.execute(move || {
            println!("执行任务 {}", i);
            thread::sleep(std::time::Duration::from_secs(1));
            println!("任务 {} 完成", i);
        });
    }
    
    // 给线程池一些时间来处理任务
    thread::sleep(std::time::Duration::from_secs(10));
}

并发最佳实践

1. 优先使用消息传递

当可能时,优先使用消息传递而不是共享状态:

use std::sync::mpsc;
use std::thread;

fn main() {
    let (tx, rx) = mpsc::channel();
    
    // 启动工作线程
    for i in 0..4 {
        let tx = tx.clone();
        thread::spawn(move || {
            // 模拟工作
            let result = i * i;
            tx.send(result).unwrap();
        });
    }
    
    // 丢弃原始发送者
    drop(tx);
    
    // 收集结果
    let mut results = Vec::new();
    for received in rx {
        results.push(received);
    }
    
    println!("结果: {:?}", results);
}

2. 使用适当的同步原语

根据需求选择合适的同步原语:

  • 对于简单计数器:使用 AtomicUsize
  • 对于需要独占访问的数据:使用 Mutex
  • 对于读多写少的数据:使用 RwLock
  • 对于一次性初始化:使用 lazy_staticOnceCell/OnceLock

3. 避免过度同步

过度同步会导致性能下降:

use std::sync::{Arc, Mutex};
use std::thread;

// 不好的做法:锁的粒度太大
fn process_data_bad(data: &[i32]) -> i32 {
    let result = Arc::new(Mutex::new(0));
    let mut handles = vec![];
    
    for chunk in data.chunks(100) {
        let result = Arc::clone(&result);
        let chunk = chunk.to_vec();
        
        handles.push(thread::spawn(move || {
            // 计算和
            let sum: i32 = chunk.iter().sum();
            
            // 获取锁并更新结果
            let mut result = result.lock().unwrap();
            *result += sum;
        }));
    }
    
    for handle in handles {
        handle.join().unwrap();
    }
    
    *result.lock().unwrap()
}

// 好的做法:减少锁的竞争
fn process_data_good(data: &[i32]) -> i32 {
    let mut handles = vec![];
    
    for chunk in data.chunks(100) {
        let chunk = chunk.to_vec();
        
        handles.push(thread::spawn(move || {
            // 计算和并返回
            chunk.iter().sum::<i32>()
        }));
    }
    
    // 收集结果
    let mut final_result = 0;
    for handle in handles {
        final_result += handle.join().unwrap();
    }
    
    final_result
}

4. 使用线程局部存储

对于每个线程需要独立状态的情况,使用线程局部存储:

use std::cell::RefCell;
use std::thread;
use std::thread_local;

thread_local! {
    static COUNTER: RefCell<u32> = RefCell::new(0);
}

fn main() {
    let mut handles = vec![];
    
    for _ in 0..5 {
        handles.push(thread::spawn(|| {
            for _ in 0..10 {
                COUNTER.with(|c| {
                    *c.borrow_mut() += 1;
                });
            }
            
            let result = COUNTER.with(|c| *c.borrow());
            println!("线程计数器: {}", result);
        }));
    }
    
    for handle in handles {
        handle.join().unwrap();
    }
}

5. 使用 parking_lot 库

parking_lot 库提供了更高性能的同步原语:

// 在 Cargo.toml 中添加:
// [dependencies]
// parking_lot = "0.12.0"

use parking_lot::{Mutex, RwLock};
use std::sync::Arc;
use std::thread;

fn main() {
    let counter = Arc::new(Mutex::new(0));
    let mut handles = vec![];
    
    for _ in 0..10 {
        let counter = Arc::clone(&counter);
        let handle = thread::spawn(move || {
            let mut num = counter.lock();
            *num += 1;
        });
        handles.push(handle);
    }
    
    for handle in handles {
        handle.join().unwrap();
    }
    
    println!("结果: {}", *counter.lock());
}

练习题

  1. 实现一个并发计数器,使用不同的同步原语(MutexRwLockAtomicUsize)并比较它们的性能。

  2. 创建一个简单的生产者-消费者系统,其中多个生产者线程生成随机数,多个消费者线程计算这些数字的平方并打印结果。

  3. 实现一个并发哈希表,允许多个线程同时读取,但只允许一个线程写入。使用适当的同步原语确保线程安全。

  4. 编写一个程序,使用屏障(Barrier)同步多个线程,让它们同时开始执行一个计算密集型任务,并测量完成时间。

  5. 实现一个简单的线程池,可以提交任务并等待所有任务完成。包括一个优雅的关闭机制,确保所有任务都能完成。

总结

在本章中,我们探讨了 Rust 的并发编程模型:

  • 线程基础和所有权规则
  • 消息传递并发(通道)
  • 共享状态并发(互斥锁、读写锁、原子类型)
  • 条件变量和屏障
  • 线程池实现
  • 并发编程最佳实践

Rust 的类型系统和所有权规则使得并发编程更加安全,在编译时就能捕获许多常见的并发错误。通过选择适当的并发模型和同步原语,你可以编写高效、安全的并发代码。在下一章中,我们将探索 Rust 的异步编程模型,它提供了一种更轻量级的并发方式。