Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Question] Actual usage examples? #29

Open
Maghoumi opened this issue Feb 19, 2019 · 7 comments
Open

[Question] Actual usage examples? #29

Maghoumi opened this issue Feb 19, 2019 · 7 comments

Comments

@Maghoumi
Copy link

Maghoumi commented Feb 19, 2019

Besides the toy examples listed in the docs and tests, are there actual examples of this library available anywhere?

I'm interested in using this library for a sequence labeling project, but I'm curious to know if I'm using this library correctly. What I have is something like this:

class MyModel(nn.Module):
    def __init__(self, num_features, num_classes):
        super(MyModel, self).__init__()
        self.num_features = num_features
        self.num_classes = num_classes
        self.lstm = nn.LSTM(num_features, 128)
        self.fc = nn.Linear(128, num_classes)
        self.crf = CRF(num_classes)

# ----------------------------------------------------------
model = MyModel(...)

# Training loop:
y_hat = model(batch)  # The network's forward returns fc(lstm(batch))
loss = -model.crf(y_hat, y)
loss.backward()
optimizer.step()

Although this seems to work and the loss is decreasing, I have a feeling that I might be missing something.
Any help is appreciated. Thanks!

@kmkurn
Copy link
Owner

kmkurn commented Feb 21, 2019

Hi,

Your usage seems alright. The examples are meant to show how to use the CRF layer given that one has produced the emission scores, i.e. (unnormalized) log P(y_t | X) where y_t is the tag at position t and X is the input sentence. In your code, y_hat would have a shape of (seq_length, batch_size, num_classes) where each y_hat[i, j, k] contains the score of the j-th example in the batch having tag k in the i-th position, which is as expected. I'll consider adding a more complete example in the docs. Thanks for the suggestion!

@Maghoumi
Copy link
Author

Thanks for your response. What was confusing to me originally was the fact that your CRF layer is actually a loss that one can minimize, whereas other PyTorch implementations had a separately-defined Viterbi loss module.

Yes, dimensions that you mentioned coincide with what I have in my code.
After reading your explanation, I think the only change needed in my pseudo-code above would be to change loss = -model.crf(y_hat, y) to loss = -model.crf(y_hat.log_softmax(2), y) (given that the output of the FC layer is directly returned from the network, but we need emission scores).

@kmkurn
Copy link
Owner

kmkurn commented Feb 21, 2019

What was confusing to me originally was the fact that your CRF layer is actually a loss that one can minimize, whereas other PyTorch implementations had a separately-defined Viterbi loss module.

Actually, this is something that I think about every now and then. Right now the forward method returns the loss, which does not really fit to other patterns where forward returns some kind of prediction and the loss object is the one computing the loss from the given prediction and gold target, as you mentioned. It may be helpful to provide such loss class for those who are more comfortable with this pattern such as yourself. Thanks for bringing this up.

would be to change loss = -model.crf(y_hat, y) to loss = -model.crf(y_hat.log_softmax(2), y)

You don't have to. The CRF layer accepts unnormalized emission scores just fine. It'll normalize the score of y over all possible sequence of tags.

@Maghoumi
Copy link
Author

It may be helpful to provide such loss class for those who are more comfortable with this pattern such as yourself. Thanks for bringing this up.

No problem! That would be an excellent idea.
Also thanks for clarification regarding the log_softmax() bit.

Feel free to close this issue, or keep it open as a reminder if you decide to incorporate more examples and also change the library such that loss is separate from the CRF's output. I'd personally be very happy if the changes/examples are added.
Also, thanks for the great library!

@xiaodaoyoumin
Copy link

xiaodaoyoumin commented Jun 11, 2019

Hi, I got some problem
I download the test_crf.py which you provide some examples.
then in the final line I add following codes

if __name__ == "__main__":
    m = TestDecode()
    
    m.test_batched_decode()

first error come as CRF model has no the attribute batch_first , I solve it manually by ‘ CRF.batch_first = False’

then when I run the above code , the error come as

*** IndexError: invalid index of a 0-dim tensor. Use tensor.item() to convert a 0-dim tensor to a Python number

I am not sure if it is because my torch version is 1.1.0 , by the way , which torch version do you recommend?

@xiaodaoyoumin
Copy link

I have solve it , when the torch version is <= 1.0.0 , then no error

@kmkurn
Copy link
Owner

kmkurn commented Jun 13, 2019

@Huijun-Cui Thanks for letting me know. Next time please open a separate issue.

Repository owner locked as off-topic and limited conversation to collaborators Jun 13, 2019
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants