Skip to content

Commit

Permalink
Merge pull request #865 from pavaris-pm/update-pos-tag-transformers
Browse files Browse the repository at this point in the history
Update `pos_tag_transformers` function
  • Loading branch information
wannaphong committed Nov 25, 2023
2 parents a319d08 + 5574ce3 commit abfbf02
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 26 deletions.
74 changes: 52 additions & 22 deletions pythainlp/tag/pos_tag.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,13 +180,36 @@ def pos_tag_sents(


def pos_tag_transformers(
words: str, engine: str = "bert-base-th-cased-blackboard"
):
sentence: str,
engine: str = "bert",
corpus: str = "blackboard",
)->List[List[Tuple[str, str]]]:
"""
"wangchanberta-ud-thai-pud-upos",
"mdeberta-v3-ud-thai-pud-upos",
"bert-base-th-cased-blackboard",
Marks sentences with part-of-speech (POS) tags.
:param str sentence: a list of lists of tokenized words
:param str engine:
* *bert* - BERT: Bidirectional Encoder Representations from Transformers (default)
* *wangchanberta* - fine-tuned version of airesearch/wangchanberta-base-att-spm-uncased on pud corpus (support PUD cotpus only)
* *mdeberta* - mDeBERTa: Multilingual Decoding-enhanced BERT with disentangled attention (support PUD corpus only)
:param str corpus: the corpus that is used to create the language model for tagger
* *blackboard* - `blackboard treebank (support bert engine only) <https://bitbucket.org/kaamanita/blackboard-treebank/src/master/>`_
* *pud* - `Parallel Universal Dependencies (PUD)\
<https://github.com/UniversalDependencies/UD_Thai-PUD>`_ \
treebanks, natively use Universal POS tags (support wangchanberta and mdeberta engine)
:return: a list of lists of tuples (word, POS tag)
:rtype: list[list[tuple[str, str]]]
:Example:
Labels POS for given sentence::
from pythainlp.tag import pos_tag_transformers
sentences = "แมวทำอะไรตอนห้าโมงเช้า"
pos_tag_transformers(sentences, engine="bert", corpus='blackboard')
# output:
# [[('แมว', 'NOUN'), ('ทําอะไร', 'VERB'), ('ตอนห้าโมงเช้า', 'NOUN')]]
"""

try:
Expand All @@ -196,28 +219,35 @@ def pos_tag_transformers(
raise ImportError(
"Not found transformers! Please install transformers by pip install transformers")

if not words:
if not sentence:
return []

if engine == "wangchanberta-ud-thai-pud-upos":
model = AutoModelForTokenClassification.from_pretrained(
"Pavarissy/wangchanberta-ud-thai-pud-upos")
tokenizer = AutoTokenizer.from_pretrained("Pavarissy/wangchanberta-ud-thai-pud-upos")
elif engine == "mdeberta-v3-ud-thai-pud-upos":
model = AutoModelForTokenClassification.from_pretrained(
"Pavarissy/mdeberta-v3-ud-thai-pud-upos")
tokenizer = AutoTokenizer.from_pretrained("Pavarissy/mdeberta-v3-ud-thai-pud-upos")
elif engine == "bert-base-th-cased-blackboard":
model = AutoModelForTokenClassification.from_pretrained("lunarlist/pos_thai")
tokenizer = AutoTokenizer.from_pretrained("lunarlist/pos_thai")
_blackboard_support_engine = {
"bert" : "lunarlist/pos_thai",
}

_pud_support_engine = {
"wangchanberta" : "Pavarissy/wangchanberta-ud-thai-pud-upos",
"mdeberta" : "Pavarissy/mdeberta-v3-ud-thai-pud-upos",
}

if corpus == 'blackboard' and engine in _blackboard_support_engine.keys():
base_model = _blackboard_support_engine.get(engine)
model = AutoModelForTokenClassification.from_pretrained(base_model)
tokenizer = AutoTokenizer.from_pretrained(base_model)
elif corpus == 'pud' and engine in _pud_support_engine.keys():
base_model = _pud_support_engine.get(engine)
model = AutoModelForTokenClassification.from_pretrained(base_model)
tokenizer = AutoTokenizer.from_pretrained(base_model)
else:
raise ValueError(
"pos_tag_transformers not support {0} engine.".format(
engine
"pos_tag_transformers not support {0} engine or {1} corpus.".format(
engine, corpus
)
)

pipeline = TokenClassificationPipeline(model=model, tokenizer=tokenizer, grouped_entities=True)
pipeline = TokenClassificationPipeline(model=model, tokenizer=tokenizer, aggregation_strategy="simple")

outputs = pipeline(words)
return outputs
outputs = pipeline(sentence)
word_tags = [[(tag['word'], tag['entity_group']) for tag in outputs]]
return word_tags
11 changes: 7 additions & 4 deletions tests/test_tag.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,10 +367,13 @@ def test_NNER_class(self):

def test_pos_tag_transformers(self):
self.assertIsNotNone(pos_tag_transformers(
words="แมวทำอะไรตอนห้าโมงเช้า", engine="bert-base-th-cased-blackboard"))
words="แมวทำอะไรตอนห้าโมงเช้า", engine="bert", corpus="blackboard"))
self.assertIsNotNone(pos_tag_transformers(
words="แมวทำอะไรตอนห้าโมงเช้า", engine="mdeberta-v3-ud-thai-pud-upos"))
words="แมวทำอะไรตอนห้าโมงเช้า", engine="mdeberta", corpus="pud"))
self.assertIsNotNone(pos_tag_transformers(
words="แมวทำอะไรตอนห้าโมงเช้า", engine="wangchanberta-ud-thai-pud-upos"))
words="แมวทำอะไรตอนห้าโมงเช้า", engine="wangchanberta", corpus="pud"))
with self.assertRaises(ValueError):
pos_tag_transformers(words="แมวทำอะไรตอนห้าโมงเช้า", engine="non-existing-engine")
pos_tag_transformers(words="แมวทำอะไรตอนห้าโมงเช้า", engine="non-existing-engine")
with self.assertRaises(ValueError):
pos_tag_transformers(words="แมวทำอะไรตอนห้าโมงเช้า", engine="bert",
corpus="non-existing corpus")

0 comments on commit abfbf02

Please sign in to comment.