/*
 * Decompiled with CFR 0.152.
 */
package org.tugraz.sysds.runtime.controlprogram.paramserv;

import java.io.IOException;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.util.LongAccumulator;
import org.tugraz.sysds.api.DMLScript;
import org.tugraz.sysds.runtime.DMLRuntimeException;
import org.tugraz.sysds.runtime.controlprogram.paramserv.ParamServer;
import org.tugraz.sysds.runtime.controlprogram.paramserv.rpc.PSRpcCall;
import org.tugraz.sysds.runtime.controlprogram.paramserv.rpc.PSRpcResponse;
import org.tugraz.sysds.runtime.controlprogram.parfor.stat.Timing;
import org.tugraz.sysds.runtime.instructions.cp.ListObject;

public class SparkPSProxy
extends ParamServer {
    private final TransportClient _client;
    private final long _rpcTimeout;
    private final LongAccumulator _aRPC;

    public SparkPSProxy(TransportClient client, long rpcTimeout, LongAccumulator aRPC) {
        this._client = client;
        this._rpcTimeout = rpcTimeout;
        this._aRPC = aRPC;
    }

    private void accRpcRequestTime(Timing tRpc) {
        if (DMLScript.STATISTICS) {
            this._aRPC.add((long)tRpc.stop());
        }
    }

    @Override
    public void push(int workerID, ListObject value) {
        PSRpcResponse response;
        Timing tRpc = DMLScript.STATISTICS ? new Timing(true) : null;
        try {
            response = new PSRpcResponse(this._client.sendRpcSync(new PSRpcCall(1, workerID, value).serialize(), this._rpcTimeout));
        }
        catch (IOException e) {
            throw new DMLRuntimeException(String.format("SparkPSProxy: spark worker_%d failed to push gradients.", workerID), e);
        }
        this.accRpcRequestTime(tRpc);
        if (!response.isSuccessful()) {
            throw new DMLRuntimeException(String.format("SparkPSProxy: spark worker_%d failed to push gradients. \n%s", workerID, response.getErrorMessage()));
        }
    }

    @Override
    public ListObject pull(int workerID) {
        PSRpcResponse response;
        Timing tRpc = DMLScript.STATISTICS ? new Timing(true) : null;
        try {
            response = new PSRpcResponse(this._client.sendRpcSync(new PSRpcCall(2, workerID, null).serialize(), this._rpcTimeout));
        }
        catch (IOException e) {
            throw new DMLRuntimeException(String.format("SparkPSProxy: spark worker_%d failed to pull models.", workerID), e);
        }
        this.accRpcRequestTime(tRpc);
        if (!response.isSuccessful()) {
            throw new DMLRuntimeException(String.format("SparkPSProxy: spark worker_%d failed to pull models. \n%s", workerID, response.getErrorMessage()));
        }
        return response.getResultModel();
    }
}

