From caf8716b57d9cab0053b45f0defc3b2bb8088c3e Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Wed, 25 Oct 2023 18:25:04 -0700 Subject: [PATCH] Fix for embed-multi bug, closes #3 --- llm_embed_jina.py | 2 +- tests/test_embed_jina.py | 22 ++++++++++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/llm_embed_jina.py b/llm_embed_jina.py index 77fefc2..21c94fb 100644 --- a/llm_embed_jina.py +++ b/llm_embed_jina.py @@ -22,5 +22,5 @@ def embed_batch(self, texts): self._model = AutoModel.from_pretrained( "jinaai/{}".format(self.model_id), trust_remote_code=True ) - results = self._model.encode(texts) + results = self._model.encode(list(texts)) return (list(map(float, result)) for result in results) diff --git a/tests/test_embed_jina.py b/tests/test_embed_jina.py index 76b8ed8..085922d 100644 --- a/tests/test_embed_jina.py +++ b/tests/test_embed_jina.py @@ -1,3 +1,5 @@ +from click.testing import CliRunner +from llm.cli import cli import llm @@ -6,3 +8,23 @@ def test_jina_embed_small(): floats = model.embed("hello world") assert len(floats) == 512 assert all(isinstance(f, float) for f in floats) + + +def test_jina_embed_multi(tmpdir): + db_path = str(tmpdir / "test.db") + runner = CliRunner() + result = runner.invoke( + cli, + [ + "embed-multi", + "-m", + "jina-embeddings-v2-small-en", + "test", + "-", + "-d", + db_path, + ], + input='[{"id": "a", "text": "abc"}]', + catch_exceptions=False, + ) + assert result.exit_code == 0