Skip to content

Commit

Permalink
type fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
cmacdonald committed Sep 18, 2024
1 parent f544d2f commit ae9bc59
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions pyterrier/transformer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import types
from matchpy import Wildcard, Symbol, Operation, Arity
from matchpy import Wildcard, Symbol, Operation, Arity, ReplacementRule
from warnings import warn
import pandas as pd
from deprecated import deprecated
from typing import Iterator, List, Union
from typing import Iterator, List, Union, Tuple
import pyterrier as pt
from . import __version__

Expand Down Expand Up @@ -42,7 +42,7 @@ def get_transformer(v, stacklevel=1):
return SourceTransformer(v)
raise ValueError("Passed parameter %s of type %s cannot be coerced into a transformer" % (str(v), type(v)))

rewrite_rules = []
rewrite_rules : List[ReplacementRule] = []


class Scalar(Symbol):
Expand Down Expand Up @@ -173,7 +173,7 @@ def __call__(self, inp: Union[pd.DataFrame, pt.model.IterDict, List[pt.model.Ite
return list(out)
return out

def transform_gen(self, input : pd.DataFrame, batch_size=1, output_topics=False) -> Iterator[pd.DataFrame]:
def transform_gen(self, input : pd.DataFrame, batch_size=1, output_topics=False) -> Union[Iterator[pd.DataFrame], Iterator[Tuple[pd.DataFrame, pd.DataFrame]]]:
"""
Method for executing a transformer pipeline on smaller batches of queries.
The input dataframe is grouped into batches of batch_size queries, and a generator
Expand All @@ -191,7 +191,7 @@ def transform_gen(self, input : pd.DataFrame, batch_size=1, output_topics=False)
queries = input[["qid"]].drop_duplicates()
else:
queries = input
batch=[]
batch : List[pd.DataFrame] = []
for query in queries.itertuples():
if len(batch) == batch_size:
batch_topics = pd.concat(batch)
Expand Down Expand Up @@ -366,7 +366,7 @@ def index(self, iter : pt.model.IterDict, **kwargs):
def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
raise NotImplementedError('You called `transform()` on an indexer. Did you mean to call `index()`?')

def transform_iter(self, inp: pd.DataFrame) -> pd.DataFrame:
def transform_iter(self, inp: pt.model.IterDict) -> pt.model.IterDict:
raise NotImplementedError('You called `transform_iter()` on an indexer. Did you mean to call `index()`?')

class IterDictIndexerBase(Indexer):
Expand Down

0 comments on commit ae9bc59

Please sign in to comment.