Skip to content

Commit

Permalink
feat: add async rerank (#701)
Browse files Browse the repository at this point in the history
  • Loading branch information
hanxiao authored Apr 30, 2022
1 parent 12d33c4 commit 33efcb0
Show file tree
Hide file tree
Showing 8 changed files with 57 additions and 23 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -482,16 +482,16 @@ Fun time! Note, unlike the previous example, here the input is an image and the
</table>


### Rerank image-text matches via CLIP model
### Rank image-text matches via CLIP model

From `0.3.0` CLIP-as-service adds a new `/rerank` endpoint that re-ranks cross-modal matches according to their joint likelihood in CLIP model. For example, given an image Document with some predefined sentence matches as below:
From `0.3.0` CLIP-as-service adds a new `/rank` endpoint that re-ranks cross-modal matches according to their joint likelihood in CLIP model. For example, given an image Document with some predefined sentence matches as below:

```python
from clip_client import Client
from docarray import Document

c = Client(server='grpc://demo-cas.jina.ai:51000')
r = c.rerank(
r = c.rank(
[
Document(
uri='.github/README-img/rerank.png',
Expand Down
2 changes: 1 addition & 1 deletion client/clip_client/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = '0.3.6'
__version__ = '0.4.0'

import os

Expand Down
15 changes: 8 additions & 7 deletions client/clip_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def _prepare_single_doc(d: 'Document'):
def _prepare_rank_doc(d: 'Document', _source: str = 'matches'):
_get = lambda d: getattr(d, _source)
if not _get(d):
raise ValueError(f'`.rerank()` requires every doc to have `.{_source}`')
raise ValueError(f'`.rank()` requires every doc to have `.{_source}`')
d = Client._prepare_single_doc(d)
setattr(d, _source, [Client._prepare_single_doc(c) for c in _get(d)])
return d
Expand Down Expand Up @@ -367,25 +367,25 @@ def _iter_rank_docs(

def _get_rank_payload(self, content, kwargs):
return dict(
on='/rerank',
on='/rank',
inputs=self._iter_rank_docs(
content, _source=kwargs.get('source', 'matches')
),
request_size=kwargs.get('batch_size', 8),
total_docs=len(content) if hasattr(content, '__len__') else None,
)

def rerank(self, docs: Iterable['Document'], **kwargs) -> 'DocumentArray':
"""Rerank image-text matches according to the server CLIP model.
def rank(self, docs: Iterable['Document'], **kwargs) -> 'DocumentArray':
"""Rank image-text matches according to the server CLIP model.
Given a Document with nested matches, where the root is image/text and the matches is in another modality, i.e.
text/image; this method reranks the matches according to the CLIP model.
text/image; this method ranks the matches according to the CLIP model.
Each match now has a new score inside ``clip_score`` and matches are sorted descendingly according to this score.
More details can be found in: https://github.com/openai/CLIP#usage
:param docs: the input Documents
:return: the reranked Documents in a DocumentArray.
:return: the ranked Documents in a DocumentArray.
"""
self._prepare_streaming(
Expand All @@ -398,7 +398,8 @@ def rerank(self, docs: Iterable['Document'], **kwargs) -> 'DocumentArray':
)
return self._results

async def arerank(self, docs: Iterable['Document'], **kwargs) -> 'DocumentArray':
async def arank(self, docs: Iterable['Document'], **kwargs) -> 'DocumentArray':
from rich import filesize

self._prepare_streaming(
not kwargs.get('show_progress'),
Expand Down
4 changes: 4 additions & 0 deletions docs/changelog/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ CLIP-as-service follows semantic versioning. However, before the project reach 1

This chapter only tracks the most important breaking changes and explain the rationale behind them.

# 0.4.0: rename `rerank` concept to `rank`

"Reranking" is a new feature introduced since 0.3.3. This feature allows user to rank and score `document.matches` in a cross-modal way. From 0.4.0, this feature as well as all related functions will refer it simply as "rank".

## 0.2.0: improve the service scalability with replicas

This change is mainly intended to improve the inference performance with replicas.
Expand Down
12 changes: 6 additions & 6 deletions docs/user-guides/client.md
Original file line number Diff line number Diff line change
Expand Up @@ -255,15 +255,15 @@ asyncio.run(main())

The final time cost will be less than `3s + time(t2)`.

## Reranking
## Ranking

```{tip}
This feature is only available with `clip_server>=0.3.0` and the server is running with PyTorch backend.
```

One can also rerank cross-modal matches via {meth}`~clip_client.client.Client.rerank`. First construct a cross-modal Document where the root contains an image and `.matches` contain sentences to rerank. One can also construct text-to-image rerank as below:
One can also rank cross-modal matches via {meth}`~clip_client.client.Client.rank` or {meth}`~clip_client.client.Client.arank`. First construct a cross-modal Document where the root contains an image and `.matches` contain sentences to rerank. One can also construct text-to-image rerank as below:

````{tab} Given image, rerank sentences
````{tab} Given image, rank sentences
```python
from docarray import Document
Expand All @@ -285,7 +285,7 @@ d = Document(
````

````{tab} Given sentence, rerank images
````{tab} Given sentence, rank images
```python
from docarray import Document
Expand All @@ -304,13 +304,13 @@ d = Document(



Then call `rerank`, you can feed it with multiple Documents as a list:
Then call `rank`, you can feed it with multiple Documents as a list:

```python
from clip_client import Client

c = Client(server='grpc://demo-cas.jina.ai:51000')
r = c.rerank([d])
r = c.rank([d])

print(r['@m', ['text', 'scores__clip_score__value']])
```
Expand Down
2 changes: 1 addition & 1 deletion server/clip_server/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.3.6'
__version__ = '0.4.0'
4 changes: 2 additions & 2 deletions server/clip_server/executors/clip_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ def _split_img_txt_da(d, _img_da, _txt_da):
elif d.uri:
_img_da.append(d)

@requests(on='/rerank')
async def rerank(self, docs: 'DocumentArray', parameters: Dict, **kwargs):
@requests(on='/rank')
async def rank(self, docs: 'DocumentArray', parameters: Dict, **kwargs):
import torch

_source = parameters.get('source', 'matches')
Expand Down
35 changes: 32 additions & 3 deletions tests/test_ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ async def test_torch_executor_rank_img2texts():
d.matches.append(Document(text='hello, world!'))
d.matches.append(Document(text='goodbye, world!'))

await ce.rerank(da, {})
await ce.rank(da, {})
print(da['@m', 'scores__clip_score__value'])
for d in da:
for c in d.matches:
Expand All @@ -36,7 +36,7 @@ async def test_torch_executor_rank_text2imgs():
f'{os.path.dirname(os.path.abspath(__file__))}/**/*.jpg'
)
)
await ce.rerank(db, {})
await ce.rank(db, {})
print(db['@m', 'scores__clip_score__value'])
for d in db:
for c in d.matches:
Expand All @@ -63,7 +63,36 @@ async def test_torch_executor_rank_text2imgs():
)
def test_docarray_inputs(make_torch_flow, d):
c = Client(server=f'grpc://0.0.0.0:{make_torch_flow.port}')
r = c.rerank([d])
r = c.rank([d])
assert isinstance(r, DocumentArray)
rv = r['@m', 'scores__clip_score__value']
for v in rv:
assert v is not None
assert v > 0


@pytest.mark.parametrize(
'd',
[
Document(
uri='https://docarray.jina.ai/_static/favicon.png',
matches=[Document(text='hello, world'), Document(text='goodbye, world')],
),
Document(
text='hello, world',
matches=[
Document(uri='https://docarray.jina.ai/_static/favicon.png'),
Document(
uri=f'{os.path.dirname(os.path.abspath(__file__))}/img/00000.jpg'
),
],
),
],
)
@pytest.mark.asyncio
async def test_async_arank(make_torch_flow, d):
c = Client(server=f'grpc://0.0.0.0:{make_torch_flow.port}')
r = await c.arank([d])
assert isinstance(r, DocumentArray)
rv = r['@m', 'scores__clip_score__value']
for v in rv:
Expand Down

0 comments on commit 33efcb0

Please sign in to comment.