-
Notifications
You must be signed in to change notification settings - Fork 0
/
finetune.py
75 lines (69 loc) · 2.58 KB
/
finetune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from datasets import Dataset
from numpy.core.numeric import _rollaxis_dispatcher
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, AutoConfig, TrainingArguments, Trainer, default_data_collator
import os
from utils import seed_everything, read_squad, prepare_train_features, get_time
from pprint import pprint
import datetime
import argparse
import json
from pathlib import Path
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def parse_args():
parser = argparse.ArgumentParser(description='')
arg = parser.add_argument
arg('--model_checkpoint', required=True, type=str)
arg('--train_path', required=True, type=str)
arg('--max_length', required=True, type=int)
arg('--doc_stride', required=True, type=int)
arg('--epochs', required=True, type=int)
arg('--batch_size', required=True, type=int)
arg('--accumulation_steps', required=True, type=int)
arg('--lr', required=True, type=float)
arg('--weight_decay', required=True, type=float)
arg('--warmup_ratio', required=True, type=float)
arg('--seed', required=True, type=int)
arg('--dropout', required=True, type=float)
return parser.parse_args()
args = parse_args()
seed_everything(args.seed)
train_dataset = Dataset.from_dict(read_squad(args.train_path))
tokenizer = AutoTokenizer.from_pretrained(args.model_checkpoint)
tokenized_train_ds = train_dataset.map(lambda x: prepare_train_features(
x, tokenizer, args.max_length, args.doc_stride, tokenizer.padding_side == "right"), batched=True, remove_columns=train_dataset.column_names)
cfg = AutoConfig.from_pretrained(args.model_checkpoint)
cfg.hidden_dropout_prob = args.dropout
cfg.attention_probs_dropout_prob = args.dropout
model = AutoModelForQuestionAnswering.from_pretrained(
args.model_checkpoint, config=cfg)
timenow = get_time()
out_dir = Path(f'./model/{timenow}/')
out_dir.mkdir(parents=True, exist_ok=True)
with open(out_dir / 'hyp.json', 'w') as f:
d = args.__dict__
d['time'] = timenow
json.dump(d, f, indent=4)
pprint(d)
args = TrainingArguments(
out_dir,
evaluation_strategy="no",
save_strategy="epoch",
learning_rate=args.lr,
warmup_ratio=args.warmup_ratio,
gradient_accumulation_steps=args.accumulation_steps,
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.batch_size,
num_train_epochs=args.epochs,
weight_decay=args.weight_decay,
fp16=True,
report_to='none',
dataloader_num_workers=4
)
trainer = Trainer(
model,
args,
train_dataset=tokenized_train_ds,
tokenizer=tokenizer,
)
trainer.train()
print(datetime.datetime.now())