Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

panoptic-segmentation sync #13

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 49 additions & 23 deletions official/nlp/modeling/models/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -1087,16 +1087,17 @@ def __init__(self,

@tf.Module.with_name_scope
def __call__(self,
inputs,
inputs=None,
encoder_mask=None,
dense_inputs=None,
training=False):
"""Applies Transformer model on the inputs.

Args:
inputs: input data
inputs: input word ids. Optional if dense data are provided.
encoder_mask: the encoder self-attention mask.
dense_inputs: dense input data, concat after the embedding.
dense_inputs: dense input data. Concat after the embedding if word ids
are provided.
training: whether it is training pass, affecting dropouts.

Returns:
Expand All @@ -1106,16 +1107,27 @@ def __call__(self,
if encoder_mask is not None:
encoder_mask = tf.cast(encoder_mask, self.compute_dtype)
cfg = self.config
x = self.input_embed(inputs, one_hot=cfg.one_hot_embedding)
inputs_array = []
if inputs is not None:
inputs_array.append(
self.input_embed(inputs, one_hot=cfg.one_hot_embedding))
if dense_inputs is not None:
x = tf.concat([x, dense_inputs], axis=1)
inputs_array.append(dense_inputs)
if not inputs_array:
raise ValueError("At least one of inputs and dense_inputs must not be "
"None.")
x = tf.concat(inputs_array, axis=1)
tensor_shape = tf_utils.get_shape_list(x)
tensor_shape[-2] = 1
x = self.input_dropout(x, noise_shape=tensor_shape, training=training)
input_length = tf_utils.get_shape_list(inputs)[1]
if inputs is not None:
input_length = tf_utils.get_shape_list(inputs)[1]
else:
input_length = 0
position_bias = self.relative_embedding(input_length, input_length)
if dense_inputs is not None:
# Here we ignore relative position bias for dense embeddings.
# TODO(yejiayu): If we proceed to video use cases, rework this part.
dense_input_length = tf_utils.get_shape_list(dense_inputs)[1]
# Position bias shape: [batch, 1, len, len]
paddings = tf.constant([[0, 0], [0, 0], [0, dense_input_length],
Expand Down Expand Up @@ -1320,25 +1332,35 @@ def __init__(self,
compute_dtype=self.compute_dtype)

def encode(self,
encoder_input_tokens,
encoder_input_tokens=None,
encoder_segment_ids=None,
encoder_dense_inputs=None,
encoder_dense_segment_ids=None,
training=False):
eligible_positions = tf.cast(
tf.not_equal(encoder_input_tokens, 0), self.compute_dtype)
eligible_position_array = []
if encoder_input_tokens is not None:
eligible_position_array.append(
tf.cast(tf.not_equal(encoder_input_tokens, 0), self.compute_dtype))
if encoder_dense_inputs is not None:
eligible_dense_position = tf.cast(
eligible_dense_positions = tf.cast(
tf.reduce_any(tf.not_equal(encoder_dense_inputs, 0), axis=-1),
self.compute_dtype)
eligible_positions = tf.concat(
[eligible_positions, eligible_dense_position], axis=1)
eligible_position_array.append(eligible_dense_positions)
if not eligible_position_array:
raise ValueError("At least one of encoder_input_tokens and"
" encoder_dense_inputs must be provided.")

eligible_positions = tf.concat(eligible_position_array, axis=1)
encoder_mask = make_attention_mask(
eligible_positions, eligible_positions, dtype=tf.bool)

encoder_segment_id_array = []
if encoder_segment_ids is not None:
if encoder_dense_segment_ids is not None:
encoder_segment_ids = tf.concat(
[encoder_segment_ids, encoder_dense_segment_ids], axis=1)
encoder_segment_id_array.append(encoder_segment_ids)
if encoder_dense_segment_ids is not None:
encoder_segment_id_array.append(encoder_dense_segment_ids)
if encoder_segment_id_array:
encoder_segment_ids = tf.concat(encoder_segment_id_array, axis=1)
segment_mask = make_attention_mask(
encoder_segment_ids, encoder_segment_ids, tf.equal, dtype=tf.bool)
encoder_mask = tf.math.logical_and(encoder_mask, segment_mask)
Expand All @@ -1353,7 +1375,7 @@ def decode(
self,
encoded,
decoder_target_tokens,
encoder_input_tokens, # only used for masks
encoder_input_tokens=None, # only used for masks
encoder_dense_inputs=None,
decoder_input_tokens=None,
encoder_segment_ids=None,
Expand All @@ -1364,14 +1386,18 @@ def decode(
max_decode_len=None,
decode=False,
training=False):
eligible_inputs = tf.cast(
tf.not_equal(encoder_input_tokens, 0), self.compute_dtype)
eligible_inputs_array = []
if encoder_input_tokens is not None:
eligible_inputs = tf.cast(
tf.not_equal(encoder_input_tokens, 0), self.compute_dtype)
eligible_inputs_array.append(eligible_inputs)
if encoder_dense_inputs is not None:
eligible_dense_inputs = tf.cast(
tf.reduce_any(tf.not_equal(encoder_dense_inputs, 0), axis=-1),
self.compute_dtype)
eligible_inputs = tf.concat([eligible_inputs, eligible_dense_inputs],
axis=1)
eligible_inputs_array.append(eligible_dense_inputs)
eligible_inputs = tf.concat(eligible_inputs_array, axis=1)

if decode:
# For decoding, the decoder_input_tokens is the decoder_target_tokens.
decoder_input_tokens = decoder_target_tokens
Expand Down Expand Up @@ -1430,8 +1456,8 @@ def decode(

@tf.Module.with_name_scope
def __call__(self,
encoder_input_tokens,
decoder_target_tokens,
encoder_input_tokens=None,
decoder_target_tokens=None,
encoder_dense_inputs=None,
encoder_dense_segment_ids=None,
decoder_input_tokens=None,
Expand All @@ -1456,7 +1482,7 @@ def __call__(self,
a dictionary of logits/cache.
"""
encoded = self.encode(
encoder_input_tokens,
encoder_input_tokens=encoder_input_tokens,
encoder_segment_ids=encoder_segment_ids,
encoder_dense_inputs=encoder_dense_inputs,
encoder_dense_segment_ids=encoder_dense_segment_ids,
Expand Down
68 changes: 68 additions & 0 deletions official/nlp/modeling/models/t5_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,22 @@ def test_encoder_with_dense(self, dtype):
dense_inputs=tf.ones((4, 2, 4), dtype=dtype))
self.assertEqual(encoded.shape, (4, 10, config.d_model))

@parameterized.named_parameters(("bfloat16", tf.bfloat16),
("float32", tf.float32))
def test_encoder_only_dense(self, dtype):
config = t5.T5TransformerParams(
num_layers=2,
d_model=4,
d_kv=3,
num_heads=4,
d_ff=16,
vocab_size=10,
vocab_embeddings_initializer=tf.keras.initializers.Ones(),
relative_embeddings_initializer=tf.keras.initializers.Ones())
encoder = t5.Encoder(config, compute_dtype=dtype)
encoded = encoder(dense_inputs=tf.ones((4, 2, 4), dtype=dtype))
self.assertEqual(encoded.shape, (4, 2, config.d_model))

def test_decoder(self):
max_decode_len = 10
config = t5.T5TransformerParams(
Expand Down Expand Up @@ -515,6 +531,58 @@ def test_transformer_with_dense(self, ffn_activations, logits_via_embedding,
print(v.name, v.shape)
self.assertEqual(v.dtype, tf.float32)

@parameterized.named_parameters(
("t5_10", ("relu",), True, 26, False, tf.float32),)
def test_transformer_with_dense_only(self, ffn_activations,
logits_via_embedding,
expect_num_variables, layer_sharing,
dtype):
max_decode_len = 10
config = t5.T5TransformerParams(
num_layers=1,
d_model=8,
d_kv=4,
num_heads=4,
d_ff=32,
vocab_size=10,
shared_embedding=True,
layer_sharing=layer_sharing,
ffn_activations=ffn_activations,
logits_via_embedding=logits_via_embedding)
transformer = t5.T5Transformer(config, compute_dtype=dtype)
self.assertLen(transformer.trainable_variables, expect_num_variables)

decoder_inputs = tf.convert_to_tensor(
np.array([[2, 2, 1, 3, 1, 0], [3, 3, 1, 2, 2, 1]]))
decoder_segments = tf.convert_to_tensor(
np.array([[1, 1, 1, 2, 2, 0], [1, 1, 1, 2, 2, 2]]))

dense_inputs = tf.convert_to_tensor(np.random.randn(2, 2, 8), dtype=dtype)
dense_segments = tf.convert_to_tensor(np.array([[1, 2], [1, 2]]))
outputs = transformer(
encoder_dense_inputs=dense_inputs,
encoder_dense_segment_ids=dense_segments,
decoder_input_tokens=decoder_inputs,
decoder_target_tokens=decoder_inputs,
decoder_segment_ids=decoder_segments)
cache = {}
batch_size = 2
cache[0] = _create_cache(
batch_size, max_decode_len, config.num_heads, config.d_kv, dtype=dtype)
outputs = transformer.decode(
encoder_dense_inputs=dense_inputs,
encoded=outputs["encoded"],
decoder_target_tokens=tf.ones((batch_size, 1), dtype=tf.int32),
decode_position=1,
decode=True,
max_decode_len=max_decode_len,
cache=cache)
self.assertEqual(outputs["logits"].shape,
(batch_size, 1, config.vocab_size))
for v in transformer.trainable_variables:
print(v.name, v.shape)
self.assertEqual(v.dtype, tf.float32)

@parameterized.named_parameters(
("t5_10", ("relu",), True, 39, tf.float32, 2),
("t5_10_bfloat16", ("relu",), True, 39, tf.bfloat16, 2))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.modeling import optimization
from official.vision.beta.configs import common
from official.vision.beta.configs import image_classification as base_config
from official.vision.configs import common
from official.vision.configs import image_classification as base_config


@dataclasses.dataclass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@
from official.core import exp_factory
from official.modeling import hyperparams
from official.modeling import optimization
from official.vision.beta.configs import backbones
from official.vision.beta.configs import common
from official.vision.beta.configs import decoders
from official.vision.beta.configs import semantic_segmentation as base_cfg
from official.vision.configs import backbones
from official.vision.configs import common
from official.vision.configs import decoders
from official.vision.configs import semantic_segmentation as base_cfg


@dataclasses.dataclass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
from official.core import exp_factory
from official.modeling import hyperparams
from official.modeling import optimization
from official.vision.beta.configs import backbones
from official.vision.beta.configs import semantic_segmentation as base_cfg
from official.vision.configs import backbones
from official.vision.configs import semantic_segmentation as base_cfg

# ADE 20K Dataset
ADE20K_TRAIN_EXAMPLES = 20210
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
# Import libraries
import tensorflow as tf

from official.vision.beta.dataloaders import classification_input
from official.vision.beta.ops import preprocess_ops
from official.vision.dataloaders import classification_input
from official.vision.ops import preprocess_ops

MEAN_RGB = (0.5 * 255, 0.5 * 255, 0.5 * 255)
STDDEV_RGB = (0.5 * 255, 0.5 * 255, 0.5 * 255)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
from absl.testing import parameterized
import tensorflow as tf
from official.projects.edgetpu.vision.dataloaders import classification_input
from official.vision.beta.configs import common
from official.vision.beta.dataloaders import tfexample_utils
from official.vision.configs import common
from official.vision.dataloaders import tfexample_utils

IMAGE_FIELD_KEY = 'image/encoded'
LABEL_FIELD_KEY = 'image/class/label'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from official.modeling import hyperparams
from official.projects.edgetpu.vision.modeling.mobilenet_edgetpu_v1_model import MobilenetEdgeTPU
from official.projects.edgetpu.vision.modeling.mobilenet_edgetpu_v2_model import MobilenetEdgeTPUV2
from official.vision.beta.modeling.backbones import factory
from official.vision.modeling.backbones import factory

layers = tf.keras.layers

Expand Down
2 changes: 1 addition & 1 deletion official/projects/edgetpu/vision/serving/export_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from official.projects.edgetpu.vision.modeling.backbones import mobilenet_edgetpu
from official.projects.edgetpu.vision.tasks import image_classification
from official.projects.edgetpu.vision.tasks import semantic_segmentation as edgetpu_semantic_segmentation
from official.vision.beta.tasks import semantic_segmentation
from official.vision.tasks import semantic_segmentation
# pylint: enable=unused-import

MEAN_RGB = [127.5, 127.5, 127.5]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
from official.projects.edgetpu.vision.dataloaders import classification_input
from official.projects.edgetpu.vision.modeling import mobilenet_edgetpu_v1_model
from official.projects.edgetpu.vision.modeling import mobilenet_edgetpu_v2_model
from official.vision.beta.configs import image_classification as base_cfg
from official.vision.beta.dataloaders import input_reader_factory
from official.vision.configs import image_classification as base_cfg
from official.vision.dataloaders import input_reader_factory


def _copy_recursively(src: str, dst: str) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@
import orbit
import tensorflow as tf

from official.common import registry_imports
from official.core import exp_factory
from official.modeling import optimization
from official.projects.edgetpu.vision.configs import mobilenet_edgetpu_config
from official.projects.edgetpu.vision.tasks import image_classification
from official.vision import registry_imports


# Dummy ImageNet TF dataset.
Expand Down
10 changes: 5 additions & 5 deletions official/projects/edgetpu/vision/tasks/semantic_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@
from official.projects.edgetpu.vision.modeling import mobilenet_edgetpu_v2_model
from official.projects.edgetpu.vision.modeling.backbones import mobilenet_edgetpu # pylint: disable=unused-import
from official.projects.edgetpu.vision.modeling.heads import bifpn_head
from official.vision.beta.dataloaders import input_reader_factory
from official.vision.beta.dataloaders import segmentation_input
from official.vision.beta.dataloaders import tfds_factory
from official.vision.beta.ops import preprocess_ops
from official.vision.beta.tasks import semantic_segmentation
from official.vision.dataloaders import input_reader_factory
from official.vision.dataloaders import segmentation_input
from official.vision.dataloaders import tfds_factory
from official.vision.ops import preprocess_ops
from official.vision.tasks import semantic_segmentation


class ClassMappingParser(segmentation_input.Parser):
Expand Down
4 changes: 1 addition & 3 deletions official/projects/edgetpu/vision/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@
from absl import flags
import gin

# pylint: disable=unused-import
from official.common import registry_imports
# pylint: enable=unused-import
from official.common import distribute_utils
from official.common import flags as tfm_flags
from official.core import task_factory
Expand All @@ -35,6 +32,7 @@
from official.projects.edgetpu.vision.modeling.backbones import mobilenet_edgetpu
from official.projects.edgetpu.vision.tasks import image_classification
from official.projects.edgetpu.vision.tasks import semantic_segmentation
from official.vision import registry_imports
# pylint: enable=unused-import

FLAGS = flags.FLAGS
Expand Down
17 changes: 16 additions & 1 deletion official/projects/qat/vision/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ $ python3 train.py \
--mode=train_and_eval
```

## Model Accuracy
## Image Classification

<figure align="center">
<img width=70% src=https://storage.googleapis.com/tf_model_garden/models/qat/images/readme-qat-classification-plot.png>
Expand All @@ -46,3 +46,18 @@ Note: The Top-1 model accuracy is measured on the validation set of [ImageNet](h
|ResNet50 |224x224 |76.710% |76.420% |77.200% |[config](https://github.com/tensorflow/models/blob/master/official/projects/qat/vision/configs/experiments/image_classification/imagenet_resnet50_qat_gpu.yaml) |[TFLite(Int8/QAT)](https://storage.googleapis.com/tf_model_garden/vision/resnet50_imagenet/resnet_50_224_int8.tflite) |
|MobileNetV3.5 MultiAVG|224x224 |75.212% |74.122% |75.130% |[config](https://github.com/tensorflow/models/blob/master/official/projects/qat/vision/configs/experiments/image_classification/imagenet_mobilenetv3.5_qat_gpu.yaml)|[TFLite(Int8/QAT)](https://storage.googleapis.com/tf_model_garden/vision/mobilenet/v3.5multiavg_1.0_int8/mobilenet_v3.5multiavg_1.00_224_int8.tflite)|

## Semantic Segmentation


Model is pretrained using COCO train set. Two datasets, Pascal VOC segmentation
dataset and Cityscapes dataset (only for DeepLab v3+), are used to train and
evaluate models. Model accuracy is measured on full Pascal VOC segmentation
validation set.

### Pre-trained Models

model | resolution | mIoU | mIoU (FP32) | mIoU (FP16) | mIoU (INT8) | mIoU (QAT INT8) | download (tflite)
:------------------------- | :--------: | ----: | ----------: | ----------: | ----------: | --------------: | ------------------------------------------------------: | ------------------------------------------------------: | -------------------------------------------------------: | ------------------------------------------------------: | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | ----------------:
MobileNet v2 + DeepLab v3 | 512x512 | 75.27 | 75.30 | 75.32 | 73.95 | 74.68 | [FP32](https://storage.googleapis.com/tf_model_garden/vision/qat/deeplabv3_mobilenetv2_pascal_coco_0.21/model_none.tflite) \| [FP16](https://storage.googleapis.com/tf_model_garden/vision/qat/deeplabv3_mobilenetv2_pascal_coco_0.21/model_fp16.tflite) \| [INT8](https://storage.googleapis.com/tf_model_garden/vision/qat/deeplabv3_mobilenetv2_pascal_coco_0.21model_int8_full.tflite) \| [QAT INT8](https://storage.googleapis.com/tf_model_garden/vision/qat/deeplabv3_mobilenetv2_pascal_coco_0.21/Fmodel_default.tflite)
MobileNet v2 + DeepLab v3+ | 1024x2048 | 73.82 | 73.84 | 73.65 | 72.33 | 73.49 | [FP32](https://storage.googleapis.com/tf_model_garden/vision/qat/mnv2_deeplabv3plus_cityscapes/model_none.tflite) \| [FP16](https://storage.googleapis.com/tf_model_garden/vision/qat/mnv2_deeplabv3plus_cityscapes/Fmodel_fp16.tflite) \| [INT8](https://storage.googleapis.com/tf_model_garden/vision/qat/mnv2_deeplabv3plus_cityscapes/model_int8_full.tflite) \| [QAT INT8](https://storage.googleapis.com/tf_model_garden/vision/qat/mnv2_deeplabv3plus_cityscapes/Fmodel_default.tflite)

Loading