Skip to content

Commit

Permalink
Add a small bert test including testing an empty word
Browse files Browse the repository at this point in the history
  • Loading branch information
AngledLuffa committed Aug 24, 2024
1 parent a40e84c commit c896431
Showing 1 changed file with 33 additions and 0 deletions.
33 changes: 33 additions & 0 deletions stanza/tests/common/test_bert_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import pytest
import torch

from stanza.models.common.bert_embedding import load_bert, extract_bert_embeddings

pytestmark = [pytest.mark.travis, pytest.mark.pipeline]

BERT_MODEL = "hf-internal-testing/tiny-bert"

@pytest.fixture(scope="module")
def tiny_bert():
m, t = load_bert(BERT_MODEL)
return m, t

def test_load_bert(tiny_bert):
"""
Empty method that just tests loading the bert
"""
m, t = tiny_bert

def test_run_bert(tiny_bert):
m, t = tiny_bert
device = next(m.parameters()).device
extract_bert_embeddings(BERT_MODEL, t, m, [["This", "is", "a", "test"]], device, True)

def test_run_bert_empty_word(tiny_bert):
m, t = tiny_bert
device = next(m.parameters()).device
foo = extract_bert_embeddings(BERT_MODEL, t, m, [["This", "is", "-", "a", "test"]], device, True)
bar = extract_bert_embeddings(BERT_MODEL, t, m, [["This", "is", "", "a", "test"]], device, True)

assert len(foo) == 1
assert torch.allclose(foo[0], bar[0])

0 comments on commit c896431

Please sign in to comment.