Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add code eval #587

Merged
merged 29 commits into from
Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
a57dde8
add code eval
samhavens Sep 8, 2023
a63c4aa
Create coding_tasks.yaml
samhavens Sep 12, 2023
71de0bf
Update coding_tasks.yaml
samhavens Sep 12, 2023
85b727f
Update tasks.yaml
samhavens Sep 12, 2023
b0670e8
Update eval_gauntlet.yaml
samhavens Sep 12, 2023
1e868d8
update readme add code eval
samhavens Sep 13, 2023
c6e54ae
torch nograd eval_model
samhavens Sep 13, 2023
a2c877b
integrate seed
mcarbin Sep 14, 2023
071ddde
Merge branch 'mosaicml:sam/add-coding-eval' into sam/add-coding-eval
mcarbin Sep 14, 2023
12c0175
pass_at_k threading
mcarbin Sep 14, 2023
2755e8d
remove extraneous print
mcarbin Sep 22, 2023
06ae4fb
update num_beams to default of 20
mcarbin Sep 22, 2023
bfbbf17
update num_beamss for coding eval to 20 as the standard for pass@1
mcarbin Sep 22, 2023
e042eb7
remove C from default coding tasks because it's too small
mcarbin Sep 22, 2023
1fa60ec
update num_beams in full task list to default of 20
mcarbin Sep 22, 2023
9e241f5
Merge branch 'main' into sam/add-coding-eval
mcarbin Sep 25, 2023
4b5f2d9
fix break from merge and precommit
mcarbin Sep 25, 2023
c3819de
yamllint
mcarbin Sep 25, 2023
cb45e4b
remove code finetuning
mcarbin Sep 25, 2023
e3d7dda
Merge branch 'main' into sam/add-coding-eval
mcarbin Sep 25, 2023
0acec3d
Merge branch 'main' into sam/add-coding-eval
dakinggg Sep 26, 2023
79d1d8c
Merge branch 'main' into sam/add-coding-eval
dakinggg Sep 26, 2023
628e020
bump version for CI
mcarbin Sep 26, 2023
797c690
Merge remote-tracking branch 'refs/remotes/mosaicml/sam/add-coding-ev…
mcarbin Sep 26, 2023
0860f91
update gauntlet
mcarbin Sep 26, 2023
fbb7342
update tasks
mcarbin Sep 26, 2023
b10beaf
Merge branch 'main' into sam/add-coding-eval
mcarbin Sep 26, 2023
63470b4
remove torch.no_grad
mcarbin Sep 26, 2023
707a2e7
Merge branch 'main' into sam/add-coding-eval
dakinggg Sep 26, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion llmfoundry/models/hf/hf_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@

