Skip to content

Commit

Permalink
Merge pull request #371 from Winter523/master
Browse files Browse the repository at this point in the history
add inference/run_classifier_mt_infer.py
  • Loading branch information
zhezhaoa committed Aug 11, 2023
2 parents 12ead9c + 6ceaac9 commit 3b2ba28
Showing 1 changed file with 179 additions and 0 deletions.
179 changes: 179 additions & 0 deletions inference/run_classifier_mt_infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
"""
This script provides an example to wrap UER-py for multi-task classification inference.
"""
import sys
import os
import torch
import argparse
import collections
import torch.nn as nn

uer_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(uer_dir)

from uer.embeddings import *
from uer.encoders import *
from uer.utils.constants import *
from uer.utils import *
from uer.utils.config import load_hyperparam
from uer.utils.seed import set_seed
from uer.utils.misc import pooling
from uer.model_loader import *
from uer.opts import infer_opts, tokenizer_opts, log_opts


class MultitaskClassifier(nn.Module):
def __init__(self, args):
super(MultitaskClassifier, self).__init__()
self.embedding = Embedding(args)
for embedding_name in args.embedding:
tmp_emb = str2embedding[embedding_name](args, len(args.tokenizer.vocab))
self.embedding.update(tmp_emb, embedding_name)
self.encoder = str2encoder[args.encoder](args)
self.pooling_type = args.pooling
self.output_layers_1 = nn.ModuleList([nn.Linear(args.hidden_size, args.hidden_size) for _ in args.labels_num_list])
self.output_layers_2 = nn.ModuleList([nn.Linear(args.hidden_size, labels_num) for labels_num in args.labels_num_list])

def forward(self, src, tgt, seg, soft_tgt=None):
"""
Args:
src: [batch_size x seq_length]
tgt: [batch_size]
seg: [batch_size x seq_length]
"""
# Embedding.
emb = self.embedding(src, seg)
# Encoder.
memory_bank = self.encoder(emb, seg)
# Target.
memory_bank = pooling(memory_bank, seg, self.pooling_type)
logits = []
for i in range(len(self.output_layers_1)):
output_i = torch.tanh(self.output_layers_1[i](memory_bank))
logits_i = self.output_layers_2[i](output_i)
logits.append(logits_i)

return None, logits


def read_dataset(args, path):
dataset, columns = [], {}
with open(path, mode="r", encoding="utf-8") as f:
for line_id, line in enumerate(f):
if line_id == 0:
line = line.rstrip("\r\n").split("\t")
for i, column_name in enumerate(line):
columns[column_name] = i
continue
line = line.rstrip("\r\n").split("\t")
if "text_b" not in columns: # Sentence classification.
text_a = line[columns["text_a"]]
src = args.tokenizer.convert_tokens_to_ids([CLS_TOKEN] + args.tokenizer.tokenize(text_a) + [SEP_TOKEN])
seg = [1] * len(src)
else: # Sentence pair classification.
text_a, text_b = line[columns["text_a"]], line[columns["text_b"]]
src_a = args.tokenizer.convert_tokens_to_ids([CLS_TOKEN] + args.tokenizer.tokenize(text_a) + [SEP_TOKEN])
src_b = args.tokenizer.convert_tokens_to_ids(args.tokenizer.tokenize(text_b) + [SEP_TOKEN])
src = src_a + src_b
seg = [1] * len(src_a) + [2] * len(src_b)

if len(src) > args.seq_length:
src = src[: args.seq_length]
seg = seg[: args.seq_length]
PAD_ID = args.tokenizer.convert_tokens_to_ids([PAD_TOKEN])[0]
while len(src) < args.seq_length:
src.append(PAD_ID)
seg.append(0)
dataset.append((src, seg))

return dataset


def batch_loader(batch_size, src, seg):
instances_num = src.size()[0]
for i in range(instances_num // batch_size):
src_batch = src[i * batch_size : (i + 1) * batch_size, :]
seg_batch = seg[i * batch_size : (i + 1) * batch_size, :]
yield src_batch, seg_batch
if instances_num > instances_num // batch_size * batch_size:
src_batch = src[instances_num // batch_size * batch_size :, :]
seg_batch = seg[instances_num // batch_size * batch_size :, :]
yield src_batch, seg_batch


def main():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)

infer_opts(parser)

tokenizer_opts(parser)

parser.add_argument("--output_logits", action="store_true", help="Write logits to output file.")
parser.add_argument("--output_prob", action="store_true", help="Write probabilities to output file.")
parser.add_argument("--labels_num_list", default=[], nargs='+', type=int, help="Dataset labels num list.")
log_opts(parser)

args = parser.parse_args()

# Load the hyperparameters from the config file.
args = load_hyperparam(args)

# Build tokenizer.
args.tokenizer = str2tokenizer[args.tokenizer](args)

# Build classification model and load parameters.
args.soft_targets, args.soft_alpha = False, False
model = MultitaskClassifier(args)
model = load_model(model, args.load_model_path)

# For simplicity, we use DataParallel wrapper to use multiple GPUs.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
if torch.cuda.device_count() > 1:
print("{0} GPUs are available. Let's use them.".format(torch.cuda.device_count()))
model = torch.nn.DataParallel(model)

dataset = read_dataset(args, args.test_path)

src = torch.LongTensor([sample[0] for sample in dataset])
seg = torch.LongTensor([sample[1] for sample in dataset])

batch_size = args.batch_size
instances_num = src.size()[0]

print("The number of prediction instances: {0}".format(instances_num))

model.eval()

with open(args.prediction_path, mode="w", encoding="utf-8") as f:
f.write("label")
if args.output_logits:
f.write("\t" + "logits")
if args.output_prob:
f.write("\t" + "prob")
f.write("\n")
for i, (src_batch, seg_batch) in enumerate(batch_loader(batch_size, src, seg)):
src_batch = src_batch.to(device)
seg_batch = seg_batch.to(device)
with torch.no_grad():
_, logits = model(src_batch, None, seg_batch)

pred = [torch.argmax(logits_i, dim=-1) for logits_i in logits]
prob = [nn.Softmax(dim=-1)(logits_i) for logits_i in logits]

logits = [x.cpu().numpy().tolist() for x in logits]
pred = [x.cpu().numpy().tolist() for x in pred]
prob = [x.cpu().numpy().tolist() for x in prob]

for j in range(len(pred[0])):
f.write("|".join([str(v[j]) for v in pred]))
if args.output_logits:
f.write("\t" + "|".join([" ".join(["{0:.4f}".format(w) for w in v[j]]) for v in logits]))
if args.output_prob:
f.write("\t" + "|".join([" ".join(["{0:.4f}".format(w) for w in v[j]]) for v in prob]))
f.write("\n")
f.close()


if __name__ == "__main__":
main()

0 comments on commit 3b2ba28

Please sign in to comment.