Skip to content

Commit

Permalink
allow users to load custom hmm model
Browse files Browse the repository at this point in the history
  • Loading branch information
name1e5s committed Dec 18, 2022
1 parent 38fe280 commit 64158fc
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 17 deletions.
5 changes: 3 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,18 @@ required-features = ["tfidf", "textrank"]
regex = "1.0"
lazy_static = "1.0"
phf = "0.11"
hashbrown = { version = "0.12", default-features = false, features = ["inline-more"] }
cedarwood = "0.4"
ordered-float = { version = "3.0", optional = true }
once_cell = "1"
fxhash = "0.2.1"

[build-dependencies]
phf_codegen = "0.11"

[features]
default = ["default-dict"]
default = ["default-dict", "default-hmm-model"]
default-dict = []
default-hmm-model = []
tfidf = ["ordered-float"]
textrank = ["ordered-float"]

Expand Down
8 changes: 4 additions & 4 deletions build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@ fn main() {
let mut lines = reader.lines().map(|x| x.unwrap()).skip_while(|x| x.starts_with('#'));
let prob_start = lines.next().unwrap();
writeln!(&mut file, "#[allow(clippy::style)]").unwrap();
write!(&mut file, "static INITIAL_PROBS: StatusSet = [").unwrap();
write!(&mut file, "pub static INITIAL_PROBS: StatusSet = [").unwrap();
for prob in prob_start.split(' ') {
write!(&mut file, "{}, ", prob).unwrap();
}
write!(&mut file, "];\n\n").unwrap();
writeln!(&mut file, "#[allow(clippy::style)]").unwrap();
write!(&mut file, "static TRANS_PROBS: [StatusSet; 4] = [").unwrap();
write!(&mut file, "pub static TRANS_PROBS: [StatusSet; 4] = [").unwrap();
for line in lines
.by_ref()
.skip_while(|x| x.starts_with('#'))
Expand All @@ -38,7 +38,7 @@ fn main() {
continue;
}
writeln!(&mut file, "#[allow(clippy::style)]").unwrap();
write!(&mut file, "static EMIT_PROB_{}: phf::Map<&'static str, f64> = ", i).unwrap();
write!(&mut file, "pub static EMIT_PROB_{}: phf::Map<&'static str, f64> = ", i).unwrap();
let mut map = phf_codegen::Map::new();
for word_prob in line.split(',') {
let mut parts = word_prob.split(':');
Expand All @@ -50,5 +50,5 @@ fn main() {
i += 1;
}
writeln!(&mut file, "#[allow(clippy::style)]").unwrap();
writeln!(&mut file, "static EMIT_PROBS: [&'static phf::Map<&'static str, f64>; 4] = [&EMIT_PROB_0, &EMIT_PROB_1, &EMIT_PROB_2, &EMIT_PROB_3];").unwrap();
writeln!(&mut file, "pub static EMIT_PROBS: [&'static phf::Map<&'static str, f64>; 4] = [&EMIT_PROB_0, &EMIT_PROB_1, &EMIT_PROB_2, &EMIT_PROB_3];").unwrap();
}
149 changes: 139 additions & 10 deletions src/hmm.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use std::cmp::Ordering;

use lazy_static::lazy_static;
use once_cell::sync::OnceCell;
use regex::Regex;
use std::cmp::Ordering;
use std::collections::HashMap;

use crate::SplitMatches;

Expand All @@ -20,17 +21,108 @@ pub enum Status {
S = 3,
}

pub struct HmmModel {
pub initial_probs: StatusSet,
pub trans_probs: [StatusSet; 4],
pub emit_probs: [HashMap<String, f64>; 4],
}

static CUSTOM_HMM_MODEL: OnceCell<HmmModel> = OnceCell::new();

pub fn get_custom_hmm_model() -> Option<&'static HmmModel> {
CUSTOM_HMM_MODEL.get()
}

pub fn init_custom_hmm_model(model: HmmModel) -> Result<(), HmmModel> {
CUSTOM_HMM_MODEL.set(model)
}

static PREV_STATUS: [[Status; 2]; 4] = [
[Status::E, Status::S], // B
[Status::B, Status::M], // E
[Status::M, Status::B], // M
[Status::S, Status::E], // S
];

include!(concat!(env!("OUT_DIR"), "/hmm_prob.rs"));

const MIN_FLOAT: f64 = -3.14e100;

