Skip to content

Commit

Permalink
Merge pull request #138 from BaderLab/biobert
Browse files Browse the repository at this point in the history
BioBERT

Former-commit-id: 6969515
  • Loading branch information
JohnGiorgi authored May 23, 2019
2 parents caedd7c + 7265e0d commit b59bd29
Show file tree
Hide file tree
Showing 9 changed files with 287 additions and 70 deletions.
7 changes: 5 additions & 2 deletions saber/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@
END = '<END>' # end-of-sentence token
OUTSIDE = 'O' # 'outside' tag of the IOB, BIO, and IOBES tag formats
WORDPIECE = 'X' # special tag used by BERTs wordpiece tokenizer
CLASSIFICATION = '[CLS]' # special tag used by BERTs classifiers

CLS = '[CLS]' # special BERT classification token
SEP = '[SEP]' # special BERT sequence seperator token

# DATA
RANDOM_STATE = 42 # random seed
Expand Down Expand Up @@ -59,6 +61,7 @@
'PRGE': '1xOmxpgNjQJK8OJSvih9wW5AITGQX6ODT',
'DISO': '1qmrBuqz75KM57Ug5MiDBfp0d5H3S_5ih',
'CHED': '13s9wvu3Mc8fG73w51KD8RArA31vsuL1c',
'biobert_v1.1_pubmed': '1jI1HyzMzSShjHfeO1pSmw5su8R6p5Vsv'
}
# relative path to pre-trained model directory
PRETRAINED_MODEL_DIR = resource_filename(__name__, 'pretrained_models')
Expand All @@ -85,7 +88,7 @@
KERAS = 'keras'
PYTORCH = 'pytorch'
# which pre-trained BERT model to use
PYTORCH_BERT_MODEL = 'bert-base-cased'
PYTORCH_BERT_MODEL = 'biobert_v1.1_pubmed'

# EXTRACT 2.0 API
# arguments passed in a get request to the EXTRACT 2.0 API to specify entity type
Expand Down
8 changes: 4 additions & 4 deletions saber/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ class BaseModel():
models (list): A list of Keras or PyTorch models.
"""
def __init__(self, config, datasets, embeddings=None, **kwargs):
self.config = config # hyperparameters and model details
self.datasets = datasets # dataset(s) tied to this instance
self.embeddings = embeddings # pre-trained word embeddings tied to this instance
self.model = None # Keras / PyTorch model tied to this instance
self.config = config # Hyperparameters and model details
self.datasets = datasets # Dataset(s) tied to this instance
self.embeddings = embeddings # Pre-trained word embeddings tied to this instance
self.model = None # Saber model tied to this instance

for key, value in kwargs.items():
setattr(self, key, value)
Expand Down
90 changes: 52 additions & 38 deletions saber/models/bert_token_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,12 @@ def __init__(self, config, datasets, pretrained_model_name_or_path='bert-base-ca
self.device = torch.device("cpu")
# number of GPUs available
self.n_gpus = 0

# the name or path of a pre-trained BERT model
self.pretrained_model_name_or_path = pretrained_model_name_or_path
# a tokenizer which corresponds to the pre-trained model to load
self.tokenizer = BertTokenizer.from_pretrained(self.pretrained_model_name_or_path,
do_lower_case=False)

# +1 necessary to account for 'X' tag introduced by wordpiece tokenization
self.num_labels = [len(dataset.idx_to_tag) + 1 for dataset in self.datasets]

self.model_name = 'bert-ner'

Expand All @@ -70,18 +71,24 @@ def load(self, model_filepath):
"""
# TODO (johngiorgi): Test that saving loading from CPU/GPU works as expected
model_state_dict = torch.load(model_filepath, map_location=lambda storage, loc: storage)
# this is a trick to get the number of labels

# This is a trick to get the number of labels
num_labels = len(model_state_dict['classifier.bias'])

# TODO (johngiorgi): Can we get the model name from the model_state_dict?
model = BertForTokenClassification.from_pretrained(self.pretrained_model_name_or_path,
num_labels=num_labels,
state_dict=model_state_dict)
# get the device the model will live on, along with number of GPUs available
tokenizer = BertTokenizer.from_pretrained(self.pretrained_model_name_or_path,
do_lower_case=False)

