# -*- coding: utf-8 -*-
"""
@brief      test log(time=10s)
"""
import unittest
import numpy
from numpy.testing import assert_almost_equal
from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer
from skl2onnx import to_onnx
from skl2onnx.sklapi import TraceableTfidfVectorizer, TraceableCountVectorizer
from skl2onnx.sklapi.sklearn_text_onnx import register
from skl2onnx.common.data_types import StringTensorType
from test_utils import dump_data_and_model, TARGET_OPSET


class TestSklearnText(unittest.TestCase):

    def test_count_vectorizer(self):

        corpus = numpy.array([
            "This is the first document.",
            "This document is the second document.",
            "And this is the third one.",
            "Is this the first document?",
            "",
        ]).reshape((5, ))

        for ng in [(1, 1), (1, 2), (2, 2), (1, 3)]:
            mod1 = CountVectorizer(ngram_range=ng)
            mod1.fit(corpus)

            mod2 = TraceableCountVectorizer(ngram_range=ng)
            mod2.fit(corpus)

            pred1 = mod1.transform(corpus)
            pred2 = mod2.transform(corpus)
            assert_almost_equal(pred1.todense(), pred2.todense())

            voc = mod2.vocabulary_
            for k in voc:
                self.assertIsInstance(k, tuple)

    def test_count_vectorizer_regex(self):

        corpus = numpy.array([
            "This is the first document.",
            "This document is the second document.",
            "And this is the third one.",
            "Is this the first document?",
            "",
        ]).reshape((5, ))

        for pattern in ["[a-zA-Z ]{1,4}", "[a-zA-Z]{1,4}"]:
            for ng in [(1, 1), (1, 2), (2, 2), (1, 3)]:
                mod1 = CountVectorizer(ngram_range=ng, token_pattern=pattern)
                mod1.fit(corpus)

                mod2 = TraceableCountVectorizer(ngram_range=ng,
                                                token_pattern=pattern)
                mod2.fit(corpus)

                pred1 = mod1.transform(corpus)
                pred2 = mod2.transform(corpus)
                assert_almost_equal(pred1.todense(), pred2.todense())

                voc = mod2.vocabulary_
                for k in voc:
                    self.assertIsInstance(k, tuple)
                if " ]" in pattern:
                    spaces = 0
                    for k in voc:
                        self.assertIsInstance(k, tuple)
                        for i in k:
                            if ' ' in i:
                                spaces += 1
                    self.assertGreater(spaces, 1)

    def test_tfidf_vectorizer(self):

        corpus = numpy.array([
            "This is the first document.",
            "This document is the second document.",
            "And this is the third one.",
            "Is this the first document?",
            "",
        ]).reshape((5, ))

        for ng in [(1, 1), (1, 2), (2, 2), (1, 3)]:
            mod1 = TfidfVectorizer(ngram_range=ng)
            mod1.fit(corpus)

            mod2 = TraceableTfidfVectorizer(ngram_range=ng)
            mod2.fit(corpus)

            pred1 = mod1.transform(corpus)
            pred2 = mod2.transform(corpus)
            assert_almost_equal(pred1.todense(), pred2.todense())

            voc = mod2.vocabulary_
            for k in voc:
                self.assertIsInstance(k, tuple)

    def test_tfidf_vectorizer_regex(self):
        corpus = numpy.array([
            "This is the first document.",
            "This document is the second document.",
            "And this is the third one.",
            "Is this the first document?",
            "",
        ]).reshape((5, ))

        for pattern in ["[a-zA-Z ]{1,4}", "[a-zA-Z]{1,4}"]:
            for ng in [(1, 1), (1, 2), (2, 2), (1, 3)]:
                mod1 = TfidfVectorizer(ngram_range=ng, token_pattern=pattern)
                mod1.fit(corpus)

                mod2 = TraceableTfidfVectorizer(ngram_range=ng,
                                                token_pattern=pattern)
                mod2.fit(corpus)

                pred1 = mod1.transform(corpus)
                pred2 = mod2.transform(corpus)

                if ' ]' in pattern:
                    voc = mod2.vocabulary_
                    spaces = 0
                    for k in voc:
                        self.assertIsInstance(k, tuple)
                        for i in k:
                            if ' ' in i:
                                spaces += 1
                    self.assertGreater(spaces, 1)
                assert_almost_equal(pred1.todense(), pred2.todense())

    @unittest.skipIf(TARGET_OPSET < 10, reason="not available")
    def test_model_tfidf_vectorizer_issue(self):
        register()
        corpus = numpy.array([
            'the-first document.',
            'this-is the-third-one.',
            'this-the first-document?',
        ]).reshape((3, 1))
        vect = TraceableTfidfVectorizer(
            ngram_range=(1, 2),
            token_pattern=r"\b[a-z ]+\b")
        vect.fit(corpus.ravel())
        model_onnx = to_onnx(vect, 'TfidfVectorizer',
                             initial_types=[('input', StringTensorType([1]))],
                             target_opset=TARGET_OPSET)
        dump_data_and_model(
            corpus, vect, model_onnx,
            basename="SklearnTfidfVectorizerIssue-OneOff-SklCol")


if __name__ == "__main__":
    unittest.main()
