/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.ml.common.datastream;

import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
import org.apache.flink.annotation.Internal;
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.tuple.Tuple4;
import org.apache.flink.api.java.typeutils.TupleTypeInfo;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
import org.apache.flink.streaming.api.operators.BoundedOneInput;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.util.Collector;

@Internal
class AllReduceImpl {
    @VisibleForTesting
    static final int CHUNK_SIZE = 4096;

    AllReduceImpl() {
    }

    static DataStream<double[]> allReduceSum(DataStream<double[]> input) {
        SingleOutputStreamOperator allReduceSend = input.flatMap((FlatMapFunction)new AllReduceSend()).setParallelism(input.getParallelism()).name("all-reduce-send");
        SingleOutputStreamOperator allReduceSum = allReduceSend.partitionCustom((Partitioner & Serializable)(chunkId, numPartitions) -> chunkId % numPartitions, (KeySelector & Serializable)x -> (Integer)x.f0).transform("all-reduce-sum", (TypeInformation)new TupleTypeInfo(new TypeInformation[]{BasicTypeInfo.INT_TYPE_INFO, BasicTypeInfo.INT_TYPE_INFO, BasicTypeInfo.INT_TYPE_INFO, PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO}), (OneInputStreamOperator)new AllReduceSum()).setParallelism(input.getParallelism()).name("all-reduce-sum");
        return allReduceSum.partitionCustom((Partitioner & Serializable)(taskIdx, numPartitions) -> taskIdx % numPartitions, (KeySelector & Serializable)x -> (Integer)x.f0).transform("all-reduce-recv", (TypeInformation)PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO, (OneInputStreamOperator)new AllReduceRecv()).setParallelism(input.getParallelism()).name("all-reduce-recv");
    }

    private static int getNumChunks(int len) {
        int div = len / 4096;
        int mod = len % 4096;
        return mod == 0 ? div : div + 1;
    }

    private static int getLengthOfChunk(int chunkId, int len) {
        if (chunkId == AllReduceImpl.getNumChunks(len) - 1) {
            int mod = len % 4096;
            return mod == 0 ? 4096 : mod;
        }
        return 4096;
    }

    private static int getStartChunkId(int taskId, int numTasks, int len) {
        int numChunks = AllReduceImpl.getNumChunks(len);
        int div = numChunks / numTasks;
        int mod = numChunks % numTasks;
        if (taskId >= mod) {
            return div * taskId + mod;
        }
        return div * taskId + taskId;
    }

    private static int getNumChunksByTaskId(int taskId, int parallelism, int len) {
        int numChunks = AllReduceImpl.getNumChunks(len);
        int div = numChunks / parallelism;
        int mod = numChunks % parallelism;
        if (taskId >= mod) {
            return div;
        }
        return div + 1;
    }

    private static class AllReduceRecv
    extends AbstractStreamOperator<double[]>
    implements OneInputStreamOperator<Tuple4<Integer, Integer, Integer, double[]>, double[]>,
    BoundedOneInput {
        double[] resultArray;

        private AllReduceRecv() {
        }

        public void endInput() {
            if (null != this.resultArray) {
                this.output.collect((Object)new StreamRecord((Object)this.resultArray));
            }
        }

        public void processElement(StreamRecord<Tuple4<Integer, Integer, Integer, double[]>> streamRecord) {
            Tuple4 ele = (Tuple4)streamRecord.getValue();
            int chunkId = (Integer)ele.f1;
            int originalArrayLength = (Integer)ele.f2;
            double[] aggregatedArrayChunk = (double[])ele.f3;
            if (null == this.resultArray) {
                this.resultArray = new double[originalArrayLength];
            }
            System.arraycopy(aggregatedArrayChunk, 0, this.resultArray, chunkId * 4096, AllReduceImpl.getLengthOfChunk(chunkId, this.resultArray.length));
        }
    }

    private static class AllReduceSum
    extends AbstractStreamOperator<Tuple4<Integer, Integer, Integer, double[]>>
    implements OneInputStreamOperator<Tuple3<Integer, Integer, double[]>, Tuple4<Integer, Integer, Integer, double[]>>,
    BoundedOneInput {
        private Map<Integer, Tuple2<Integer, double[]>> aggregatedArrayChunkByChunkId = new HashMap<Integer, Tuple2<Integer, double[]>>();

        private AllReduceSum() {
        }

        public void endInput() {
            int numTasks = this.getRuntimeContext().getNumberOfParallelSubtasks();
            for (Map.Entry<Integer, Tuple2<Integer, double[]>> entry : this.aggregatedArrayChunkByChunkId.entrySet()) {
                for (int taskId = 0; taskId < numTasks; ++taskId) {
                    int chunkId = entry.getKey();
                    int originalArrayLength = (Integer)entry.getValue().f0;
                    double[] aggregatedArrayChunk = (double[])entry.getValue().f1;
                    this.output.collect((Object)new StreamRecord((Object)Tuple4.of((Object)taskId, (Object)chunkId, (Object)originalArrayLength, (Object)aggregatedArrayChunk)));
                }
            }
        }

        public void processElement(StreamRecord<Tuple3<Integer, Integer, double[]>> streamRecord) {
            Tuple3 record = (Tuple3)streamRecord.getValue();
            int chunkId = (Integer)record.f0;
            int originalArrayLength = (Integer)record.f1;
            double[] arrayChunk = (double[])record.f2;
            if (this.aggregatedArrayChunkByChunkId.containsKey(chunkId)) {
                if ((Integer)this.aggregatedArrayChunkByChunkId.get((Object)Integer.valueOf((int)chunkId)).f0 != originalArrayLength) {
                    throw new RuntimeException("The input double array must have same length.");
                }
                double[] curAggregatedArrayChunk = (double[])this.aggregatedArrayChunkByChunkId.get((Object)Integer.valueOf((int)chunkId)).f1;
                for (int i = 0; i < curAggregatedArrayChunk.length; ++i) {
                    int n = i;
                    curAggregatedArrayChunk[n] = curAggregatedArrayChunk[n] + arrayChunk[i];
                }
            } else {
                this.aggregatedArrayChunkByChunkId.put(chunkId, (Tuple2<Integer, double[]>)Tuple2.of((Object)originalArrayLength, (Object)arrayChunk));
            }
        }
    }

    private static class AllReduceSend
    extends RichFlatMapFunction<double[], Tuple3<Integer, Integer, double[]>> {
        private boolean hasReceivedOneRecord = false;
        private double[] transferBuffer = new double[4096];

        private AllReduceSend() {
        }

        public void flatMap(double[] inputArray, Collector<Tuple3<Integer, Integer, double[]>> out) {
            if (this.hasReceivedOneRecord) {
                throw new RuntimeException("The input cannot contain more than one double array.");
            }
            this.hasReceivedOneRecord = true;
            int numTasks = this.getRuntimeContext().getNumberOfParallelSubtasks();
            for (int taskId = 0; taskId < numTasks; ++taskId) {
                int startChunkId = AllReduceImpl.getStartChunkId(taskId, numTasks, inputArray.length);
                int numChunksToHandle = AllReduceImpl.getNumChunksByTaskId(taskId, numTasks, inputArray.length);
                for (int chunkId = startChunkId; chunkId < numChunksToHandle + startChunkId; ++chunkId) {
                    System.arraycopy(inputArray, chunkId * 4096, this.transferBuffer, 0, AllReduceImpl.getLengthOfChunk(chunkId, inputArray.length));
                    out.collect((Object)Tuple3.of((Object)chunkId, (Object)inputArray.length, (Object)this.transferBuffer));
                }
            }
        }
    }
}

