Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Spellcheck #345

Merged
merged 44 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
4ad41de
feat(spellcheck): :zap: Add feature: push to Argilla from an HF dataset
jeremyarancio Jun 18, 2024
128e056
fix(spellcheck): :bug: Add fixes to T5 script
jeremyarancio Jun 18, 2024
3217a89
fix(spellcheck): :bug: Add guardrail to prevent compiuting metrics wi…
jeremyarancio Jun 18, 2024
8bd97f2
feat(spellcheck): :art: Training pipeline using Metaflow
jeremyarancio Jun 18, 2024
a858bc8
feat(spellcheck): :art: LLM QLoRA TRL training script - Mistral - 7B …
jeremyarancio Jun 18, 2024
634b050
perf(spellcheck): :test_tube: Normalize evaluation algorithm
jeremyarancio Jun 19, 2024
e16c6a2
feat(spellcheck): :art: Implement LLM training with Sagemaker & Metaflow
jeremyarancio Jun 20, 2024
2f6f75b
feat(spellcheck): :zap: Mistral 7b instruct v3 trained
jeremyarancio Jun 20, 2024
db2780c
feat(spellcheck): :art: Update guidelines: accents
jeremyarancio Jun 25, 2024
a6bc377
refactor(spellcheck): :sparkles: Update Logging to consider script an…
jeremyarancio Jun 25, 2024
94b17d7
feat(spellcheck): :sparkles: Dataset processing methods & pipeline cr…
jeremyarancio Jun 25, 2024
8a3e9eb
build(spellcheck): :sparkles: Dataset processing (oe, percentage alig…
jeremyarancio Jun 27, 2024
fd0460b
feat(spellcheck): :recycle: Training Mistral-7B-Instruct: instruction…
jeremyarancio Jun 28, 2024
1559740
Delete previous training lllm dag
jeremyarancio Jun 28, 2024
a612752
feat(spellcheck): :zap: Add eval normalization: remove "\n"
jeremyarancio Jul 1, 2024
cb393c6
refactor(spellcheck): :construction_worker: Foundational LLMs re-eval…
jeremyarancio Jul 1, 2024
edce9c5
Modify overfitted prompt
jeremyarancio Jul 1, 2024
1645a83
Merge branch 'develop' into spellcheck
jeremyarancio Jul 8, 2024
6170159
refactor(spellcheck): :label: Refactor Argilla extraction: modules + …
jeremyarancio Jul 8, 2024
d2f63c7
refactor(spellcheck): :arrow_up: Refactor training job: add parameter…
jeremyarancio Jul 8, 2024
b9c1dc2
feat(spellcheck): :art: Fine-tune Mistral-7b with guidelines + traini…
jeremyarancio Jul 9, 2024
6ddf189
feat(spellcheck): :sparkles: Add args for training + Train Mistral-7b…
jeremyarancio Jul 9, 2024
33e9691
fix(spellcheck): :bug: Correction error in Mistral-7B-Base fine-tunin…
jeremyarancio Jul 10, 2024
a3d310d
feat(spellcheck): :art: DPO dataset extraction and push
jeremyarancio Jul 15, 2024
fea5716
fix(spellcheck): :bookmark: small fixes
jeremyarancio Jul 17, 2024
7092b69
feat(spellcheck): :zap: DPO training script
jeremyarancio Jul 17, 2024
7448e10
refactor(spellcheck): :construction: Refactor training pipeline: WIP
jeremyarancio Jul 17, 2024
cdcc825
Update get_logger for Metaflow logging
jeremyarancio Jul 18, 2024
7cf82a9
feat(spellcheck): :sparkles: Double the benchmark size: extraction an…
jeremyarancio Jul 18, 2024
e28a470
docs(spellcheck): :memo: Document benchmark generation pipeline
jeremyarancio Jul 21, 2024
b6b5bf8
feat(spellcheck): :bug: Remove legacy metadata in Argilla
jeremyarancio Jul 21, 2024
a4f295a
refactor(spellcheck): :construction: Refactor training pipeline (WIP)
jeremyarancio Jul 22, 2024
c6aaee5
refactor(spellcheck): :construction: Refactor training script (WIP)
jeremyarancio Jul 23, 2024
4401966
refactor(spellcheck): :construction: Refactor training pipeline
jeremyarancio Jul 23, 2024
e870a24
chore(spellcheck): :sparkles: Update Python from 3.9 to 3.10
jeremyarancio Jul 23, 2024
08a097a
refactor(spellcheck): :zap: LLM training pipeline refactored
jeremyarancio Jul 24, 2024
b563be1
feat(spellcheck): :sparkles: Pretraining before fine-tuning (WIP)
jeremyarancio Jul 25, 2024
f0d924d
feat(spellcheck): :ambulance: Pretraining + Finetuning Mistral-7B
jeremyarancio Jul 25, 2024
ac4f496
refactor(spellcheck): :sparkles: Refactor training
jeremyarancio Jul 28, 2024
1600a65
feat(spellcheck): :construction: Batch processing (WIP)
jeremyarancio Aug 5, 2024
f28dbc5
feat(spellcheck): :construction: Batch job with vllm and GCP (WIP)
jeremyarancio Aug 14, 2024
c7b0709
feat(spellcheck): :zap: Batch job operational on GCP
jeremyarancio Aug 15, 2024
0ef32f8
refactor(spellcheck): :sparkles: Clean code and add logging to batch job
jeremyarancio Aug 16, 2024
70a33e1
fix(spellcheck): :package: Forgot to add batch dep requirements
jeremyarancio Aug 16, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions spellcheck/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
FROM pytorch/pytorch:2.2.0-cuda12.1-cudnn8-devel

ENV PYTHONUNBUFFERED=1 \
PYTHONDONTWRITEBYTECODE=1 \
PIP_DISABLE_PIP_VERSION_CHECK=on \
PYTHONPATH="/app/src"

WORKDIR /app

COPY ./src /app

COPY ./scripts/batch/. /app

RUN pip install --no-cache-dir -r requirements.txt

# Set the entrypoint to the batch job script
ENTRYPOINT ["python", "main.py"]
33 changes: 18 additions & 15 deletions spellcheck/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@ From the different types of errors observed across products, we came up with the
* The only case when a whitespace involving a percentage should be modified is if the *digit* is stuck in the previous word (*ex: cheese1.9% -> cheese 1.9%*)
* Some ingredients are enclosed with `_`, such as `_milk_` or `_cacahuetes_`, to detect allergens. Should remain unchanged. However, in the case it is not an ingredient, such as `_Cacahuetes_ con cáscara tostado. _Trazas de frutos de cáscara_.`, it needs to be modified into `_Cacahuetes_ con cáscara tostado. Trazas de frutos de cáscara.`;
* Some percentages were badly parsed by the OCR. Since we cannot be sure about what is the right value, it is preferable to keep it as it is.
* We're ok with accents modified or not.
* Accents and other language specific punctuations:
* In Romanian, the characters ["ş" (351), "ţ" (355)] (ASCII id) should be retrieved by the Spellcheck when necessary,
* Uppercase letters should remain unchanged => "ECOSSE" -> "ECOSSE"; "ÉCOSSE" -> "ÉCOSSE"
* If lowercase, accent should be added if missing.
* `*` should remain in the corrected text as much as possible (*ex: Schweinefleisch\* -> Schweinefleisch\**)
* Whitespaces shouldn't been modified except for these cases:
* When two words are stuck to each other: *"rizbrun -> riz brun*
Expand Down Expand Up @@ -233,30 +236,30 @@ We evaluated **Proprietary LLMs** such as OpenAI GPTs and Anthropic Claude 3 mod

Texts are normalized to not consider some typical corrections:
* lowercase-uppercase
* whitespaces between words
* words are stripped (whitespace)
* replace ("œ", "oe")
* replace ("flavour", "flavor") - ("colour", "color") - ("pasteurized", "pasteurised")
* removed all accent using the Unidecode library
* remove linebreaks: ("\n", "")

In addition to computing metrics using the evaluation algorithm, predictions against the benchmark are pushed to Argilla for human evaluation. The proportion of good corrections is then calculated.

Benchmark version: **v5**
Prompt version: **v6**
Benchmark version: **v7.3** -- Prompt version: **v7**


| Model | Correction Precision | Correction Recall | Localisation Precision | Localisation Recall | Localisation F1 | Human evaluation
|----------|----------|----------|----------|----------|----------|----------|
| FlanT5-Small | **0.815** | 0.486 | **0.876** | 0.522 | 0.654 | - |
| GPT-3.5-Turbo | 0.729 | **0.779** | 0.767 | **0.820** | **0.793** | **0.894** |
| Gemini-1.0-pro | 0.499 | 0.586 | 0.561 | 0.658 | 0.605 | 0.717 |
| Gemini-1.5-flash | 0.514 | 0.693 | 0.590 | 0.795 | 0.677 | 0.790 |
| Gemini-1.5-pro | 0.364 | 0.658 | 0.415 | 0.750 | 0.534 | - |
| Mistral-7B-Instruct-v3 (not fine-tuned) | 0.381 | 0.501 | 0.488 | 0.641 | 0.554 | - |
| Model | Correction Precision | Correction Recall | Correction F1 | Human evaluation
|----------|----------|----------|----------|----------|
| GPT-3.5-Turbo | 0.557 | 0.727 | 0.631 | - |
| GPT-4o | 0.311 | 0.702 | 0.431 |
| Gemini-1.5-flash | 0.544 | 0.596 | 0.569 | - |
| Claude3-Sonnet-3.5 | 0.178 | **0.810** | 0.292 | - |
| **Our model** | **0.664** | 0.630 | **0.647** | - |


Notes:
* **Correction Precision**: Proportion of correct modifications.
* **Correction Recall**: Proportion of errors found and corrected
* **Localisation Precision**: Proportion of errors rightly detected by the model
* **Localisation Recall**: Proportion of errors founded
* **Localisation F1**: Mean-like between Precision and Recall
* **Correction F1**: Mean-like between Precision and Recall
* **Human evaluation**: Proportion of good corrections after human analysis

### 100 % known-ingredients products
Expand Down
9 changes: 9 additions & 0 deletions spellcheck/commands/extract_from_argilla.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
python scripts/dags/extract_from_argilla.py run \
--deploy_to_hf true \
--local_path data/dataset/deployed_data.parquet \
--argilla_dataset_name training_dataset \
--dataset_hf_repo openfoodfacts/spellcheck-dataset \
--dataset_revision v4 \
--dataset_test_size 0.1 \
--dataset_version v4.3 \
--status submitted
6 changes: 6 additions & 0 deletions spellcheck/commands/run_training.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
python scripts/dags/training/training.py run \
--do_human_eval False \
--evaluation_data_version v8.0 \
--training_data_version v5.2 \
--experiment_tag eval_loss --experiment_tag mistral-7b-v0.3 --experiment_tag eval-normalization --experiment_tag test

40 changes: 0 additions & 40 deletions spellcheck/config/training.yml

This file was deleted.

50 changes: 50 additions & 0 deletions spellcheck/config/training/pretraining_conf.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
estimator:
entry_point: "pretraining_llm.py" # train script
source_dir: "scripts/training/llm/" # directory containing training script and requirements requirements.
dependencies:
- "src/" # Additional local library
output_path: "s3://open-food-facts-robotoff/spellcheck/model-training/" # s3 path to save the artifacts
code_location: "s3://open-food-facts-robotoff/spellcheck/model-training/" # s3 path to stage the code during the training job
base_job_name: "mistral-7b-v03" # name of the training job
instance_count: 1 # the number of instances used for training
instance_type: "ml.g5.2xlarge" # instances type used for the training job
transformers_version: "4.36" # transformers version used in the training job
pytorch_version: "2.1" # pytorch_version version used in the training job
py_version: "py310" # python version used in the training job
disable_output_compression: true # not compress output to save training time and cost
volume_size: 300 # the size of the EBS volume in GB

hyperparameters:
# Data
training_data: "openfoodfacts/spellcheck-corpus"
train_split: "train"

# Trainer
output_dir: "/opt/ml/model"
pretrained_model_name: "mistralai/Mistral-7B-v0.3"
num_train_epochs: 1
per_device_train_batch_size: 4
learning_rate: 0.0002 # Paper https://arxiv.org/pdf/2210.11416
warmup_steps: 0
warmup_ratio: 0.1
weight_decay: 0.1
gradient_checkpointing: true
seed: 42
optim: "adamw_torch_fused" # The optimizer to use: adamw_hf, adamw_torch, adamw_torch_fused, adamw_apex_fused, adamw_anyprecision or adafactor.
lr_scheduler_type: "cosine"
gradient_accumulation_steps: 8
bf16: true
tf32: true
fp16: false
logging_steps : 1
save_total_limit: 1
report_to: "none" # Important to avoid superposition of Trainer callback and our custom callback
max_seq_length: 2048
packing: true
dataset_text_field: "ingredients_text"
# add_special_tokens: true # Add bos token and other special token from the tokenizer
# append_concat_token: true # If true, appends eos_token_id at the end of each sample being packed.

# Saving
merge_weights: true
max_shard_size: "2GB"
75 changes: 75 additions & 0 deletions spellcheck/config/training/training_conf.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
estimator:
entry_point: "refactored_llm.py" # train script
source_dir: "scripts/training/llm/" # directory containing training script and requirements requirements.
dependencies:
- "src/" # Additional local library
output_path: "s3://open-food-facts-robotoff/spellcheck/model-training/" # s3 path to save the artifacts
code_location: "s3://open-food-facts-robotoff/spellcheck/model-training/" # s3 path to stage the code during the training job
base_job_name: "mistral-7b-v03" # name of the training job
instance_count: 1 # the number of instances used for training
instance_type: "ml.g5.2xlarge" # instances type used for the training job
transformers_version: "4.36" # transformers version used in the training job
pytorch_version: "2.1" # pytorch_version version used in the training job
py_version: "py310" # python version used in the training job
disable_output_compression: true # not compress output to save training time and cost
volume_size: 300 # the size of the EBS volume in GB

additional_conf:
s3_evaluation_uri: "s3://open-food-facts-robotoff/spellcheck/evaluation_output/"

hyperparameters:
# Data
training_data: "openfoodfacts/spellcheck-dataset"
evaluation_data: "openfoodfacts/spellcheck-benchmark"
train_split: "train+test"
eval_split: "train"
train_text_feature: "original"
train_label_feature: "reference"
eval_text_feature: "original"
eval_label_feature: "reference"
train_data_revision: "v5"
eval_data_revision: "v8"

# TrainingArguments
output_dir: "/opt/ml/model"
pretrained_model_name: "mistralai/Mistral-7B-v0.3"
num_train_epochs: 0.01
per_device_train_batch_size: 8
per_device_eval_batch_size: 4
learning_rate: 0.0002 # Paper https://arxiv.org/pdf/2210.11416
warmup_steps: 0
warmup_ratio: 0.1
weight_decay: 0.1
gradient_checkpointing: true
seed: 42
optim: "adamw_torch_fused" # The optimizer to use: adamw_hf, adamw_torch, adamw_torch_fused, adamw_apex_fused, adamw_anyprecision or adafactor.
lr_scheduler_type: "cosine"
gradient_accumulation_steps: 4
bf16: true
tf32: true
fp16: false
logging_steps : 5
evaluation_strategy: "steps"
save_strategy: "steps"
eval_steps: 10 # Careful, need to be a multiple of eval-steps: 500 by default
save_total_limit: 1
report_to: "none" # Important to avoid superposition of Trainer callback and our custom callback

# SFTConfig
max_seq_length: 1024
packing: true
dataset_text_field: "text"
# add_special_tokens: true # Add bos token and other special token from the tokenizer
# append_concat_token: true # If true, appends eos_token_id at the end of each sample being packed.

# Saving
merge_weights: true
max_shard_size: "2GB"

# Inference
max_new_token: 1024
batch_size: 1

#Data processing
batched: false
# instruction_template
85 changes: 85 additions & 0 deletions spellcheck/data/evaluation/metrics.jsonl
Original file line number Diff line number Diff line change
Expand Up @@ -466,3 +466,88 @@
"prompt_version": "v6",
"benchmark_size": 151
}
{
"metrics": {
"correction_precision": 0.7101024890190337,
"correction_recall": 0.7860615883306321,
"precision": 0.7481698389458272,
"recall": 0.8282009724473258,
"f1": 0.7861538461538463,
"f1_beta": 0.7861538461538463,
"beta": 1.0,
"drop_count": 0
},
"model": "gpt-3.5-turbo",
"date": "25/06/2024 10:19:48",
"benchmark_version": "v5",
"prompt_version": "v6",
"benchmark_size": 151
}
{
"metrics": {
"correction_precision": 0.5573366214549939,
"correction_recall": 0.7266881028938906,
"precision": 0.6091245376078915,
"recall": 0.7942122186495176,
"f1": 0.6894626657362177,
"f1_beta": 0.6894626657362177,
"beta": 1.0,
"drop_count": 0
},
"model": "gpt-3.5-turbo",
"date": "01/07/2024 13:08:32",
"benchmark_version": "v7.3",
"prompt_version": "v7",
"benchmark_size": 151
}
{
"metrics": {
"correction_precision": 0.5439882697947214,
"correction_recall": 0.5964630225080386,
"precision": 0.6304985337243402,
"recall": 0.6913183279742765,
"f1": 0.6595092024539878,
"f1_beta": 0.6595092024539878,
"beta": 1.0,
"drop_count": 0
},
"model": "gemini-1.5-flash-preview-0514",
"date": "01/07/2024 15:19:04",
"benchmark_version": "v7.3",
"prompt_version": "v7",
"benchmark_size": 151
}
{
"metrics": {
"correction_precision": 0.17844767844767845,
"correction_recall": 0.809748427672956,
"precision": 0.19889119889119888,
"recall": 0.9025157232704403,
"f1": 0.3259511641113004,
"f1_beta": 0.3259511641113004,
"beta": 1.0,
"drop_count": 0
},
"model": "claude-3-5-sonnet-20240620",
"date": "01/07/2024 16:07:23",
"benchmark_version": "v7.3",
"prompt_version": "v7",
"benchmark_size": 152
}
{
"metrics": {
"correction_precision": 0.31130063965884863,
"correction_recall": 0.7019230769230769,
"precision": 0.35252309879175553,
"recall": 0.7948717948717948,
"f1": 0.4884293451501724,
"f1_beta": 0.4884293451501724,
"beta": 1.0,
"drop_count": 0
},
"model": "gpt-4o-2024-05-13",
"date": "01/07/2024 16:34:15",
"benchmark_version": "v7.3",
"prompt_version": "v7",
"benchmark_size": 151
}
Loading
Loading