Skip to content

Commit

Permalink
Add text embedding model i
Browse files Browse the repository at this point in the history
  • Loading branch information
davidkyle committed May 22, 2023
1 parent 4de95a4 commit 6f4bc58
Showing 1 changed file with 26 additions and 1 deletion.
27 changes: 26 additions & 1 deletion tests/ml/pytorch/test_pytorch_model_upload_pytest.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,14 @@
)
]

TEXT_EMBEDDING_MODELS = [
(
"sentence-transformers/all-MiniLM-L6-v2",
"text_embedding",
"Paris is the capital of France.",
)
]


@pytest.fixture(scope="function", autouse=True)
def setup_and_tear_down():
Expand Down Expand Up @@ -94,8 +102,25 @@ def download_model_and_start_deployment(tmp_dir, quantize, model_id, task):

class TestPytorchModel:
@pytest.mark.parametrize("model_id,task,text_input,value", TEXT_PREDICTION_MODELS)
def test_text_classification(self, model_id, task, text_input, value):
def test_text_prediction(self, model_id, task, text_input, value):
with tempfile.TemporaryDirectory() as tmp_dir:
ptm = download_model_and_start_deployment(tmp_dir, True, model_id, task)
result = ptm.infer(docs=[{"text_field": text_input}])
assert result["predicted_value"] == value

@pytest.mark.parametrize("model_id,task,text_input", TEXT_EMBEDDING_MODELS)
def test_text_embedding(self, model_id, task, text_input):
with tempfile.TemporaryDirectory() as tmp_dir:
ptm = download_model_and_start_deployment(tmp_dir, True, model_id, task)
ptm.infer(docs=[{"text_field": text_input}])

if ES_VERSION >= (8, 8, 0):
configs = ES_TEST_CLIENT.ml.get_trained_models(model_id=model_id)
assert (
int(
configs["trained_model_configs"][0]["inference_config"][
"text_embedding"
]["embedding_size"]
)
> 0
)

0 comments on commit 6f4bc58

Please sign in to comment.