在写CS336 Assignment 1的时候发现

1
2
3
4
5
6
7
8
9
10
11
12
def merge(indices: list[int], pair: tuple[int, int], new_index: int) -> list[int]:
"""Return `indices`, but with all instances of `pair` replaced with `new_index`."""
new_indices = []
i = 0
while i < len(indices):
if i + 1 < len(indices) and indices[i] == pair[0] and indices[i + 1] == pair[1]: # 当前索引和下一个索引匹配pair
new_indices.append(new_index)
i += 2
else:
new_indices.append(indices[i])
i += 1
return new_indices

这个函数运行情况如下:

1
2
3
4
   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
135040250 110.687 0.000 161.293 0.000 train_tokenizer.py:26(merge)
743 49.224 0.066 59.760 0.080 train_tokenizer.py:39(get_stats)
1 23.879 23.879 245.661 245.661 train_tokenizer.py:55(train_bpe)

要尽可能的减少BPE的训练时间,正所谓“哪里有需求,哪里就有优化”。于是我打算使用Rust重写该函数。

视频参考利用Rust加速Python:PyO3与maturin实践方法。目前还没能在该项目利用maturin成功调用Rust编写的函数,原因未知。所以目前只能用土办法来实现用Rust加速Python。

首先应该在该文件下创建一个Rust项目

1
cargo init

在Cargo.toml里面配置

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
[package]
name = "cs336_basics"
version = "0.0.1"
edition = "2024"

[lib]
name = "cs336_basics"
crate-type = ["cdylib"]

[dependencies]


[dependencies.pyo3]
version = "0.27.1"
features = ["extension-module"]

src\lib.rs下实现你想要实现的函数,例如

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
use pyo3::prelude::*;
use pyo3::wrap_pyfunction;

/// 计算斐波那契数列的 Rust 实现
#[pyfunction]
fn fib(n: u32) -> PyResult<u32> {
let result = match n {
0 => 0,
1 => 1,
_ => fib(n - 1)? + fib(n - 2)?,
};
Ok(result)
}

/// Python 模块
#[pymodule]
fn cs336_basics(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(fib, m)?)?;
Ok(())
}

在命令行里面输入

1
cargo build --release

target\release里面找到xxx.dll文件(Linux为xxx.so)并将其后缀改为.pyd,移动到Python代码的文件夹里面即可运行。

1
2
3
4
# test.py
import cs336_basics

print(cs336_basics.fib(10)) # 应该输出 55

2025年11月3日21:00 更新
重新开了个文件夹rust4cs336并使用maturin init,用Rust重写了mergeget_stat函数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
use std::collections::HashMap;
use pyo3::prelude::*;
use pyo3::types::PyDict;

use pyo3::wrap_pyfunction;

/// Return `indices`, but with all instances of `pair` replaced with `new_index`.
#[pyfunction]
fn merge(indices: Vec<i32>, pair: (i32, i32), new_index: i32) -> Vec<i32> {
let mut new_indices = Vec::new();
let mut i = 0;

while i < indices.len() {
if i + 1 < indices.len() && indices[i] == pair.0 && indices[i + 1] == pair.1 {
new_indices.push(new_index);
i += 2;
} else {
new_indices.push(indices[i]);
i += 1;
}
}

new_indices
}


/// Count pairs of adjacent tokens in `token_groups`.
#[pyfunction]
fn get_stats(py: Python, token_groups: Vec<Vec<i64>>) -> PyResult<Py<PyAny>> {
let mut counts: HashMap<(i64, i64), i64> = HashMap::new();

for group in token_groups {
if group.len() < 2 {
continue;
}
for w in group.windows(2) {
let key = (w[0], w[1]);
*counts.entry(key).or_insert(0) += 1;
}
}

let dict = PyDict::new(py);
for ((a, b), cnt) in counts {
// 将 Rust 元组 (i64,i64) 直接作为 Python 元组键传入
dict.set_item((a, b), cnt)?;
}

// 将本地引用转换为持久的 Py<PyAny>
Ok(dict.into())
}


/// Python模块定义 - 模块名必须与Cargo.toml中的name匹配
#[pymodule]
fn rust4cs336(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(merge, m)?)?;
m.add_function(wrap_pyfunction!(get_stats, m)?)?;
Ok(())
}

提升还是非常大的

1
2
3
   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
743 13.910 0.019 13.910 0.019 {built-in method cs336_basics.get_stats}
135040250 37.197 0.000 37.197 0.000 {built-in method cs336_basics.merge}

使用maturin build并用uv pip install.whl文件安装,结果仍然不变。