# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Seq2Seq构建"""
import math
import numpy as np
from mindspore import Tensor
from mindspore import Tensor, Parameter
import mindspore.nn as nn
import mindspore.ops.operations as P
import mindspore.common.dtype as mstype
from mindspore.nn.loss.loss import NLLLoss


# 定义GRU中的权重和偏置
def gru_default_state(
    batch_size, input_size, hidden_size, num_layers=1, bidirectional=False
):
    """GRU（LSTM的变体）的权重初始化"""
    stdv = 1 / math.sqrt(hidden_size)  # 设置标准差
    # 输入层权重初始化（权重从一个均匀分布[low,high)中随机采样）
    weight_i = Parameter(
        Tensor(
            np.random.uniform(-stdv, stdv, (input_size, 3 * hidden_size)).astype(
                np.float32
            )
        ),
        name="weight_i",
    )
    # 隐藏层权重初始化
    weight_h = Parameter(
        Tensor(
            np.random.uniform(-stdv, stdv, (hidden_size, 3 * hidden_size)).astype(
                np.float32
            )
        ),
        name="weight_h",
    )
    # 输入层偏置初始化
    bias_i = Parameter(
        Tensor(np.random.uniform(-stdv, stdv, (3 * hidden_size)).astype(np.float32)),
        name="bias_i",
    )
    # 隐藏层偏置初始化
    bias_h = Parameter(
        Tensor(np.random.uniform(-stdv, stdv, (3 * hidden_size)).astype(np.float32)),
        name="bias_h",
    )
    return weight_i, weight_h, bias_i, bias_h


# 定义GRU网络
class GRU(nn.Cell):
    def __init__(self, config, is_training=True):
        super(GRU, self).__init__()
        if is_training:  # 确认是否训练，若是训练，则采用训练集
            self.batch_size = config.batch_size
        else:
            self.batch_size = config.eval_batch_size
        self.hidden_size = config.hidden_size  # 调用参数config
        self.weight_i, self.weight_h, self.bias_i, self.bias_h = gru_default_state(
            self.batch_size, self.hidden_size, self.hidden_size
        )  # 调用GRU网络中的权重和偏置
        # self.rnn = P.DynamicGRUV2()  # 调用AI框架Mindspore的GRU网络
        self.rnn = nn.GRU(
            input_size=self.hidden_size, hidden_size=self.hidden_size, batch_first=False
        )
        self.cast = P.Cast()  # 转换成特定的数据类型

    def construct(self, x, hidden):
        x = self.cast(x, mstype.float16)  # 转换成特定的数据类型
        # y1, h1, _, _, _, _ = self.rnn(
        #     x, self.weight_i, self.weight_h, self.bias_i, self.bias_h, None, hidden
        # )
        y1, h1 = self.rnn(x, hidden.unsqueeze(0))
        return y1, h1


# 定义编码器，将一组序列编码成一个向量，选用GRU在最后一个时间点的输出hidden来作为來context vector。
class Encoder(nn.Cell):
    def __init__(self, config, is_training=True):
        super(Encoder, self).__init__()
        self.vocab_size = config.en_vocab_size  # 英文词典大小
        self.hidden_size = config.hidden_size  # 隐藏层单元数量
        if is_training:  # 确认训练过程
            self.batch_size = config.batch_size
        else:
            self.batch_size = config.eval_batch_size

        self.trans = P.Transpose()  # 矩阵转置
        self.perm = (1, 0, 2)  # 维度的重新排列
        self.embedding = nn.Embedding(
            self.vocab_size, self.hidden_size
        )  # 设置嵌入层于特定维度
        self.gru = GRU(config, is_training=is_training).to_float(
            mstype.float16
        )  # 转换成特定类型
        self.h = Tensor(
            np.zeros((self.batch_size, self.hidden_size)).astype(np.float16)
        )  # 隐藏层张量初始输入设为0

    def construct(self, encoder_input):
        embeddings = self.embedding(encoder_input)  # 输入嵌入层
        embeddings = self.trans(embeddings, self.perm)  # 设置输入层：转置+维度排列
        output, hidden = self.gru(embeddings, self.h)  # 经过隐藏层输出
        return output, hidden


