Skip to content

Commit

Permalink
fix parse_image_size
Browse files Browse the repository at this point in the history
  • Loading branch information
smedegaard committed Jun 11, 2021
1 parent f2b4480 commit e5dd6a5
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 6 deletions.
1 change: 1 addition & 0 deletions efficientdet/.#utils.py
14 changes: 9 additions & 5 deletions efficientdet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""Common utils."""
import contextlib
import os
from typing import Text, Tuple, Union
from typing import Text, Tuple, Union, Dict
from absl import logging
import numpy as np
import tensorflow.compat.v1 as tf
Expand Down Expand Up @@ -482,11 +482,12 @@ def archive_ckpt(ckpt_eval, ckpt_objective, ckpt_path):
return True


def parse_image_size(image_size: Union[Text, int, Tuple[int, int]]):
def parse_image_size(image_size: Union[Text, int, Tuple[int, int], Dict[int, int]]):
"""Parse the image size and return (height, width).
Args:
image_size: A integer, a tuple (H, W), or a string with HxW format.
image_size: A integer, a tuple (H, W), a string with HxW format,
or a dict with keys 'height' and 'width' with corresponding int values.
Returns:
A tuple of integer (height, width).
Expand All @@ -496,10 +497,13 @@ def parse_image_size(image_size: Union[Text, int, Tuple[int, int]]):
return (image_size, image_size)

if isinstance(image_size, str):
# image_size is a string with format WxH
width, height = image_size.lower().split('x')
# image_size is a string with format HxW
height, width = image_size.lower().split('x')
return (int(height), int(width))

if isinstance(image_size, dict):
return (image_size['height'], image_size['width'])

if isinstance(image_size, tuple):
return image_size

Expand Down
4 changes: 3 additions & 1 deletion efficientdet/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,11 @@ def test_archive_ckpt(self):
self.assertTrue(tf.io.gfile.exists(os.path.join(model_dir, 'backup')))

def test_image_size(self):
self.assertEqual(utils.parse_image_size('1280x640'), (640, 1280))
self.assertEqual(utils.parse_image_size('1280x640'), (1280,640))
self.assertEqual(utils.parse_image_size(1280), (1280, 1280))
self.assertEqual(utils.parse_image_size((1280, 640)), (1280, 640))
self.assertEqual(utils.parse_image_size({'width': 640, 'height': 1280}), (1280, 640))
self.assertEqual(utils.parse_image_size({'height': 3744, 'width': 5616}), (3744, 5616))

def test_get_feat_sizes(self):
feats = utils.get_feat_sizes(640, 2)
Expand Down

0 comments on commit e5dd6a5

Please sign in to comment.