-
Notifications
You must be signed in to change notification settings - Fork 5
/
predict.py
168 lines (153 loc) · 5.24 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
#!/usr/bin/python
# -*- coding: UTF-8 -*-
from numpy import *
import operator
import MySQLdb
class Predict:
def __init__(self, recoNum = 10):
self.userList = {}
self.relationList = {}
self.songList = {}
self.tripleListTrain = []
self.predictList = []
self.recoNum = recoNum
self.rank =[]
def predict(self):
if self.connectDB() == -1:
return
print("Succeed to connect datase")
if self.loadData() == -1:
return
print("Succeed to load data")
print("Start to predict...")
self.getSongRank()
print("Succeed to get song rank")
if self.writeRankToDB() == -1:
return
print("Succeed to write rank to databse")
self.closeDB()
def connectDB(self):
try:
self.dbc = MySQLdb.connect("localhost","root","","music3")
except:
print("Fail to connect database!!!!!!")
return -1
def closeDB(self):
self.dbc.close()
def writeRankToDB(self):
print self.rank
print("writing rank to database...")
cursor = self.dbc.cursor()
sql = "SELECT * FROM recommand"
try:
cursor.execute(sql)
results = cursor.fetchall()
except:
print("Error:fail to write song rank of somebody to database")
return -1
if not results:#如果推荐表中无数据
for r in self.rank:
sql = "INSERT INTO recommand values (%d, '%s')"%(r[0], str(r[1]))
try:
cursor.execute(sql)
self.dbc.commit()
except:
self.dbc.rollback()
cursor.close()
print("Error:fail to write song rank of somebody to database")
return -1
else:
for r in self.rank:
sql = "UPDATE recommand SET ranklist = '%s' WHERE id = %d"%(str(r[1]),r[0])
try:
cursor.execute(sql)
self.dbc.commit()
except:
self.dbc.rollback()
cursor.close()
print("Error:fail to write song rank of somebody to database")
return -1
cursor.close()
def getSongRank(self):
cou = 0
for predictUser in self.predictList:
rankList = {}
for songTemp in self.songList.keys():
corruptedTriplet = (predictUser, predictUser, songTemp)
if corruptedTriplet in self.tripleListTrain:
continue
rankList[songTemp] = distance(self.userList[predictUser], self.relationList[predictUser], self.songList[songTemp])
nameRank = sorted(rankList.items(), key = operator.itemgetter(1)) #sorted返回一个元组组成的List
x = 0
rankList = []
for i in nameRank:
rankList.append(i[0])
x += 1
if x >= self.recoNum:
break
self.rank.append((predictUser,rankList))
cou += 1
if cou % 10000 == 0:
print(cou)
def loadData(self):
cursor = self.dbc.cursor()
#读取用户向量表
sql = "SELECT * FROM user_vector"
try:
cursor.execute(sql)
allUsers = cursor.fetchall()
except:
print("Error: unable to fetch data")
cursor.close()
return -1
for userVec in allUsers:
self.userList[userVec[0]] = str2vec(userVec[1])
#读取音乐向量表
sql = "SELECT * FROM song_vector"
try:
cursor.execute(sql)
allSongs = cursor.fetchall()
except:
print("Error: unable to fetch data")
cursor.close()
return -1
for songVec in allSongs:
self.songList[songVec[0]] = str2vec(songVec[1])
#读取关系向量表
sql = "SELECT * FROM relation_vector"
try:
cursor.execute(sql)
allRelation = cursor.fetchall()
except:
print("Error: unable to fetch data")
cursor.close()
return -1
for relationVec in allRelation:
self.relationList[relationVec[0]] = str2vec(relationVec[1])
#初始化训练列表
sql = "SELECT * FROM songlike"
try:
cursor.execute(sql)
allTriplets = cursor.fetchall()
except:
print("Error: unable to fetch data")
cursor.close()
return -1
for triplet in allTriplets:
self.tripleListTrain.append((triplet[1],triplet[1],triplet[2]))
#初始化预测需要的用户表
for triplet in allTriplets:
if triplet[1] in self.predictList:
continue
else:
self.predictList.append(triplet[1])
cursor.close()
def distance(h, r, t):
s = h + r - t
return linalg.norm(s)
def str2vec(str):
vecList = [float(s) for s in str[1:-1].split(", ")]
return array(vecList)
if __name__ == '__main__':
predict = Predict()
predict.predict()