import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore import Tensor
import mindspore.common.dtype as mstype
import mindspore.ops.functional as F


class NLLLoss(nn.LossBase):
    """
    Custom NLLLoss function in MindSpore.
    The input is a log probability vector and a target label.
    """

    def __init__(self, reduction="mean"):
        super(NLLLoss, self).__init__(reduction=reduction)
        self.one_hot = P.OneHot()  # OneHot encoding
        self.reduce_sum = P.ReduceSum(keep_dims=False)  # Summation operation

    def construct(self, logits, label):
        # Ensure labels are of integer type for one-hot encoding
        label = label.astype(mstype.int32)
        # One-hot encoding of labels
        label_one_hot = self.one_hot(
            label, F.shape(logits)[-1], F.scalar_to_array(1.0), F.scalar_to_array(0.0)
        )

        # Ensure logits are log probabilities; if logits are probabilities, uncomment the next line:
        # logits = F.log(logits)

        # Calculate the negative log likelihood loss
        loss = self.reduce_sum(-1.0 * logits * label_one_hot, -1)
        return self.get_loss(loss)


# Example usage:
# Assuming `logits` are log probabilities and `labels` are ground truth labels.
# logits = Tensor(np.random.rand(10, 5).astype(np.float32))
# labels = Tensor(np.array([1, 0, 4, 1, 2, 3, 0, 4, 1, 2]), dtype=mstype.int32)
# loss_fn = NLLLoss(reduction='mean')
# loss = loss_fn(logits, labels)
# print("Calculated Loss:", loss)
