Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

multilingual coreference and singletons support #1403

Closed
wants to merge 7 commits into from

Conversation

Jemoka
Copy link
Member

@Jemoka Jemoka commented Jul 16, 2024

adds support for multilingual and singletons support through xlm-roberta-large and t5-large.

  • adds Rn -> R1 projection anaphoricity scorer for start-of-chain in order to have singletons
  • integrates newer PEFT architecture for FTing xlm-roberta with adaptors
  • created adapter for CorefUD data
  • training throughput fixes for long documents

@Jemoka Jemoka requested a review from AngledLuffa July 16, 2024 19:10
@Jemoka Jemoka marked this pull request as ready for review July 16, 2024 20:09
@@ -87,10 +87,6 @@ def predict(self, batch, unsort=True):
pred_tokens.append("".join(pred_seq))
else:
pred_tokens = ["".join(seq) for seq in pred_seqs] # join chars to be tokens
# if any tokens are predicted to expand to blank,
Copy link
Member Author

@Jemoka Jemoka Jul 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I honestly forgot why this is removed; perhaps it is a merging artifact?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect so. I just made a couple changes on the MWT processing last week in order to fix some weird tokenization of previously unknown lemmas in Spanish

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, would try to merge in that change or otherwise undo this. shouldn't be here

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


