Skip to content

Commit

Permalink
Fix CI Test for Basnet OOM and PyCoCo Test Failure for JAX (#2322)
Browse files Browse the repository at this point in the history
  • Loading branch information
sampathweb committed Jan 31, 2024
1 parent a3fd57b commit fd56980
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 7 deletions.
4 changes: 2 additions & 2 deletions .kokoro/github/ubuntu/gpu/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ pip install --no-deps -e "." --progress-bar off
# Run Extra Large Tests for Continuous builds
if [ "${RUN_XLARGE:-0}" == "1" ]
then
pytest --check_gpu --run_large --run_extra_large --durations 0 \
pytest --cache-clear --check_gpu --run_large --run_extra_large --durations 0 \
keras_cv/bounding_box \
keras_cv/callbacks \
keras_cv/losses \
Expand All @@ -65,7 +65,7 @@ then
keras_cv/models/segmentation \
keras_cv/models/stable_diffusion
else
pytest --check_gpu --run_large --durations 0 \
pytest --cache-clear --check_gpu --run_large --durations 0 \
keras_cv/bounding_box \
keras_cv/callbacks \
keras_cv/losses \
Expand Down
3 changes: 3 additions & 0 deletions keras_cv/metrics/coco/pycoco_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,9 @@ def _convert_predictions_to_coco_annotations(predictions):
num_batches = len(predictions["source_id"])
for i in range(num_batches):
batch_size = predictions["source_id"][i].shape[0]
predictions["detection_boxes"][i] = predictions["detection_boxes"][
i
].copy()
for j in range(batch_size):
max_num_detections = predictions["num_detections"][i][j]
predictions["detection_boxes"][i][j] = _yxyx_to_xywh(
Expand Down
14 changes: 9 additions & 5 deletions keras_cv/models/segmentation/basnet/basnet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import gc
import os

import numpy as np
Expand All @@ -23,13 +24,13 @@
from keras_cv.backend import ops
from keras_cv.backend.config import keras_3
from keras_cv.models import BASNet
from keras_cv.models import ResNet34Backbone
from keras_cv.models import ResNet18Backbone
from keras_cv.tests.test_case import TestCase


class BASNetTest(TestCase):
def test_basnet_construction(self):
backbone = ResNet34Backbone()
backbone = ResNet18Backbone()
model = BASNet(
input_shape=[288, 288, 3], backbone=backbone, num_classes=1
)
Expand All @@ -41,7 +42,7 @@ def test_basnet_construction(self):

@pytest.mark.large
def test_basnet_call(self):
backbone = ResNet34Backbone()
backbone = ResNet18Backbone()
model = BASNet(
input_shape=[288, 288, 3], backbone=backbone, num_classes=1
)
Expand All @@ -61,7 +62,7 @@ def test_weights_change(self):
ds = ds.repeat(2)
ds = ds.batch(2)

backbone = ResNet34Backbone()
backbone = ResNet18Backbone()
model = BASNet(
input_shape=[288, 288, 3], backbone=backbone, num_classes=1
)
Expand Down Expand Up @@ -99,7 +100,7 @@ def test_with_model_preset_forward_pass(self):
def test_saved_model(self):
target_size = [288, 288, 3]

backbone = ResNet34Backbone()
backbone = ResNet18Backbone()
model = BASNet(
input_shape=[288, 288, 3], backbone=backbone, num_classes=1
)
Expand All @@ -112,6 +113,9 @@ def test_saved_model(self):
model.save(save_path)
else:
model.save(save_path, save_format="keras_v3")
# Free up model memory
del model
gc.collect()
restored_model = keras.models.load_model(save_path)

# Check we got the real object back.
Expand Down

0 comments on commit fd56980

Please sign in to comment.