Skip to content

Commit

Permalink
Fix formatting and modue import sort order.
Browse files Browse the repository at this point in the history
  • Loading branch information
owenvallis committed Aug 8, 2023
1 parent fc00864 commit 5eb932d
Show file tree
Hide file tree
Showing 13 changed files with 9 additions and 16 deletions.
2 changes: 1 addition & 1 deletion tensorflow_similarity/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "0.17.1"
__version__ = "0.17.2"


from . import algebra # noqa
Expand Down
2 changes: 0 additions & 2 deletions tensorflow_similarity/augmenters/barlow.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,6 @@ def augment(
num_augmentations_per_example: int = 2,
is_warmup: bool = True,
) -> list[Any]:

with tf.device("/cpu:0"):
if y is None:
y = tf.constant([0])
Expand Down Expand Up @@ -162,7 +161,6 @@ def augment(
solarize_thresh=self.solarize_thresh,
)
for _ in range(num_augmentations_per_example):

view = tf.map_fn(
lambda img: augment_fn(image=img),
inputs,
Expand Down
2 changes: 0 additions & 2 deletions tensorflow_similarity/augmenters/simclr.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@ def __init__(
version: str = "v2",
num_cpu: int | None = os.cpu_count(),
):

self.width = width
self.height = height
self.is_training = is_training
Expand All @@ -132,7 +131,6 @@ def __init__(

@tf.function
def augment(self, x: Tensor, y: Tensor, num_views: int, is_warmup: bool) -> list[Tensor]:

with tf.device("/cpu:0"):
inputs = tf.stack(x)
inputs = tf.cast(inputs, dtype="float32") / 255.0
Expand Down
1 change: 0 additions & 1 deletion tensorflow_similarity/evaluators/memory_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,6 @@ def _optimal_cutpoint(
"value": metrics[calibration_metric.name][idx].item(),
}
for metric_name in metrics.keys():

optimal_cp[metric_name] = metrics[metric_name][idx].item()

return optimal_cp
Expand Down
1 change: 0 additions & 1 deletion tensorflow_similarity/indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,6 @@ def single_lookup(self, prediction: FloatTensor, k: int = 5) -> list[Lookup]:
return lookups

def batch_lookup(self, predictions: FloatTensor, k: int = 5, verbose: int = 1) -> list[list[Lookup]]:

"""Find the k closest matches for a set of embeddings
Args:
Expand Down
2 changes: 1 addition & 1 deletion tensorflow_similarity/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"""
from .barlow import Barlow # noqa
from .circle_loss import CircleLoss # noqa
from .lifted_structure_loss import LiftedStructLoss # noqa
from .metric_loss import MetricLoss # noqa
from .multinegrank_loss import MultiNegativesRankLoss # noqa
from .multisim_loss import MultiSimilarityLoss # noqa
Expand All @@ -25,6 +26,5 @@
from .simsiam import SimSiamLoss # noqa
from .softnn_loss import SoftNearestNeighborLoss # noqa
from .triplet_loss import TripletLoss # noqa
from .lifted_structure_loss import LiftedStructLoss # noqa
from .vicreg import VicReg # noqa
from .xbm_loss import XBM # noqa
5 changes: 4 additions & 1 deletion tensorflow_similarity/losses/lifted_structure_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@
https://arxiv.org/abs/1511.06452
"""
from __future__ import annotations

import tensorflow as tf

from tensorflow_similarity import losses as tfsim_losses
from tensorflow_similarity.algebra import build_masks
from tensorflow_similarity.distances import Distance, distance_canonicalizer
from tensorflow_similarity.types import FloatTensor, IntTensor
from tensorflow_similarity import losses as tfsim_losses

from .metric_loss import MetricLoss
from .utils import positive_distances

Expand Down
1 change: 0 additions & 1 deletion tensorflow_similarity/losses/simclr.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def __init__(self, temperature: float = 0.05, **kwargs):
self.temperature = temperature

def contrast(self, hidden1: FloatTensor, hidden2: FloatTensor) -> FloatTensor:

# local replica batch size
batch_size = tf.shape(hidden1)[0]

Expand Down
1 change: 0 additions & 1 deletion tensorflow_similarity/matchers/match_majority_vote.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ class MatchMajorityVote(ClassificationMatch):
"""Match metrics for the most common label in a result set."""

def __init__(self, name: str = "majority_vote", **kwargs) -> None:

if "canonical_name" not in kwargs:
kwargs["canonical_name"] = "match_majority_vote"

Expand Down
1 change: 0 additions & 1 deletion tensorflow_similarity/matchers/match_nearest.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ class MatchNearest(ClassificationMatch):
"""Match metrics for labels at k=1."""

def __init__(self, name: str = "nearest", **kwargs) -> None:

if "canonical_name" not in kwargs:
kwargs["canonical_name"] = "match_nearest"

Expand Down
2 changes: 0 additions & 2 deletions tensorflow_similarity/training_metrics/distance_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ def __init__(
negative_mining_strategy: str = "hard",
**kwargs,
):

if not name:
name = "%s_%s" % (aggregate, anchor[:3])
super().__init__(name=name, **kwargs)
Expand All @@ -63,7 +62,6 @@ def __init__(
self.aggregated_distances = tf.Variable(0, dtype=tf.keras.backend.floatx())

def update_state(self, labels: IntTensor, embeddings: FloatTensor, sample_weight: FloatTensor) -> None:

# [distances]
pairwise_distances = self.distance(embeddings, embeddings)

Expand Down
1 change: 0 additions & 1 deletion tensorflow_similarity/visualization/vizualize_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ def visualize_views(
# Plot the images
fig, axes = plt.subplots(num_row, num_col, figsize=fig_size)
for i in range(num_imgs):

# If the number of rows is 1, the axes array is one-dimensional
if num_row == 1:
ax = axes[i % num_col]
Expand Down
4 changes: 3 additions & 1 deletion tests/losses/test_lifted_structure_loss.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import tensorflow as tf
from absl.testing import parameterized
from tensorflow.python.framework import combinations
from tensorflow.keras.losses import Reduction
from tensorflow.python.framework import combinations

from tensorflow_similarity import losses

from . import utils


Expand Down

0 comments on commit 5eb932d

Please sign in to comment.