# Get the device the model will live on, along with number of GPUs available
self.device, self.n_gpus = model_utils.get_device(model)

self.model = model
self.tokenizer = tokenizer

return model
return model, tokenizer

def specify(self):
"""Specifies an op-for-op PyTorch implementation of Google's BERT for sequence tagging.
Expand All @@ -90,16 +97,22 @@ def specify(self):
is appended to `self.models`.
"""
# (TODO): Update to support MT learning.
# plus 1 is necessary to account for 'X' tag introduced by wordpiece tokenization
num_labels = len(self.datasets[0].type_to_idx['tag']) + 1
if self.pretrained_model_name_or_path in constants.PRETRAINED_MODELS:
self.pretrained_model_name_or_path = \
model_utils.download_model_from_gdrive(self.pretrained_model_name_or_path)

model = BertForTokenClassification.from_pretrained(self.pretrained_model_name_or_path,
num_labels=num_labels)
# get the device the model will live on, along with number of GPUs available
num_labels=self.num_labels[0])
tokenizer = BertTokenizer.from_pretrained(self.pretrained_model_name_or_path,
do_lower_case=False)

# Get the device the model will live on, along with number of GPUs available
self.device, self.n_gpus = model_utils.get_device(model)

self.model = model
self.tokenizer = tokenizer

return model
return model, tokenizer

