import os  # 导入标准库OS
import re
import sys
import numpy as np
from mindspore import dataset as ds


# 得到目标操作（通过encoder-decoder得到相应的输入输出）
def target_operation(encoder_data, decoder_data):
    encoder_data = encoder_data[1:]
    target_data = decoder_data[1:]
    decoder_data = decoder_data[:-1]
    return encoder_data, decoder_data, target_data


# 验证操作
def eval_operation(encoder_data, decoder_data):
    encoder_data = encoder_data[1:]
    decoder_data = decoder_data[:-1]
    return encoder_data, decoder_data


# 得到训练数据集
def create_dataset(
    data_home, batch_size, repeat_num=1, is_training=True, device_num=1, rank=0
):
    if is_training:
        data_dir = os.path.join(data_home, "gru_train.mindrecord")  # 合并路径
    else:
        data_dir = os.path.join(data_home, "gru_eval.mindrecord")  #
    data_set = ds.MindDataset(
        data_dir,
        columns_list=["encoder_data", "decoder_data"],
        num_parallel_workers=4,
        num_shards=device_num,
        shard_id=rank,
    )  # 通过训练分别得到encoder和decoder的数据集
    if is_training:  # 训练阶段
        operations = target_operation  # 调用得到目标数据
        data_set = data_set.map(
            operations=operations,
            input_columns=["encoder_data", "decoder_data"],
            output_columns=["encoder_data", "decoder_data", "target_data"],
            # column_order=["encoder_data", "decoder_data", "target_data"],
        )
        data_set = data_set.project(["encoder_data", "decoder_data", "target_data"])
    else:  # 验证阶段
        operations = eval_operation
        data_set = data_set.map(
            operations=operations,
            input_columns=["encoder_data", "decoder_data"],
            output_columns=["encoder_data", "decoder_data"],
            # column_order=["encoder_data", "decoder_data"],
        )
        data_set = data_set.project(["encoder_data", "decoder_data"])
    data_set = data_set.shuffle(buffer_size=data_set.get_dataset_size())  # 打乱数据集
    data_set = data_set.batch(
        batch_size=batch_size, drop_remainder=True
    )  # 将数据集分批
    data_set = data_set.repeat(count=repeat_num)  # 重复数据集
    return data_set
