This repository has been archived by the owner on Jun 7, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
81 lines (64 loc) · 2.69 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import os
from typing import List, Dict
import jieba.analyse
import numpy as np
import sys
# 计算主过程
def get_similarity(source_path: str, copy_path: str) -> float:
source_content = read_file(source_path)
copy_content = read_file(copy_path)
source_tfidf_dict = get_tfidf_dict(source_content)
copy_tfidf_dict = get_tfidf_dict(copy_content)
[source_tfidf_list, copy_tfidf_list] = get_tfidf_list(
source_tfidf_dict, copy_tfidf_dict)
return calculate_similarity(source_tfidf_list, copy_tfidf_list)
# 读取文件返回文件内容
def read_file(path: str) -> str:
if not os.path.exists(path):
raise ValueError('文件不存在')
with open(path, 'r', encoding='utf-8') as f:
content = f.read()
return content
# 获取一个文本的 tfidf dict
def get_tfidf_dict(content: str) -> Dict[str, float]:
tfidf_dict = {}
for word, tfidf in jieba.analyse.extract_tags(content, topK=0, withWeight=True):
tfidf_dict[word] = tfidf
if len(tfidf_dict) == 0:
raise ValueError('无效文本')
return tfidf_dict
# 获取两个 tfidf dict 的 list
def get_tfidf_list(source_tfidf_dict: Dict[str, float], copy_tfidf_dict: Dict[str, float]) -> List[List[float]]:
source_tfidf_list = []
copy_tfidf_list = []
# 遍历两个 dict 若一个不存在则设为0
for item in source_tfidf_dict:
source_tfidf_list.append(source_tfidf_dict[item])
copy_tfidf_list.append(
copy_tfidf_dict[item] if item in copy_tfidf_dict else 0)
for item in copy_tfidf_dict:
if item not in source_tfidf_dict:
source_tfidf_list.append(0)
copy_tfidf_list.append(copy_tfidf_dict[item])
return [source_tfidf_list, copy_tfidf_list]
# 计算结果
def calculate_similarity(source_tfidf_list: List[float], copy_tfidf_list: List[float]) -> float:
source_tfidf_array = np.array(source_tfidf_list)
copy_tfidf_array = np.array(copy_tfidf_list)
# 使用生成的 ndarray 计算余弦相似度
return np.dot(source_tfidf_array, copy_tfidf_array, out=None) / (
np.linalg.norm(source_tfidf_array) * np.linalg.norm(copy_tfidf_array))
# 写入文件
def write_consequence_to_file(similarity: float, path: str):
if not os.path.exists(path):
raise ValueError('文件不存在')
with open(path, 'w', encoding='utf-8') as f:
f.write(str(round(similarity, 2))) # 保留两位小数,并四舍五入
def main():
if len(sys.argv) < 4:
raise ValueError('参数缺失')
[source_path, copy_path, ans_path] = sys.argv[1:]
similarity = get_similarity(source_path, copy_path)
write_consequence_to_file(similarity, ans_path)
if __name__ == '__main__':
main()