import os
import re
import sys
import random
import numpy as np
import unicodedata  # 使用unicodedata模块先将文本标准化
from mindspore import dataset as ds
from mindspore.mindrecord import FileWriter

# 预备特殊字元，在开头添加 <SOS>，在结尾添加 <EOS>
EOS = "<eos>"
SOS = "<sos>"
MAX_SEQ_LEN = 10


# 多用于那些需要包含音调的字符体系中，Unicode体系中，使用Decompose(分离)分别存储字符(U+0043)本身和音调(U+0327)本身。
# 从给定的字符串中删除重音符号。 输入文本是unicode字符串，返回带有重音符号的输入字符串，作为unicode。
# normalize() 第一个参数指定字符串标准化的方式。 NFD表示字符应该分解为多个组合字符表示。
def unicodeToAscii(s):
    return "".join(
        c for c in unicodedata.normalize("NFD", s) if unicodedata.category(c) != "Mn"
    )


# 标准化处理字符串
def normalizeString(s):
    s = s.lower().strip()  # lower将整个字符串改为小写；strip删除字符串前后的空白。
    s = unicodeToAscii(s)  # 调用函数将Unicode转化成Ascii
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)  # 将符号“.!?”前用空格隔开
    return s


def prepare_data(data_path, vocab_save_path, max_seq_len):
    with open(data_path, "r", encoding="utf-8") as f:
        data = f.read()  # 读取文件

    # 得到文件中的内容
    data = data.split("\n")

    data = data[:2000]

    # 拆分英文句子和中文句子
    en_data = [
        normalizeString(line.split("\t")[0]) for line in data
    ]  # 得到标准化处理的英文句子
    ch_data = [line.split("\t")[1] for line in data]  # 得到中文句子

    # 获取单词并存储
    en_vocab = set(" ".join(en_data).split(" "))  # 获取不重复的英文单词
    id2en = [EOS] + [SOS] + list(en_vocab)  # 英文单词表中加上两个始末特殊字元
    en2id = {c: i for i, c in enumerate(id2en)}  # 遍历所有英文单词组合为一个索引序列
    en_vocab_size = len(id2en)  # 查看英文单词个数
    np.savetxt(
        os.path.join(vocab_save_path, "en_vocab.txt"), np.array(id2en), fmt="%s"
    )  # 将英文单词表保存

    ch_vocab = set("".join(ch_data))  # 获取不重复的中文单词
    id2ch = [EOS] + [SOS] + list(ch_vocab)  # 中文单词表中加上两个始末特殊字元
    ch2id = {
        c: i for i, c in enumerate(id2ch)
    }  # 遍历所有中文单词组合为一个索引序列，即获取每个单词的id
    ch_vocab_size = len(id2ch)  # 查看中文单词个数
    np.savetxt(
        os.path.join(vocab_save_path, "ch_vocab.txt"), np.array(id2ch), fmt="%s"
    )  # 将中文单词表保存

    # 将中英文句子转换为单词ids组合 --> [SOS] + sentences ids + [EOS]
    en_num_data = np.array(
        [[1] + [int(en2id[en]) for en in line.split(" ")] + [0] for line in en_data],
        dtype=object,
    )
    ch_num_data = np.array(
        [[1] + [int(ch2id[ch]) for ch in line] + [0] for line in ch_data], dtype=object
    )

    # 将上述句子的索引ID组合长度延长到自定义的max_length
    for i in range(len(en_num_data)):
        num = max_seq_len + 1 - len(en_num_data[i])
        if num >= 0:
            en_num_data[i] += [0] * num
        else:
            en_num_data[i] = en_num_data[i][:max_seq_len] + [0]

    for i in range(len(ch_num_data)):
        num = max_seq_len + 1 - len(ch_num_data[i])
        if num >= 0:
            ch_num_data[i] += [0] * num
        else:
            ch_num_data[i] = ch_num_data[i][:max_seq_len] + [0]

    return en_num_data, ch_num_data, en_vocab_size, ch_vocab_size


# 转换保存mindspore的中英文单词表
def convert_to_mindrecord(data_path, mindrecord_save_path, max_seq_len):
    en_num_data, ch_num_data, en_vocab_size, ch_vocab_size = prepare_data(
        data_path, mindrecord_save_path, max_seq_len
    )

    data_list_train = []
    for en, de in zip(en_num_data, ch_num_data):
        en = np.array(en).astype(np.int32)  # 将英文句子ID强制转换为指定的整数类型。
        de = np.array(de).astype(np.int32)  # 将中文句子ID强制转换为指定的整数类型。
        data_json = {"encoder_data": en.reshape(-1), "decoder_data": de.reshape(-1)}
        data_list_train.append(data_json)  # 将英文作为编码器，中文作为解码器加入
    data_list_eval = random.sample(data_list_train, 20)

    data_dir = os.path.join(
        mindrecord_save_path, "gru_train.mindrecord"
    )  # 把目录和文件名合成一个路径.

    writer = FileWriter(data_dir)  # 用于将用户定义的原始数据写入MindRecord File系列。
    schema_json = {
        "encoder_data": {"type": "int32", "shape": [-1]},
        "decoder_data": {"type": "int32", "shape": [-1]},
    }  # 设计编码器和解码器架构
    writer.add_schema(
        schema_json, "gru_schema"
    )  # 添加架构，如果成功添加架构，则返回架构ID，或引发异常。
    writer.write_raw_data(
        data_list_train
    )  # 默认情况下，写入原始数据，生成MindRecord File的顺序对，并根据预定义的模式对数据进行校验。
    writer.commit()  # 将数据刷新到磁盘并生成相应的db文件。

    data_dir = os.path.join(mindrecord_save_path, "gru_eval.mindrecord")
    writer = FileWriter(data_dir)
    writer.add_schema(schema_json, "gru_schema")
    writer.write_raw_data(data_list_eval)
    writer.commit()

    print("en_vocab_size: ", en_vocab_size)  # 打印出英文单词长度
    print("ch_vocab_size: ", ch_vocab_size)  # 打印出中文单词长度

    return en_vocab_size, ch_vocab_size


if __name__ == "__main__":
    convert_to_mindrecord("src/cmn_zhsim.txt", "./preprocess", MAX_SEQ_LEN)
