Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace apex with torch.cuda.amp & transformer and python latest version #1257

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 52 additions & 51 deletions layoutlm/deprecated/examples/seq_labeling/run_seq_labeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import os
import random
import shutil

import numpy as np
import torch
from seqeval.metrics import (
Expand All @@ -49,17 +48,14 @@
get_linear_schedule_with_warmup,
)

import torch.cuda.amp as amp

from layoutlm import FunsdDataset, LayoutlmConfig, LayoutlmForTokenClassification


logger = logging.getLogger(__name__)

ALL_MODELS = sum(
(
tuple(conf.pretrained_config_archive_map.keys())
for conf in (BertConfig, RobertaConfig, LayoutlmConfig)
),
(),
)
ALL_MODELS = ('bert-base-uncased', 'bert-large-uncased', 'bert-base-cased', 'bert-large-cased', 'bert-base-multilingual-uncased', 'bert-base-multilingual-cased', 'bert-base-chinese', 'bert-base-german-cased', 'bert-large-uncased-whole-word-masking', 'bert-large-cased-whole-word-masking', 'bert-large-uncased-whole-word-masking-finetuned-squad', 'bert-large-cased-whole-word-masking-finetuned-squad', 'bert-base-cased-finetuned-mrpc', 'bert-base-german-dbmdz-cased', 'bert-base-german-dbmdz-uncased', 'bert-base-japanese', 'bert-base-japanese-whole-word-masking', 'bert-base-japanese-char', 'bert-base-japanese-char-whole-word-masking', 'bert-base-finnish-cased-v1', 'bert-base-finnish-uncased-v1', 'bert-base-dutch-cased', 'roberta-base', 'roberta-large', 'roberta-large-mnli', 'distilroberta-base', 'roberta-base-openai-detector', 'roberta-large-openai-detector')

MODEL_CLASSES = {
"bert": (BertConfig, BertForTokenClassification, BertTokenizer),
Expand Down Expand Up @@ -152,16 +148,13 @@ def train( # noqa C901
scheduler = get_linear_schedule_with_warmup(
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
)
import torch
# from torch.utils.data import DataLoader
from tqdm import tqdm

# Enable mixed precision training
if args.fp16:
try:
from apex import amp
except ImportError:
raise ImportError(
"Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
)
model, optimizer = amp.initialize(
model, optimizer, opt_level=args.fp16_opt_level
)
scaler = torch.cuda.amp.GradScaler()

# multi-gpu training (should be after apex fp16 initialization)
if args.n_gpu > 1:
Expand All @@ -180,72 +173,79 @@ def train( # noqa C901
logger.info("***** Running training *****")
logger.info(" Num examples = %d", len(train_dataset))
logger.info(" Num Epochs = %d", args.num_train_epochs)
logger.info(
" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size
)
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
logger.info(
" Total train batch size (w. parallel, distributed & accumulation) = %d",
args.train_batch_size
* args.gradient_accumulation_steps
* (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
)
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
logger.info(" Total optimization steps = %d", t_total)

global_step = 0
tr_loss, logging_loss = 0.0, 0.0
model.zero_grad()
train_iterator = trange(
int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]
)
set_seed(args) # Added here for reproductibility (even between python 2 and 3)
for _ in train_iterator:
epoch_iterator = tqdm(
train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0]
)

train_iterator = tqdm(range(int(args.num_train_epochs)), desc="Epoch", disable=args.local_rank not in [-1, 0])
set_seed(args) # Added here for reproducibility (even between python 2 and 3)

for epoch in train_iterator:
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])

for step, batch in enumerate(epoch_iterator):
model.train()
inputs = {
"input_ids": batch[0].to(args.device),
"attention_mask": batch[1].to(args.device),
"labels": batch[3].to(args.device),
}
if args.model_type in ["layoutlm"]:

if args.model_type == "layoutlm":
inputs["bbox"] = batch[4].to(args.device)
inputs["token_type_ids"] = (
batch[2].to(args.device) if args.model_type in ["bert", "layoutlm"] else None
) # RoBERTa don"t use segment_ids

outputs = model(**inputs)
# model outputs are always tuple in pytorch-transformers (see doc)
loss = outputs[0]


if args.model_type in ["bert", "layoutlm"]:
inputs["token_type_ids"] = batch[2].to(args.device)

# Enable autocasting for the forward pass (input tensors will be automatically cast to fp16/float16)
if args.fp16:
with torch.cuda.amp.autocast(enabled=args.fp16):
outputs = model(**inputs)
# model outputs are always tuple in pytorch-transformers (see doc)
loss = outputs[0]
else:
outputs = model(**inputs)
loss = outputs[0]

if args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel training

if args.gradient_accumulation_steps > 1:
loss = loss / args.gradient_accumulation_steps

if args.fp16:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
scaler.scale(loss).backward()
else:
loss.backward()

tr_loss += loss.item()

if (step + 1) % args.gradient_accumulation_steps == 0:
if args.fp16:
torch.nn.utils.clip_grad_norm_(
amp.master_params(optimizer), args.max_grad_norm
)
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
else:
torch.nn.utils.clip_grad_norm_(
model.parameters(), args.max_grad_norm
)
optimizer.step()
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)

if args.fp16:
scaler.step(optimizer)
scaler.update()
else:
optimizer.step()

scheduler.step() # Update learning rate schedule
model.zero_grad()
global_step += 1


if (
args.local_rank in [-1, 0]
and args.logging_steps > 0
Expand Down Expand Up @@ -686,6 +686,7 @@ def main(): # noqa C901
from_tf=bool(".ckpt" in args.model_name_or_path),
config=config,
cache_dir=args.cache_dir if args.cache_dir else None,
ignore_mismatched_sizes=True,
)

if args.local_rank == 0:
Expand Down
2 changes: 1 addition & 1 deletion layoutlm/deprecated/layoutlm/modeling/layoutlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from transformers import BertConfig, BertModel, BertPreTrainedModel
from transformers.modeling_bert import BertLayerNorm
BertLayerNorm = torch.nn.LayerNorm

logger = logging.getLogger(__name__)

Expand Down