bleuscore/
tokenizer.rs

1use cached::proc_macro::cached;
2use lazy_static::lazy_static;
3use regex::Regex;
4
5lazy_static! {
6    pub static ref REGEX_ARRAY: [(Regex, &'static str); 4] = [
7        (
8            Regex::new(r"([\{-\~\[-\` -\&\(-\+\:-\@\/])").unwrap(),
9            r" $1 "
10        ),
11        (Regex::new(r"([^0-9])([\.,])").unwrap(), r"$1 $2 "),
12        (Regex::new(r"([\.,])([^0-9])").unwrap(), r" $1 $2"),
13        (Regex::new(r"([0-9])(-)").unwrap(), r"$1 $2 "),
14    ];
15}
16
17/// tokenize function is used to tokenize the strings
18pub trait Tokenizer {
19    fn signature(&self) -> &str;
20    fn tokenize(&self, line: &str) -> Vec<String>;
21}
22
23/// Same implementation with [huggingface/sacrebleu](https://github.com/huggingface/evaluate/blob/main/metrics/bleu/tokenizer_13a.py)
24#[derive(Debug)]
25pub struct TokenizerRegex {
26    pub signature: String,
27}
28
29impl Default for TokenizerRegex {
30    fn default() -> Self {
31        Self {
32            signature: "re".to_string(),
33        }
34    }
35}
36
37impl TokenizerRegex {
38    pub fn new() -> Self {
39        Self::default()
40    }
41}
42
43#[cached(size = 65536)]
44fn regex_tokenize_cache(line: String) -> Vec<String> {
45    let mut res = line;
46    for &(ref re_capture, re_replace) in REGEX_ARRAY.iter() {
47        res = re_capture.replace_all(&res, re_replace).to_string();
48    }
49    res.split_whitespace().map(|x| x.to_string()).collect()
50}
51
52impl Tokenizer for TokenizerRegex {
53    fn signature(&self) -> &str {
54        &self.signature
55    }
56    fn tokenize(&self, line: &str) -> Vec<String> {
57        regex_tokenize_cache(line.to_string())
58    }
59}
60
61/// Same implementation with [huggingface/sacrebleu](https://github.com/huggingface/evaluate/blob/main/metrics/bleu/tokenizer_13a.py)
62#[derive(Debug)]
63pub struct Tokenizer13a {
64    pub signature: String,
65}
66
67impl Default for Tokenizer13a {
68    fn default() -> Self {
69        Self {
70            signature: "13a".to_string(),
71        }
72    }
73}
74
75impl Tokenizer13a {
76    pub fn new() -> Self {
77        Self::default()
78    }
79}
80
81#[cached(size = 65536)]
82fn tokenize_13a_cache(line: String) -> Vec<String> {
83    let mut res = line;
84    res = res
85        .replace("<skipped>", "")
86        .replace("-\n", "")
87        .replace('\n', " ");
88    if res.contains('&') {
89        res = res
90            .replace("&quot;", "\"")
91            .replace("&amp;", "&")
92            .replace("&lt;", "<")
93            .replace("&gt;", ">");
94    }
95    TokenizerRegex::new().tokenize(&format!(" {res} "))
96}
97
98impl Tokenizer for Tokenizer13a {
99    fn signature(&self) -> &str {
100        &self.signature
101    }
102    fn tokenize(&self, line: &str) -> Vec<String> {
103        tokenize_13a_cache(line.to_string())
104    }
105}
106
107#[cfg(test)]
108mod test {
109    use crate::tokenizer;
110    use crate::tokenizer::Tokenizer;
111
112    #[test]
113    fn test_tokenize_regex() {
114        let tokenizer_regex = tokenizer::TokenizerRegex::new();
115        let mut line = "Hello, World!";
116        let mut res = tokenizer_regex.tokenize(line);
117        assert_eq!(res, vec!["Hello", ",", "World", "!"]);
118
119        line = "/usr/sbin/sendmail - 0 errors, 12 warnings";
120        res = tokenizer_regex.tokenize(line);
121        assert_eq!(
122            res,
123            vec![
124                "/", "usr", "/", "sbin", "/", "sendmail", "-", "0", "errors", ",", "12", "warnings"
125            ]
126        )
127    }
128
129    #[test]
130    fn test_tokenize_13a_regex() {
131        let tokenizer_regex = tokenizer::Tokenizer13a::new();
132        let mut line = "Hello, &quot;World!<skipped>";
133        let mut res = tokenizer_regex.tokenize(line);
134        assert_eq!(res, vec!["Hello", ",", "\"", "World", "!"]);
135
136        line = "/usr/sbin/sendmail - 0 errors, 12 warnings";
137        res = tokenizer_regex.tokenize(line);
138        assert_eq!(
139            res,
140            vec![
141                "/", "usr", "/", "sbin", "/", "sendmail", "-", "0", "errors", ",", "12", "warnings"
142            ]
143        )
144    }
145}
146
147#[cfg(test)]
148mod benchmark {
149    use crate::tokenizer;
150    use crate::tokenizer::Tokenizer;
151    use test::Bencher;
152    #[bench]
153    fn bench_tokenizer(b: &mut Bencher) {
154        let tokenizer_regex = tokenizer::Tokenizer13a::new();
155        let line = "Hello, &quot;World!<skipped>";
156
157        let iter_num: usize = 100;
158        b.iter(|| {
159            std::hint::black_box(for _ in 1..=iter_num {
160                tokenizer_regex.tokenize(line);
161            });
162        });
163    }
164}