# 定义解码器，额外加上一个线性输出层out，用来预测当时时间点的输出字母：
class Decoder(nn.Cell):
    def __init__(self, config, is_training=True):
        super(Decoder, self).__init__()

        self.vocab_size = config.ch_vocab_size  # 中文词表大小
        self.hidden_size = config.hidden_size  # 隐藏层单元数量

        self.trans = P.Transpose()  # 矩阵转置
        self.perm = (1, 0, 2)  # 维度的重新排列
        self.embedding = nn.Embedding(
            self.vocab_size, self.hidden_size
        )  # 设置嵌入层于特定维度
        self.gru = GRU(config, is_training=is_training).to_float(
            mstype.float16
        )  # 定义GRU网络，确认数据类型
        self.dense = nn.Dense(self.hidden_size, self.vocab_size)  # 定义全连接层
        self.softmax = nn.LogSoftmax(
            axis=2
        )  # 定义LogSoftmax激活函数，数值稳定性优于Softmax
        self.cast = P.Cast()  # 转换数据类型

    def construct(self, decoder_input, hidden):
        embeddings = self.embedding(decoder_input)  # 解码器嵌入层输入
        embeddings = self.trans(embeddings, self.perm)  # 输入嵌入层转置+维度排列
        output, hidden = self.gru(embeddings, hidden)  # 通过隐藏层输出
        output = self.cast(output, mstype.float32)  # 将输出转换数据格式
        output = self.dense(output)  # 通过全连接层输出
        output = self.softmax(output)  # 通过激活函数得结果

        return output, hidden


# 构建Seq2Seq模型
class Seq2Seq(nn.Cell):
    def __init__(self, config, is_train=True):
        super(Seq2Seq, self).__init__()
        self.max_len = config.max_seq_length  # Token序列的最大长度
        self.is_train = is_train  # 确认是否训练

        self.encoder = Encoder(config, is_train)  # 确认训练编码器
        self.decoder = Decoder(config, is_train)  # 确认训练解码器
        self.expanddims = P.ExpandDims()  # 扩展维度
        self.squeeze = P.Squeeze(axis=0)  # 移除维度
        self.argmax = P.ArgMaxWithValue(axis=int(2), keep_dims=True)  # 输出最大索引值
        self.concat = P.Concat(axis=1)  # 横向合并
        self.concat2 = P.Concat(axis=0)  # 纵向合并
        self.select = P.Select()

    def construct(self, src, dst):
        encoder_output, hidden = self.encoder(src)  # 将编码器输入到隐藏层训练
        decoder_hidden = self.squeeze(
            encoder_output[self.max_len - 2 : self.max_len - 1 : 1, ::, ::]
        )  # 将编码器的输出到解码器的隐藏层
        if self.is_train:
            outputs, _ = self.decoder(dst, decoder_hidden)
        else:
            decoder_input = dst[::, 0:1:1]
            decoder_outputs = ()
            for i in range(0, self.max_len):
                decoder_output, decoder_hidden = self.decoder(
                    decoder_input, decoder_hidden
                )
                # 从[seq_length，batch_size, hidden_size]squeeze把第一维移除变成[batch_size, hidden_size]
                decoder_hidden = self.squeeze(decoder_hidden)
                decoder_output, _ = self.argmax(decoder_output)  # 获取最大结果
                decoder_output = self.squeeze(
                    decoder_output
                )  # 移除第一维度（seq_length）
                decoder_outputs += (decoder_output,)  # 更新结果
                decoder_input = decoder_output
            outputs = self.concat(decoder_outputs)  # 横向合并解码器结果
        return outputs


class WithLossCell(nn.Cell):
    def __init__(self, backbone, config):
        super(WithLossCell, self).__init__(auto_prefix=False)
        self._backbone = backbone
        self.batch_size = config.batch_size
        # self.onehot = nn.OneHot(depth=config.ch_vocab_size)  # 独热编码中文
        self.loss_fn = NLLLoss()  # 调用损失函数
        self.max_len = config.max_seq_length  # 最长序列参数
        self.squeeze = P.Squeeze()  # 移除维度
        self.cast = P.Cast()  # 转换数据类型
        self.argmax = P.ArgMaxWithValue(axis=1, keep_dims=True)
        self.print = P.Print()

    def construct(self, src, dst, label):
        out = self._backbone(src, dst)
        loss_total = 0  # 定义初始损失值
        for i in range(self.batch_size):
            loss = self.loss_fn(
                self.squeeze(out[::, i : i + 1 : 1, ::]),
                self.squeeze(label[i : i + 1 : 1, ::]),
            )
            loss_total += loss
        loss = loss_total / self.batch_size  # 单个批尺寸数据集的损失值
        return loss


class InferCell(nn.Cell):
    def __init__(self, network, config):
        super(InferCell, self).__init__(auto_prefix=False)
        self.expanddims = P.ExpandDims()  # 扩展维度
        self.network = network

    def construct(self, src, dst):
        out = self.network(src, dst)
        return out
