数据集下载¶
使用THUCNews的一个子集进行训练与测试:https://www.lanzous.com/i5t0lsd
本次训练使用了其中的10个分类,每个分类6500条数据。类别如下:
体育, 财经, 房产, 家居, 教育, 科技, 时尚, 时政, 游戏, 娱乐
cnews_loader.py为数据的预处理文件。¶
- read_file(): 读取文件数据;
- build_vocab(): 构建词汇表,使用字符级的表示,这一函数会将词汇表存储下来,避免每一次重复处理;
- read_vocab(): 读取上一步存储的词汇表,转换为{词:id}表示;
- read_category(): 将分类目录固定,转换为{类别: id}表示;
- to_words(): 将一条由id表示的数据重新转换为文字;
- process_file(): 将数据集从文字转换为固定长度的id序列表示;
- batch_iter(): 为神经网络的训练准备经过shuffle的批次的数据。
textRNN模型和可配置的参数,在rnn_model.py中。¶
In [2]:
from __future__ import print_function
import os
import sys
import time
from datetime import timedelta
import numpy as np
import tensorflow as tf
from sklearn import metrics
from rnn_model import TRNNConfig, TextRNN
from cnews_loader import read_vocab, read_category, batch_iter, process_file, build_vocab
In [3]:
base_dir = 'cnews'
train_dir = os.path.join(base_dir, 'cnews.train.txt')
test_dir = os.path.join(base_dir, 'cnews.test.txt')
val_dir = os.path.join(base_dir, 'cnews.val.txt')
vocab_dir = os.path.join(base_dir, 'cnews.vocab.txt')
save_dir = 'checkpoints/textrnn'
save_path = os.path.join(save_dir, 'best_validation') # 最佳验证结果保存路径
def get_time_dif(start_time):
"""获取已使用时间"""
end_time = time.time()
time_dif = end_time - start_time
return timedelta(seconds=int(round(time_dif)))
def feed_data(x_batch, y_batch, keep_prob):
feed_dict = {
model.input_x: x_batch,
model.input_y: y_batch,
model.keep_prob: keep_prob
}
return feed_dict
def evaluate(sess, x_, y_):
"""评估在某一数据上的准确率和损失"""
data_len = len(x_)
batch_eval = batch_iter(x_, y_, 128)
total_loss = 0.0
total_acc = 0.0
for x_batch, y_batch in batch_eval:
batch_len = len(x_batch)
feed_dict = feed_data(x_batch, y_batch, 1.0)
loss, acc = sess.run([model.loss, model.acc], feed_dict=feed_dict)
total_loss += loss * batch_len
total_acc += acc * batch_len
return total_loss / data_len, total_acc / data_len
def train():
print("Configuring TensorBoard and Saver...")
# 配置 Tensorboard,重新训练时,请将tensorboard文件夹删除,不然图会覆盖
tensorboard_dir = 'tensorboard/textrnn'
if not os.path.exists(tensorboard_dir):
os.makedirs(tensorboard_dir)
tf.summary.scalar("loss", model.loss)
tf.summary.scalar("accuracy", model.acc)
merged_summary = tf.summary.merge_all()
writer = tf.summary.FileWriter(tensorboard_dir)
# 配置 Saver
saver = tf.train.Saver()
if not os.path.exists(save_dir):
os.makedirs(save_dir)
print("Loading training and validation data...")
# 载入训练集与验证集
start_time = time.time()
x_train, y_train = process_file(train_dir, word_to_id, cat_to_id, config.seq_length)
x_val, y_val = process_file(val_dir, word_to_id, cat_to_id, config.seq_length)
time_dif = get_time_dif(start_time)
print("Time usage:", time_dif)
# 创建session
session = tf.Session()
session.run(tf.global_variables_initializer())
writer.add_graph(session.graph)
print('Training and evaluating...')
start_time = time.time()
total_batch = 0 # 总批次
best_acc_val = 0.0 # 最佳验证集准确率
last_improved = 0 # 记录上一次提升批次
require_improvement = 1000 # 如果超过1000轮未提升,提前结束训练
flag = False
for epoch in range(config.num_epochs):
print('Epoch:', epoch + 1)
batch_train = batch_iter(x_train, y_train, config.batch_size)
for x_batch, y_batch in batch_train:
feed_dict = feed_data(x_batch, y_batch, config.dropout_keep_prob)
if total_batch % config.save_per_batch == 0:
# 每多少轮次将训练结果写入tensorboard scalar
s = session.run(merged_summary, feed_dict=feed_dict)
writer.add_summary(s, total_batch)
if total_batch % config.print_per_batch == 0:
# 每多少轮次输出在训练集和验证集上的性能
feed_dict[model.keep_prob] = 1.0
loss_train, acc_train = session.run([model.loss, model.acc], feed_dict=feed_dict)
loss_val, acc_val = evaluate(session, x_val, y_val) # todo
if acc_val > best_acc_val:
# 保存最好结果
best_acc_val = acc_val
last_improved = total_batch
saver.save(sess=session, save_path=save_path)
improved_str = '*'
else:
improved_str = ''
time_dif = get_time_dif(start_time)
msg = 'Iter: {0:>6}, Train Loss: {1:>6.2}, Train Acc: {2:>7.2%},' \
+ ' Val Loss: {3:>6.2}, Val Acc: {4:>7.2%}, Time: {5} {6}'
print(msg.format(total_batch, loss_train, acc_train, loss_val, acc_val, time_dif, improved_str))
feed_dict[model.keep_prob] = config.dropout_keep_prob
session.run(model.optim, feed_dict=feed_dict) # 运行优化
total_batch += 1
if total_batch - last_improved > require_improvement:
# 验证集正确率长期不提升,提前结束训练
print("No optimization for a long time, auto-stopping...")
flag = True
break # 跳出循环
if flag: # 同上
break
def test():
print("Loading test data...")
start_time = time.time()
x_test, y_test = process_file(test_dir, word_to_id, cat_to_id, config.seq_length)
session = tf.Session()
session.run(tf.global_variables_initializer())
saver = tf.train.Saver()
saver.restore(sess=session, save_path=save_path) # 读取保存的模型
print('Testing...')
loss_test, acc_test = evaluate(session, x_test, y_test)
msg = 'Test Loss: {0:>6.2}, Test Acc: {1:>7.2%}'
print(msg.format(loss_test, acc_test))
batch_size = 128
data_len = len(x_test)
num_batch = int((data_len - 1) / batch_size) + 1
y_test_cls = np.argmax(y_test, 1)
y_pred_cls = np.zeros(shape=len(x_test), dtype=np.int32) # 保存预测结果
for i in range(num_batch): # 逐批次处理
start_id = i * batch_size
end_id = min((i + 1) * batch_size, data_len)
feed_dict = {
model.input_x: x_test[start_id:end_id],
model.keep_prob: 1.0
}
y_pred_cls[start_id:end_id] = session.run(model.y_pred_cls, feed_dict=feed_dict)
# 评估
print("Precision, Recall and F1-Score...")
print(metrics.classification_report(y_test_cls, y_pred_cls, target_names=categories))
# 混淆矩阵
print("Confusion Matrix...")
cm = metrics.confusion_matrix(y_test_cls, y_pred_cls)
print(cm)
time_dif = get_time_dif(start_time)
print("Time usage:", time_dif)
开始训练¶
In [4]:
type_ = 'train'
print('Configuring RNN model...')
config = TRNNConfig()
if not os.path.exists(vocab_dir): # 如果不存在词汇表,重建
build_vocab(train_dir, vocab_dir, config.vocab_size)
categories, cat_to_id = read_category()
words, word_to_id = read_vocab(vocab_dir)
config.vocab_size = len(words)
model = TextRNN(config)
if type_ == 'train':
train()
else:
test()
W0826 20:37:22.551977 140497609688896 module_wrapper.py:136] From /usr/local/python3/lib/python3.6/site-packages/tensorflow_core/python/util/module_wrapper.py:163: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead. W0826 20:37:22.569656 140497609688896 lazy_loader.py:50] The TensorFlow contrib module will not be included in TensorFlow 2.0. For more information, please see: * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md * https://github.com/tensorflow/addons * https://github.com/tensorflow/io (for I/O related ops) If you depend on functionality not listed there, please file an issue. W0826 20:37:22.570458 140497609688896 deprecation.py:323] From /home/python_home/WeiZhongChuang/ML/RNN/rnn_model.py:48: GRUCell.__init__ (from tensorflow.python.ops.rnn_cell_impl) is deprecated and will be removed in a future version. Instructions for updating: This class is equivalent as tf.keras.layers.GRUCell, and will be replaced by that in Tensorflow 2.0. W0826 20:37:22.588945 140497609688896 deprecation.py:323] From /home/python_home/WeiZhongChuang/ML/RNN/rnn_model.py:65: MultiRNNCell.__init__ (from tensorflow.python.ops.rnn_cell_impl) is deprecated and will be removed in a future version. Instructions for updating: This class is equivalent as tf.keras.layers.StackedRNNCells, and will be replaced by that in Tensorflow 2.0. W0826 20:37:22.589862 140497609688896 deprecation.py:323] From /home/python_home/WeiZhongChuang/ML/RNN/rnn_model.py:67: dynamic_rnn (from tensorflow.python.ops.rnn) is deprecated and will be removed in a future version. Instructions for updating: Please use `keras.layers.RNN(cell)`, which is equivalent to this API W0826 20:37:22.668138 140497609688896 deprecation.py:323] From /usr/local/python3/lib/python3.6/site-packages/tensorflow_core/python/ops/rnn_cell_impl.py:558: Layer.add_variable (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version. Instructions for updating: Please use `layer.add_weight` method instead. W0826 20:37:22.677227 140497609688896 deprecation.py:506] From /usr/local/python3/lib/python3.6/site-packages/tensorflow_core/python/ops/rnn_cell_impl.py:564: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version. Instructions for updating: Call initializer instance with the dtype argument instead of passing it to the constructor W0826 20:37:22.689308 140497609688896 deprecation.py:506] From /usr/local/python3/lib/python3.6/site-packages/tensorflow_core/python/ops/rnn_cell_impl.py:574: calling Zeros.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version. Instructions for updating: Call initializer instance with the dtype argument instead of passing it to the constructor
Configuring RNN model...
W0826 20:37:22.789028 140497609688896 deprecation.py:323] From /home/python_home/WeiZhongChuang/ML/RNN/rnn_model.py:72: dense (from tensorflow.python.layers.core) is deprecated and will be removed in a future version. Instructions for updating: Use keras.layers.Dense instead. W0826 20:37:22.790071 140497609688896 deprecation.py:323] From /usr/local/python3/lib/python3.6/site-packages/tensorflow_core/python/layers/core.py:187: Layer.apply (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version. Instructions for updating: Please use `layer.__call__` method instead. W0826 20:37:22.828924 140497609688896 deprecation.py:323] From /home/python_home/WeiZhongChuang/ML/RNN/rnn_model.py:82: softmax_cross_entropy_with_logits (from tensorflow.python.ops.nn_ops) is deprecated and will be removed in a future version. Instructions for updating: Future major versions of TensorFlow will allow gradients to flow into the labels input on backprop by default. See `tf.nn.softmax_cross_entropy_with_logits_v2`. W0826 20:37:22.848501 140497609688896 module_wrapper.py:136] From /usr/local/python3/lib/python3.6/site-packages/tensorflow_core/python/util/module_wrapper.py:163: The name tf.train.AdamOptimizer is deprecated. Please use tf.compat.v1.train.AdamOptimizer instead. W0826 20:37:23.638830 140497609688896 module_wrapper.py:136] From /usr/local/python3/lib/python3.6/site-packages/tensorflow_core/python/util/module_wrapper.py:163: The name tf.summary.scalar is deprecated. Please use tf.compat.v1.summary.scalar instead.
Configuring TensorBoard and Saver... Loading training and validation data... Time usage: 0:00:13 Training and evaluating... Epoch: 1 Iter: 0, Train Loss: 2.3, Train Acc: 11.72%, Val Loss: 2.3, Val Acc: 9.28%, Time: 0:00:16 * Iter: 100, Train Loss: 0.9, Train Acc: 70.31%, Val Loss: 1.2, Val Acc: 61.48%, Time: 0:03:20 * Iter: 200, Train Loss: 0.53, Train Acc: 84.38%, Val Loss: 0.82, Val Acc: 74.44%, Time: 0:06:34 * Iter: 300, Train Loss: 0.4, Train Acc: 86.72%, Val Loss: 0.64, Val Acc: 82.04%, Time: 0:09:52 * Epoch: 2 Iter: 400, Train Loss: 0.41, Train Acc: 89.84%, Val Loss: 0.61, Val Acc: 82.94%, Time: 0:13:10 * Iter: 500, Train Loss: 0.38, Train Acc: 89.84%, Val Loss: 0.51, Val Acc: 86.54%, Time: 0:16:30 * Iter: 600, Train Loss: 0.23, Train Acc: 92.19%, Val Loss: 0.54, Val Acc: 85.46%, Time: 0:19:51 Iter: 700, Train Loss: 0.11, Train Acc: 96.88%, Val Loss: 0.5, Val Acc: 86.48%, Time: 0:23:13 Epoch: 3 Iter: 800, Train Loss: 0.2, Train Acc: 95.31%, Val Loss: 0.49, Val Acc: 87.02%, Time: 0:26:34 * Iter: 900, Train Loss: 0.19, Train Acc: 96.09%, Val Loss: 0.41, Val Acc: 89.22%, Time: 0:29:56 * Iter: 1000, Train Loss: 0.1, Train Acc: 96.88%, Val Loss: 0.4, Val Acc: 90.08%, Time: 0:33:18 * Iter: 1100, Train Loss: 0.18, Train Acc: 94.53%, Val Loss: 0.37, Val Acc: 90.50%, Time: 0:36:41 * Epoch: 4 Iter: 1200, Train Loss: 0.11, Train Acc: 96.09%, Val Loss: 0.38, Val Acc: 90.24%, Time: 0:40:03 Iter: 1300, Train Loss: 0.13, Train Acc: 94.53%, Val Loss: 0.34, Val Acc: 91.80%, Time: 0:43:26 * Iter: 1400, Train Loss: 0.28, Train Acc: 93.75%, Val Loss: 0.31, Val Acc: 91.44%, Time: 0:46:48 Iter: 1500, Train Loss: 0.18, Train Acc: 92.97%, Val Loss: 0.31, Val Acc: 92.32%, Time: 0:50:11 * Epoch: 5 Iter: 1600, Train Loss: 0.029, Train Acc: 100.00%, Val Loss: 0.47, Val Acc: 87.54%, Time: 0:53:32 Iter: 1700, Train Loss: 0.086, Train Acc: 96.88%, Val Loss: 0.37, Val Acc: 90.70%, Time: 0:56:55 Iter: 1800, Train Loss: 0.21, Train Acc: 93.75%, Val Loss: 0.32, Val Acc: 91.16%, Time: 1:00:18 Iter: 1900, Train Loss: 0.16, Train Acc: 95.31%, Val Loss: 0.31, Val Acc: 91.26%, Time: 1:03:41 Epoch: 6 Iter: 2000, Train Loss: 0.059, Train Acc: 98.44%, Val Loss: 0.35, Val Acc: 91.70%, Time: 1:07:03 Iter: 2100, Train Loss: 0.085, Train Acc: 98.44%, Val Loss: 0.34, Val Acc: 90.58%, Time: 1:10:26 Iter: 2200, Train Loss: 0.067, Train Acc: 98.44%, Val Loss: 0.31, Val Acc: 91.86%, Time: 1:13:49 Iter: 2300, Train Loss: 0.035, Train Acc: 99.22%, Val Loss: 0.32, Val Acc: 91.66%, Time: 1:17:11 Epoch: 7 Iter: 2400, Train Loss: 0.026, Train Acc: 100.00%, Val Loss: 0.39, Val Acc: 90.64%, Time: 1:20:34 Iter: 2500, Train Loss: 0.027, Train Acc: 98.44%, Val Loss: 0.44, Val Acc: 89.42%, Time: 1:23:57 No optimization for a long time, auto-stopping...
模型测试¶
In [5]:
test()
Loading test data...
Testing...
Test Loss: 0.19, Test Acc: 94.51%
Precision, Recall and F1-Score...
precision recall f1-score support
体育 0.99 0.99 0.99 1000
财经 0.92 0.99 0.95 1000
房产 1.00 1.00 1.00 1000
家居 0.98 0.83 0.90 1000
教育 0.88 0.91 0.90 1000
科技 0.93 0.97 0.95 1000
时尚 0.90 0.97 0.93 1000
时政 0.96 0.86 0.91 1000
游戏 0.96 0.96 0.96 1000
娱乐 0.95 0.98 0.96 1000
accuracy 0.95 10000
macro avg 0.95 0.95 0.94 10000
weighted avg 0.95 0.95 0.94 10000
Confusion Matrix...
[[987 1 0 0 8 2 0 0 2 0]
[ 1 987 0 0 2 0 0 8 0 2]
[ 0 0 996 2 0 0 0 0 1 1]
[ 4 26 0 831 27 25 65 4 4 14]
[ 5 8 0 5 914 16 13 23 10 6]
[ 1 1 0 4 7 967 4 1 11 4]
[ 1 0 0 5 8 1 971 1 3 10]
[ 0 51 0 2 60 19 1 863 1 3]
[ 0 0 0 0 9 5 20 0 957 9]
[ 2 3 0 2 2 1 7 1 4 978]]
Time usage: 0:01:05
In [ ]: