/*
 * Decompiled with CFR 0.152.
 */
package oracle.pgx.config.mllib;

import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import oracle.pgx.common.util.ErrorMessages;
import oracle.pgx.config.mllib.GraphWiseBaseModelConfig;
import oracle.pgx.config.mllib.GraphWiseConvLayerConfig;
import oracle.pgx.config.mllib.edgecombination.EdgeCombinationMethod;
import oracle.pgx.config.mllib.edgecombination.EdgeCombinationMethods;

public abstract class EdgeWiseModelConfig
extends GraphWiseBaseModelConfig {
    public static final Integer EDGE_EMBEDDING_SIZE = null;
    public static final EdgeWiseConvModelVariant DEFAULT_MODE = null;
    private EdgeWiseConvModelVariant variant = DEFAULT_MODE;
    private Integer edgeEmbeddingDim = EDGE_EMBEDDING_SIZE;
    private EdgeCombinationMethod edgeCombinationMethod = EdgeCombinationMethods.DEFAULT_CONCAT_METHOD;
    private List<Set<String>> targetEdgeLabelSets;

    EdgeWiseModelConfig() {
    }

    public EdgeWiseModelConfig(int batchSize, int numEpochs, double learningRate, double weightDecay, int embeddingDim, Integer seed, GraphWiseConvLayerConfig[] convLayerConfigs, boolean standardize, boolean shuffle, List<String> vertexInputPropertyNames, List<String> edgeInputPropertyNames, boolean fitted, double trainingLoss, int vertexInputFeatureDim, int edgeInputFeatureDim, List<Set<String>> targetEdgeLabelSets, GraphWiseBaseModelConfig.Backend backend, Integer edgeEmbeddingDim, EdgeWiseConvModelVariant variant, EdgeCombinationMethod edgeCombinationMethod) {
        super(batchSize, numEpochs, learningRate, weightDecay, embeddingDim, seed, convLayerConfigs, standardize, shuffle, vertexInputPropertyNames, edgeInputPropertyNames, fitted, trainingLoss, vertexInputFeatureDim, edgeInputFeatureDim, backend);
        this.targetEdgeLabelSets = targetEdgeLabelSets;
        this.edgeEmbeddingDim = edgeEmbeddingDim;
        this.edgeCombinationMethod = edgeCombinationMethod;
        this.variant = variant;
    }

    EdgeWiseModelConfig(EdgeWiseModelConfig source) {
        super(source);
        this.setTargetEdgeLabelSets(source.getTargetEdgeLabelSets());
        this.edgeEmbeddingDim = source.getEdgeEmbeddingDim();
        this.edgeCombinationMethod = source.getEdgeCombinationMethod();
        this.variant = source.getVariant();
    }

    public List<Set<String>> getTargetEdgeLabelSets() {
        return this.targetEdgeLabelSets;
    }

    public final void setTargetEdgeLabelSets(List<Set<String>> targetEdgeLabelSets) {
        this.targetEdgeLabelSets = targetEdgeLabelSets;
    }

    public void setTargetEdgeLabels(List<String> targetEdgeLabels) {
        this.targetEdgeLabelSets = EdgeWiseModelConfig.listOfStringsToListOfSetOfStrings(targetEdgeLabels);
    }

    public Integer getEdgeEmbeddingDim() {
        return this.edgeEmbeddingDim;
    }

    public void setEdgeEmbeddingDim(Integer edgeEmbeddingDim) {
        this.edgeEmbeddingDim = edgeEmbeddingDim;
    }

    public EdgeCombinationMethod getEdgeCombinationMethod() {
        return this.edgeCombinationMethod;
    }

    public void setEdgeCombinationMethod(EdgeCombinationMethod edgeCombinationMethod) {
        this.edgeCombinationMethod = edgeCombinationMethod;
    }

    public final void setVariant(EdgeWiseConvModelVariant variant) {
        if (this.variant != null) {
            throw new IllegalStateException(ErrorMessages.getMessage((String)"IMMUTABLE_EDGEWISE_VARIANT", (Object[])new Object[0]));
        }
        this.variant = variant;
    }

    public EdgeWiseConvModelVariant getVariant() {
        return this.variant;
    }

    private static List<Set<String>> listOfStringsToListOfSetOfStrings(List<String> targetEdgeLabels) {
        return targetEdgeLabels.stream().map(label -> new HashSet<String>(Collections.singletonList(label))).collect(Collectors.toList());
    }

    public static enum EdgeWiseConvModelVariant {
        EDGEWISE,
        INTERTWINED_EDGEWISE;

    }
}

