#
#   Copyright © 2021 Uncharted Software Inc.
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.

from setuptools import setup, find_packages


with open("distil/version.py") as f:
    exec(f.read())

setup(
    name="distil-primitives",
    version=__version__,
    description="Distil primitives as a single library",
    packages=find_packages(),
    keywords=["d3m_primitive"],
    license="Apache-2.0",
    install_requires=[
        "d3m",  # d3m best-practice moving forward is to remove the version (simplifies updates)
        # shared d3m versions - need to be aligned with core package
        "scikit-learn==0.22.2.post1",
        "scipy==1.4.1",
        "numpy==1.18.2",
        "pandas>=1.1.3",
        "torch>=1.4.0",  # validated up to 1.7.0
        "networkx==2.4",
        "pillow==7.1.2",
        # additional dependencies
        "joblib>=0.13.2",
        "torchvision>=0.5.0",  # validated up to 0.8.0
        'pytorch-pretrained-bert==0.4.0',
        'sklearn_pandas==1.8.0',
        "frozendict>=1.2",
        # 'nose>=1.3.7', Needs to be installed in the primitive install section
        'pyprctl>=0.1,<0.2',
        "fastdtw>=0.3.2",
        "resampy>=0.2.1",
        "soundfile>=0.10.2",
        "seeded-graph-matching==1.0.3",
        "basenet",
        "rescal==0.4",
        "scikit-image<=0.17.2",
        "shap>=0.29",
        # Can cause errors with pretrained-bert: https://github.com/NVIDIA/apex/issues/156
        #'apex @ git+https://github.com/NVIDIA/apex.git@47e3367fcd6636db6cd549bbb385a6e06a3861d0',
        "torchvggish>=0.2,<0.3",
        "nonneg_rescal>=0.1,<0.2",
    ],
    extras_require={
        "cpu": ["tensorflow==2.2.0"],
        "gpu": ["tensorflow-gpu==2.2.0"],
    },
    entry_points={
        "d3m.primitives": [
            "community_detection.parser.DistilCommunityDetection = distil.primitives.community_detection:DistilCommunityDetectionPrimitive",
            "link_prediction.link_prediction.DistilLinkPrediction = distil.primitives.link_prediction:DistilLinkPredictionPrimitive",
            "vertex_nomination.seeded_graph_matching.DistilVertexNomination = distil.primitives.vertex_nomination:DistilVertexNominationPrimitive",
            "data_transformation.load_single_graph.DistilSingleGraphLoader = distil.primitives.load_single_graph:DistilSingleGraphLoaderPrimitive",
            "data_transformation.load_edgelist.DistilEdgeListLoader = distil.primitives.load_edgelist:DistilEdgeListLoaderPrimitive",
            "graph_matching.seeded_graph_matching.DistilSeededGraphMatcher = distil.primitives.seeded_graph_matching:DistilSeededGraphMatchingPrimitive",
            "data_transformation.load_graphs.DistilGraphLoader = distil.primitives.load_graphs:DistilGraphLoaderPrimitive",
            "data_transformation.replace_singletons.DistilReplaceSingletons = distil.primitives.replace_singletons:ReplaceSingletonsPrimitive",
            "data_transformation.imputer.DistilCategoricalImputer = distil.primitives.categorical_imputer:CategoricalImputerPrimitive",
            "data_transformation.enrich_dates.DistilEnrichDates = distil.primitives.enrich_dates:EnrichDatesPrimitive",
            "learner.random_forest.DistilEnsembleForest = distil.primitives.ensemble_forest:EnsembleForestPrimitive",
            "classification.bert_classifier.DistilBertPairClassification = distil.primitives.bert_classifier:BertPairClassificationPrimitive",
            "classification.linear_svc.DistilRankedLinearSVC = distil.primitives.ranked_linear_svc:RankedLinearSVCPrimitive",
            "collaborative_filtering.link_prediction.DistilCollaborativeFiltering = distil.primitives.collaborative_filtering_link_prediction:CollaborativeFilteringPrimitive",
            "data_transformation.one_hot_encoder.DistilOneHotEncoder = distil.primitives.one_hot_encoder:OneHotEncoderPrimitive",
            "data_transformation.encoder.DistilBinaryEncoder = distil.primitives.binary_encoder:BinaryEncoderPrimitive",
            "data_transformation.encoder.DistilTextEncoder = distil.primitives.text_encoder:TextEncoderPrimitive",
            "data_transformation.satellite_image_loader.DistilSatelliteImageLoader = distil.primitives.satellite_image_loader:DataFrameSatelliteImageLoaderPrimitive",
            "data_transformation.time_series_formatter.DistilTimeSeriesFormatter = distil.primitives.time_series_formatter:TimeSeriesFormatterPrimitive",
            "classification.text_classifier.DistilTextClassifier = distil.primitives.text_classifier:TextClassifierPrimitive",
            "feature_extraction.image_transfer.DistilImageTransfer = distil.primitives.image_transfer:ImageTransferPrimitive",
            "feature_extraction.audio_transfer.DistilAudioTransfer = distil.primitives.audio_transfer:AudioTransferPrimitive",
            "data_transformation.audio_reader.DistilAudioDatasetLoader = distil.primitives.audio_reader:AudioDatasetLoaderPrimitive",
            "clustering.k_means.DistilKMeans = distil.primitives.k_means:KMeansPrimitive",
            "data_transformation.list_to_dataframe.DistilListEncoder = distil.primitives.list_to_dataframe:ListEncoderPrimitive",
            "data_transformation.column_parser.DistilColumnParser = distil.primitives.column_parser:ColumnParserPrimitive",
        ],
    },
)
