Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Run mypy tests on commit #56

Merged
merged 3 commits into from
Aug 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@

name: check

on:
pull_request:
push:
branches:
- main

jobs:
check:
name: Style-check ${{ matrix.python-version }}
runs-on: ubuntu-latest
strategy:
matrix:
# Only lint on the min and max supported Python versions.
# It's extremely unlikely that there's a lint issue on any version in between
# that doesn't show up on the min or max versions.
#
# GitHub rate-limits how many jobs can be running at any one time.
# Starting new jobs is also relatively slow,
# so linting on fewer versions makes CI faster.
python-version:
- "3.8"
- "3.11"

steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[dev,all]

- name: check-sort-import
run: |
make check-sort-imports

- name: check-black-format
run: |
make check-format

- name: check-mypy
run: |
make mypy
25 changes: 15 additions & 10 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -19,27 +19,32 @@ help:
# help: Style
# help: -------

# help: style - Sort imports and format with black
.PHONY: style
style: sort-imports format


# help: check-style - check code style compliance
.PHONY: check-style
check-style: check-sort-imports check-format

# help: check - run all checks for a commit
.PHONY: check
check: check-format check-sort-imports mypy

# help: format - perform code style format
.PHONY: format
format:
format: sort-imports
@black ./redisvl ./tests/


# help: check-format - check code format compliance
.PHONY: check-format
check-format:
@black --check ./redisvl


# help: sort-imports - apply import sort ordering
.PHONY: sort-imports
sort-imports:
@isort ./redisvl ./tests/ --profile black

# help: check-sort-imports - check imports are sorted
.PHONY: check-sort-imports
check-sort-imports:
@isort ./redisvl --check-only --profile black


# help: check-lint - run static analysis checks
.PHONY: check-lint
Expand Down
18 changes: 9 additions & 9 deletions redisvl/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(
self,
name: str,
prefix: str = "rvl",
storage_type: Optional[str] = "hash",
storage_type: str = "hash",
fields: Optional[List["Field"]] = None,
):
self._name = name
Expand Down Expand Up @@ -62,7 +62,7 @@ def search(self, *args, **kwargs) -> List["Result"]:
Returns:
List[Result]: A list of search results
"""
results: List["Result"] = self._redis_conn.ft(self._name).search(
results: List["Result"] = self._redis_conn.ft(self._name).search( # type: ignore
*args, **kwargs
)
return results
Expand Down Expand Up @@ -148,7 +148,7 @@ def disconnect(self):
"""Disconnect from the Redis instance"""
self._redis_conn = None

def _get_key(self, record: Dict[str, Any], key_field: str = None) -> str:
def _get_key(self, record: Dict[str, Any], key_field: Optional[str] = None) -> str:
"""Construct the Redis HASH top level key.

Args:
Expand Down Expand Up @@ -236,7 +236,7 @@ def __init__(
self,
name: str,
prefix: str = "rvl",
storage_type: Optional[str] = "hash",
storage_type: str = "hash",
fields: Optional[List["Field"]] = None,
):
super().__init__(name, prefix, storage_type, fields)
Expand Down Expand Up @@ -313,7 +313,7 @@ def create(self, overwrite: Optional[bool] = False):
# set storage_type, default to hash
storage_type = IndexType.HASH
if self._storage.lower() == "json":
self._storage = IndexType.JSON
storage_type = IndexType.JSON

