Skip to content

Commit

Permalink
#48 add deep_classifier , #43 add test for deep_classifier
Browse files Browse the repository at this point in the history
  • Loading branch information
yangyaofei committed Dec 6, 2021
1 parent c2b4448 commit 54f59cf
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 5 deletions.
7 changes: 7 additions & 0 deletions docs/nlpir.rst
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,13 @@ nlpir.eye_checker module
:undoc-members:
:show-inheritance:

nlpir.deep_classifier module
-------------------------------

.. automodule:: nlpir.deep_classifier
:members:
:undoc-members:
:show-inheritance:

nlpir.tools module
--------------------
Expand Down
49 changes: 49 additions & 0 deletions nlpir/deep_classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#! coding=utf-8
"""
high-level toolbox for text classify
"""
import re
import typing
import nlpir
from nlpir import get_instance as __get_instance__
from nlpir import native

# class and class instance
__cls__ = native.deep_classifier.DeepClassifier
__instance__: typing.Optional[native.deep_classifier.DeepClassifier] = None
# Location of DLL
__lib__ = None
# Data directory
__data__ = None
# license_code
__license_code__ = None
# encode
__nlpir_encode__ = native.UTF8_CODE

__handler__ = None


@__get_instance__
def get_native_instance() -> native.deep_classifier.DeepClassifier:
"""
返回原生NLPIR接口,使用更多函数
:return: The singleton instance
"""
return __instance__


@__get_instance__
def classify(txt: str) -> str:
"""
Text classify
:param txt: text
:return: class
"""
global __handler__
if __handler__ is None:
# default model
__handler__ = __instance__.new_instance(800)
__instance__.load_train_result(__handler__)
return __instance__.classify(txt, handler=__handler__)
3 changes: 0 additions & 3 deletions nlpir/eye_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,9 @@
import os
import re
import typing
from enum import Enum
from pathlib import Path

from pydantic import BaseModel

import nlpir
from nlpir import get_instance as __get_instance__
from nlpir import native

Expand Down
6 changes: 5 additions & 1 deletion nlpir/native/deep_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,11 @@ def add_train(self, classname: str, text: str, handler: int = 0) -> bool:
:param handler: classifier handler
:return: add success or not
"""
return self.get_func("DeepClassifier_AddTrain", [c_char_p, c_char_p, POINTER(c_int)], c_bool)(classname, text, handler)
return self.get_func(
"DeepClassifier_AddTrain",
[c_char_p, c_char_p, POINTER(c_int)],
c_bool
)(classname, text, handler)

@NLPIRBase.byte_str_transform
def add_train_file(self, classname: str, filename: str, handler: int = 0) -> int:
Expand Down
12 changes: 12 additions & 0 deletions tests/test_deep_classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# coding=utf-8
"""
Tested function:
- :func:`nlpir.deep_classifier.classify`
"""
from nlpir import deep_classifier


def test_classify():
from tests.strings import test_str
assert deep_classifier.classify(txt=test_str) == "教育"
12 changes: 11 additions & 1 deletion tests/test_eye_checker.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
# coding=utf-8
"""
Tested function:
- :func:`nlpir.eye_checker.import_kgb_rules`
- :func:`nlpir.eye_checker.list_rules`
- :func:`nlpir.eye_checker.delete_rules`
- :func:`nlpir.eye_checker.extract_knowledge`
"""

import pytest

from nlpir import eye_checker
Expand All @@ -15,7 +25,7 @@ def test_extract():

@pytest.mark.run(order=1)
def test_rule_manage():
from tests.strings import test_kgb_test_text, test_kgb_rules
from tests.strings import test_kgb_rules
rule_set = {1, 2, 3, 4, 6, 7, 9}
for rule in rule_set:
assert eye_checker.import_kgb_rules(rule_text=test_kgb_rules, report_type=rule, overwrite=True)
Expand Down

0 comments on commit 54f59cf

Please sign in to comment.