def prepare_data_for_training(self):
"""Returns a list containing the training data for each dataset at `self.datasets`.
Expand Down Expand Up @@ -202,23 +215,22 @@ def train_valid_test(training_data, output_dir, optimizers):

# setup a progress bar
pbar_descr = f'Epoch: {epoch + 1}/{self.config.epochs}'
pbar = tqdm(training_data['train_dataloader'], unit='batch', desc=pbar_descr)
pbar = tqdm(training_data['train_dataloader'],
unit='batch',
desc=pbar_descr,
dynamic_ncols=True)

for _, batch in enumerate(pbar):

optimizer.zero_grad()

# update train loss in progress bar
train_loss = train_loss / train_steps if train_steps > 0 else 0.
pbar.set_postfix(train_loss=train_loss)

# add batch to device
batch = tuple(t.to(self.device) for t in batch)
b_input_ids, b_input_mask, b_labels = batch

# forward pass
loss = self.model(b_input_ids,
token_type_ids=None,
token_type_ids=torch.zeros_like(b_input_ids),
attention_mask=b_input_mask,
labels=b_labels)
# backward pass
Expand All @@ -235,6 +247,9 @@ def train_valid_test(training_data, output_dir, optimizers):
# update parameters
optimizer.step()

# update train loss in progress bar
pbar.set_postfix(loss=f'{train_loss / train_steps:.4f}')

# need to feed epoch argument manually, as this is a keras callback object
metrics.on_epoch_end(epoch=metrics.current_epoch)

Expand All @@ -244,7 +259,7 @@ def cross_validation(training_data, output_dir, optimizers):
# get the train / valid partitioned data for all datasets and all folds
training_data = data_utils.collect_cv_data(training_data, self.config.k_folds)

# TODO (John): In the future, this will be a list.
# TODO (John): In the future, this will be a list
optimizer = optimizers

# training loop
Expand All @@ -256,23 +271,21 @@ def cross_validation(training_data, output_dir, optimizers):
datasets=self.datasets,
training_data=training_data,
# TODO (John): Drop index when MTM is implemented
output_dir=output_dir,)[0]

# TODO (John): Drop when MTM is implemented
training_data = training_data[0]
output_dir=output_dir,
fold=fold)[0]

# TODO (John): Dataloaders should be handled outside of the train loop
# append dataloaders to training_data
training_data[fold]['train_dataloader'] = \
model_utils.get_dataloader_for_ber(x=training_data[fold]['x_train'][0],
y=training_data[fold]['y_train'],
attention_mask=training_data[fold]['x_train'][-1],
training_data[0][fold]['train_dataloader'] = \
model_utils.get_dataloader_for_ber(x=training_data[0][fold]['x_train'][0],
y=training_data[0][fold]['y_train'],
attention_mask=training_data[0][fold]['x_train'][-1],
config=self.config,
data_partition='train')
training_data[fold]['valid_dataloader'] = \
model_utils.get_dataloader_for_ber(x=training_data[fold]['x_valid'][0],
y=training_data[fold]['y_valid'],
attention_mask=training_data[fold]['x_valid'][-1],
training_data[0][fold]['valid_dataloader'] = \
model_utils.get_dataloader_for_ber(x=training_data[0][fold]['x_valid'][0],
y=training_data[0][fold]['y_valid'],
attention_mask=training_data[0][fold]['x_valid'][-1],
config=self.config,
data_partition='eval')

Expand All @@ -287,24 +300,22 @@ def cross_validation(training_data, output_dir, optimizers):
# setup a progress bar
fold_and_epoch = (fold + 1, self.config.k_folds, epoch + 1, self.config.epochs)
pbar_descr = 'Fold: {}/{}, Epoch: {}/{}'.format(*fold_and_epoch)
pbar = \
tqdm(training_data[fold]['train_dataloader'], unit='batch', desc=pbar_descr)
pbar = tqdm(training_data[0][fold]['train_dataloader'],
unit='batch',
desc=pbar_descr,
dynamic_ncols=True)

for _, batch in enumerate(pbar):

optimizer.zero_grad()

# update train loss in progress bar
train_loss = train_loss / train_steps if train_steps > 0 else 0.
pbar.set_postfix(train_loss=train_loss)

# add batch to gpu
batch = tuple(t.to(self.device) for t in batch)
b_input_ids, b_input_mask, b_labels = batch

# forward pass
loss = self.model(b_input_ids,
token_type_ids=None,
token_type_ids=torch.zeros_like(b_input_ids),
attention_mask=b_input_mask,
labels=b_labels)

Expand All @@ -327,6 +338,9 @@ def cross_validation(training_data, output_dir, optimizers):
# update parameters
optimizer.step()

# update train loss in progress bar
pbar.set_postfix(loss=f'{train_loss / train_steps:.4f}')

# need to feed epoch argument manually, as this is a keras callback object
metrics.on_epoch_end(epoch=metrics.current_epoch)

Expand All @@ -335,6 +349,7 @@ def cross_validation(training_data, output_dir, optimizers):
# clear and rebuild the model at end of each fold (except for the last fold)
if fold < self.config.k_folds - 1:
self.reset_model()
optimizer = model_utils.get_bert_optimizer(self.model, self.config)

# TODO: User should be allowed to overwrite this
if training_data[0]['x_valid'] is not None or training_data[0]['x_test'] is not None:
Expand Down Expand Up @@ -370,7 +385,6 @@ def evaluate(self, training_data, model_idx=-1, partition='train'):

# get the dataset / dataloader for the given partition
dataset = self.datasets[model_idx]
print(training_data)
dataloader = training_data[f'{partition}_dataloader']

for batch in dataloader:
Expand Down
16 changes: 3 additions & 13 deletions saber/saber.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from pprint import pprint

import numpy as np
from google_drive_downloader import GoogleDriveDownloader as gdd
from keras.models import Model
from spacy import displacy

Expand Down Expand Up @@ -260,18 +259,9 @@ def load(self, directory):
directory = [directory]

for dir_ in directory:
# get what might be a pretrained model name
pretrained_model = os.path.splitext(dir_)[0].strip().upper()

# allows user to provide names of pre-trained models (e.g. 'PRGE-base')
if pretrained_model in constants.PRETRAINED_MODELS:
dir_ = os.path.join(constants.PRETRAINED_MODEL_DIR, pretrained_model)
# download model from Google Drive, will skip if already exists
file_id = constants.PRETRAINED_MODELS[pretrained_model]
dest_path = '{}.tar.bz2'.format(dir_)
gdd.download_file_from_google_drive(file_id=file_id, dest_path=dest_path)

LOGGER.info('Loaded pre-trained model %s from Google Drive', pretrained_model)
# If directory is an available pretained model, download it from Google Drive
if dir_ in constants.PRETRAINED_MODELS:
dir_ = model_utils.download_model_from_gdrive(pretrained_model=dir_)

dir_ = generic_utils.clean_path(dir_)
generic_utils.extract_directory(dir_)
Expand Down
66 changes: 63 additions & 3 deletions saber/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
from ..dataset import Dataset
from ..embeddings import Embeddings
from ..metrics import Metrics
from ..models.base_model import BaseModel
from ..models.base_model import BaseKerasModel
from ..models.base_model import BasePyTorchModel
from ..models.bert_token_classifier import BertTokenClassifier
from ..models.multi_task_lstm_crf import MultiTaskLSTMCRF
from ..preprocessor import Preprocessor
Expand Down Expand Up @@ -376,6 +378,15 @@ def single_mt_bilstm_model_specify(single_mt_bilstm_model):
return single_mt_bilstm_model


@pytest.fixture
def compound_mt_bilstm_model_specify(compound_mt_bilstm_model):
"""Returns an instance of MultiTaskLSTMCRF initialized with the default configuration file and
a single specified model."""
compound_mt_bilstm_model.specify()

return compound_mt_bilstm_model


@pytest.fixture
def single_mt_bilstm_model_embeddings(dummy_config, dummy_dataset_1, dummy_embeddings):
"""Returns an instance of MultiTaskLSTMCRF initialized with the default configuration file and
Expand All @@ -398,7 +409,27 @@ def single_mt_bilstm_model_embeddings_specify(single_mt_bilstm_model_embeddings)


