使用fastText对文本进行分类¶
清华大学的新闻分类文本数据集下载:https://thunlp.oss-cn-qingdao.aliyuncs.com/THUCNews.zip
下载后进行解压,把相应的中文目录替换成以下英文名,方便程序读取数据
['affairs','constellation','economic','edu','ent','fashion','game','home','house','lottery','science','sports','society','stock']
第一步获取分类文本:输出数据格式: 样本 + 样本标签
In [ ]:
import jieba
import os
In [5]:
basedir = "/home/python_home/WeiZhongChuang/ML/Embedding/THUCNews/" #这是我的文件地址,需跟据文件夹位置进行更改
dir_list = ['affairs','constellation','economic','edu','ent','fashion','game','home','house','lottery','science','sports','society','stock']
##生成fastext的训练和测试数据集
ftrain = open("news_fasttext_train.txt","w")
ftest = open("news_fasttext_test.txt","w")
num = -1
for e in dir_list:
num += 1
indir = basedir + e + '/'
files = os.listdir(indir)
count = 0
for fileName in files:
count += 1
filepath = indir + fileName
with open(filepath,'r') as fr:
text = fr.read()
text = str(text.encode("utf-8"),'utf-8')
seg_text = jieba.cut(text.replace("\t"," ").replace("\n"," "))
outline = " ".join(seg_text)
outline = outline + "\t__label__" + e + "\n"
# print outline
# break
if count < 10000:
ftrain.write(outline)
ftrain.flush()
continue
elif count < 20000:
ftest.write(outline)
ftest.flush()
continue
else:
break
ftrain.close()
ftest.close()
print('完成输出数据!')
第二步,使用fastText进行训练模型(如果数据已经准备好,可以直接运行第二步)¶
In [1]:
import logging
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO)
import fasttext
#训练模型
classifier = fasttext.train_supervised("news_fasttext_train.txt",label_prefix="__label__")
#load训练好的模型
#classifier = fasttext.load_model('news_fasttext.model.bin', label_prefix='__label__')
print('训练完成!')
训练完成!
测试模型¶
In [2]:
result = classifier.test("news_fasttext_test.txt")
print('precision:', result[1])
precision: 0.7845266342650989
由于fasttext貌似只提供全部结果的p值和r值,想要统计不同分类的结果,就需要自己写代码来实现了。¶
In [10]:
labels_right = []
texts = []
with open("news_fasttext_test.txt") as fr:
for line in fr:
line = str(line.encode("utf-8"), 'utf-8').rstrip()
labels_right.append(line.split("\t")[1].replace("__label__",""))
texts.append(line.split("\t")[0])
# print labels
# print texts
# break
labels_predict = [term[0] for term in classifier.predict(texts)[0]] #预测输出结果为二维形式
# print labels_predict
text_labels = list(set(labels_right))
text_predict_labels = list(set(labels_predict))
print(text_predict_labels)
print(text_labels)
print()
A = dict.fromkeys(text_labels,0) #预测正确的各个类的数目
B = dict.fromkeys(text_labels,0) #测试数据集中各个类的数目
C = dict.fromkeys(text_predict_labels,0) #预测结果中各个类的数目
for i in range(0,len(labels_right)):
B[labels_right[i]] += 1
C[labels_predict[i]] += 1
if labels_right[i] == labels_predict[i].replace('__label__', ''):
A[labels_right[i]] += 1
print('预测正确的各个类的数目:', A)
print()
print('测试数据集中各个类的数目:', B)
print()
print('预测结果中各个类的数目:', C)
print()
#计算准确率,召回率,F值
for key in B:
try:
r = float(A[key]) / float(B[key])
p = float(A[key]) / float(C['__label__' + key])
f = p * r * 2 / (p + r)
print("%s:\t p:%f\t r:%f\t f:%f" % (key,p,r,f))
except:
print("error:", key, "right:", A.get(key,0), "real:", B.get(key,0), "predict:",C.get(key,0))
['__label__economic', '__label__sports', '__label__fashion', '__label__stock', '__label__edu', '__label__lottery', '__label__home', '__label__house', '__label__game', '__label__constellation', '__label__ent', '__label__science', '__label__affairs', '__label__society']
['home', 'sports', 'science', 'stock', 'game', 'economic', 'ent', 'house', 'fashion', 'edu', 'society', 'affairs']
预测正确的各个类的数目: {'home': 9390, 'sports': 9273, 'science': 9507, 'stock': 4257, 'game': 9387, 'economic': 8844, 'ent': 8509, 'house': 8588, 'fashion': 1236, 'edu': 8165, 'society': 3467, 'affairs': 8318}
测试数据集中各个类的数目: {'home': 10000, 'sports': 10000, 'science': 10000, 'stock': 10000, 'game': 10000, 'economic': 10000, 'ent': 10000, 'house': 10000, 'fashion': 3369, 'edu': 10000, 'society': 10000, 'affairs': 10000}
预测结果中各个类的数目: {'__label__economic': 9489, '__label__sports': 13467, '__label__fashion': 1305, '__label__stock': 4896, '__label__edu': 8554, '__label__lottery': 396, '__label__home': 15112, '__label__house': 10030, '__label__game': 10532, '__label__constellation': 734, '__label__ent': 9305, '__label__science': 13474, '__label__affairs': 12543, '__label__society': 3532}
home: p:0.621361 r:0.939000 f:0.747850
sports: p:0.688572 r:0.927300 f:0.790301
science: p:0.705581 r:0.950700 f:0.810003
stock: p:0.869485 r:0.425700 f:0.571563
game: p:0.891284 r:0.938700 f:0.914378
economic: p:0.932027 r:0.884400 f:0.907589
ent: p:0.914455 r:0.850900 f:0.881533
house: p:0.856231 r:0.858800 f:0.857514
fashion: p:0.947126 r:0.366874 f:0.528883
edu: p:0.954524 r:0.816500 f:0.880134
society: p:0.981597 r:0.346700 f:0.512415
affairs: p:0.663159 r:0.831800 f:0.737967