rust-candle学习笔记9-使用tokenizers加载qwen3分词,使用分词器处理文本

发布于:2025-05-10 ⋅ 阅读:(18) ⋅ 点赞:(0)

参考:about-pytorch, about-tokenizers

在魔搭社区链接下载qwen3的tokenizer.json文件

添加依赖库:

cargo add tokenizers

tokenizers库初体验:

use tokenizers::tokenizer::{self, Result, Tokenizer};

fn main() -> Result<()> {
    let tokenizer = Tokenizer::from_file("assets/qwen3/tokenizer.json")?;
    let text = "Hello, do you like tea? <|endoftext|> In the sunlit terraces of someunknownPlace.";
    let encoding = tokenizer.encode(text, false)?;
    println!("{:?}\n", encoding.get_tokens());
    let ids = encoding.get_ids();
    println!("{:?}\n", ids);
    let text = tokenizer.decode(ids, false)?;
    println!("{:?}\n", text);
    Ok(())
}

定义一个dataset trait,包含常用的方法

trait Dataset {
    fn get_batch(&self, start: usize, end: usize) -> Result<(Tensor, Tensor)> ;
    fn len(&self) -> usize;
    fn shuffle(&mut self) -> Result<()>;
}

定义tokenDataset

struct TokenDataset {
    inputs_ids: Tensor,
    target_ids: Tensor,
    device: Device
}

为TokenDataset实现Dataset的trait:

impl Dataset for TokenDataset {
    fn get_batch(&self, start: usize, end: usize) -> Result<(Tensor, Tensor)> {
        Ok((self.inputs_ids.i((start..end, ..))?, self.target_ids.i((start..end, ..))?))
    }

    fn len(&self) -> usize {
        self.inputs_ids.shape().dims()[0]
    }

    fn shuffle(&mut self)  -> Result<()>  {
        let len = self.len();
        let mut indices: Vec<u32> = (0..len).map(|i| i as u32).collect();
        let mut rng = rand::rng();
        indices.shuffle(&mut rng);
        let idx_tensor = Tensor::from_vec(indices.clone(), (indices.len(), ), &self.device)?; 
        self.inputs_ids = self.inputs_ids.index_select(&idx_tensor, 0)?;
        self.target_ids = self.target_ids.index_select(&idx_tensor, 0)?;
        Ok(())
    }
}

为TokenDataset定义new方法:

impl TokenDataset {
    fn new(
        txt: String, 
        tokenizer: Tokenizer, 
        max_length: usize, 
        stride: usize, 
        device: Device
    ) -> Result<Self> {
        let tokens = tokenizer.encode(txt, true)?;
        let tokens_id = tokens.get_ids();
        let token_len = tokens_id.len();
        if token_len <= max_length {
            return Err(Box::new(candle_core::Error::msg("Text is too short for given max_length")));
        }
        let max_start_index = token_len - max_length;
        let mut inputs_ids_vec: Vec<u32> = Vec::with_capacity(max_start_index * max_length);
        let mut target_ids_vec: Vec<u32> = Vec::with_capacity(max_start_index * max_length);
        
        for i in (0..max_start_index).step_by(stride) {            
            inputs_ids_vec.extend_from_slice(&tokens_id[i..i+max_length]);
            target_ids_vec.extend_from_slice(&tokens_id[i+1..i+1+max_length]);
        }
        let total_samples = inputs_ids_vec.len() / max_length;
        let inputs_ids = Tensor::from_vec(inputs_ids_vec, (total_samples, max_length), &device)?;
        let target_ids = Tensor::from_vec(target_ids_vec, (total_samples, max_length),  &device)?;
        Ok(Self { inputs_ids, target_ids, device })
    }
    fn get_item(&self, idx: usize)  -> Result<(Tensor, Tensor)>{
        Ok((self.inputs_ids.i((idx, ..))?, self.target_ids.i((idx, ..))?))
    }     
}

定义Dataloader, 实现了Dataset trait的struct都可以用这个加载

struct DataLoader<'a> {
    dataset: Box<dyn Dataset + 'a>,
    batch_size: usize,
    shuffle: bool,
    current_index: usize
}

 为Dataloader实现常用方法:

impl<'a> DataLoader<'a> {
    pub fn new<D: Dataset + 'a>(dataset: D, batch_size: usize, shuffle: bool) -> Self {
        Self {
            dataset: Box::new(dataset),
            batch_size,
            shuffle,
            current_index: 0,
        }
    }

    pub fn reset(&mut self) {
        self.current_index = 0;
        if self.shuffle {
           let _ = self.dataset.shuffle();
        }
    }
}

为Dataloader实现Iterator trait:

impl<'a> Iterator for DataLoader<'a> {
    type Item = Result<(Tensor, Tensor)>;
    fn next(&mut self) -> Option<Self::Item> {
        let start = self.current_index * self.batch_size;
        let end = std::cmp::min(start+self.batch_size, self.dataset.len());
        if start >= end {
            return None;
        }

        let batch = self.dataset.get_batch(start, end).ok()?;
        self.current_index += 1;
        Some(Ok(batch))
    }
}

测试dataloader:

use tokenizers::tokenizer::{self, Result, Tokenizer};
#[allow(unused)]
mod learn_tokenizer;
use learn_tokenizer::read_txt;
use candle_core::{Device, Tensor, IndexOp};
use rand::seq::SliceRandom;

fn main() -> Result<()> {
    let tokenizer = Tokenizer::from_file("assets/qwen3/tokenizer.json")?;    

    let text = read_txt("assets/the-verdict.txt")?;
    let device = Device::cuda_if_available(0)?;
    let dataset = TokenDataset::new(text, tokenizer, 512, 256, device.clone())?;
    let (inputs, targets) = dataset.get_item(0)?;
    println!("{:?}\n", inputs);
    println!("{:?}\n", targets);
    let len = dataset.len();
    println!("{:?}", len);
    let mut loader = DataLoader::new(dataset, 6, true);
    loader.reset();
    for batch in &mut loader {
        let (x, y) = batch.unwrap();
        println!("input: {:?}", x);
        println!("target: {:?}", y);
    }
    
    Ok(())
}