Skip to content

Commit

Permalink
Merge pull request #8 from robertknight/training-guide
Browse files Browse the repository at this point in the history
Add training guide and align text detection model training with recognition model
  • Loading branch information
robertknight committed Jan 30, 2024
2 parents 2c696aa + 64361c2 commit 9c6c72a
Show file tree
Hide file tree
Showing 9 changed files with 866 additions and 630 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ jobs:
- name: Install pipenv
run: pip install pipenv
- name: Install dependencies
run: pipenv install --dev
run: |
pipenv install --dev
pipenv run pip install torch torchvision
- name: Check formatting and types
run: pipenv run qa
3 changes: 1 addition & 2 deletions Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,14 @@ verify_ssl = true
name = "pypi"

[packages]
torch = "*"
numpy = "*"
torchvision = "*"
pillow = "*"
tqdm = "*"
opencv-python = "*"
shapely = "*"
wandb = "*"
pylev = "*"
onnx = "*"

[dev-packages]
black = "*"
Expand Down
1,147 changes: 539 additions & 608 deletions Pipfile.lock

Large diffs are not rendered by default.

13 changes: 13 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
This project contains tools for training PyTorch models for use with the
[**Ocrs**](https://github.com/robertknight/ocrs/) OCR engine.

## About the models

The ocrs engine splits text detection and recognition into three phases, each
of which corresponds to a different model in this repository:

Expand All @@ -22,3 +24,14 @@ All models can be exported to ONNX for downstream use.

The models are trained exclusively on datasets which are a) open and b) have non-restrictive licenses. This currently includes:
- [HierText](https://github.com/google-research-datasets/hiertext) (CC-BY-SA 4.0)

## Pre-trained models

Pre-trained models are available from [Hugging
Face](https://huggingface.co/robertknight/ocrs) as PyTorch checkpoints,
[ONNX](https://onnx.ai) and [RTen](https://github.com/robertknight/rten) models.

## Training custom models

See the [Training guide](docs/training.md) for a walk-through of the process to
train models from scratch or fine-tune existing models.
190 changes: 190 additions & 0 deletions docs/training.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
# Training Ocrs models

This document describes how to train models for use with
[ocrs](https://github.com/robertknight/ocrs).

## Prerequisites

To train the models you will need:

- Python 3.10 or later
- A GPU. The initial training was done on NVidia A10G GPUs with 24 GB RAM, via
[AWS EC2 G5 instances](https://aws.amazon.com/ec2/instance-types/g5/) (the
smallest `g5.xlarge` size will work).
- Optional: A Weights and Biases account (https://wandb.ai/) to track training progress

## About Ocrs models

Ocrs splits the OCR process into three stages:

1. Text detection
2. Layout analysis
3. Text recognition

Each of these stages corresponds to a separate PyTorch model. The layout
analysis model is incomplete and is not currently used in Ocrs.

You can mix and match default/pre-trained and custom models for the different
stages. For example you may wish to use a pre-trained detection model but a
custom recognition model.

## Download the dataset

Following the instructions in
https://github.com/google-research-datasets/hiertext#getting-started, clone the
HierText repository and download the training data.

Note that you do **not** need to follow the step about decompressing the
`.jsonl.gz` files. The training tools will do this for you.

The compressed dataset is ~3.6 GB in total size.

```
# Clone the HierText repository. This contains the ground truth data.
mkdir -p datasets/
cd datasets
git clone https://github.com/google-research-datasets/hiertext.git
cd hiertext
# Download the training, validation and test images.
aws s3 --no-sign-request cp s3://open-images-dataset/ocr/train.tgz .
aws s3 --no-sign-request cp s3://open-images-dataset/ocr/validation.tgz .
aws s3 --no-sign-request cp s3://open-images-dataset/ocr/test.tgz .
# Decompress the datasets.
tar -xf train.tgz
tar -xf validation.tgz
tar -xf test.tgz
```

## Set up the training environment

1. Install [Pipenv](https://pipenv.pypa.io/en/latest/)
2. Install dependencies, except for PyTorch:

```
pipenv install --dev
```

3. Install the appropriate version of PyTorch for your system, in the virtualenv
created by pipenv:

```
pipenv run pip install torch torchvision
```

See https://pytorch.org/get-started/locally/ for an appropriate pip command
depending on your platform and GPU.

4. Start a dummy training run of text detection training to verify everything is working:

```
pipenv run python -m ocrs_models.train_detection hiertext datasets/hiertext/ --max-images 100
```

Wait for one successful epoch of training and validation to complete and then
exit the process with Ctrl+C.

## Set up Weights and Biases integration (optional)

The ocrs-models training scripts support tracking training progress using
[Weights and Biases](https://wandb.ai). To enable this you will need to create
an account and then set the `WANDB_API_KEY` environment variable before running
training scripts:

```
export WANDB_API_KEY=<your_api_key>
```

## Train the text detection model

To launch a training run for the text detection model, run:

```
pipenv run python -m ocrs_models.train_detection hiertext datasets/hiertext/ \
--max-epochs 50 \
--batch-size 28
```

The `--batch-size` flag will need to be varied according to the amount of GPU
memory you have available. One way to do this is to start with a small value,
and then increase it until the training process is using most of the available
GPU memory. The above value was used with a GPU that has 24 GB of memory. When
training with an NVidia GPU, you can use the `nvidia-smi` tool to get memory
usage statistics.

To fine-tune an existing model, pass the `--checkpoint` flag to specify the
pre-trained model to start with.

### Export the text detection model

As training progresses, the latest checkpoint will be saved to
`text-detection-checkpoint.pt`. Once training completes, you can export the
model to ONNX via:

```
pipenv run python -m ocrs_models.train_detection hiertext datasets/hiertext/ \
--checkpoint text-detection-checkpoint.pt \
--export text-detection.onnx
```

### Convert the text detection model

To use the exported ONNX model with Ocrs, you will need to convert it to
the `.rten` format used by [RTen][rten].

See the [RTen README](https://github.com/robertknight/rten#getting-started)
for current instructions on how to do this.

To use the converted model with the `ocrs` CLI tool, you can either pass the
model path via CLI arguments, or replace the default models in the cache
directory (`~/.cache/ocrs`). Example using CLI arguments:

```sh
ocrs --detect-model custom-detection-model.rten image.jpg
```

[rten]: https://github.com/robertknight/rten

## Train the text recognition model

To launch a training run for the text recognition model, run:

```
pipenv run python -m ocrs_models.train_rec hiertext datasets/hiertext/ \
--max-epochs 50 \
--batch-size 250
```

The `--batch-size` flag will need to be varied according to the amount of GPU
memory you have available. One way to do this is to start with a small value,
and then increase it until the training process is using most of the available
GPU memory. The above value was used with a GPU that has 24 GB of memory.

To fine-tune an existing model, pass the `--checkpoint` flag to specify the
pre-trained model to start with.

### Export the text recognition model

As training progresses, the latest checkpoint will be saved to
`text-rec-checkpoint.pt`. Once training completes, you can export the model to
ONNX via:

```
pipenv run python -m ocrs_models.train_rec hiertext datasets/hiertext/ \
--checkpoint text-rec.pt \
--export text-recognition.onnx
```

### Convert the text recognition model

To use the exported ONNX models with Ocrs, convert it to `.rten` format using
the same process as for the detection model.

To use the converted model with the `ocrs` CLI tool, you can either pass the
model path via CLI arguments, or replace the default models in the cache
directory (`~/.cache/ocrs`). Example using CLI arguments:

```sh
ocrs --rec-model custom-recognition-model.rten image.jpg
```
6 changes: 0 additions & 6 deletions ocrs_models/eval_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,6 @@ def main():
parser.add_argument("model")
parser.add_argument("image")
parser.add_argument("out_basename")
parser.add_argument(
"--export", type=str, help="Export model as ONNX after evaluation"
)
args = parser.parse_args()

model = DetectionModel()
Expand All @@ -51,9 +48,6 @@ def main():
pred_masks = model(img)
end = time.time()

if args.export:
torch.onnx.export(model, img, args.export)

print(f"Predicted text in {end - start:.2f}s", file=sys.stderr)

pred_masks = pred_masks[0] # Remove dummy batch dimension
Expand Down
35 changes: 34 additions & 1 deletion ocrs_models/postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,29 @@ def expand_quads(quads: torch.Tensor, dist: float) -> torch.Tensor:
return torch.stack([expand_quad(quad, dist) for quad in quads])


def lines_intersect(a_start: float, a_end: float, b_start: float, b_end: float) -> bool:
"""
Return true if the lines (a_start, a_end) and (b_start, b_end) intersect.
"""
if a_start <= b_start:
return a_end > b_start
else:
return b_end > a_start


def bounds_intersect(
a: tuple[float, float, float, float], b: tuple[float, float, float, float]
) -> bool:
"""
Return true if the rects defined by two (min_x, min_y, max_x, max_y) tuples intersect.
"""
a_min_x, a_min_y, a_max_x, a_max_y = a
b_min_x, b_min_y, b_max_x, b_max_y = b
return lines_intersect(a_min_x, a_max_x, b_min_x, b_max_x) and lines_intersect(
a_min_y, a_max_y, b_min_y, b_max_y
)


def box_match_metrics(pred: torch.Tensor, target: torch.Tensor) -> dict[str, float]:
"""
Compute metrics for quality of matches between two sets of rotated rects.
Expand All @@ -99,12 +122,22 @@ def box_match_metrics(pred: torch.Tensor, target: torch.Tensor) -> dict[str, flo
# Areas of unions of predictions and targets
union = torch.zeros((len(pred), len(target)))

# Get bounding boxes of polys for a cheap intersection test.
pred_polys_bounds = [poly.bounds for poly in pred_polys]
target_polys_bounds = [poly.bounds for poly in target_polys]

pred_areas = torch.zeros((len(pred),))
for pred_index, pred_poly in enumerate(pred_polys):
pred_areas[pred_index] = pred_poly.area
pred_bounds = pred_polys_bounds[pred_index]

for target_index, target_poly in enumerate(target_polys):
if not pred_poly.intersects(target_poly):
# Do a cheap intersection test and skip computing the actual
# union/intersection if that fails.
target_bounds = target_polys_bounds[target_index]
if not bounds_intersect(pred_bounds, target_bounds):
continue

pt_intersection = pred_poly.intersection(target_poly)
intersection[pred_index, target_index] = pt_intersection.area

Expand Down
Loading

0 comments on commit 9c6c72a

Please sign in to comment.