# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
#################验证 gru 网络######################
"""
import argparse
import os
import numpy as np
from src.dataset import create_dataset
from src.seq2seq import Seq2Seq, InferCell
from src.config import cfg
from mindspore import Tensor, nn, Model, context, DatasetHelper
from mindspore.train.serialization import load_param_into_net, load_checkpoint
from mindspore.communication.management import init, get_rank
from mindspore.context import ParallelMode

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='MindSpore GRU Example')
    parser.add_argument('--dataset_path', type=str, default='./preprocess', help='dataset path.')
    parser.add_argument('--checkpoint_path', type=str, default='', help='checkpoint path.')
    args = parser.parse_args()
    # 在Ascend芯片设备中训练
    context.set_context(
        mode=context.GRAPH_MODE,#MindSpore图模式
        save_graphs=False,
        device_target='Ascend')

    rank = 0
    device_num = 1 #设备数
    ds_eval= create_dataset(args.dataset_path, cfg.eval_batch_size, is_training=False) #获取数据集

    network = Seq2Seq(cfg,is_train=False) #建立Seq2Seq网络
    network = InferCell(network, cfg)  #将设定的参数带入网络
    network.set_train(False) #验证阶段
    parameter_dict = load_checkpoint(args.checkpoint_path) #加载检查点
    load_param_into_net(network, parameter_dict)
    model = Model(network) #建立模型

    with open(os.path.join(args.dataset_path,"en_vocab.txt"), 'r', encoding='utf-8') as f:
        data = f.read() #读取英文词表
    en_vocab = list(data.split('\n')) #换行分割

    with open(os.path.join(args.dataset_path,"ch_vocab.txt"), 'r', encoding='utf-8') as f:
        data = f.read() #读取中文词表
    ch_vocab = list(data.split('\n'))
    #创建中英文对照输出
    for data in ds_eval.create_dict_iterator():
        en_data=''
        ch_data=''
        for x in data['encoder_data'][0]: #编码器输出
            if x == 0:
                break
            en_data += en_vocab[x] #将英文数据逐步更新
            en_data += ' ' #空格间隔
        for x in data['decoder_data'][0]: #解码器输出
            if x == 0:
                break
            if x == 1:
                continue
            ch_data += ch_vocab[x]  #将中文数据逐步更新
        output = network(data['encoder_data'],data['decoder_data']) #输出结果
        print('English:', en_data) #打印英文结果
        print('expect Chinese:', ch_data) #打印对应的中文翻译
        out ='' #中文结果初始空白
        for x in output[0]:
            if x == 0:
                break
            out += ch_vocab[x]
        print('predict Chinese:', out) #答应翻译中文结果
        print(' ')
