Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

transformer_main.py - Low TPU usage V3-8 #8671

Closed
3 tasks done
soares-f opened this issue Jun 13, 2020 · 5 comments
Closed
3 tasks done

transformer_main.py - Low TPU usage V3-8 #8671

soares-f opened this issue Jun 13, 2020 · 5 comments
Assignees
Labels
models:official models that come under official repository type:bug Bug in the code

Comments

@soares-f
Copy link

Prerequisites

Please answer the following questions for yourself before submitting an issue.

  • I am using the latest TensorFlow Model Garden release and TensorFlow 2.
  • I am reporting the issue to the correct repository. (Model Garden official or research directory)
  • I checked to make sure that this issue has not been filed already.

1. The entire URL of the file you are using

https://github.com/tensorflow/models/blob/master/official/nlp/transformer/transformer_main.py

2. Describe the bug

I followed the instructions to use TPUs to train a transformer model, but I only get around 12% of TPU utilization when running the code.

3. Steps to reproduce

export PYTHONPATH="$PYTHONPATH:/path/to/models"

cd /path/to/models/official/nlp/transformer

# Export variables
PARAM_SET=big
DATA_DIR=$HOME/transformer/data
MODEL_DIR=$HOME/transformer/model_$PARAM_SET
VOCAB_FILE=$DATA_DIR/vocab.ende.32768

# Download training/evaluation/test datasets
python3 data_download.py --data_dir=$DATA_DIR

# Train the model for 100000 steps and evaluate every 5000 steps on a single GPU.
# Each train step, takes 4096 tokens as a batch budget with 64 as sequence
# maximal length.

python transformer_main.py \
--tpu=$TPU_NAME \
--model_dir=$MODEL_DIR \
--data_dir=$DATA_DIR \
--vocab_file=$DATA_DIR/vocab.ende.32768 \
--batch_size=10048 \
--train_steps=200000 \
--static_batch=true \
--use_ctl=true \
--param_set=big \
--steps_between_evals=30000 \
--max_length=64 \
--decode_batch_size=1024 \
--decode_max_length=97 \
--padded_decode=true \
--distribution_strategy=tpu \
--enable_metrics_in_training=true \
--enable_tensorboard=true 

capture_tpu_profile --tpu=$TPU_NAME  --monitoring_level=2

TPU type: TPU v3
Number of TPU cores: 8 (Replica count = 8, num cores per replica = 1)
TPU idle time (lower is better): 0.058%
Utilization of TPU Matrix Units (higher is better): 11.9%
Step time: 209ms (avg), 209ms (min), 209ms (max)
Infeed percentage: 0.048% (avg), 0.048% (min), 0.048% (max)

4. Expected behavior

I would expect TPU usage to be higher, it does seem to be using only 1 TPU core.

5. Additional context

Memory usage is also low, around 12GB, which again seems to be using just 1 TPU core.

6. System information

== check python ===================================================
python version: 3.7.3
python branch:
python build version: ('default', 'Dec 20 2019 18:57:59')
python compiler version: GCC 8.3.0
python implementation: CPython
== check os platform ===============================================
os: Linux
os kernel version: #1 SMP Debian 4.19.118-2+deb10u1 (2020-06-07)
os release version: 4.19.0-9-cloud-amd64
os platform: Linux-4.19.0-9-cloud-amd64-x86_64-with-debian-10.4
linux distribution: ('debian', '10.4', '')
linux os distribution: ('debian', '10.4', '')
mac version: ('', ('', '', ''), '')
uname: uname_result(system='Linux', node='garden', release='4.19.0-9-cloud-amd64', version='#1 SMP Debian 4.19.11
8-2+deb10u1 (2020-06-07)', machine='x86_64', processor='')
architecture: ('64bit', 'ELF')
machine: x86_64
== are we in docker =============================================
No
== compiler =====================================================
c++ (Debian 8.3.0-6) 8.3.0
Copyright (C) 2018 Free Software Foundation, Inc.

@soares-f soares-f added models:official models that come under official repository type:bug Bug in the code labels Jun 13, 2020
@saberkun
Copy link
Member

Hi, what TPU topology you use? e.g. tpu-v3-8?

Utilization of TPU Matrix Units (higher is better): 11.9%
This is the matrix compute unit of TPU rather than the total TPU utilization. Usually ~40% means a good utilization.
TPU idle time (lower is better): 0.058% shows that TPU is indeed busy.
We will need more information about profiling.
Could you use TF profiler to see if this model is bounded by input processing? https://www.tensorflow.org/tensorboard/tensorboard_profiling_keras#use_the_tensorflow_profiler_to_profile_model_training_performance
@gagika @allenwang28 for Cloud TPU help.
In the meanwhile, we will check what's the utilization we can get.
Thanks!

@soares-f
Copy link
Author

Hi, I'm using TPU V3-8.

This is the link for the profiler reports:
https://drive.google.com/drive/folders/1MnWTG0lq8geCz3bB_iQwUtMDgVjzQYPJ?usp=sharing

I generated two of them, just in case.
When checking them, there is something odd, there's a sorting going on apparently.

tflop_util

tf_stats

I used exactly the same file generate_data.py, only changing the source data to my parallel training files (2.48 M lines).

I have tried changing the batch size, it didn't make any difference.

Thanks

@saberkun
Copy link
Member

The sorting and topk are from training metrics. It is surprising to see they are that slow.
Could you try adding --enable_metrics_in_training=False? By default it should be false.
I feel they are from:

class MetricLayer(tf.keras.layers.Layer):

@soares-f
Copy link
Author

The sorting and topk are from training metrics. It is surprising to see they are that slow.
Could you try adding --enable_metrics_in_training=False? By default it should be false.
I feel they are from:

class MetricLayer(tf.keras.layers.Layer):

In fact, when I set --enable_metrics_in_training to False it did increase TPU usage to around 40%

TPU type: TPU v3
Number of TPU cores: 8 (Replica count = 8, num cores per replica = 1)
Per-core batch size: 4096
TPU idle time (lower is better): 0.151%
Utilization of TPU Matrix Units (higher is better): 38.9%
Step time: 65.1ms (avg), 65.1ms (min), 65.2ms (max)
Infeed percentage: 0.140% (avg), 0.139% (min), 0.140% (max)

image

I wonder if it is possible to get even better TPU usage.

Thanks,

@saberkun
Copy link
Member

saberkun commented Jun 17, 2020

Utilization of TPU Matrix Units (higher is better): 38.9%, already looks ok to me.
The Matrix units are a bit special that the data format may trigger padding.
For transformer, there are tricks to improve efficiency by providing packed sequences as inputs on TPU. We have not implemented it yet.
Increasing the batch size further should be able to increate TPU utilization.
Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
models:official models that come under official repository type:bug Bug in the code
Projects
None yet
Development

No branches or pull requests

4 participants