In [1]:
import math
import os
import sys
import time
import numpy as np
from six.moves import xrange
import tensorflow as tf
datautil = __import__("datautil")
seq2seq_model = __import__("seq2seq_model")
import datautil
import seq2seq_model
Building prefix dict from the default dictionary ... Loading model from cache /tmp/jieba.cache Loading model cost 0.657 seconds. Prefix dict has been built succesfully.
In [ ]:
tf.reset_default_graph()
steps_per_checkpoint=200
max_train_data_size= 0#(0: no limit)
dropout = 0.9
grad_clip = 5.0
batch_size = 60
num_layers =2
learning_rate =0.5
lr_decay_factor =0.99
###############翻译
hidden_size = 100
checkpoint_dir= "fanyichina/checkpoints/"
_buckets =[(20, 20), (40, 40), (50, 50), (60, 60)]
def getfanyiInfo():
vocaben, rev_vocaben=datautil.initialize_vocabulary(os.path.join(datautil.data_dir, datautil.vocabulary_fileen))
vocab_sizeen= len(vocaben)
print("vocab_size",vocab_sizeen)
vocabch, rev_vocabch=datautil.initialize_vocabulary(os.path.join(datautil.data_dir, datautil.vocabulary_filech))
vocab_sizech= len(vocabch)
print("vocab_sizech",vocab_sizech)
filesfrom,_=datautil.getRawFileList(datautil.data_dir+"fromids/")
filesto,_=datautil.getRawFileList(datautil.data_dir+"toids/")
source_train_file_path = filesfrom[0]
target_train_file_path= filesto[0]
return vocab_sizeen,vocab_sizech,rev_vocaben,rev_vocabch,source_train_file_path,target_train_file_path
################################################################
#source_train_file_path = os.path.join(datautil.data_dir, "data_source_test.txt")
#target_train_file_path = os.path.join(datautil.data_dir, "data_target_test.txt")
def main():
vocab_sizeen,vocab_sizech,rev_vocaben,rev_vocabch,source_train_file_path,target_train_file_path = getfanyiInfo()
if not os.path.exists(checkpoint_dir):
os.mkdir(checkpoint_dir)
print ("checkpoint_dir is {0}".format(checkpoint_dir))
with tf.Session() as sess:
model = createModel(sess,False,vocab_sizeen,vocab_sizech)
print ("Using bucket sizes:")
print (_buckets)
source_test_file_path = source_train_file_path
target_test_file_path = target_train_file_path
print (source_train_file_path)
print (target_train_file_path)
train_set = readData(source_train_file_path, target_train_file_path,max_train_data_size)
test_set = readData(source_test_file_path, target_test_file_path,max_train_data_size)
train_bucket_sizes = [len(train_set[b]) for b in xrange(len(_buckets))]
print( "bucket sizes = {0}".format(train_bucket_sizes))
train_total_size = float(sum(train_bucket_sizes))
# A bucket scale is a list of increasing numbers from 0 to 1 that we'll use
# to select a bucket. Length of [scale[i], scale[i+1]] is proportional to
# the size if i-th training bucket, as used later.
train_buckets_scale = [sum(train_bucket_sizes[:i + 1]) / train_total_size for i in xrange(len(train_bucket_sizes))]
step_time, loss = 0.0, 0.0
current_step = 0
previous_losses = []
while True:
# Choose a bucket according to data distribution. We pick a random number
# in [0, 1] and use the corresponding interval in train_buckets_scale.
random_number_01 = np.random.random_sample()
bucket_id = min([i for i in xrange(len(train_buckets_scale)) if train_buckets_scale[i] > random_number_01])
# 开始训练.
start_time = time.time()
encoder_inputs, decoder_inputs, target_weights = model.get_batch(train_set, bucket_id)
_, step_loss, _ = model.step(sess, encoder_inputs, decoder_inputs,target_weights, bucket_id, False)
step_time += (time.time() - start_time) / steps_per_checkpoint
loss += step_loss / steps_per_checkpoint
current_step += 1
# 保存检查点,测试数据
if current_step % steps_per_checkpoint == 0:
# Print statistics for the previous epoch.
perplexity = math.exp(loss) if loss < 300 else float('inf')
print ("global step %d learning rate %.4f step-time %.2f perplexity "
"%.2f" % (model.global_step.eval(), model.learning_rate.eval(),step_time, perplexity))
# Decrease learning rate if no improvement was seen over last 3 times.
if len(previous_losses) > 2 and loss > max(previous_losses[-3:]):
sess.run(model.learning_rate_decay_op)
previous_losses.append(loss)
# Save checkpoint and zero timer and loss.
checkpoint_path = os.path.join(checkpoint_dir, "seq2seqtest.ckpt")
print(checkpoint_path)
model.saver.save(sess, checkpoint_path, global_step=model.global_step)
step_time, loss = 0.0, 0.0
# Run evals on development set and print their perplexity.
for bucket_id in xrange(len(_buckets)):
if len(test_set[bucket_id]) == 0:
print(" eval: empty bucket %d" % (bucket_id))
continue
encoder_inputs, decoder_inputs, target_weights = model.get_batch(test_set, bucket_id)
_, eval_loss,output_logits = model.step(sess, encoder_inputs, decoder_inputs,target_weights, bucket_id, True)
eval_ppx = math.exp(eval_loss) if eval_loss < 300 else float('inf')
print(" eval: bucket %d perplexity %.2f" % (bucket_id, eval_ppx))
inputstr = datautil.ids2texts(reversed([en[0] for en in encoder_inputs]) ,rev_vocaben)
print("输入",inputstr)
print("输出",datautil.ids2texts([en[0] for en in decoder_inputs] ,rev_vocabch))
outputs = [np.argmax(logit, axis=1)[0] for logit in output_logits]
#outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]
#print("outputs",outputs,datautil.EOS_ID)
if datautil.EOS_ID in outputs:
outputs = outputs[:outputs.index(datautil.EOS_ID)]
print("结果",datautil.ids2texts(outputs,rev_vocabch))
sys.stdout.flush()
def createModel(session, forward_only,from_vocab_size,to_vocab_size):
"""Create translation model and initialize or load parameters in session."""
model = seq2seq_model.Seq2SeqModel(
from_vocab_size,#from
to_vocab_size,#to
_buckets,
hidden_size,
num_layers,
dropout,
grad_clip,
batch_size,
learning_rate,
lr_decay_factor,
forward_only=forward_only,
dtype=tf.float32)
print("model is ok")
ckpt = tf.train.latest_checkpoint(checkpoint_dir)
if ckpt!=None:
model.saver.restore(session, ckpt)
print ("Reading model parameters from {0}".format(ckpt))
else:
print ("Created model with fresh parameters.")
session.run(tf.global_variables_initializer())
return model
def readData(source_path, target_path, max_size=None):
'''
This method directly from tensorflow translation example
'''
data_set = [[] for _ in _buckets]
with tf.gfile.GFile(source_path, mode="r") as source_file:
with tf.gfile.GFile(target_path, mode="r") as target_file:
source, target = source_file.readline(), target_file.readline()
counter = 0
while source and target and (not max_size or counter < max_size):
counter += 1
if counter % 100000 == 0:
print(" reading data line %d" % counter)
sys.stdout.flush()
source_ids = [int(x) for x in source.split()]
target_ids = [int(x) for x in target.split()]
target_ids.append(datautil.EOS_ID)
for bucket_id, (source_size, target_size) in enumerate(_buckets):
if len(source_ids) < source_size and len(target_ids) < target_size:
data_set[bucket_id].append([source_ids, target_ids])
break
source, target = source_file.readline(), target_file.readline()
return data_set
if __name__ == '__main__':
main()
vocab_size 11963 vocab_sizech 15165 checkpoint_dir is fanyichina/checkpoints/ WARNING:tensorflow:From /usr/local/python3/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version. Instructions for updating: Colocations handled automatically by placer. WARNING: The TensorFlow contrib module will not be included in TensorFlow 2.0. For more information, please see: * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md * https://github.com/tensorflow/addons If you depend on functionality not listed there, please file an issue. WARNING:tensorflow:From /home/python_home/WeiZhongChuang/ML/TensorFlow/Attention/seq2seq_model.py:124: GRUCell.__init__ (from tensorflow.python.ops.rnn_cell_impl) is deprecated and will be removed in a future version. Instructions for updating: This class is equivalent as tf.keras.layers.GRUCell, and will be replaced by that in Tensorflow 2.0. WARNING:tensorflow:From /home/python_home/WeiZhongChuang/ML/TensorFlow/Attention/seq2seq_model.py:128: MultiRNNCell.__init__ (from tensorflow.python.ops.rnn_cell_impl) is deprecated and will be removed in a future version. Instructions for updating: This class is equivalent as tf.keras.layers.StackedRNNCells, and will be replaced by that in Tensorflow 2.0. WARNING:tensorflow:At least two cells provided to MultiRNNCell are the same object and will share weights. new a cell WARNING:tensorflow:From /usr/local/python3/lib/python3.6/site-packages/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py:863: static_rnn (from tensorflow.python.ops.rnn) is deprecated and will be removed in a future version. Instructions for updating: Please use `keras.layers.RNN(cell, unroll=True)`, which is equivalent to this API WARNING:tensorflow:From /usr/local/python3/lib/python3.6/site-packages/tensorflow/python/ops/rnn_cell_impl.py:1259: calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is deprecated and will be removed in a future version. Instructions for updating: Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`. WARNING:tensorflow:From /usr/local/python3/lib/python3.6/site-packages/tensorflow/python/ops/nn_impl.py:1444: sparse_to_dense (from tensorflow.python.ops.sparse_ops) is deprecated and will be removed in a future version. Instructions for updating: Create a `tf.sparse.SparseTensor` and use `tf.sparse.to_dense` instead. new a cell new a cell new a cell WARNING:tensorflow:From /usr/local/python3/lib/python3.6/site-packages/tensorflow/python/ops/array_grad.py:425: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version. Instructions for updating: Use tf.cast instead. model is ok WARNING:tensorflow:From /usr/local/python3/lib/python3.6/site-packages/tensorflow/python/training/saver.py:1266: checkpoint_exists (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version. Instructions for updating: Use standard file APIs to check for files with this prefix. INFO:tensorflow:Restoring parameters from fanyichina/checkpoints/seq2seqtest.ckpt-54000 Reading model parameters from fanyichina/checkpoints/seq2seqtest.ckpt-54000 Using bucket sizes: [(20, 20), (40, 40), (50, 50), (60, 60)] fanyichina/fromids/english1w.txt fanyichina/toids/chinese1w.txt bucket sizes = [1649, 4933, 1904, 1383] global step 54200 learning rate 0.3699 step-time 0.72 perplexity 3.16 fanyichina/checkpoints/seq2seqtest.ckpt eval: bucket 0 perplexity 1.74 输入 ['third', ',', 'the', 'pace', 'of', 'selling', 'public', 'houses', 'was', 'accelerated', '.', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD'] 输出 ['_GO', '三', '是', '加快', '了', '公有', '住房', '的', '出售', '.', '_EOS', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD'] eval: bucket 1 perplexity 2.93 输入 ['thanks', 'to', 'the', 'equilibrium', 'of', 'the', 'international', 'payments', ',', 'china', "'s", 'exchange', 'rates', 'have', 'all', 'along', 'been', 'comparatively', 'stable', '.', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD'] 输出 ['_GO', '由', '於', '国际', '收支平衡', ',', '中国', '的', '汇率', '一直', '比较', '稳定', '.', '_EOS', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD'] 结果 ['记者', '中国', '关系', '进行', '对'] eval: bucket 2 perplexity 3.96 输入 ['in', 'order', 'to', 'respond', 'to', 'the', 'vast', 'business', 'opportunities', 'in', 'the', 'future', 'when', 'there', 'are', 'direct', 'cross', '-', 'strait', 'flights', ',', 'taiwan', "'s", 'fu', 'hsing', 'aviation', 'has', 'spent', 'a', 'huge', 'sum', 'on', 'buying', 'medium', 'and', 'long', '-', 'range', 'versions', 'of', 'the', 'european', 'airbus', '.', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD'] 输出 ['_GO', '台湾', '复兴', '航空', '为', '因', '应', '未来', '两岸', '直航', '的', '庞大', '商机', ',', '大笔', '购', '进', '欧洲', '空中', '巴士', '的', '中', ',', '长程', '客机', '.', '_EOS', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD'] 结果 ['建设', '建设', '说', '建设', '加强', '是', '说', '是', '重要', '是', '是', '是', '实现', '市场', '这', '这', '的'] eval: bucket 3 perplexity 5.66 输入 ['zhu', 'bangzao', 'said', 'that', 'after', 'hong', 'kong', "'s", 'reversion', ',', 'china', "'s", 'central', 'government', 'and', 'the', 'hong', 'kong', 'special', 'administrative', 'region', '[', 'hksar', ']', 'government', 'implemented', 'the', 'policy', 'of', '"', 'one', 'country', ',', 'two', 'systems', ',', '"', '"', 'hong', 'kong', 'people', 'governing', 'hong', 'kong', ',', '"', 'and', 'a', 'high', 'degree', 'of', 'autonomy', '.', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD'] 输出 ['_GO', '朱邦造', '说', ',', '香港', '回归', '后', ',', '中国', '中央政府', '和', '香港特区', '政府', '贯彻', '"', '一国两制', '"', '"', '港人', '治', '港', '"', '和', '高度', '自治', '的', '方针', '.', '_EOS', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD', '_PAD']
In [ ]: