diff --git a/efficientdet/keras/README.md b/efficientdet/keras/README.md index e984fac42..0c05f94f1 100644 --- a/efficientdet/keras/README.md +++ b/efficientdet/keras/README.md @@ -3,16 +3,6 @@ [1] Mingxing Tan, Ruoming Pang, Quoc V. Le. EfficientDet: Scalable and Efficient Object Detection. CVPR 2020. Arxiv link: https://arxiv.org/abs/1911.09070 -Updates: - - - **Jul20: Added keras/TF2 and new SOTA D7x: 55.1mAP with 153ms** - - Apr22: Sped up end-to-end latency: D0 has up to >200 FPS throughput on Tesla V100. - * A great collaboration with [@fsx950223](https://github.com/fsx950223). - - Apr1: Updated results for test-dev and added EfficientDet-D7. - - Mar26: Fixed a few bugs and updated all checkpoints/results. - - Mar24: Added tutorial with visualization and coco eval. - - Mar 13: Released the initial code and models. - **Quick start tutorial: [tutorial.ipynb](tutorial.ipynb)** **Quick install dependencies: ```pip install -r requirements.txt```** @@ -25,7 +15,7 @@ EfficientDets are a family of object detection models, which achieve state-of-th EfficientDets are developed based on the advanced backbone, a new BiFPN, and a new scaling technique:

- +