#[cfg(feature = "default-hmm-model")]
mod default_hmm {
use super::*;

include!(concat!(env!("OUT_DIR"), "/hmm_prob.rs"));
}

#[inline]
fn get_initial_prob(index: usize) -> f64 {
debug_assert!(index < 4);
if let Some(model) = get_custom_hmm_model() {
model.initial_probs[index]
} else {
#[cfg(feature = "default-hmm-model")]
{
default_hmm::INITIAL_PROBS[index]
}
#[cfg(not(feature = "default-hmm-model"))]
{
debug_assert!(
true,
"No default hmm model is provided, please use `set_custom_hmm_model` to set a custom model."
);
MIN_FLOAT
}
}
}

#[inline]
fn get_emit_prob(index: usize, word: &str) -> f64 {
debug_assert!(index < 4);
if let Some(model) = get_custom_hmm_model() {
model.emit_probs[index].get(word).cloned().unwrap_or(MIN_FLOAT)
} else {
#[cfg(feature = "default-hmm-model")]
{
default_hmm::EMIT_PROBS[index].get(word).cloned().unwrap_or(MIN_FLOAT)
}
#[cfg(not(feature = "default-hmm-model"))]
{
debug_assert!(
true,
"No default hmm model is provided, please use `set_custom_hmm_model` to set a custom model."
);
MIN_FLOAT
}
}
}

#[inline]
fn get_trans_prob(from_index: usize, to_index: usize) -> f64 {
debug_assert!(from_index < 4);
debug_assert!(to_index < 4);
if let Some(model) = get_custom_hmm_model() {
model.trans_probs[from_index]
.get(to_index)
.cloned()
.unwrap_or(MIN_FLOAT)
} else {
#[cfg(feature = "default-hmm-model")]
{
default_hmm::TRANS_PROBS[from_index]
.get(to_index)
.cloned()
.unwrap_or(MIN_FLOAT)
}
#[cfg(not(feature = "default-hmm-model"))]
{
debug_assert!(
true,
"No default hmm model is provided, please use `set_custom_hmm_model` to set a custom model."
);
MIN_FLOAT
}
}
}

