-
Notifications
You must be signed in to change notification settings - Fork 1
/
single_label_classification_analysis.py
56 lines (46 loc) · 1.84 KB
/
single_label_classification_analysis.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from __future__ import division
import nltk
import csv
import re
import random
import math
import sys
import os
import util
from keyword_frequency_classifier import KeywordFrequencyClassifier
from sklearn.linear_model import LogisticRegression
from nltk.classify.scikitlearn import SklearnClassifier
from nltk import NaiveBayesClassifier
# analysis of the single label classification models
if __name__ == "__main__":
# read in the data set
subdir = 'data/single_tags/'
fname = 'dataset.csv'
data = util.parse_data(subdir,fname,single_label=True,extract_features=True)
# randomize the data cases
random.shuffle(data)
# split into training and testing data
slice = math.trunc(len(data)*(.8)) # 80% train, 20% test
train_set = data[:slice]
test_set = data[slice:]
# train classification models
print 'Training models on',len(train_set),'data samples...'
nb = NaiveBayesClassifier.train(train_set)
lr = SklearnClassifier(LogisticRegression()).train(train_set)
kwfc = KeywordFrequencyClassifier()
kwfc.train(train_set)
# calculate and report model accuracy
print '\nKey Word Frequency Classifier accuracy based on',len(test_set),'samples:'
print kwfc.accuracy(test_set)
print '\nNaive Bayes accuracy based on',len(test_set),'samples:'
print nltk.classify.util.accuracy(nb,test_set)
print '\nLogistic Regression accuracy based on',len(test_set),'samples:'
print nltk.classify.util.accuracy(lr,test_set)
# an example
sample_post = 'How many numbers less than 70 are relatively prime to it?'
test = util.features(sample_post)
# attempt to classsify sample sentence
print '\nAn Example:\n',sample_post
print 'Naive Bayes:',nb.classify(test)
print 'Keyword Classifier',kwfc.predict(test)
print 'Logistic Regression:',lr.classify(test)