Skip to content

Commit

Permalink
Fix distance imports in the models.
Browse files Browse the repository at this point in the history
  • Loading branch information
owenvallis committed Sep 3, 2023
1 parent 6584cbb commit 235e181
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 9 deletions.
9 changes: 5 additions & 4 deletions tensorflow_similarity/models/contrastive_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,12 @@
from termcolor import cprint
from tqdm.auto import tqdm

from tensorflow_similarity import distances
import tensorflow_similarity.distances
from tensorflow_similarity.classification_metrics import ( # noqa
ClassificationMetric,
make_classification_metric,
)
from tensorflow_similarity.distances import Distance
from tensorflow_similarity.evaluators.evaluator import Evaluator
from tensorflow_similarity.indexer import Indexer
from tensorflow_similarity.layers import ActivationStdLoggingLayer
Expand Down Expand Up @@ -189,7 +190,7 @@ def compile(
weighted_metrics: Metric | DistanceMetric | str | Mapping | Sequence | None = None, # noqa
run_eagerly: bool = False,
steps_per_execution: int = 1,
distance: distances.Distance | str = "cosine",
distance: Distance | str = "cosine",
kv_store: Store | str = "memory",
search: Search | str = "linear",
evaluator: Evaluator | str = "memory",
Expand Down Expand Up @@ -271,7 +272,7 @@ def compile(
ValueError: In case of invalid arguments for
`optimizer`, `loss` or `metrics`.
"""
distance_obj = distances.get(distance)
distance_obj = tensorflow_similarity.distances.get(distance)

# init index
self.create_index(
Expand Down Expand Up @@ -481,7 +482,7 @@ def predict(

def create_index(
self,
distance: distances.Distance | str = "cosine",
distance: Distance | str = "cosine",
search: Search | str = "linear",
kv_store: Store | str = "memory",
evaluator: Evaluator | str = "memory",
Expand Down
11 changes: 6 additions & 5 deletions tensorflow_similarity/models/similarity_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,12 @@
from tensorflow.keras.optimizers import Optimizer
from tqdm.auto import tqdm

from tensorflow_similarity import distances
import tensorflow_similarity.distances
from tensorflow_similarity.classification_metrics import (
ClassificationMetric,
make_classification_metric,
)
from tensorflow_similarity.distances import Distance
from tensorflow_similarity.evaluators.evaluator import Evaluator
from tensorflow_similarity.indexer import Indexer
from tensorflow_similarity.losses import MetricLoss
Expand Down Expand Up @@ -98,7 +99,7 @@ def compile(
weighted_metrics: Metric | DistanceMetric | str | Mapping | Sequence | None = None, # noqa
run_eagerly: bool = False,
steps_per_execution: int = 1,
distance: distances.Distance | str = "auto",
distance: Distance | str = "auto",
embedding_output: int | None = None,
kv_store: Store | str = "memory",
search: Search | str = "linear",
Expand Down Expand Up @@ -199,7 +200,7 @@ def compile(
# Fetching the distance used from the first loss if auto
if distance == "auto":
if loss is None:
distance = distances.get("cosine")
distance = tensorflow_similarity.distances.get("cosine")
else:
metric_loss = loss[0] if isinstance(loss, list) else loss

Expand All @@ -211,7 +212,7 @@ def compile(

print(f"Distance metric automatically set to {distance} use the " "distance arg to override.")
else:
distance = distances.get(distance)
distance = tensorflow_similarity.distances.get(distance)

# init index
self.create_index(
Expand Down Expand Up @@ -265,7 +266,7 @@ def _index(self, index: Indexer) -> None:

def create_index(
self,
distance: distances.Distance | str = "cosine",
distance: Distance | str = "cosine",
search: Search | str = "linear",
kv_store: Store | str = "memory",
evaluator: Evaluator | str = "memory",
Expand Down

0 comments on commit 235e181

Please sign in to comment.