#[allow(non_snake_case)]
fn viterbi(sentence: &str, V: &mut Vec<f64>, prev: &mut Vec<Option<Status>>, best_path: &mut Vec<Status>) {
let str_len = sentence.len();
Expand All @@ -57,7 +149,8 @@ fn viterbi(sentence: &str, V: &mut Vec<f64>, prev: &mut Vec<Option<Status>>, bes
let x2 = *curr.peek().unwrap();
for y in &states {
let first_word = &sentence[x1..x2];
let prob = INITIAL_PROBS[*y as usize] + EMIT_PROBS[*y as usize].get(first_word).cloned().unwrap_or(MIN_FLOAT);
let index = *y as usize;
let prob = get_initial_prob(index) + get_emit_prob(index, first_word);
V[*y as usize] = prob;
}

Expand All @@ -66,14 +159,12 @@ fn viterbi(sentence: &str, V: &mut Vec<f64>, prev: &mut Vec<Option<Status>>, bes
for y in &states {
let byte_end = *curr.peek().unwrap_or(&str_len);
let word = &sentence[byte_start..byte_end];
let em_prob = EMIT_PROBS[*y as usize].get(word).cloned().unwrap_or(MIN_FLOAT);
let em_prob = get_emit_prob(*y as usize, word);
let (prob, state) = PREV_STATUS[*y as usize]
.iter()
.map(|y0| {
(
V[(t - 1) * R + (*y0 as usize)]
+ TRANS_PROBS[*y0 as usize].get(*y as usize).cloned().unwrap_or(MIN_FLOAT)
+ em_prob,
V[(t - 1) * R + (*y0 as usize)] + get_trans_prob(*y0 as usize, *y as usize) + em_prob,
*y0,
)
})
Expand Down Expand Up @@ -197,15 +288,52 @@ pub fn cut<'a>(sentence: &'a str, words: &mut Vec<&'a str>) {
cut_with_allocated_memory(sentence, words, &mut V, &mut prev, &mut path);
}

#[cfg(all(test, not(feature = "default-hmm-model")))]
pub fn test_init_custom_hmm_model() {
use std::convert::TryInto;

mod hmm_prob {
use super::*;
include!(concat!(env!("OUT_DIR"), "/hmm_prob.rs"));
}

if get_custom_hmm_model().is_none() {
let initial_probs = hmm_prob::INITIAL_PROBS;
let trans_probs = hmm_prob::TRANS_PROBS;
let emit_probs: [HashMap<_, _>; 4] = {
let mut emit_probs = Vec::new();
for prob in hmm_prob::EMIT_PROBS {
let mut probs = HashMap::new();
for (k, v) in prob {
probs.insert(k.to_string(), *v);
}
emit_probs.push(probs);
}
emit_probs.try_into().unwrap()
};
let _ = init_custom_hmm_model(HmmModel {
initial_probs,
trans_probs,
emit_probs,
});
}
}

#[cfg(all(test, feature = "default-hmm-model"))]
pub fn test_init_custom_hmm_model() {
// nothing
}

#[cfg(test)]
mod tests {
use super::{cut, viterbi, Status};
use super::*;

#[test]
#[allow(non_snake_case)]
fn test_viterbi() {
use super::Status::*;

test_init_custom_hmm_model();
let sentence = "小明硕士毕业于中国科学院计算所";

let R = 4;
Expand All @@ -219,6 +347,7 @@ mod tests {

#[test]
fn test_hmm_cut() {
test_init_custom_hmm_model();
let sentence = "小明硕士毕业于中国科学院计算所";
let mut words = Vec::with_capacity(sentence.chars().count() / 2);
cut(sentence, &mut words);
Expand Down
11 changes: 10 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ use std::cmp::Ordering;
use std::io::BufRead;

use cedarwood::Cedar;
use hashbrown::HashMap;
use regex::{Match, Matches, Regex};
use std::collections::HashMap;

pub(crate) type FxHashMap<K, V> = HashMap<K, V, fxhash::FxBuildHasher>;

Expand All @@ -88,6 +88,8 @@ pub use crate::keywords::tfidf::TFIDF;
#[cfg(any(feature = "tfidf", feature = "textrank"))]
pub use crate::keywords::{Keyword, KeywordExtract};

pub use crate::hmm::{get_custom_hmm_model, init_custom_hmm_model};

mod errors;
mod hmm;
#[cfg(any(feature = "tfidf", feature = "textrank"))]
Expand Down Expand Up @@ -806,6 +808,7 @@ impl Jieba {
#[cfg(test)]
mod tests {
use super::{Jieba, SplitMatches, SplitState, Tag, Token, TokenizeMode, RE_HAN_DEFAULT};
use crate::hmm::test_init_custom_hmm_model;
use std::io::BufReader;

#[test]
Expand Down Expand Up @@ -900,6 +903,7 @@ mod tests {

#[test]
fn test_cut_with_hmm() {
test_init_custom_hmm_model();
let jieba = Jieba::new();
let words = jieba.cut("我们中出了一个叛徒", false);
assert_eq!(words, vec!["我们", "中", "出", "了", "一个", "叛徒"]);
Expand All @@ -917,6 +921,7 @@ mod tests {

#[test]
fn test_cut_weicheng() {
test_init_custom_hmm_model();
static WEICHENG_TXT: &str = include_str!("../examples/weicheng/src/weicheng.txt");
let jieba = Jieba::new();
for line in WEICHENG_TXT.split('\n') {
Expand All @@ -926,6 +931,7 @@ mod tests {

#[test]
fn test_cut_for_search() {
test_init_custom_hmm_model();
let jieba = Jieba::new();
let words = jieba.cut_for_search("南京市长江大桥", true);
assert_eq!(words, vec!["南京", "京市", "南京市", "长江", "大桥", "长江大桥"]);
Expand Down Expand Up @@ -962,6 +968,7 @@ mod tests {

#[test]
fn test_tag() {
test_init_custom_hmm_model();
let jieba = Jieba::new();
let tags = jieba.tag(
"我是拖拉机学院手扶拖拉机专业的。不用多久,我就会升职加薪,当上CEO,走上人生巅峰。",
Expand Down Expand Up @@ -1078,6 +1085,7 @@ mod tests {

#[test]
fn test_tokenize() {
test_init_custom_hmm_model();
let jieba = Jieba::new();
let tokens = jieba.tokenize("南京市长江大桥", TokenizeMode::Default, false);
assert_eq!(
Expand Down Expand Up @@ -1305,6 +1313,7 @@ mod tests {

#[test]
fn test_userdict_hmm() {
test_init_custom_hmm_model();
let mut jieba = Jieba::new();
let tokens = jieba.tokenize("我们中出了一个叛徒", TokenizeMode::Default, true);
assert_eq!(
Expand Down

0 comments on commit 64158fc

Please sign in to comment.