在写CS336 Assignment 1的时候发现
1 2 3 4 5 6 7 8 9 10 11 12 def merge (indices: list [int ], pair: tuple [int , int ], new_index: int ) -> list [int ]: """Return `indices`, but with all instances of `pair` replaced with `new_index`.""" new_indices = [] i = 0 while i < len (indices): if i + 1 < len (indices) and indices[i] == pair[0 ] and indices[i + 1 ] == pair[1 ]: new_indices.append(new_index) i += 2 else : new_indices.append(indices[i]) i += 1 return new_indices
这个函数运行情况如下:
1 2 3 4 ncalls tottime percall cumtime percall filename:lineno(function) 135040250 110.687 0.000 161.293 0.000 train_tokenizer.py:26(merge) 743 49.224 0.066 59.760 0.080 train_tokenizer.py:39(get_stats) 1 23.879 23.879 245.661 245.661 train_tokenizer.py:55(train_bpe)
要尽可能的减少BPE的训练时间,正所谓“哪里有需求,哪里就有优化”。于是我打算使用Rust重写该函数。
视频参考利用Rust加速Python:PyO3与maturin实践方法 。目前还没能在该项目利用maturin成功调用Rust编写的函数,原因未知。所以目前只能用土办法来实现用Rust加速Python。
首先应该在该文件下创建一个Rust项目
在Cargo.toml里面配置
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 [package] name = "cs336_basics" version = "0.0.1" edition = "2024" [lib] name = "cs336_basics" crate-type = ["cdylib" ][dependencies] [dependencies.pyo3] version = "0.27.1" features = ["extension-module" ]
在src\lib.rs下实现你想要实现的函数,例如
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 use pyo3::prelude::*;use pyo3::wrap_pyfunction;#[pyfunction] fn fib (n: u32 ) -> PyResult<u32 > { let result = match n { 0 => 0 , 1 => 1 , _ => fib (n - 1 )? + fib (n - 2 )?, }; Ok (result) } #[pymodule] fn cs336_basics (m: &Bound<'_ , PyModule>) -> PyResult<()> { m.add_function (wrap_pyfunction!(fib, m)?)?; Ok (()) }
在命令行里面输入
在target\release里面找到xxx.dll文件(Linux为xxx.so)并将其后缀改为.pyd,移动到Python代码的文件夹里面即可运行。
1 2 3 4 import cs336_basicsprint (cs336_basics.fib(10 ))
2025年11月3日21:00 更新
重新开了个文件夹rust4cs336并使用maturin init,用Rust重写了merge和get_stat函数。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 use std::collections::HashMap;use pyo3::prelude::*;use pyo3::types::PyDict;use pyo3::wrap_pyfunction;#[pyfunction] fn merge (indices: Vec <i32 >, pair: (i32 , i32 ), new_index: i32 ) -> Vec <i32 > { let mut new_indices = Vec ::new (); let mut i = 0 ; while i < indices.len () { if i + 1 < indices.len () && indices[i] == pair.0 && indices[i + 1 ] == pair.1 { new_indices.push (new_index); i += 2 ; } else { new_indices.push (indices[i]); i += 1 ; } } new_indices } #[pyfunction] fn get_stats (py: Python, token_groups: Vec <Vec <i64 >>) -> PyResult<Py<PyAny>> { let mut counts : HashMap<(i64 , i64 ), i64 > = HashMap::new (); for group in token_groups { if group.len () < 2 { continue ; } for w in group.windows (2 ) { let key = (w[0 ], w[1 ]); *counts.entry (key).or_insert (0 ) += 1 ; } } let dict = PyDict::new (py); for ((a, b), cnt) in counts { dict.set_item ((a, b), cnt)?; } Ok (dict.into ()) } #[pymodule] fn rust4cs336 (m: &Bound<'_ , PyModule>) -> PyResult<()> { m.add_function (wrap_pyfunction!(merge, m)?)?; m.add_function (wrap_pyfunction!(get_stats, m)?)?; Ok (()) }
提升还是非常大的
1 2 3 ncalls tottime percall cumtime percall filename:lineno(function) 743 13.910 0.019 13.910 0.019 {built-in method cs336_basics.get_stats} 135040250 37.197 0.000 37.197 0.000 {built-in method cs336_basics.merge}
使用maturin build并用uv pip install把.whl文件安装,结果仍然不变。