# Create Index
# will raise correct response error if index already exists
Expand Down Expand Up @@ -358,7 +358,7 @@ def load(

# Check if outer interface passes in TTL on load
ttl = kwargs.get("ttl")
with self._redis_conn.pipeline(transaction=False) as pipe:
with self._redis_conn.pipeline(transaction=False) as pipe: # type: ignore
for record in data:
key = self._get_key(record, key_field)
pipe.hset(key, mapping=record) # type: ignore
Expand Down Expand Up @@ -394,7 +394,7 @@ def __init__(
self,
name: str,
prefix: str = "rvl",
storage_type: Optional[str] = "hash",
storage_type: str = "hash",
fields: Optional[List["Field"]] = None,
):
super().__init__(name, prefix, storage_type, fields)
Expand Down Expand Up @@ -467,7 +467,7 @@ async def create(self, overwrite: Optional[bool] = False):
# set storage_type, default to hash
storage_type = IndexType.HASH
if self._storage.lower() == "json":
self._storage = IndexType.JSON
storage_type = IndexType.JSON

# Create Index
await self._redis_conn.ft(self._name).create_index( # type: ignore
Expand Down Expand Up @@ -516,7 +516,7 @@ async def _load(record: dict):
key = self._get_key(record, key_field)
await self._redis_conn.hset(key, mapping=record) # type: ignore
if ttl:
await self._redis_conn.expire(key, ttl)
await self._redis_conn.expire(key, ttl) # type: ignore

# gather with concurrency
await asyncio.gather(*[_load(record) for record in data])
Expand Down
26 changes: 13 additions & 13 deletions redisvl/llmcache/semantic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import List, Optional, Union

from redis.commands.search.field import VectorField
from redis.commands.search.field import Field, VectorField

from redisvl.index import SearchIndex
from redisvl.llmcache.base import BaseLLMCache
Expand All @@ -14,8 +14,8 @@ class SemanticCache(BaseLLMCache):
"""Cache for Large Language Models."""

# TODO allow for user to change default fields
_vector_field_name = "prompt_vector"
_default_fields = [
_vector_field_name: str = "prompt_vector"
_default_fields: List[Field] = [
VectorField(
_vector_field_name,
"FLAT",
Expand All @@ -25,27 +25,27 @@ class SemanticCache(BaseLLMCache):

def __init__(
self,
index_name: Optional[str] = "cache",
prefix: Optional[str] = "llmcache",
threshold: Optional[float] = 0.9,
index_name: str = "cache",
prefix: str = "llmcache",
threshold: float = 0.9,
ttl: Optional[int] = None,
vectorizer: Optional[BaseVectorizer] = HFTextVectorizer(
vectorizer: BaseVectorizer = HFTextVectorizer(
"sentence-transformers/all-mpnet-base-v2"
),
redis_url: Optional[str] = "redis://localhost:6379",
redis_url: str = "redis://localhost:6379",
connection_args: Optional[dict] = None,
index: Optional[SearchIndex] = None,
):
"""Semantic Cache for Large Language Models.

Args:
index_name (Optional[str], optional): The name of the index. Defaults to "cache".
prefix (Optional[str], optional): The prefix for the index. Defaults to "llmcache".
threshold (Optional[float], optional): Semantic threshold for the cache. Defaults to 0.9.
index_name (str, optional): The name of the index. Defaults to "cache".
prefix (str, optional): The prefix for the index. Defaults to "llmcache".
threshold (float, optional): Semantic threshold for the cache. Defaults to 0.9.
ttl (Optional[int], optional): The TTL for the cache. Defaults to None.
vectorizer (Optional[BaseVectorizer], optional): The vectorizer for the cache.
vectorizer (BaseVectorizer, optional): The vectorizer for the cache.
Defaults to HFTextVectorizer("sentence-transformers/all-mpnet-base-v2").
redis_url (Optional[str], optional): The redis url. Defaults to "redis://localhost:6379".
redis_url (str, optional): The redis url. Defaults to "redis://localhost:6379".
connection_args (Optional[dict], optional): The connection arguments for the redis client. Defaults to None.
index (Optional[SearchIndex], optional): The underlying search index to use for the semantic cache. Defaults to None.

Expand Down
2 changes: 1 addition & 1 deletion redisvl/query/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from redisvl.query.query import FilterQuery, VectorQuery, RangeQuery
from redisvl.query.query import FilterQuery, RangeQuery, VectorQuery

__all__ = ["VectorQuery", "FilterQuery", "RangeQuery"]
Loading
Loading