From 235e1812cf41522218939435d89ae9b02fa922bd Mon Sep 17 00:00:00 2001 From: Owen Vallis Date: Sun, 3 Sep 2023 23:14:43 +0000 Subject: [PATCH] Fix distance imports in the models. --- tensorflow_similarity/models/contrastive_model.py | 9 +++++---- tensorflow_similarity/models/similarity_model.py | 11 ++++++----- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/tensorflow_similarity/models/contrastive_model.py b/tensorflow_similarity/models/contrastive_model.py index 0ac24745..2e1a81a3 100644 --- a/tensorflow_similarity/models/contrastive_model.py +++ b/tensorflow_similarity/models/contrastive_model.py @@ -28,11 +28,12 @@ from termcolor import cprint from tqdm.auto import tqdm -from tensorflow_similarity import distances +import tensorflow_similarity.distances from tensorflow_similarity.classification_metrics import ( # noqa ClassificationMetric, make_classification_metric, ) +from tensorflow_similarity.distances import Distance from tensorflow_similarity.evaluators.evaluator import Evaluator from tensorflow_similarity.indexer import Indexer from tensorflow_similarity.layers import ActivationStdLoggingLayer @@ -189,7 +190,7 @@ def compile( weighted_metrics: Metric | DistanceMetric | str | Mapping | Sequence | None = None, # noqa run_eagerly: bool = False, steps_per_execution: int = 1, - distance: distances.Distance | str = "cosine", + distance: Distance | str = "cosine", kv_store: Store | str = "memory", search: Search | str = "linear", evaluator: Evaluator | str = "memory", @@ -271,7 +272,7 @@ def compile( ValueError: In case of invalid arguments for `optimizer`, `loss` or `metrics`. """ - distance_obj = distances.get(distance) + distance_obj = tensorflow_similarity.distances.get(distance) # init index self.create_index( @@ -481,7 +482,7 @@ def predict( def create_index( self, - distance: distances.Distance | str = "cosine", + distance: Distance | str = "cosine", search: Search | str = "linear", kv_store: Store | str = "memory", evaluator: Evaluator | str = "memory", diff --git a/tensorflow_similarity/models/similarity_model.py b/tensorflow_similarity/models/similarity_model.py index bc8c6662..fe84387a 100644 --- a/tensorflow_similarity/models/similarity_model.py +++ b/tensorflow_similarity/models/similarity_model.py @@ -53,11 +53,12 @@ from tensorflow.keras.optimizers import Optimizer from tqdm.auto import tqdm -from tensorflow_similarity import distances +import tensorflow_similarity.distances from tensorflow_similarity.classification_metrics import ( ClassificationMetric, make_classification_metric, ) +from tensorflow_similarity.distances import Distance from tensorflow_similarity.evaluators.evaluator import Evaluator from tensorflow_similarity.indexer import Indexer from tensorflow_similarity.losses import MetricLoss @@ -98,7 +99,7 @@ def compile( weighted_metrics: Metric | DistanceMetric | str | Mapping | Sequence | None = None, # noqa run_eagerly: bool = False, steps_per_execution: int = 1, - distance: distances.Distance | str = "auto", + distance: Distance | str = "auto", embedding_output: int | None = None, kv_store: Store | str = "memory", search: Search | str = "linear", @@ -199,7 +200,7 @@ def compile( # Fetching the distance used from the first loss if auto if distance == "auto": if loss is None: - distance = distances.get("cosine") + distance = tensorflow_similarity.distances.get("cosine") else: metric_loss = loss[0] if isinstance(loss, list) else loss @@ -211,7 +212,7 @@ def compile( print(f"Distance metric automatically set to {distance} use the " "distance arg to override.") else: - distance = distances.get(distance) + distance = tensorflow_similarity.distances.get(distance) # init index self.create_index( @@ -265,7 +266,7 @@ def _index(self, index: Indexer) -> None: def create_index( self, - distance: distances.Distance | str = "cosine", + distance: Distance | str = "cosine", search: Search | str = "linear", kv_store: Store | str = "memory", evaluator: Evaluator | str = "memory",