@staticmethod
def _get_pair_matrix(all_mentions: torch.Tensor,
mentions_batch: torch.Tensor,
def _get_pair_matrix(mentions_batch: torch.Tensor,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so the change here is to push the dereferencing up into the caller? sounds fair, could maybe split that out for readability of the PR but it's not necessary

Copy link
Member Author

@Jemoka Jemoka Jul 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also the indexing is required multiple times instead of one, so I felt it would be easier than passing everything around to two stacks.
happy to split it, do you mean I should undo the change and apply another diff?

thanks in advance!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no change needed, was just thinking in terms of making the big change more readable with smaller pieces cut off. it's not an issue though.

you can if you like split off individual edits with git rebase -i dev and then edit the change. it's really not necessary in this case, though, unless you wanted the practice

@@ -40,14 +44,19 @@ def get_subwords_batches(doc: Doc,
while end and doc["sent_id"][doc["word_id"][end - 1]] == sent_id:
end -= 1

# if we ended up at prev end, well, looks like we will
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is what happens if a single sentence is longer than the maximum length of the transformer?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, clarified comment

@@ -15,7 +21,24 @@ def __init__(self):
self._r = 0.0
self._p_weight = 0.0
self._r_weight = 0.0
self._num_preds = 0.0

# muc
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these names are somewhat opaque but i assume they're just the standard scoring names? seems reasonable

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated with better variable names; they are underscore prefixed, so hopefully folks won't try to access it from the outside

try:
return int(split[-1].replace(")", "").strip())
except ValueError:
breakpoint()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could remove this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

cluster_info_lst.append(f"e{cluster_marker})")


# we need our clusters to be ordered such that the one that closest first is listed last
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that closest first

-->

that is closest to the first

?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

breakpoint()
else:
# we want everything that's a closer to be first
return 1000000000
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about float('inf')

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

of course; no clue why I did this....

#dev_data = "data/coref/corefud_concat_v1_0_langid-bal.dev.json"
#test_data = "data/coref/corefud_concat_v1_0_langid-bal.test.json"

train_data = "data/coref/corefud_concat_v1_0_langid.train.json"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a script or an explanation of how to build this?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wondering if some of the others could be cleaned up

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done. left the ontonotes + gum + balanced langid. Also, looks like I missed the conversion script being committed: would love your review.

  • convert_udcoref.py: converts depparse annotated udcoref files into our format
  • balance_languages.py: takes a dataset built in the previous option and balance the document counts for each language within the JSON.

lora_dropout = 0.1
lora_alpha = 128
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could be undone just to make the change cleaner

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done. no clue how this ended up changed. apologies

@@ -58,7 +60,18 @@ def forward(self, # type: ignore # pylint: disable=arguments-differ #35566 in
distance = torch.where(distance < 5, distance - 1, log_distance + 2)
distance = self.distance_emb(distance)

genre = torch.tensor(self.genre2int[doc["document_id"][:2]],
if not self.__full_pw:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is for VRAM OOM issues?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not quite: this is for documents that have a genre and speaker embedding, which doesn't exist for UDCoref

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, got it. is that detected automatically from the input files? i might have missed that if it is. if not, it would be simpler for the user to do that rather than make it an option

@@ -29,6 +29,7 @@ def forward(self, # type: ignore # pylint: disable=arguments-differ #35566 in
Returns rough anaphoricity scores for candidates, which consist of
the bilinear output of the current model summed with mention scores.
"""

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe undo just to keep things cleaner?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

help="Adjust the dummy mix")
argparser.add_argument("--bert_finetune_begin_epoch", type=float,
help="Adjust the bert finetune begin epoch")
argparser.add_argument("--warm_start", action="store_true",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

worth adding an argument for --full_pairwise here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done. I expect that for non ontonotes documents this will be rarely used, however, because most datasets don't have speaker embeddings.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

heh, adding that flag might have been the opposite of what i just suggested above with the __full_pw option

@@ -30,13 +30,67 @@
from stanza.models.coref.utils import GraphNode
from stanza.models.coref.word_encoder import WordEncoder

from torch.utils.data import Dataset
from functools import lru_cache, wraps
import weakref
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this still used? maybe it could go away

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, apologies. used to be used for dataset memoization, but turns out we were system OOMing on long docs instead

@@ -30,13 +30,67 @@
from stanza.models.coref.utils import GraphNode
from stanza.models.coref.word_encoder import WordEncoder

from torch.utils.data import Dataset
from functools import lru_cache, wraps
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same with wraps, a quick ctrl-f doesn't find it anywhere else

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed, thanks

from peft import LoraConfig, get_peft_model, get_peft_model_state_dict, set_peft_model_state_dict

from stanza.utils.get_tqdm import get_tqdm # type: ignore
tqdm = get_tqdm()

logger = logging.getLogger('stanza')

class CorefDataset(Dataset):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can / should this be refactored into a different module?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

self.config = config
self.tokenizer = tokenizer

self.__filter_func = TOKENIZER_FILTERS.get(self.config.bert_model,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe leave a comment here to specify that the default is to not filter anything? it takes a couple seconds to understand, so maybe that time can be saved for the reader instead

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -91,14 +144,15 @@ def __init__(self,
modules_to_save=self.config.lora_fully_tune,
bias="none")

self.bert = get_peft_model(self.bert, self.__peft_config)
self.bert.train()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this switch necessary? it was working before

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted. There's a chance that I did this because certain types of loading loads in eval, which makes PEFT do weird things. But, I can't seem to reproduce it. Apologies

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no worries, mostly was wondering about needing to change other models

self.bert.train()
self.bert = get_peft_model(self.bert, self.__peft_config)
self.trainable["bert"] = self.bert

if build_optimizers:
self._build_optimizers()
self._set_training(False)
self._coref_criterion = CorefLoss(self.config.bce_loss_weight)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe a comment on the distinction between the coref & rough criterions would be helpful

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -117,13 +171,15 @@ def training(self, new_value: bool):
@torch.no_grad()
def evaluate(self,
data_split: str = "dev",
word_level_conll: bool = False
word_level_conll: bool = False,
eval_lang=None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consistency on the typing might be nice

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -185,8 +244,9 @@ def evaluate(self,
f" p: {s_lea[1]:.5f},"
f" r: {s_lea[2]:<.5f}"
)
logger.info(f"BAKE!: {w_checker.bakeoff:.5f}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i do think a more informative log line would be helpful

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done. apologies

@@ -421,12 +488,17 @@ def train(self, log=False):
for doc_indx, doc_id in enumerate(pbar):
doc = docs[doc_id]

# skip very long documents during training time
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could this be an option?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe we discussed this being strictly good—simply because it quadruples the memory limit during training and seems to confer no actual performance benefits. happy to make this a flag if needed, however.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's fair. although i was thinking that there could be a cutoff where those lines become batches by themselves. however, it's also not necessary to do that, i think, especially if it's not giving any benefit

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe a comment on how many training lines will be skipped, in that case?

running_c_loss += c_loss.item()
running_s_loss += s_loss.item()

# log every 50 docs
if log and doc_indx % 50 == 0:
# log every 100 docs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this could also be an option

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

happy to do so; do you think there are other areas where this flag would be used? i.e.: this only affects wandb logs; would love to hear how best I could implement this.

Thanks in advance!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

eh, guess not. having some default behavior for the logs is fine until someone complains about wanting more granularity

@@ -490,30 +564,44 @@ def train(self, log=False):
# ========================================================= Private methods

def _bertify(self, doc: Doc) -> torch.Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some similar pieces of logic are also in models/common/bert_embedding.py
not saying this needs to happen this time around, but it would be useful to unify that so there's just one source of truth

that code also handles some other model types, such as the VI extension to bert (phobert)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it looks like the original design put bert.py there as bert utilities (which has no logic dependent upon bert choice/initialization), whereas this does. happy to refact if you think that's best

y = (y == cluster_ids.unsqueeze(1)) # True if coreferent
# For all rows with no gold antecedents setting dummy to True
y[y.sum(dim=1) == 0, 0] = True

if singletons:
# add another dummy for firts coref
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

*first

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

whoops; thanks

@AngledLuffa
Copy link
Collaborator

Overall it looks good, thanks! Just a bunch of random nitpicks and of course the MWT code being reverted.

I would think that with the config changes the way they are, the original model is no longer viable, right? Or does loading the pipeline with the old model and the new code still work?

If the old model is dead (which is fine) we should either fix the existing .pt file or rebuild it.

IS_UDCOREF_FORMAT = True
UDCOREF_ADDN = 0 if not IS_UDCOREF_FORMAT else 1

# TODO: move this to a utility module and try it on other languages
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is just a copy of the one in convert_ontonotes.py, right? can we refactor that now?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, yes, absolutely. will do ASAP with the other items today (just got a temp laptop, setting up dev env)

@Jemoka
Copy link
Member Author

Jemoka commented Jul 26, 2024

closing in favor of #1406

@Jemoka Jemoka closed this Jul 26, 2024
@AngledLuffa AngledLuffa deleted the multilingual-coref-2 branch July 31, 2024 21:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants