/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.instructions.fed;

import java.util.ArrayList;
import java.util.concurrent.Future;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.codegen.CodegenUtils;
import org.apache.sysds.runtime.codegen.SpoofCellwise;
import org.apache.sysds.runtime.codegen.SpoofMultiAggregate;
import org.apache.sysds.runtime.codegen.SpoofOperator;
import org.apache.sysds.runtime.codegen.SpoofOuterProduct;
import org.apache.sysds.runtime.codegen.SpoofRowwise;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;

public class SpoofFEDInstruction
extends FEDInstruction {
    private final SpoofOperator _op;
    private final CPOperand[] _inputs;
    private final CPOperand _output;

    private SpoofFEDInstruction(SpoofOperator op, CPOperand[] in, CPOperand out, String opcode, String instStr) {
        super(FEDInstruction.FEDType.SpoofFused, opcode, instStr);
        this._op = op;
        this._inputs = in;
        this._output = out;
    }

    public static SpoofFEDInstruction parseInstruction(String str) {
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
        CPOperand[] inputCpo = new CPOperand[parts.length - 3 - 2];
        Class<?> cla = CodegenUtils.getClass(parts[2]);
        SpoofOperator op = CodegenUtils.createInstance(cla);
        String opcode = parts[0] + op.getSpoofType();
        for (int counter = 3; counter < parts.length - 2; ++counter) {
            inputCpo[counter - 3] = new CPOperand(parts[counter]);
        }
        CPOperand out = new CPOperand(parts[parts.length - 2]);
        return new SpoofFEDInstruction(op, inputCpo, out, opcode, str);
    }

    @Override
    public void processInstruction(ExecutionContext ec) {
        Class<?> scla = this._op.getClass().getSuperclass();
        SpoofFEDType spoofType = null;
        if (scla == SpoofCellwise.class) {
            spoofType = new SpoofFEDCellwise(this._op, this._output);
        } else if (scla == SpoofRowwise.class) {
            spoofType = new SpoofFEDRowwise(this._op, this._output);
        } else if (scla == SpoofMultiAggregate.class) {
            spoofType = new SpoofFEDMultiAgg(this._op, this._output);
        } else if (scla == SpoofOuterProduct.class) {
            spoofType = new SpoofFEDOuterProduct(this._op, this._output);
        } else {
            throw new DMLRuntimeException("Federated code generation only supported for cellwise, rowwise, multiaggregate, and outerproduct templates.");
        }
        FederationMap fedMap = null;
        for (CPOperand cpo : this._inputs) {
            CPOperand[] tmpData = ec.getVariable(cpo);
            if (!(tmpData instanceof MatrixObject) || !((MatrixObject)tmpData).isFederated()) continue;
            fedMap = ((MatrixObject)tmpData).getFedMapping();
            break;
        }
        ArrayList<FederatedRequest[]> frBroadcast = new ArrayList<FederatedRequest[]>();
        ArrayList<FederatedRequest[]> frBroadcastSliced = new ArrayList<FederatedRequest[]>();
        long[] frIds = new long[this._inputs.length];
        int index = 0;
        for (CPOperand cpo : this._inputs) {
            FederatedRequest[] tmpFr;
            Data tmpData = ec.getVariable(cpo);
            if (tmpData instanceof MatrixObject) {
                MatrixObject mo = (MatrixObject)tmpData;
                if (mo.isFederated()) {
                    frIds[index++] = mo.getFedMapping().getID();
                    continue;
                }
                if (spoofType.needsBroadcastSliced(fedMap, mo.getNumRows(), mo.getNumColumns(), index)) {
                    tmpFr = spoofType.broadcastSliced(mo, fedMap);
                    frIds[index++] = tmpFr[0].getID();
                    frBroadcastSliced.add(tmpFr);
                    continue;
                }
                tmpFr = fedMap.broadcast(mo);
                frIds[index++] = tmpFr.getID();
                frBroadcast.add(tmpFr);
                continue;
            }
            if (!(tmpData instanceof ScalarObject)) continue;
            ScalarObject so = (ScalarObject)tmpData;
            tmpFr = fedMap.broadcast(so);
            frIds[index++] = tmpFr.getID();
            frBroadcast.add(tmpFr);
        }
        this.instString = this.instString.replace("true", "false");
        FederatedRequest frCompute = FederationUtils.callInstruction(this.instString, this._output, this._inputs, frIds);
        FederatedRequest frGet = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, frCompute.getID());
        ArrayList<FederatedRequest> frCleanup = new ArrayList<FederatedRequest>();
        frCleanup.add(fedMap.cleanup(this.getTID(), frCompute.getID()));
        for (FederatedRequest[] fr : frBroadcast) {
            frCleanup.add(fedMap.cleanup(this.getTID(), fr.getID()));
        }
        for (FederatedRequest[] fr : frBroadcastSliced) {
            frCleanup.add(fedMap.cleanup(this.getTID(), fr[0].getID()));
        }
        FederatedRequest[] frAll = (FederatedRequest[])ArrayUtils.addAll((Object[])ArrayUtils.addAll((Object[])frBroadcast.toArray(new FederatedRequest[0]), (Object[])new FederatedRequest[]{frCompute, frGet}), (Object[])frCleanup.toArray(new FederatedRequest[0]));
        Future<FederatedResponse>[] response = fedMap.executeMultipleSlices(this.getTID(), true, (FederatedRequest[][])frBroadcastSliced.toArray((T[])new FederatedRequest[0][]), frAll);
        spoofType.setOutput(ec, response, fedMap);
    }

    private static class SpoofFEDOuterProduct
    extends SpoofFEDType {
        private final SpoofOuterProduct _op;

        SpoofFEDOuterProduct(SpoofOperator op, CPOperand out) {
            super(out);
            this._op = (SpoofOuterProduct)op;
        }

        @Override
        protected FederatedRequest[] broadcastSliced(MatrixObject mo, FederationMap fedMap) {
            return fedMap.broadcastSliced(mo, fedMap.getType() == FederationMap.FType.COL);
        }

        @Override
        protected boolean needsBroadcastSliced(FederationMap fedMap, long rowNum, long colNum, int inputIndex) {
            boolean retVal = false;
            FederationMap.FType fedType = fedMap.getType();
            retVal |= rowNum == fedMap.getMaxIndexInRange(0) && colNum == fedMap.getMaxIndexInRange(1);
            if (fedType == FederationMap.FType.ROW) {
                retVal |= rowNum == fedMap.getMaxIndexInRange(0) && inputIndex != 2;
            } else if (fedType == FederationMap.FType.COL) {
                retVal |= rowNum == fedMap.getMaxIndexInRange(1) && inputIndex != 1;
            } else {
                throw new DMLRuntimeException("Only row partitioned or column partitioned federated input supported yet.");
            }
            return retVal;
        }

        /*
         * Enabled force condition propagation
         * Lifted jumps to return sites
         */
        @Override
        protected void setOutput(ExecutionContext ec, Future<FederatedResponse>[] response, FederationMap fedMap) {
            FederationMap.FType fedType = fedMap.getType();
            SpoofOuterProduct.OutProdType outProdType = this._op.getOuterProdType();
            if (outProdType == SpoofOuterProduct.OutProdType.LEFT_OUTER_PRODUCT) {
                if (fedType == FederationMap.FType.ROW) {
                    AggregateUnaryOperator aop = InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
                    ec.setMatrixOutput(this._output.getName(), FederationUtils.aggMatrix(aop, response, fedMap));
                    return;
                } else {
                    if (fedType != FederationMap.FType.COL) throw new DMLRuntimeException("Only row partitioned or column partitioned federated matrices supported yet.");
                    ec.setMatrixOutput(this._output.getName(), FederationUtils.bind(response, false));
                }
                return;
            } else if (outProdType == SpoofOuterProduct.OutProdType.RIGHT_OUTER_PRODUCT) {
                if (fedType == FederationMap.FType.ROW) {
                    ec.setMatrixOutput(this._output.getName(), FederationUtils.bind(response, false));
                    return;
                } else {
                    if (fedType != FederationMap.FType.COL) throw new DMLRuntimeException("Only row partitioned or column partitioned federated matrices supported yet.");
                    AggregateUnaryOperator aop = InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
                    ec.setMatrixOutput(this._output.getName(), FederationUtils.aggMatrix(aop, response, fedMap));
                }
                return;
            } else if (outProdType == SpoofOuterProduct.OutProdType.CELLWISE_OUTER_PRODUCT) {
                if (fedType == FederationMap.FType.ROW) {
                    ec.setMatrixOutput(this._output.getName(), FederationUtils.bind(response, false));
                    return;
                } else {
                    if (fedType != FederationMap.FType.COL) throw new DMLRuntimeException("Only row partitioned or column partitioned federated matrices supported yet.");
                    ec.setMatrixOutput(this._output.getName(), FederationUtils.bind(response, true));
                }
                return;
            } else {
                if (outProdType != SpoofOuterProduct.OutProdType.AGG_OUTER_PRODUCT) throw new DMLRuntimeException("Outer Product Type " + (Object)((Object)outProdType) + " not supported yet.");
                AggregateUnaryOperator aop = InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
                ec.setVariable(this._output.getName(), FederationUtils.aggScalar(aop, response));
            }
        }
    }

    private static class SpoofFEDMultiAgg
    extends SpoofFEDType {
        private final SpoofMultiAggregate _op;

        SpoofFEDMultiAgg(SpoofOperator op, CPOperand out) {
            super(out);
            this._op = (SpoofMultiAggregate)op;
        }

        @Override
        protected void setOutput(ExecutionContext ec, Future<FederatedResponse>[] response, FederationMap fedMap) {
            MatrixBlock[] partRes = FederationUtils.getResults(response);
            SpoofCellwise.AggOp[] aggOps = this._op.getAggOps();
            for (int counter = 1; counter < partRes.length; ++counter) {
                SpoofMultiAggregate.aggregatePartialResults(aggOps, partRes[0], partRes[counter]);
            }
            ec.setMatrixOutput(this._output.getName(), partRes[0]);
        }
    }

    private static class SpoofFEDRowwise
    extends SpoofFEDType {
        private final SpoofRowwise _op;

        SpoofFEDRowwise(SpoofOperator op, CPOperand out) {
            super(out);
            this._op = (SpoofRowwise)op;
        }

        /*
         * Enabled force condition propagation
         * Lifted jumps to return sites
         */
        @Override
        protected void setOutput(ExecutionContext ec, Future<FederatedResponse>[] response, FederationMap fedMap) {
            SpoofRowwise.RowType rowType = this._op.getRowType();
            if (rowType == SpoofRowwise.RowType.FULL_AGG) {
                AggregateUnaryOperator aop = InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
                ec.setVariable(this._output.getName(), FederationUtils.aggScalar(aop, response));
                return;
            } else if (rowType == SpoofRowwise.RowType.ROW_AGG) {
                AggregateUnaryOperator aop = InstructionUtils.parseBasicAggregateUnaryOperator("uark+");
                ec.setMatrixOutput(this._output.getName(), FederationUtils.aggMatrix(aop, response, fedMap));
                return;
            } else if (rowType == SpoofRowwise.RowType.COL_AGG || rowType == SpoofRowwise.RowType.COL_AGG_T || rowType == SpoofRowwise.RowType.COL_AGG_B1 || rowType == SpoofRowwise.RowType.COL_AGG_B1_T || rowType == SpoofRowwise.RowType.COL_AGG_B1R || rowType == SpoofRowwise.RowType.COL_AGG_CONST) {
                AggregateUnaryOperator aop = InstructionUtils.parseBasicAggregateUnaryOperator("uack+");
                ec.setMatrixOutput(this._output.getName(), FederationUtils.aggMatrix(aop, response, fedMap));
                return;
            } else {
                if (rowType != SpoofRowwise.RowType.NO_AGG && rowType != SpoofRowwise.RowType.NO_AGG_B1 && rowType != SpoofRowwise.RowType.NO_AGG_CONST) throw new DMLRuntimeException("AggregationType not supported yet.");
                if (fedMap.getType() != FederationMap.FType.ROW) throw new DMLRuntimeException("Only row partitioned federated matrices supported yet.");
                ec.setMatrixOutput(this._output.getName(), FederationUtils.bind(response, false));
            }
        }
    }

    private static class SpoofFEDCellwise
    extends SpoofFEDType {
        private final SpoofCellwise _op;

        SpoofFEDCellwise(SpoofOperator op, CPOperand out) {
            super(out);
            this._op = (SpoofCellwise)op;
        }

        /*
         * Enabled force condition propagation
         * Lifted jumps to return sites
         */
        @Override
        protected void setOutput(ExecutionContext ec, Future<FederatedResponse>[] response, FederationMap fedMap) {
            FederationMap.FType fedType = fedMap.getType();
            SpoofCellwise.AggOp aggOp = this._op.getAggOp();
            SpoofCellwise.CellType cellType = this._op.getCellType();
            if (cellType == SpoofCellwise.CellType.FULL_AGG) {
                AggregateUnaryOperator aop = null;
                if (aggOp == SpoofCellwise.AggOp.SUM || aggOp == SpoofCellwise.AggOp.SUM_SQ) {
                    aop = InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
                } else if (aggOp == SpoofCellwise.AggOp.MIN) {
                    aop = InstructionUtils.parseBasicAggregateUnaryOperator("uamin");
                } else {
                    if (aggOp != SpoofCellwise.AggOp.MAX) throw new DMLRuntimeException("Aggregation operation not supported yet.");
                    aop = InstructionUtils.parseBasicAggregateUnaryOperator("uamax");
                }
                ec.setVariable(this._output.getName(), FederationUtils.aggScalar(aop, response));
                return;
            } else if (cellType == SpoofCellwise.CellType.ROW_AGG) {
                if (fedType == FederationMap.FType.ROW) {
                    ec.setMatrixOutput(this._output.getName(), FederationUtils.bind(response, false));
                    return;
                } else {
                    if (fedType != FederationMap.FType.COL) throw new DMLRuntimeException("Aggregation type for federated spoof instructions not supported yet.");
                    AggregateUnaryOperator aop = null;
                    if (aggOp == SpoofCellwise.AggOp.SUM || aggOp == SpoofCellwise.AggOp.SUM_SQ) {
                        aop = InstructionUtils.parseBasicAggregateUnaryOperator("uark+");
                    } else if (aggOp == SpoofCellwise.AggOp.MIN) {
                        aop = InstructionUtils.parseBasicAggregateUnaryOperator("uarmin");
                    } else {
                        if (aggOp != SpoofCellwise.AggOp.MAX) throw new DMLRuntimeException("Aggregation operation not supported yet.");
                        aop = InstructionUtils.parseBasicAggregateUnaryOperator("uarmax");
                    }
                    ec.setMatrixOutput(this._output.getName(), FederationUtils.aggMatrix(aop, response, fedMap));
                }
                return;
            } else if (cellType == SpoofCellwise.CellType.COL_AGG) {
                if (fedType == FederationMap.FType.ROW) {
                    AggregateUnaryOperator aop = null;
                    if (aggOp == SpoofCellwise.AggOp.SUM || aggOp == SpoofCellwise.AggOp.SUM_SQ) {
                        aop = InstructionUtils.parseBasicAggregateUnaryOperator("uack+");
                    } else if (aggOp == SpoofCellwise.AggOp.MIN) {
                        aop = InstructionUtils.parseBasicAggregateUnaryOperator("uacmin");
                    } else {
                        if (aggOp != SpoofCellwise.AggOp.MAX) throw new DMLRuntimeException("Aggregation operation not supported yet.");
                        aop = InstructionUtils.parseBasicAggregateUnaryOperator("uacmax");
                    }
                    ec.setMatrixOutput(this._output.getName(), FederationUtils.aggMatrix(aop, response, fedMap));
                    return;
                } else {
                    if (fedType != FederationMap.FType.COL) throw new DMLRuntimeException("Aggregation type for federated spoof instructions not supported yet.");
                    ec.setMatrixOutput(this._output.getName(), FederationUtils.bind(response, true));
                }
                return;
            } else {
                if (cellType != SpoofCellwise.CellType.NO_AGG) throw new DMLRuntimeException("Aggregation type not supported yet.");
                if (fedType == FederationMap.FType.ROW) {
                    ec.setMatrixOutput(this._output.getName(), FederationUtils.bind(response, false));
                    return;
                } else {
                    if (fedType != FederationMap.FType.COL) throw new DMLRuntimeException("Only row partitioned or column partitioned federated matrices supported yet.");
                    ec.setMatrixOutput(this._output.getName(), FederationUtils.bind(response, true));
                }
            }
        }
    }

    private static abstract class SpoofFEDType {
        CPOperand _output;

        protected SpoofFEDType(CPOperand out) {
            this._output = out;
        }

        protected FederatedRequest[] broadcastSliced(MatrixObject mo, FederationMap fedMap) {
            return fedMap.broadcastSliced(mo, false);
        }

        protected boolean needsBroadcastSliced(FederationMap fedMap, long rowNum, long colNum, int inputIndex) {
            boolean retVal;
            FederationMap.FType fedType = fedMap.getType();
            boolean bl = retVal = rowNum == fedMap.getMaxIndexInRange(0) && colNum == fedMap.getMaxIndexInRange(1);
            if (fedType == FederationMap.FType.ROW) {
                retVal |= rowNum == fedMap.getMaxIndexInRange(0) && (colNum == 1L || colNum == (long)fedMap.getSize() || fedMap.getMaxIndexInRange(1) == 1L);
            } else if (fedType == FederationMap.FType.COL) {
                retVal |= colNum == fedMap.getMaxIndexInRange(1) && (rowNum == 1L || rowNum == (long)fedMap.getSize() || fedMap.getMaxIndexInRange(0) == 1L);
            } else {
                throw new DMLRuntimeException("Only row partitioned or column partitioned federated input supported yet.");
            }
            return retVal;
        }

        protected abstract void setOutput(ExecutionContext var1, Future<FederatedResponse>[] var2, FederationMap var3);
    }
}

