Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CRF layer v3.0 continued #1999

Merged
merged 21 commits into from
Jul 19, 2020
Merged

Conversation

jaspersjsun
Copy link
Contributor

This is a continued PR of #1733 by @gabrieldemarmiesse .

Several suggestions from code review were applied.

Following is the original comment.


With a subclassing approch, we have a nicer API and it's very flexible.

Works only with TF 2.2+

@howl-anderson for the review and the CLA

The plan is to show users how to do the subclassing for the CRF. We shouldn't provide and API to save them some code there because it's going to become very complex to design a good API and to maintain it later on.

So the CRF layer is a public API and for the CRF loss, we give a good tutorial about subclassing.

Quick tutorial right now:

import tensorflow as tf
from tensorflow_addons.layers.crf import CRF
from tensorflow_addons.text.crf import crf_log_likelihood

def unpack_data(data):
    if len(data) == 2:
        return data[0], data[1], None
    elif len(data) == 3:
        return data
    else:
        raise TypeError("Expected data to be a tuple of size 2 or 3.")


class ModelWithCRFLoss(tf.keras.Model):
    """Wrapper around the base model for custom training logic."""

    def __init__(self, base_model):
        super().__init__()
        self.base_model = base_model

    def call(self, inputs):
        return self.base_model(inputs)

    def compute_loss(self, x, y, sample_weight, training=False):
        y_pred = self(x, training=training)
        _, potentials, sequence_length, chain_kernel = y_pred

        crf_loss = -crf_log_likelihood(potentials, y, sequence_length, chain_kernel)[0]

        if sample_weight is not None:
            crf_loss = crf_loss * sample_weight

        return tf.reduce_mean(crf_loss), sum(self.losses)

    def train_step(self, data):
        x, y, sample_weight = unpack_data(data)

        with tf.GradientTape() as tape:
            crf_loss, internal_losses = self.compute_loss(
                x, y, sample_weight, training=True
            )
            total_loss = crf_loss + internal_losses

        gradients = tape.gradient(total_loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))

        return {"crf_loss": crf_loss, "internal_losses": internal_losses}

    def test_step(self, data):
        x, y, sample_weight = unpack_data(data)
        crf_loss, internal_losses = self.compute_loss(x, y, sample_weight)
        return {"crf_loss": crf_loss, "internal_losses": internal_losses}


x_np, y_np = get_test_data()

x_input = tf.keras.layers.Input(shape=x_np.shape[1:])
crf_outputs = CRF(5)(x_input)
base_model = tf.keras.Model(x_input, crf_outputs)
model = ModelWithCRFLoss(base_model)

model.compile("adam")
model.fit(x=x_np, y=y_np)
model.evaluate(x_np, y_np)
model.predict(x_np)
model.save("my_model.tf")

@googlebot
Copy link

Thanks for your pull request. It looks like this may be your first contribution to a Google open source project (if not, look below for help). Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

📝 Please visit https://cla.developers.google.com/ to sign.

Once you've signed (or fixed any issues), please reply here with @googlebot I signed it! and we'll verify it.


What to do if you already signed the CLA

Individual signers
Corporate signers

ℹ️ Googlers: Go here for more info.

@jaspersjsun
Copy link
Contributor Author

@googlebot I signed it!

@googlebot
Copy link

All (the pull request submitter and all commit authors) CLAs are signed, but one or more commits were authored or co-authored by someone other than the pull request submitter.

We need to confirm that all authors are ok with their commits being contributed to this project. Please have them confirm that by leaving a comment that contains only @googlebot I consent. in this pull request.

Note to project maintainer: There may be cases where the author cannot leave a comment, or the comment is not properly detected as consent. In those cases, you can manually confirm consent of the commit author(s), and set the cla label to yes (if enabled on your project).

ℹ️ Googlers: Go here for more info.

@gabrieldemarmiesse
Copy link
Member

gabrieldemarmiesse commented Jul 14, 2020

@googlebot I consent.

@howl-anderson
Copy link
Contributor

@googlebot I signed it!

@jaspersjsun
Copy link
Contributor Author

jaspersjsun commented Jul 15, 2020

Wheels build broke on Ubuntu-18.04 with Python3.5. Seems there are some problems with the docker env. Anything I need to do to pass those two checks?

Looks like #2002 fixed it.

@jaspersjsun
Copy link
Contributor Author

@howl-anderson Thank you so much for your contribution. Would you mind replying with only @googlebot I consent. to grand consent to this PR?

@howl-anderson
Copy link
Contributor

@googlebot I consent.

@googlebot
Copy link

CLAs look good, thanks!

ℹ️ Googlers: Go here for more info.

@jaspersjsun
Copy link
Contributor Author

@facaiy @seanpmorgan Hi guys, would you mind having a look at this PR whenever you are available? The history of adding CRF layer can be dated back one year ago ( #22 #314 #377 and many other issues and PRs). Having this published will help a lot . Very appreciated!

@WindQAQ WindQAQ self-requested a review July 19, 2020 00:52
Copy link
Member

@WindQAQ WindQAQ left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the contribution to this long journey 👍

@WindQAQ WindQAQ merged commit 4f65776 into tensorflow:master Jul 19, 2020
@luozhouyang
Copy link

Thank you a lot for this contribution!

Can we use tf.keras.layers.Layer.add_loss API inside CRF layer to calculate the crf_loss instead of using a wrapper model like ModelWithCRFLoss?

@howl-anderson
Copy link
Contributor

@luozhouyang The pattern that you said is called endpoint pattern I think. In that pattern, users need to pass true labels as one of the inputs, and I think it is not a user-friendly API. The tensorflow/tensorflow#37818 is a promising solution to this, but still under review.

@eloukas
Copy link

eloukas commented Nov 5, 2020

This approach can work but it's incredibly slow, especially when we have a vast amount of classes/labels.
Is there any way you can incorporate Viterbi decoding in it?

jrruijli pushed a commit to jrruijli/addons that referenced this pull request Dec 23, 2020
* Squash all.

* Cleanup for easier review.

* Calming the angry bazel.

* Fix the strange bug.

* Replaced one bug by another bug.

* Minor simplification.

* Fix unused parameter.

* Simplified the signature.

* Removing boilerplate

* Unused import.

* CRF layer v3.0

* Finish the conversion.

* Some renaming here and there.

* Added a test where some training is done after reloading the model.

* Apply suggesstions from CR

* update ops in _compute_mask_[left|right]_boundary

Co-authored-by: howl-anderson <u1mail2me@gmail.com>
Co-authored-by: gabrieldemarmiesse <gabrieldemarmiesse@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants