Skip to content

Commit

Permalink
feat: implement flair sklearn vectorizer wrappers
Browse files Browse the repository at this point in the history
  • Loading branch information
djaniak committed Nov 15, 2022
1 parent b1a97cd commit 7413318
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 10 deletions.
14 changes: 5 additions & 9 deletions embeddings/embedding/sklearn_embedding.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, Dict, Optional

import pandas as pd
import scipy
from sklearn.base import BaseEstimator as AnySklearnVectorizer

from embeddings.embedding.embedding import Embedding
Expand All @@ -9,21 +10,16 @@

class SklearnEmbedding(Embedding[ArrayLike, pd.DataFrame]):
def __init__(
self,
vectorizer: AnySklearnVectorizer,
vectorizer_has_sparse_output: bool = True,
vectorizer_kwargs: Optional[Dict[str, Any]] = None,
self, vectorizer: AnySklearnVectorizer, vectorizer_kwargs: Optional[Dict[str, Any]] = None
):
super().__init__()
self.vectorizer_kwargs = vectorizer_kwargs if vectorizer_kwargs else {}
self.vectorizer_has_sparse_output = vectorizer_has_sparse_output
self.vectorizer = vectorizer(**self.vectorizer_kwargs)
self.vectorizer = vectorizer(**vectorizer_kwargs if vectorizer_kwargs else {})

def fit(self, data: ArrayLike) -> None:
self.vectorizer.fit(data)

def embed(self, data: ArrayLike) -> pd.DataFrame:
embedded = self.vectorizer.transform(data)
if self.vectorizer_has_sparse_output:
if scipy.sparse.issparse(embedded):
embedded = embedded.A
return pd.DataFrame(embedded, columns=self.vectorizer.get_feature_names_out())
return pd.DataFrame(embedded)
Empty file.
46 changes: 46 additions & 0 deletions embeddings/embedding/vectorizer/flair.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import abc
from typing import Any, Dict, Generic, List, Optional, TypeVar

import numpy as np
from flair.data import Sentence
from numpy import typing as nptyping
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.feature_extraction.text import _VectorizerMixin

from embeddings.embedding.flair_embedding import FlairEmbedding
from embeddings.utils.array_like import ArrayLike

Output = TypeVar("Output")


# ignoring the mypy error due to no types (Any) in TransformerMixin and BaseEstimator classes
class FlairVectorizer(TransformerMixin, _VectorizerMixin, BaseEstimator, Generic[Output]): # type: ignore
def __init__(self, flair_embedding: FlairEmbedding) -> None:
self.embedder = flair_embedding

def fit(self, x: ArrayLike, y: Optional[ArrayLike] = None) -> None:
pass

@abc.abstractmethod
def transform(self, x: ArrayLike) -> Output:
pass

def fit_transform(self, x: ArrayLike, y: Optional[ArrayLike] = None, **kwargs: Any) -> Output:
return self.transform(x)


class FlairDocumentVectorizer(FlairVectorizer[nptyping.NDArray[np.float_]]):
def transform(self, x: ArrayLike) -> nptyping.NDArray[np.float_]:
sentences = [Sentence(example) for example in x]
embeddings = [sentence.embedding.numpy() for sentence in self.embedder.embed(sentences)]
return np.vstack(embeddings)


class FlairWordVectorizer(FlairVectorizer[List[List[Dict[int, float]]]]):
def transform(self, x: ArrayLike) -> List[List[Dict[int, float]]]:
sentences = [Sentence(example) for example in x]
embeddings = [sentence for sentence in self.embedder.embed(sentences)]
return [
[{i: value for i, value in enumerate(word.embedding.numpy())} for word in sent]
for sent in embeddings
]
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ module = [
"spacy",
"appdirs",
"dataset.arrow_dataset",
"seqeval.*"
"seqeval.*",
"scipy"
]
ignore_missing_imports = true

Expand Down

0 comments on commit 7413318

Please sign in to comment.