# required for loading a python model into composer
import transformers
from composer.metrics.nlp import (InContextLearningLMAccuracy,
from composer.metrics.nlp import (InContextLearningCodeEvalAccuracy,
InContextLearningLMAccuracy,
InContextLearningLMExpectedCalibrationError,
InContextLearningMCExpectedCalibrationError,
InContextLearningMultipleChoiceAccuracy,
Expand Down Expand Up @@ -74,6 +75,7 @@ def __init__(self, om_model_config: Union[DictConfig,
InContextLearningLMAccuracy(),
InContextLearningMultipleChoiceAccuracy(),
InContextLearningQAAccuracy(),
InContextLearningCodeEvalAccuracy(),
InContextLearningLMExpectedCalibrationError(),
InContextLearningMCExpectedCalibrationError()
]
Expand Down
4 changes: 3 additions & 1 deletion llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from composer.metrics import (InContextLearningLMAccuracy,
from composer.metrics import (InContextLearningCodeEvalAccuracy,
InContextLearningLMAccuracy,
InContextLearningLMExpectedCalibrationError,
InContextLearningMCExpectedCalibrationError,
InContextLearningMultipleChoiceAccuracy,
Expand Down Expand Up @@ -700,6 +701,7 @@ def __init__(
InContextLearningLMAccuracy(),
InContextLearningMultipleChoiceAccuracy(),
InContextLearningQAAccuracy(),
InContextLearningCodeEvalAccuracy(),
InContextLearningLMExpectedCalibrationError(),
InContextLearningMCExpectedCalibrationError(),
]
Expand Down
8 changes: 8 additions & 0 deletions llmfoundry/utils/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,8 @@ def _validate_cfg(icl_cfg: DictConfig):
]
elif icl_cfg.icl_task_type == 'question_answering':
icl_cfg.metric_names = ['InContextLearningQAAccuracy']
elif icl_cfg.icl_task_type == 'code_evaluation':
icl_cfg.metric_names = ['InContextLearningCodeEvalAccuracy']
else:
raise ValueError(
f'No metric_names defined, unable to build default metrics for icl_task_type={icl_cfg.icl_task_type}.'
Expand All @@ -244,6 +246,10 @@ def _validate_cfg(icl_cfg: DictConfig):
icl_cfg.max_seq_len = default_max_seq_len
if 'batch_size' not in icl_cfg:
icl_cfg.batch_size = default_batch_size
if 'pass_at_k' not in icl_cfg:
dakinggg marked this conversation as resolved.
Show resolved Hide resolved
icl_cfg.pass_at_k = 1
if 'num_beams' not in icl_cfg:
icl_cfg.num_beams = 20
dakinggg marked this conversation as resolved.
Show resolved Hide resolved

for icl_cfg in icl_tasks_list:
assert isinstance(icl_cfg, DictConfig)
Expand Down Expand Up @@ -274,6 +280,8 @@ def _validate_cfg(icl_cfg: DictConfig):
example_delimiter=icl_cfg.example_delimiter,
continuation_delimiter=icl_cfg.continuation_delimiter,
destination_path=destination_path,
pass_at_k=icl_cfg.pass_at_k,
generations_per_sample=icl_cfg.num_beams,
has_categories=icl_cfg.get('has_categories', False),
)
if hasattr(
Expand Down
35 changes: 30 additions & 5 deletions scripts/eval/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,12 +153,13 @@ This document explains the ICL formats compatible with [Composer](https://github

## Supported ICL formats

Composer currently supports four ICL formats
Composer currently supports five ICL formats:

1. [InContextLearningQATaskDataset](https://github.com/mosaicml/composer/blob/v0.14.0/composer/datasets/in_context_learning_evaluation.py#L92-L253)
2. [InContextLearningLMTaskDataset](https://github.com/mosaicml/composer/blob/v0.14.0/composer/datasets/in_context_learning_evaluation.py#L256-L402)
3. [InContextLearningMultipleChoiceTaskDataset](https://github.com/mosaicml/composer/blob/v0.14.0/composer/datasets/in_context_learning_evaluation.py#L405-L599)
4. [InContextLearningSchemaTaskDataset](https://github.com/mosaicml/composer/blob/v0.14.0/composer/datasets/in_context_learning_evaluation.py#L602-L773)
1. [InContextLearningQATaskDataset](https://github.com/mosaicml/composer/blob/336bf8db3e2c09ae942d4bf8a819935106589d1a/composer/datasets/in_context_learning_evaluation.py#L103)
2. [InContextLearningLMTaskDataset](https://github.com/mosaicml/composer/blob/336bf8db3e2c09ae942d4bf8a819935106589d1a/composer/datasets/in_context_learning_evaluation.py#L293)
3. [InContextLearningMultipleChoiceTaskDataset](https://github.com/mosaicml/composer/blob/336bf8db3e2c09ae942d4bf8a819935106589d1a/composer/datasets/in_context_learning_evaluation.py#L444)
4. [InContextLearningSchemaTaskDataset](https://github.com/mosaicml/composer/blob/336bf8db3e2c09ae942d4bf8a819935106589d1a/composer/datasets/in_context_learning_evaluation.py#L676)
5. [InContextLearningCodeEvalDataset](https://github.com/mosaicml/composer/blob/336bf8db3e2c09ae942d4bf8a819935106589d1a/composer/datasets/in_context_learning_evaluation.py#L852)

----

Expand Down Expand Up @@ -346,6 +347,30 @@ Below is a YAML section that works with the Winograd dataset in [`scripts/eval/l
continuation_delimiter: ' ' # this separates questions from answers
>


----

### InContextLearningCodeEvalDataset

The ICL CodeEvalDataset takes a prompt, and, working with the NLP metric [InContextLearningCodeEvalAccuracy](https://docs.mosaicml.com/projects/composer/en/latest/api_reference/generated/composer.metrics.InContextLearningCodeEvalAccuracy.html), generates code which gets run against the supplied tests, as in HumanEval ([Evaluating Large Language Models Trained on Code](https://arxiv.org/abs/2107.03374)) and MBPP ([Program Synthesis with Large Language Models](https://arxiv.org/abs/2108.07732)). This generation involves many decoding steps, so can take longer per sample than other ICL tasks. An example datum:

```json
{"task_id": "JavaScript/2", "prompt": "/* Given a positive floating point number, it can be decomposed into\n and integer part (largest integer smaller than given number) and decimals\n (leftover part always smaller than 1).\n\n Return the decimal part of the number.\n >>> truncateNumber(3.5)\n 0.5\n */\nconst truncateNumber = (number) => {\n", "canonical_solution": " return number % 1.0;\n}\n\n", "test": "const testTruncateNumber = () => {\n console.assert(truncateNumber(3.5) === 0.5)\n\n console.assert(Math.abs(truncateNumber(1.33) - 0.33) < 1e-6)\n\n console.assert(Math.abs(truncateNumber(123.456 - 0.456) < 1e-6))\n}\n\ntestTruncateNumber()\n", "entry_point": "truncateNumber", "test_inputs": ["3.5", "1.33", "123.456"], "test_outputs": ["0.5", "0.33", "0.456"], "language": "javascript"}
```

Required keys for each datum:

* `prompt: str`
* `test: str`
* `entry_point: str`
* `test_inputs: List[str]`
* `test_outputs: List[str]`
* `language: str`

Code evaluation can happen locally (insecure) or inside an AWS Lambda function sandbox. This is controlled by setting the environment variable `CODE_EVAL_DEVICE` to `LOCAL` or `LAMBDA`. If set to `LAMBDA`, you must also provide `CODE_EVAL_URL` and `CODE_EVAL_APIKEY` to query the API gateway in the AWS Sandbox.

----

### Build your own dataset (BYOD)
Building a dataset compatible with our eval suite is very easy if it fits with one of the four supported task types. Simply choose the appropriate task type (LM, MC, QA, or Schema) and process each dataset into a jsonl format in which each row has the format described above.

Expand Down
4 changes: 4 additions & 0 deletions scripts/eval/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def evaluate_model(
model_cfg: DictConfig,
dist_timeout: Union[float, int],
run_name: str,
seed: int,
icl_tasks: Union[str, ListConfig],
max_seq_len: int,
device_eval_batch_size: int,
Expand All @@ -107,6 +108,7 @@ def evaluate_model(
eval_gauntlet_df: Optional[pd.DataFrame],
icl_subset_num_batches: Optional[int],
):

print(f'Evaluating model: {model_cfg.model_name}', flush=True)
# Build tokenizer and model
tokenizer_cfg: Dict[str,
Expand Down Expand Up @@ -158,6 +160,7 @@ def evaluate_model(

trainer = Trainer(
run_name=run_name,
seed=seed,
model=composer_model,
callbacks=callbacks,
loggers=loggers,
Expand Down Expand Up @@ -276,6 +279,7 @@ def main(cfg: DictConfig):
model_cfg=model_cfg,
dist_timeout=dist_timeout,
run_name=run_name,
seed=seed,
icl_tasks=icl_tasks,
max_seq_len=max_seq_len,
device_eval_batch_size=device_eval_batch_size,
Expand Down
Loading
Loading