@pytest.fixture
def single_base_keras_model(dummy_config, dummy_dataset_1, dummy_embeddings):
def single_base_model(dummy_config, dummy_dataset_1):
"""Returns an instance of MultiTaskLSTMCRF initialized with the default configuration."""
model = BaseModel(config=dummy_config,
datasets=[dummy_dataset_1],
# to test passing of arbitrary keyword args to constructor
totally_arbitrary='arbitrary')
return model


@pytest.fixture
def compound_base_model(dummy_config, dummy_dataset_1, dummy_dataset_2):
"""Returns an instance of MultiTaskLSTMCRF initialized with the default configuration."""
model = BaseModel(config=dummy_config,
datasets=[dummy_dataset_1, dummy_dataset_2],
# to test passing of arbitrary keyword args to constructor
totally_arbitrary='arbitrary')
return model


@pytest.fixture
def single_base_keras_model(dummy_config, dummy_dataset_1):
"""Returns an instance of MultiTaskLSTMCRF initialized with the default configuration."""
model = BaseKerasModel(config=dummy_config,
datasets=[dummy_dataset_1],
Expand All @@ -407,6 +438,16 @@ def single_base_keras_model(dummy_config, dummy_dataset_1, dummy_embeddings):
return model


@pytest.fixture
def compound_base_keras_model(dummy_config, dummy_dataset_1, dummy_dataset_2):
"""Returns an instance of MultiTaskLSTMCRF initialized with the default configuration."""
model = BaseKerasModel(config=dummy_config,
datasets=[dummy_dataset_1, dummy_dataset_2],
# to test passing of arbitrary keyword args to constructor
totally_arbitrary='arbitrary')
return model


@pytest.fixture
def single_base_keras_model_embeddings(dummy_config, dummy_dataset_1, dummy_embeddings):
"""Returns an instance of MultiTaskLSTMCRF initialized with the default configuration file and
Expand All @@ -418,15 +459,34 @@ def single_base_keras_model_embeddings(dummy_config, dummy_dataset_1, dummy_embe
totally_arbitrary='arbitrary')
return model


@pytest.fixture
def single_base_pytorch_model(dummy_config, dummy_dataset_1):
"""Returns an instance of MultiTaskLSTMCRF initialized with the default configuration."""
model = BasePyTorchModel(config=dummy_config,
datasets=[dummy_dataset_1],
# to test passing of arbitrary keyword args to constructor
totally_arbitrary='arbitrary')
return model


@pytest.fixture
def compound_base_pytorch_model(dummy_config, dummy_dataset_1, dummy_dataset_2):
"""Returns an instance of MultiTaskLSTMCRF initialized with the default configuration."""
model = BasePyTorchModel(config=dummy_config,
datasets=[dummy_dataset_1, dummy_dataset_2],
# to test passing of arbitrary keyword args to constructor
totally_arbitrary='arbitrary')
return model

# BERT models


@pytest.fixture
def bert_tokenizer():
"""Tokenizer for pre-trained BERT model.
"""
bert_tokenizer = BertTokenizer.from_pretrained(constants.PYTORCH_BERT_MODEL,
do_lower_case=False)
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-cased', do_lower_case=False)

return bert_tokenizer

Expand Down
Loading

0 comments on commit b59bd29

Please sign in to comment.