/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.iteration.operator.coordinator;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executor;
import java.util.function.Supplier;
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.iteration.IterationID;
import org.apache.flink.iteration.operator.coordinator.SharedProgressAlignerListener;
import org.apache.flink.iteration.operator.event.CoordinatorCheckpointEvent;
import org.apache.flink.iteration.operator.event.GloballyAlignedEvent;
import org.apache.flink.iteration.operator.event.SubtaskAlignedEvent;
import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.runtime.jobgraph.OperatorInstanceID;
import org.apache.flink.runtime.operators.coordination.OperatorCoordinator;
import org.apache.flink.util.ExceptionUtils;
import org.apache.flink.util.Preconditions;
import org.apache.flink.util.function.ThrowingRunnable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SharedProgressAligner {
    private static final Logger LOG = LoggerFactory.getLogger(SharedProgressAligner.class);
    public static ConcurrentHashMap<IterationID, SharedProgressAligner> instances = new ConcurrentHashMap();
    private final IterationID iterationId;
    private final int totalHeadParallelism;
    private final OperatorCoordinator.Context context;
    private final Executor executor;
    private final Map<Integer, EpochStatus> statusByEpoch;
    private boolean globallyTerminating;
    private final Map<OperatorID, SharedProgressAlignerListener> listeners;
    private final Map<Long, CheckpointStatus> checkpointStatuses;

    public static SharedProgressAligner getOrCreate(IterationID iterationId, int totalHeadParallelism, OperatorCoordinator.Context context, Supplier<Executor> executorFactory) {
        return instances.computeIfAbsent(iterationId, ignored -> new SharedProgressAligner(iterationId, totalHeadParallelism, context, (Executor)executorFactory.get()));
    }

    @VisibleForTesting
    static ConcurrentHashMap<IterationID, SharedProgressAligner> getInstances() {
        return instances;
    }

    private SharedProgressAligner(IterationID iterationId, int totalHeadParallelism, OperatorCoordinator.Context context, Executor executor) {
        this.iterationId = Objects.requireNonNull(iterationId);
        this.totalHeadParallelism = totalHeadParallelism;
        this.context = Objects.requireNonNull(context);
        this.executor = Objects.requireNonNull(executor);
        this.statusByEpoch = new HashMap<Integer, EpochStatus>();
        this.listeners = new HashMap<OperatorID, SharedProgressAlignerListener>();
        this.checkpointStatuses = new HashMap<Long, CheckpointStatus>();
    }

    public void registerAlignedListener(OperatorID operatorID, SharedProgressAlignerListener alignedConsumer) {
        this.runInEventLoop((ThrowingRunnable<Throwable>)((ThrowingRunnable)() -> this.listeners.put(operatorID, alignedConsumer)), "Register listeners %s", operatorID.toHexString());
    }

    public void unregisterListener(OperatorID operatorID) {
        this.runInEventLoop((ThrowingRunnable<Throwable>)((ThrowingRunnable)() -> {
            this.listeners.remove(operatorID);
            if (this.listeners.isEmpty()) {
                instances.remove((Object)this.iterationId);
            }
        }), "Unregister listeners %s", operatorID.toHexString());
    }

    public void reportSubtaskProgress(OperatorID operatorId, int subtaskIndex, SubtaskAlignedEvent subtaskAlignedEvent) {
        this.runInEventLoop((ThrowingRunnable<Throwable>)((ThrowingRunnable)() -> {
            LOG.debug("Processing {} from {}-{}", new Object[]{subtaskAlignedEvent, operatorId, subtaskIndex});
            EpochStatus roundStatus = this.statusByEpoch.computeIfAbsent(subtaskAlignedEvent.getEpoch(), round -> new EpochStatus((int)round, this.totalHeadParallelism));
            boolean globallyAligned = roundStatus.report(operatorId, subtaskIndex, subtaskAlignedEvent);
            if (globallyAligned) {
                GloballyAlignedEvent globallyAlignedEvent = new GloballyAlignedEvent(subtaskAlignedEvent.getEpoch(), roundStatus.isTerminated());
                for (SharedProgressAlignerListener listeners : this.listeners.values()) {
                    listeners.onAligned(globallyAlignedEvent);
                }
                if (roundStatus.isTerminated()) {
                    this.globallyTerminating = true;
                }
            }
        }), "Report subtask %s-%d", operatorId.toHexString(), subtaskIndex);
    }

    public void requestCheckpoint(long checkpointId, int operatorParallelism, CompletableFuture<byte[]> snapshotStateFuture) {
        this.runInEventLoop((ThrowingRunnable<Throwable>)((ThrowingRunnable)() -> {
            CheckpointStatus checkpointStatus = this.checkpointStatuses.computeIfAbsent(checkpointId, ignored -> new CheckpointStatus(this.totalHeadParallelism));
            boolean aligned = checkpointStatus.notify(operatorParallelism, snapshotStateFuture);
            if (aligned) {
                if (!this.globallyTerminating) {
                    CoordinatorCheckpointEvent checkpointEvent = new CoordinatorCheckpointEvent(checkpointId);
                    for (SharedProgressAlignerListener listener : this.listeners.values()) {
                        listener.onCheckpointAligned(checkpointEvent);
                    }
                }
                for (CompletableFuture<byte[]> stateFuture : checkpointStatus.getStateFutures()) {
                    stateFuture.complete(new byte[0]);
                }
                this.checkpointStatuses.remove(checkpointId);
            }
        }), "Coordinator report checkpoint %d", checkpointId);
    }

    public void notifyGloballyTerminating() {
        this.runInEventLoop((ThrowingRunnable<Throwable>)((ThrowingRunnable)() -> {
            this.globallyTerminating = true;
        }), "Report globally terminating", new Object[0]);
    }

    public void removeProgressInfo(OperatorID operatorId) {
        this.runInEventLoop((ThrowingRunnable<Throwable>)((ThrowingRunnable)() -> this.statusByEpoch.values().forEach(status -> status.remove(operatorId))), "remove the progress information for {}", operatorId);
    }

    public void removeProgressInfo(OperatorID operatorId, int subtaskIndex) {
        this.runInEventLoop((ThrowingRunnable<Throwable>)((ThrowingRunnable)() -> this.statusByEpoch.values().forEach(status -> status.remove(operatorId, subtaskIndex))), "remove the progress information for {}-{}", operatorId, subtaskIndex);
    }

    private void runInEventLoop(ThrowingRunnable<Throwable> action, String actionName, Object ... actionNameFormatParameters) {
        this.executor.execute(() -> {
            try {
                action.run();
            }
            catch (Throwable t) {
                ExceptionUtils.rethrowIfFatalErrorOrOOM((Throwable)t);
                String actionString = String.format(actionName, actionNameFormatParameters);
                LOG.error("Uncaught exception in the SharedProgressAligner for iteration {} while {}. Triggering job failover.", new Object[]{this.iterationId, actionString, t});
                this.context.failJob(t);
            }
        });
    }

    @VisibleForTesting
    int getNumberListeners() {
        return this.listeners.size();
    }

    private static class CheckpointStatus {
        private final long totalHeadParallelism;
        private final List<CompletableFuture<byte[]>> stateFutures = new ArrayList<CompletableFuture<byte[]>>();
        private int notifiedCoordinatorParallelism;

        private CheckpointStatus(long totalHeadParallelism) {
            this.totalHeadParallelism = totalHeadParallelism;
        }

        public boolean notify(int parallelism, CompletableFuture<byte[]> stateFuture) {
            this.stateFutures.add(stateFuture);
            this.notifiedCoordinatorParallelism += parallelism;
            return (long)this.notifiedCoordinatorParallelism == this.totalHeadParallelism;
        }

        public List<CompletableFuture<byte[]>> getStateFutures() {
            return this.stateFutures;
        }
    }

    private static class EpochStatus {
        private final int epoch;
        private final long totalHeadParallelism;
        private final Map<OperatorInstanceID, SubtaskAlignedEvent> reportedSubtasks;

        public EpochStatus(int epoch, long totalHeadParallelism) {
            this.epoch = epoch;
            this.totalHeadParallelism = totalHeadParallelism;
            this.reportedSubtasks = new HashMap<OperatorInstanceID, SubtaskAlignedEvent>();
        }

        public boolean report(OperatorID operatorID, int subtaskIndex, SubtaskAlignedEvent event) {
            this.reportedSubtasks.put(new OperatorInstanceID(subtaskIndex, operatorID), event);
            Preconditions.checkState(((long)this.reportedSubtasks.size() <= this.totalHeadParallelism ? 1 : 0) != 0, (Object)("Received more subtasks" + this.reportedSubtasks + "than the expected total parallelism " + this.totalHeadParallelism));
            return (long)this.reportedSubtasks.size() == this.totalHeadParallelism;
        }

        public void remove(OperatorID operatorID) {
            this.reportedSubtasks.entrySet().removeIf(entry -> ((OperatorInstanceID)entry.getKey()).getOperatorId().equals((Object)operatorID));
        }

        public void remove(OperatorID operatorID, int subtaskIndex) {
            this.reportedSubtasks.remove(new OperatorInstanceID(subtaskIndex, operatorID));
        }

        public boolean isTerminated() {
            Preconditions.checkState(((long)this.reportedSubtasks.size() == this.totalHeadParallelism ? 1 : 0) != 0, (Object)"The round is not globally aligned yet");
            if (this.epoch == 0) {
                return false;
            }
            long totalRecord = 0L;
            boolean hasCriteriaStream = false;
            long totalCriteriaRecord = 0L;
            for (SubtaskAlignedEvent event : this.reportedSubtasks.values()) {
                totalRecord += event.getNumRecordsThisRound();
                if (!event.isCriteriaStream()) continue;
                hasCriteriaStream = true;
                totalCriteriaRecord += event.getNumRecordsThisRound();
            }
            return totalRecord == 0L || hasCriteriaStream && totalCriteriaRecord == 0L;
        }
    }
}