* **Backbone**: we employ [EfficientNets](https://arxiv.org/abs/1905.11946) as our backbone networks. @@ -38,10 +28,10 @@ Our model family starts from EfficientDet-D0, which has comparable accuracy as [
- + - +
@@ -56,15 +46,15 @@ We have provided a list of EfficientDet checkpoints and results as follows: | Model | APtest | AP50 | AP75 |APS | APM | APL | APval | | #params | #FLOPs | |---------- |------ |------ |------ | -------- | ------| ------| ------ |------ |------ | :------: | -| EfficientDet-D0 ([h5](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d0.h5), [ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d0.tar.gz), [val](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/val/d0_coco_val.txt), [test-dev](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/testdev/d0_coco_test-dev2017.txt)) | 34.6 | 53.0 | 37.1 | 12.4 | 39.0 | 52.7 | 34.3 | | 3.9M | 2.54B | -| EfficientDet-D1 ([h5](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d1.h5), [ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d1.tar.gz), [val](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/val/d1_coco_val.txt), [test-dev](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/testdev/d1_coco_test-dev2017.txt)) | 40.5 | 59.1 | 43.7 | 18.3 | 45.0 | 57.5 | 40.2 | | 6.6M | 6.10B | -| EfficientDet-D2 ([h5](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d2.h5), [ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d2.tar.gz), [val](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/val/d2_coco_val.txt), [test-dev](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/testdev/d2_coco_test-dev2017.txt)) | 43.0 | 62.3 | 46.2 | 22.5 | 47.0 | 58.4 | 42.5 | | 8.1M | 11.0B | -| EfficientDet-D3 ([h5](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d3.h5), [ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d3.tar.gz), [val](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/val/d3_coco_val.txt), [test-dev](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/testdev/d3_coco_test-dev2017.txt)) | 47.5 | 66.2 | 51.5 | 27.9 | 51.4 | 62.0 | 47.2 | | 12.0M | 24.9B | -| EfficientDet-D4 ([h5](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d4.h5), [ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d4.tar.gz), [val](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/val/d4_coco_val.txt), [test-dev](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/testdev/d4_coco_test-dev2017.txt)) | 49.7 | 68.4 | 53.9 | 30.7 | 53.2 | 63.2 | 49.3 | | 20.7M | 55.2B | -| EfficientDet-D5 ([h5](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d5.h5), [ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d5.tar.gz), [val](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/val/d5_coco_val.txt), [test-dev](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/testdev/d5_coco_test-dev2017.txt)) | 51.5 | 70.5 | 56.1 | 33.9 | 54.7 | 64.1 | 51.2 | | 33.7M | 130B | -| EfficientDet-D6 ([h5](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d6.h5), [ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d6.tar.gz), [val](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/val/d6_coco_val.txt), [test-dev](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/testdev/d6_coco_test-dev2017.txt)) | 52.6 | 71.5 | 57.2 | 34.9 | 56.0 | 65.4 | 52.1 | | 51.9M | 226B | -| EfficientDet-D7 ([h5](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d7.h5), [ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d7.tar.gz), [val](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/val/d7_coco_val.txt), [test-dev](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/testdev/d7_coco_test-dev2017.txt)) | 53.7 | 72.4 | 58.4 | 35.8 | 57.0 | 66.3 | 53.4 | | 51.9M | 325B | -| EfficientDet-D7x ([h5](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d7x.h5), [ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d7x.tar.gz), [val](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/val/d7x_coco_val.txt), [test-dev](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/testdev/d7x_coco_test-dev2017.txt)) | 55.1 | 74.3 | 59.9 | 37.2 | 57.9 | 68.0 | 54.4 | | 77.0M | 410B | +| EfficientDet-D0 ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d0.tar.gz), [val](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/val/d0_coco_val.txt), [test-dev](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/testdev/d0_coco_test-dev2017.txt)) | 34.6 | 53.0 | 37.1 | 12.4 | 39.0 | 52.7 | 34.3 | | 3.9M | 2.54B | +| EfficientDet-D1 ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d1.tar.gz), [val](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/val/d1_coco_val.txt), [test-dev](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/testdev/d1_coco_test-dev2017.txt)) | 40.5 | 59.1 | 43.7 | 18.3 | 45.0 | 57.5 | 40.2 | | 6.6M | 6.10B | +| EfficientDet-D2 ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d2.tar.gz), [val](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/val/d2_coco_val.txt), [test-dev](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/testdev/d2_coco_test-dev2017.txt)) | 43.0 | 62.3 | 46.2 | 22.5 | 47.0 | 58.4 | 42.5 | | 8.1M | 11.0B | +| EfficientDet-D3 ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d3.tar.gz), [val](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/val/d3_coco_val.txt), [test-dev](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/testdev/d3_coco_test-dev2017.txt)) | 47.5 | 66.2 | 51.5 | 27.9 | 51.4 | 62.0 | 47.2 | | 12.0M | 24.9B | +| EfficientDet-D4 ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d4.tar.gz), [val](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/val/d4_coco_val.txt), [test-dev](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/testdev/d4_coco_test-dev2017.txt)) | 49.7 | 68.4 | 53.9 | 30.7 | 53.2 | 63.2 | 49.3 | | 20.7M | 55.2B | +| EfficientDet-D5 ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d5.tar.gz), [val](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/val/d5_coco_val.txt), [test-dev](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/testdev/d5_coco_test-dev2017.txt)) | 51.5 | 70.5 | 56.1 | 33.9 | 54.7 | 64.1 | 51.2 | | 33.7M | 130B | +| EfficientDet-D6 ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d6.tar.gz), [val](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/val/d6_coco_val.txt), [test-dev](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/testdev/d6_coco_test-dev2017.txt)) | 52.6 | 71.5 | 57.2 | 34.9 | 56.0 | 65.4 | 52.1 | | 51.9M | 226B | +| EfficientDet-D7 ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d7.tar.gz), [val](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/val/d7_coco_val.txt), [test-dev](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/testdev/d7_coco_test-dev2017.txt)) | 53.7 | 72.4 | 58.4 | 35.8 | 57.0 | 66.3 | 53.4 | | 51.9M | 325B | +| EfficientDet-D7x ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d7x.tar.gz), [val](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/val/d7x_coco_val.txt), [test-dev](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/testdev/d7x_coco_test-dev2017.txt)) | 55.1 | 74.3 | 59.9 | 37.2 | 57.9 | 68.0 | 54.4 | | 77.0M | 410B | val denotes validation results, test-dev denotes test-dev2017 results. APval is for validation accuracy, all other AP results in the table are for COCO test-dev2017. All accuracy numbers are for single-model single-scale without ensemble or test-time augmentation. EfficientDet-D0 to D6 are trained for 300 epochs and D7/D7x are trained for 600 epochs. @@ -73,11 +63,11 @@ In addition, the following table includes a list of models trained with fixed 64 | Model | mAP | Latency | | ------ | ------ | ------ | -| D2(640) [h5](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco640/efficientdet-d2-640.h5), [ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco640/efficientdet-d2-640.tar.gz) | 41.7 | 14.8ms | -| D3(640) [h5](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco640/efficientdet-d3-640.h5), [ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco640/efficientdet-d3-640.tar.gz) | 44.0 | 18.7ms | -| D4(640) [h5](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco640/efficientdet-d4-640.h5), [ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco640/efficientdet-d4-640.tar.gz) | 45.7 | 21.7ms | -| D5(640 [h5](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco640/efficientdet-d5-640.h5), [ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco640/efficientdet-d5-640.tar.gz) | 46.6 | 26.6ms | -| D6(640) [h5](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco640/efficientdet-d6-640.h5), [ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco640/efficientdet-d6-640.tar.gz) | 47.9 | 33.8ms | +| D2(640) [ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco640/efficientdet-d2-640.tar.gz) | 41.7 | 14.8ms | +| D3(640) [ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco640/efficientdet-d3-640.tar.gz) | 44.0 | 18.7ms | +| D4(640) [ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco640/efficientdet-d4-640.tar.gz) | 45.7 | 21.7ms | +| D5(640) [ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco640/efficientdet-d5-640.tar.gz) | 46.6 | 26.6ms | +| D6(640) [ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco640/efficientdet-d6-640.tar.gz) | 47.9 | 33.8ms | @@ -96,10 +86,9 @@ Then you will get: - saved model under `savedmodeldir/` - frozen graph with name `savedmodeldir/efficientdet-d0_frozen.pb` - TensorRT saved model under `savedmodeldir/tensorrt_fp32/` - - tflite file with name `efficientdet-d0.tflite` + - tflite file with name `savedmodeldir/fp32.tflite` Notably, - --tflite_path only works after 2.3.0-dev20200521 , --model_dir=xx/archive is the folder for exporting the best model. @@ -154,12 +143,12 @@ latency and throughput are: # Step2: inference image. !python inspector.py --mode=infer \ - --model_name=efficientdet-d0 --model_dir=efficientdet-d0 \ + --model_name=efficientdet-d0 --saved_model_dir=/tmp/saved_model \ --hparams="image_size=1920x1280" \ --input_image=img.png --output_image_dir=/tmp/ -Alternatively, if you want to do inference using frozen graph instead of saved model, you can run +If you want to do inference using frozen graph, you can run # Step 1 is the same as before. # Step 2: do inference with frozen graph. @@ -168,18 +157,28 @@ Alternatively, if you want to do inference using frozen graph instead of saved m --saved_model_dir=/tmp/saved_model/efficientdet-d0_frozen.pb \ --input_image=img.png --output_image_dir=/tmp/ +If you want to do inference using tflite, you can run + + # Step 1 is the same as before. + # Step 2: do inference with frozen graph. + !python inspector.py --mode=infer \ + --model_name=efficientdet-d0 \ + --saved_model_dir=/tmp/saved_model/fp32.tflite \ + --input_image=img.png --output_image_dir=/tmp/ + Lastly, if you only have one image and just want to run a quick test, you can also run the following command (it is slow because it needs to construct the graph from scratch): # Run inference for a single image. - !python inspector.py --mode=infer --model_name=$MODEL \ + !python inspector.py --mode=infer \ + --model_name=efficientdet-d0 --model_dir=$CKPT_PATH \ --hparams="image_size=1920x1280" \ - --model_dir=$CKPT_PATH --input_image=img.png --output_image_dir=/tmp + --input_image=img.png --output_image_dir=/tmp/ # you can visualize the output /tmp/0.jpg Here is an example of EfficientDet-D0 visualization: more on [tutorial](tutorial.ipynb)

- +

## 6. Inference for videos. @@ -243,14 +242,15 @@ Create a config file for the PASCAL VOC dataset called voc_config.yaml and put t var_freeze_expr: '(efficientnet|fpn_cells|resample_p6)' label_map: {1: aeroplane, 2: bicycle, 3: bird, 4: boat, 5: bottle, 6: bus, 7: car, 8: cat, 9: chair, 10: cow, 11: diningtable, 12: dog, 13: horse, 14: motorbike, 15: person, 16: pottedplant, 17: sheep, 18: sofa, 19: train, 20: tvmonitor} -Finetune needs to use --ckpt rather than --backbone_ckpt. +Finetune needs to use --pretrained_ckpt. !python train.py --training_file_pattern=tfrecord/pascal*.tfrecord \ --val_file_pattern=tfrecord/pascal*.tfrecord \ + --val_file_pattern=tfrecord/*.json \ --model_name=efficientdet-d0 \ --model_dir=/tmp/efficientdet-d0-finetune \ - --ckpt=efficientdet-d0 \ + --pretrained_ckpt=efficientdet-d0 \ --batch_size=64 \ --eval_samples=1024 \ --num_examples_per_epoch=5717 --num_epochs=50 \ @@ -258,52 +258,9 @@ Finetune needs to use --ckpt rather than --backbone_ckpt. If you want to continue to train the model, simply re-run the above command because the `num_epochs` is a maximum number of epochs. For example, to reproduce the result of efficientdet-d0, set `--num_epochs=300` then run the command multiple times until the training is finished. -If you want to do inference for custom data, you can run - - # Setting hparams-flag is needed sometimes. - !python inspector.py --mode=infer \ - --model_name=efficientdet-d0 --model_dir=efficientdet-d0 \ - --hparams=voc_config.yaml \ - --input_image=img.png --output_image_dir=/tmp/ - -You should check more details of runmode which is written in caption-4. - ## 9. Train on multi GPUs. -Create a config file for the PASCAL VOC dataset called voc_config.yaml and put this in it. - - num_classes: 21 - var_freeze_expr: '(efficientnet|fpn_cells|resample_p6)' - label_map: {1: aeroplane, 2: bicycle, 3: bird, 4: boat, 5: bottle, 6: bus, 7: car, 8: cat, 9: chair, 10: cow, 11: diningtable, 12: dog, 13: horse, 14: motorbike, 15: person, 16: pottedplant, 17: sheep, 18: sofa, 19: train, 20: tvmonitor} - -Download efficientdet coco checkpoint. - - !wget https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco/efficientdet-d0.tar.gz - !tar xf efficientdet-d0.tar.gz - -Finetune needs to use --ckpt rather than --backbone_ckpt. - - python train.py \ - --training_file_pattern=tfrecord/pascal*.tfrecord \ - --val_file_pattern=tfrecord/pascal*.tfrecord \ - --model_name=efficientdet-d0 \ - --model_dir=/tmp/efficientdet-d0-finetune \ - --ckpt=efficientdet-d0 \ - --batch_size=64 \ - --eval_samples=1024 \ - --num_examples_per_epoch=5717 --num_epochs=50 \ - --hparams=voc_config.yaml \ - --strategy=gpus - -If you want to do inference for custom data, you can run - - # Setting hparams-flag is needed sometimes. - !python inspector.py --mode=infer \ - --model_name=efficientdet-d0 --model_dir=efficientdet-d0 \ - --hparams=voc_config.yaml \ - --input_image=img.png --output_image_dir=/tmp/ - -You should check more details of runmode which is written in caption-4. +Just add ```--strategy=gpus``` ## 10. Training EfficientDets on TPUs. @@ -335,7 +292,7 @@ EfficientDets use a lot of GPU memory for a few reasons: * Large internal activations for backbone: our backbone uses a relatively large expansion ratio (6), causing the large expanded activations. * Deep BiFPN: our BiFPN has multiple top-down and bottom-up paths, which leads to a lot of intermediate memory usage during training. -To train this model on GPU with low memory there is an experimental option gradient_checkpointing. +To train this model on GPU with low memory there is an experimental option grad_checkpoint. Check these links for a high-level idea of what gradient checkpointing is doing: 1. https://medium.com/tensorflow/fitting-larger-networks-into-memory-583e3c758ff9 diff --git a/efficientdet/keras/eval.py b/efficientdet/keras/eval.py index 7153df6d0..79871a1a4 100644 --- a/efficientdet/keras/eval.py +++ b/efficientdet/keras/eval.py @@ -76,6 +76,7 @@ def main(_): model.build((None, *config.image_size, 3)) util_keras.restore_ckpt(model, tf.train.latest_checkpoint(FLAGS.model_dir), + config.moving_average_decay, skip_mismatch=False) @tf.function def model_fn(images, labels): diff --git a/efficientdet/keras/inference.py b/efficientdet/keras/inference.py index 6b9af12b1..cc377e3c1 100644 --- a/efficientdet/keras/inference.py +++ b/efficientdet/keras/inference.py @@ -202,7 +202,9 @@ def build(self, params_override=None): self.model = efficientdet_keras.EfficientDetModel(config=config) image_size = utils.parse_image_size(params['image_size']) self.model.build((self.batch_size, *image_size, 3)) - util_keras.restore_ckpt(self.model, self.ckpt_path, skip_mismatch=False) + util_keras.restore_ckpt(self.model, self.ckpt_path, + self.params['moving_average_decay'], + skip_mismatch=False) def visualize(self, image, boxes, classes, scores, **kwargs): """Visualize prediction on image.""" diff --git a/efficientdet/keras/train.py b/efficientdet/keras/train.py index 6fcd59054..049640f46 100644 --- a/efficientdet/keras/train.py +++ b/efficientdet/keras/train.py @@ -220,7 +220,7 @@ def get_dataset(is_training, config): model = setup_model(config) if FLAGS.pretrained_ckpt: ckpt_path = tf.train.latest_checkpoint(FLAGS.pretrained_ckpt) - util_keras.restore_ckpt(model, ckpt_path) + util_keras.restore_ckpt(model, ckpt_path, config.moving_average_decay) init_experimental(config) val_dataset = get_dataset(False, config).repeat() model.fit( diff --git a/efficientdet/keras/train_lib.py b/efficientdet/keras/train_lib.py index 54f2f3a88..f55fde007 100644 --- a/efficientdet/keras/train_lib.py +++ b/efficientdet/keras/train_lib.py @@ -398,13 +398,13 @@ def _draw_inference(self, step): def get_callbacks(params, val_dataset): """Get callbacks for given params.""" - if False: + if params['moving_average_decay']: from tensorflow_addons.callbacks import AverageModelCheckpoint avg_callback = AverageModelCheckpoint( filepath=os.path.join(params['model_dir'], 'ckpt'), verbose=1, save_weights_only=True, - update_weights=True) + update_weights=False) callbacks = [avg_callback] else: ckpt_callback = tf.keras.callbacks.ModelCheckpoint( diff --git a/efficientdet/keras/util_keras.py b/efficientdet/keras/util_keras.py index 5cde4b45f..79cd2b411 100644 --- a/efficientdet/keras/util_keras.py +++ b/efficientdet/keras/util_keras.py @@ -93,7 +93,7 @@ def average_name(ema, var): var.name.split(':')[0] + '/' + ema.name, mark_as_used=False) -def restore_ckpt(model, ckpt_path_or_file, ema_decay=0., skip_mismatch=True): +def restore_ckpt(model, ckpt_path_or_file, ema_decay=0.9998, skip_mismatch=True): """Restore variables from a given checkpoint. Args: