/*
 * Decompiled with CFR 0.152.
 */
package org.apache.uniffle.shuffle.manager;

import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.spark.MapOutputTracker;
import org.apache.spark.MapOutputTrackerMaster;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkEnv;
import org.apache.spark.SparkException;
import org.apache.spark.scheduler.SparkListenerEvent;
import org.apache.spark.shuffle.RssShuffleHandle;
import org.apache.spark.shuffle.RssSparkConfig;
import org.apache.spark.shuffle.RssSparkShuffleUtils;
import org.apache.spark.shuffle.RssStageResubmitManager;
import org.apache.spark.shuffle.ShuffleHandleInfoManager;
import org.apache.spark.shuffle.ShuffleManager;
import org.apache.spark.shuffle.SparkVersionUtils;
import org.apache.spark.shuffle.events.TaskReassignInfoEvent;
import org.apache.spark.shuffle.handle.MutableShuffleHandleInfo;
import org.apache.spark.shuffle.handle.ShuffleHandleInfo;
import org.apache.spark.shuffle.handle.SimpleShuffleHandleInfo;
import org.apache.spark.shuffle.handle.StageAttemptShuffleHandleInfo;
import org.apache.spark.shuffle.writer.AddBlockEvent;
import org.apache.spark.shuffle.writer.DataPusher;
import org.apache.spark.shuffle.writer.OverlappingCompressionDataPusher;
import org.apache.uniffle.client.api.ShuffleManagerClient;
import org.apache.uniffle.client.api.ShuffleWriteClient;
import org.apache.uniffle.client.factory.CoordinatorClientFactory;
import org.apache.uniffle.client.factory.ShuffleManagerClientFactory;
import org.apache.uniffle.client.impl.FailedBlockSendTracker;
import org.apache.uniffle.client.impl.grpc.CoordinatorGrpcRetryableClient;
import org.apache.uniffle.client.request.RssFetchClientConfRequest;
import org.apache.uniffle.client.request.RssPartitionToShuffleServerRequest;
import org.apache.uniffle.client.response.RssFetchClientConfResponse;
import org.apache.uniffle.client.response.RssReassignOnBlockSendFailureResponse;
import org.apache.uniffle.client.response.RssReassignOnStageRetryResponse;
import org.apache.uniffle.client.util.ClientUtils;
import org.apache.uniffle.common.ClientType;
import org.apache.uniffle.common.PartitionRange;
import org.apache.uniffle.common.PartitionSplitMode;
import org.apache.uniffle.common.ReceivingFailureServer;
import org.apache.uniffle.common.RemoteStorageInfo;
import org.apache.uniffle.common.ShuffleAssignmentsInfo;
import org.apache.uniffle.common.ShuffleDataDistributionType;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.config.ConfigOption;
import org.apache.uniffle.common.config.RssBaseConf;
import org.apache.uniffle.common.config.RssClientConf;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.rpc.GrpcServer;
import org.apache.uniffle.common.rpc.StatusCode;
import org.apache.uniffle.common.util.BlockIdLayout;
import org.apache.uniffle.common.util.Constants;
import org.apache.uniffle.common.util.ExpiringCloseableSupplier;
import org.apache.uniffle.common.util.JavaUtils;
import org.apache.uniffle.common.util.RetryUtils;
import org.apache.uniffle.common.util.RssUtils;
import org.apache.uniffle.common.util.ThreadUtils;
import org.apache.uniffle.shaded.com.google.common.annotations.VisibleForTesting;
import org.apache.uniffle.shaded.com.google.common.collect.Maps;
import org.apache.uniffle.shaded.com.google.common.collect.Sets;
import org.apache.uniffle.shaded.org.apache.commons.collections4.CollectionUtils;
import org.apache.uniffle.shaded.org.apache.commons.lang3.StringUtils;
import org.apache.uniffle.shuffle.BlockIdManager;
import org.apache.uniffle.shuffle.ShuffleIdMappingManager;
import org.apache.uniffle.shuffle.manager.RssShuffleManagerInterface;
import org.apache.uniffle.shuffle.manager.ShuffleManagerGrpcService;
import org.apache.uniffle.shuffle.manager.ShuffleManagerServerFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Tuple2;

public abstract class RssShuffleManagerBase
implements RssShuffleManagerInterface,
ShuffleManager {
    private static final Logger LOG = LoggerFactory.getLogger(RssShuffleManagerBase.class);
    protected final int dataTransferPoolSize;
    protected final int dataCommitPoolSize;
    protected final int dataReplica;
    protected final int dataReplicaWrite;
    protected final int dataReplicaRead;
    protected final boolean dataReplicaSkipEnabled;
    protected final Map<String, Set<Long>> taskToSuccessBlockIds;
    protected final Map<String, FailedBlockSendTracker> taskToFailedBlockSendTracker;
    private Set<String> failedTaskIds = Sets.newConcurrentHashSet();
    private AtomicBoolean isInitialized = new AtomicBoolean(false);
    private Method unregisterAllMapOutputMethod;
    private Method registerShuffleMethod;
    private volatile BlockIdManager blockIdManager;
    protected ShuffleDataDistributionType dataDistributionType;
    private Object blockIdManagerLock = new Object();
    protected AtomicReference<String> id = new AtomicReference();
    protected String appId = "";
    protected ShuffleWriteClient shuffleWriteClient;
    protected boolean dynamicConfEnabled;
    protected int maxConcurrencyPerPartitionToWrite;
    protected String clientType;
    protected final SparkConf sparkConf;
    protected final RssConf rssConf;
    protected Map<Integer, Integer> shuffleIdToPartitionNum;
    protected Map<Integer, Integer> shuffleIdToNumMapTasks;
    protected Supplier<ShuffleManagerClient> managerClientSupplier;
    protected boolean rssStageRetryEnabled;
    protected boolean rssStageRetryForWriteFailureEnabled;
    protected boolean rssStageRetryForFetchFailureEnabled;
    protected ShuffleHandleInfoManager shuffleHandleInfoManager;
    protected RssStageResubmitManager rssStageResubmitManager;
    protected ShuffleIdMappingManager shuffleIdMappingManager;
    protected int partitionReassignMaxServerNum;
    protected boolean blockIdSelfManagedEnabled;
    protected boolean partitionReassignEnabled;
    protected boolean shuffleManagerRpcServiceEnabled;
    protected boolean heartbeatStarted = false;
    protected final long heartbeatInterval;
    protected final long heartbeatTimeout;
    protected String user;
    protected String uuid;
    protected ScheduledExecutorService heartBeatScheduledExecutorService;
    protected final int maxFailures;
    protected final boolean speculation;
    protected final BlockIdLayout blockIdLayout;
    private ShuffleManagerGrpcService service;
    protected GrpcServer shuffleManagerServer;
    protected DataPusher dataPusher;
    private int partitionSplitLoadBalanceServerNum;
    protected PartitionSplitMode partitionSplitMode;
    private AtomicBoolean reassignTriggeredOnPartitionSplit = new AtomicBoolean(false);
    private AtomicBoolean reassignTriggeredOnBlockSendFailure = new AtomicBoolean(false);
    private AtomicBoolean reassignTriggeredOnStageRetry = new AtomicBoolean(false);
    private boolean isDriver = false;

    public RssShuffleManagerBase(SparkConf conf, boolean isDriver) {
        LOG.info("Uniffle {} version: {}", (Object)this.getClass().getName(), (Object)Constants.VERSION_AND_REVISION_SHORT);
        this.sparkConf = conf;
        this.isDriver = isDriver;
        this.checkSupported(this.sparkConf);
        boolean supportsRelocation = Optional.ofNullable(SparkEnv.get()).map(env -> env.serializer().supportsRelocationOfSerializedObjects()).orElse(true);
        if (!supportsRelocation) {
            LOG.warn("RSSShuffleManager requires a serializer which supports relocations of serialized object. Please set spark.serializer to org.apache.spark.serializer.KryoSerializer instead");
        }
        this.user = this.sparkConf.get("spark.rss.quota.user", "user");
        this.uuid = this.sparkConf.get("spark.rss.quota.uuid", Long.toString(System.currentTimeMillis()));
        this.dynamicConfEnabled = (Boolean)this.sparkConf.get(RssSparkConfig.RSS_DYNAMIC_CLIENT_CONF_ENABLED);
        if (isDriver && this.dynamicConfEnabled) {
            RssShuffleManagerBase.fetchAndApplyDynamicConf(this.sparkConf);
        }
        RssSparkShuffleUtils.validateRssClientConf(this.sparkConf);
        this.rssConf = RssSparkConfig.toRssConf(this.sparkConf);
        RssUtils.setExtraJavaProperties(this.rssConf);
        this.dataReplica = (Integer)this.sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA);
        this.dataReplicaWrite = (Integer)this.sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA_WRITE);
        this.dataReplicaRead = (Integer)this.sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA_READ);
        this.dataReplicaSkipEnabled = (Boolean)this.sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA_SKIP_ENABLED);
        LOG.info("Check quorum config [" + this.dataReplica + ":" + this.dataReplicaWrite + ":" + this.dataReplicaRead + ":" + this.dataReplicaSkipEnabled + "]");
        RssUtils.checkQuorumSetting(this.dataReplica, this.dataReplicaWrite, this.dataReplicaRead);
        this.maxConcurrencyPerPartitionToWrite = this.rssConf.get(RssClientConf.MAX_CONCURRENCY_PER_PARTITION_TO_WRITE);
        this.clientType = (String)this.sparkConf.get(RssSparkConfig.RSS_CLIENT_TYPE);
        this.maxFailures = this.sparkConf.getInt("spark.task.maxFailures", 4);
        this.speculation = this.sparkConf.getBoolean("spark.speculation", false);
        this.configureBlockIdLayout(this.sparkConf, this.rssConf);
        this.blockIdLayout = BlockIdLayout.from(this.rssConf);
        this.dataTransferPoolSize = (Integer)this.sparkConf.get(RssSparkConfig.RSS_DATA_TRANSFER_POOL_SIZE);
        this.dataCommitPoolSize = (Integer)this.sparkConf.get(RssSparkConfig.RSS_DATA_COMMIT_POOL_SIZE);
        this.sparkConf.set("spark.shuffle.service.enabled", "false");
        this.sparkConf.set("spark.dynamicAllocation.shuffleTracking.enabled", "false");
        this.sparkConf.set(RssSparkConfig.RSS_ENABLED.key(), "true");
        LOG.info("Disable external shuffle service in RssShuffleManager.");
        this.sparkConf.set("spark.sql.adaptive.localShuffleReader.enabled", "false");
        LOG.info("Disable local shuffle reader in RssShuffleManager.");
        this.sparkConf.set("spark.shuffle.reduceLocality.enabled", "false");
        LOG.info("Disable shuffle data locality in RssShuffleManager.");
        this.taskToSuccessBlockIds = JavaUtils.newConcurrentMap();
        this.taskToFailedBlockSendTracker = JavaUtils.newConcurrentMap();
        this.shuffleIdToPartitionNum = JavaUtils.newConcurrentMap();
        this.shuffleIdToNumMapTasks = JavaUtils.newConcurrentMap();
        this.rssStageRetryForFetchFailureEnabled = this.rssConf.get(RssSparkConfig.RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED);
        this.rssStageRetryForWriteFailureEnabled = this.rssConf.get(RssSparkConfig.RSS_RESUBMIT_STAGE_WITH_WRITE_FAILURE_ENABLED);
        if (this.rssStageRetryForFetchFailureEnabled || this.rssStageRetryForWriteFailureEnabled) {
            this.rssStageRetryEnabled = true;
            ArrayList<String> logTips = new ArrayList<String>();
            if (this.rssStageRetryForWriteFailureEnabled) {
                logTips.add("write");
            }
            if (this.rssStageRetryForWriteFailureEnabled) {
                logTips.add("fetch");
            }
            LOG.info("Activate the stage retry mechanism that will resubmit stage on {} failure", (Object)StringUtils.join(logTips, "/"));
        }
        this.partitionReassignEnabled = this.rssConf.get(RssClientConf.RSS_CLIENT_REASSIGN_ENABLED);
        if (this.partitionReassignEnabled) {
            if (this.dataReplica > 1) {
                throw new RssException("The feature of task partition reassign is incompatible with multiple replicas mechanism.");
            }
            this.partitionSplitMode = this.rssConf.get(RssClientConf.RSS_CLIENT_PARTITION_SPLIT_MODE);
            this.partitionSplitLoadBalanceServerNum = this.rssConf.get(RssClientConf.RSS_CLIENT_PARTITION_SPLIT_LOAD_BALANCE_SERVER_NUMBER);
            LOG.info("Partition reassign is enabled.");
        }
        this.blockIdSelfManagedEnabled = this.rssConf.getBoolean(RssSparkConfig.RSS_BLOCK_ID_SELF_MANAGEMENT_ENABLED);
        boolean bl = this.shuffleManagerRpcServiceEnabled = this.partitionReassignEnabled || this.rssStageRetryEnabled || this.blockIdSelfManagedEnabled || RssSparkShuffleUtils.isSparkUIEnabled(conf);
        if (isDriver) {
            this.heartBeatScheduledExecutorService = ThreadUtils.getDaemonSingleThreadScheduledExecutor("rss-heartbeat");
            if (this.shuffleManagerRpcServiceEnabled) {
                LOG.info("stage resubmit is supported and enabled");
                this.rssConf.set(RssBaseConf.RPC_SERVER_PORT, 0);
                ShuffleManagerServerFactory factory = new ShuffleManagerServerFactory(this, this.rssConf);
                this.service = factory.getService();
                this.shuffleManagerServer = factory.getServer(this.service);
                try {
                    this.shuffleManagerServer.start();
                    this.sparkConf.set(RssSparkConfig.RSS_SHUFFLE_MANAGER_GRPC_PORT, (Object)this.shuffleManagerServer.getPort());
                }
                catch (Exception e) {
                    LOG.error("Failed to start shuffle manager server", (Throwable)e);
                    throw new RssException(e);
                }
            }
        }
        if (this.shuffleManagerRpcServiceEnabled) {
            this.getOrCreateShuffleManagerClientSupplier();
        }
        this.heartbeatInterval = (Long)this.sparkConf.get(RssSparkConfig.RSS_HEARTBEAT_INTERVAL);
        this.heartbeatTimeout = this.sparkConf.getLong(RssSparkConfig.RSS_HEARTBEAT_TIMEOUT.key(), this.heartbeatInterval / 2L);
        this.heartBeatScheduledExecutorService = ThreadUtils.getDaemonSingleThreadScheduledExecutor("rss-heartbeat");
        this.shuffleWriteClient = this.createShuffleWriteClient();
        this.registerCoordinator();
        LOG.info("Rss data pusher is starting...");
        int poolSize = (Integer)this.sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_THREAD_POOL_SIZE);
        int keepAliveTime = (Integer)this.sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_THREAD_POOL_KEEPALIVE);
        boolean overlappingCompressionEnabled = this.rssConf.get(RssSparkConfig.RSS_WRITE_OVERLAPPING_COMPRESSION_ENABLED);
        int overlappingCompressionThreadsPerVcore = this.rssConf.get(RssSparkConfig.RSS_WRITE_OVERLAPPING_COMPRESSION_THREADS_PER_VCORE);
        if (overlappingCompressionEnabled && overlappingCompressionThreadsPerVcore > 0) {
            int compressionThreads = overlappingCompressionThreadsPerVcore * this.sparkConf.getInt("spark.executor.cores", 1);
            this.dataPusher = new OverlappingCompressionDataPusher(this.shuffleWriteClient, this.taskToSuccessBlockIds, this.taskToFailedBlockSendTracker, this.failedTaskIds, poolSize, keepAliveTime, compressionThreads);
        } else {
            this.dataPusher = new DataPusher(this.shuffleWriteClient, this.taskToSuccessBlockIds, this.taskToFailedBlockSendTracker, this.failedTaskIds, poolSize, keepAliveTime);
        }
        this.partitionReassignMaxServerNum = this.rssConf.get(RssSparkConfig.RSS_PARTITION_REASSIGN_MAX_REASSIGNMENT_SERVER_NUM);
        this.shuffleHandleInfoManager = new ShuffleHandleInfoManager();
        this.rssStageResubmitManager = new RssStageResubmitManager();
        this.shuffleIdMappingManager = new ShuffleIdMappingManager();
    }

    @VisibleForTesting
    protected RssShuffleManagerBase(SparkConf conf, boolean isDriver, DataPusher dataPusher, Map<String, Set<Long>> taskToSuccessBlockIds, Map<String, FailedBlockSendTracker> taskToFailedBlockSendTracker) {
        this.sparkConf = conf;
        this.clientType = (String)this.sparkConf.get(RssSparkConfig.RSS_CLIENT_TYPE);
        this.rssConf = RssSparkConfig.toRssConf(this.sparkConf);
        this.dataDistributionType = this.rssConf.get(RssClientConf.DATA_DISTRIBUTION_TYPE);
        this.blockIdLayout = BlockIdLayout.from(this.rssConf);
        this.maxConcurrencyPerPartitionToWrite = this.rssConf.get(RssClientConf.MAX_CONCURRENCY_PER_PARTITION_TO_WRITE);
        this.maxFailures = this.sparkConf.getInt("spark.task.maxFailures", 4);
        this.speculation = this.sparkConf.getBoolean("spark.speculation", false);
        this.configureBlockIdLayout(this.sparkConf, this.rssConf);
        this.heartbeatInterval = (Long)this.sparkConf.get(RssSparkConfig.RSS_HEARTBEAT_INTERVAL);
        this.heartbeatTimeout = this.sparkConf.getLong(RssSparkConfig.RSS_HEARTBEAT_TIMEOUT.key(), this.heartbeatInterval / 2L);
        this.dataReplica = (Integer)this.sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA);
        this.dataReplicaWrite = (Integer)this.sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA_WRITE);
        this.dataReplicaRead = (Integer)this.sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA_READ);
        this.dataReplicaSkipEnabled = (Boolean)this.sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA_SKIP_ENABLED);
        LOG.info("Check quorum config [" + this.dataReplica + ":" + this.dataReplicaWrite + ":" + this.dataReplicaRead + ":" + this.dataReplicaSkipEnabled + "]");
        RssUtils.checkQuorumSetting(this.dataReplica, this.dataReplicaWrite, this.dataReplicaRead);
        this.dataTransferPoolSize = (Integer)this.sparkConf.get(RssSparkConfig.RSS_DATA_TRANSFER_POOL_SIZE);
        this.dataCommitPoolSize = (Integer)this.sparkConf.get(RssSparkConfig.RSS_DATA_COMMIT_POOL_SIZE);
        this.createShuffleWriteClient();
        this.taskToSuccessBlockIds = taskToSuccessBlockIds;
        this.heartBeatScheduledExecutorService = null;
        this.taskToFailedBlockSendTracker = taskToFailedBlockSendTracker;
        this.dataPusher = dataPusher;
        this.partitionReassignMaxServerNum = this.rssConf.get(RssSparkConfig.RSS_PARTITION_REASSIGN_MAX_REASSIGNMENT_SERVER_NUM);
        this.shuffleHandleInfoManager = new ShuffleHandleInfoManager();
        this.rssStageResubmitManager = new RssStageResubmitManager();
        this.shuffleIdMappingManager = new ShuffleIdMappingManager();
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public BlockIdManager getBlockIdManager() {
        if (this.blockIdManager == null) {
            Object object = this.blockIdManagerLock;
            synchronized (object) {
                if (this.blockIdManager == null) {
                    this.blockIdManager = new BlockIdManager();
                    LOG.info("BlockId manager has been initialized.");
                }
            }
        }
        return this.blockIdManager;
    }

    public boolean unregisterShuffle(int shuffleId) {
        try {
            if (this.blockIdManager != null) {
                this.blockIdManager.remove(shuffleId);
            }
            if (SparkEnv.get().executorId().equals("driver")) {
                this.shuffleWriteClient.unregisterShuffle(this.getAppId(), shuffleId);
                this.shuffleIdToPartitionNum.remove(shuffleId);
                this.shuffleIdToNumMapTasks.remove(shuffleId);
                if (this.service != null) {
                    this.service.unregisterShuffle(shuffleId);
                }
            }
        }
        catch (Exception e) {
            LOG.warn("Errors on unregistering from remote shuffle-servers", (Throwable)e);
        }
        return true;
    }

    public void configureBlockIdLayout(SparkConf sparkConf, RssConf rssConf) {
        RssShuffleManagerBase.configureBlockIdLayout(sparkConf, rssConf, this.maxFailures, this.speculation);
    }

    @VisibleForTesting
    protected static void configureBlockIdLayout(SparkConf sparkConf, RssConf rssConf, int maxFailures, boolean speculation) {
        if (sparkConf.contains(RssSparkConfig.RSS_MAX_PARTITIONS.key())) {
            RssShuffleManagerBase.configureBlockIdLayoutFromMaxPartitions(sparkConf, rssConf, maxFailures, speculation);
        } else {
            RssShuffleManagerBase.configureBlockIdLayoutFromLayoutConfig(sparkConf, rssConf, maxFailures, speculation);
        }
    }

    private static void configureBlockIdLayoutFromMaxPartitions(SparkConf sparkConf, RssConf rssConf, int maxFailures, boolean speculation) {
        int maxPartitions = sparkConf.getInt(RssSparkConfig.RSS_MAX_PARTITIONS.key(), ((Integer)RssSparkConfig.RSS_MAX_PARTITIONS.defaultValue().get()).intValue());
        if (maxPartitions <= 1) {
            throw new IllegalArgumentException("Value of " + RssSparkConfig.RSS_MAX_PARTITIONS.key() + " must be larger than 1: " + maxPartitions);
        }
        int attemptIdBits = ClientUtils.getNumberOfSignificantBits(ClientUtils.getMaxAttemptNo(maxFailures, speculation));
        int partitionIdBits = ClientUtils.getNumberOfSignificantBits(maxPartitions - 1);
        int taskAttemptIdBits = partitionIdBits + attemptIdBits;
        int sequenceNoBits = 63 - partitionIdBits - taskAttemptIdBits;
        if (taskAttemptIdBits > 31) {
            throw new IllegalArgumentException("Cannot support " + RssSparkConfig.RSS_MAX_PARTITIONS.key() + "=" + maxPartitions + " partitions, as this would require to reserve more than 31 bits in the block id for task attempt ids. With spark.maxFailures=" + maxFailures + " and spark.speculation=" + (speculation ? "true" : "false") + " at most " + (1 << 31 - attemptIdBits) + " partitions can be supported.");
        }
        if (sequenceNoBits > 31) {
            int spareBits = sequenceNoBits - 31;
            spareBits += spareBits % 2;
            taskAttemptIdBits += spareBits / 2;
            maxPartitions = 1 << (partitionIdBits += spareBits / 2);
            if (LOG.isInfoEnabled()) {
                LOG.info("Increasing " + RssSparkConfig.RSS_MAX_PARTITIONS.key() + " to " + maxPartitions + ", otherwise we would have to support 2^" + sequenceNoBits + " (more than 2^31) sequence numbers.");
            }
            sequenceNoBits -= spareBits;
            sparkConf.set(RssSparkConfig.RSS_MAX_PARTITIONS.key(), String.valueOf(maxPartitions));
        }
        rssConf.set(RssClientConf.BLOCKID_SEQUENCE_NO_BITS, sequenceNoBits);
        rssConf.set(RssClientConf.BLOCKID_PARTITION_ID_BITS, partitionIdBits);
        rssConf.set(RssClientConf.BLOCKID_TASK_ATTEMPT_ID_BITS, taskAttemptIdBits);
        sparkConf.set("spark." + RssClientConf.BLOCKID_SEQUENCE_NO_BITS.key(), String.valueOf(sequenceNoBits));
        sparkConf.set("spark." + RssClientConf.BLOCKID_PARTITION_ID_BITS.key(), String.valueOf(partitionIdBits));
        sparkConf.set("spark." + RssClientConf.BLOCKID_TASK_ATTEMPT_ID_BITS.key(), String.valueOf(taskAttemptIdBits));
    }

    private static void configureBlockIdLayoutFromLayoutConfig(SparkConf sparkConf, RssConf rssConf, int maxFailures, boolean speculation) {
        String sparkPrefix = "spark.";
        String sparkSeqNoBitsKey = sparkPrefix + RssClientConf.BLOCKID_SEQUENCE_NO_BITS.key();
        String sparkPartIdBitsKey = sparkPrefix + RssClientConf.BLOCKID_PARTITION_ID_BITS.key();
        String sparkTaskIdBitsKey = sparkPrefix + RssClientConf.BLOCKID_TASK_ATTEMPT_ID_BITS.key();
        List<String> sparkKeys = Arrays.asList(sparkSeqNoBitsKey, sparkPartIdBitsKey, sparkTaskIdBitsKey);
        if (sparkKeys.stream().anyMatch(arg_0 -> ((SparkConf)sparkConf).contains(arg_0))) {
            if (!sparkKeys.stream().allMatch(arg_0 -> ((SparkConf)sparkConf).contains(arg_0))) {
                String allKeys = sparkKeys.stream().collect(Collectors.joining(", "));
                String existingKeys = Arrays.stream(sparkConf.getAll()).map(t2 -> (String)t2._1).filter(sparkKeys.stream().collect(Collectors.toSet())::contains).collect(Collectors.joining(", "));
                throw new IllegalArgumentException("All block id bit config keys must be provided (" + allKeys + "), not just a sub-set: " + existingKeys);
            }
        }
        List<ConfigOption> rssKeys = Arrays.asList(RssClientConf.BLOCKID_SEQUENCE_NO_BITS, RssClientConf.BLOCKID_PARTITION_ID_BITS, RssClientConf.BLOCKID_TASK_ATTEMPT_ID_BITS);
        if (rssKeys.stream().anyMatch(rssConf::contains)) {
            if (!rssKeys.stream().allMatch(rssConf::contains)) {
                String allKeys = rssKeys.stream().map(ConfigOption::key).collect(Collectors.joining(", "));
                String existingKeys = rssConf.getKeySet().stream().filter(rssKeys.stream().map(ConfigOption::key).collect(Collectors.toSet())::contains).collect(Collectors.joining(", "));
                throw new IllegalArgumentException("All block id bit config keys must be provided (" + allKeys + "), not just a sub-set: " + existingKeys);
            }
        }
        if (sparkKeys.stream().allMatch(arg_0 -> ((SparkConf)sparkConf).contains(arg_0))) {
            rssConf.set(RssClientConf.BLOCKID_SEQUENCE_NO_BITS, sparkConf.getInt(sparkSeqNoBitsKey, 0));
            rssConf.set(RssClientConf.BLOCKID_PARTITION_ID_BITS, sparkConf.getInt(sparkPartIdBitsKey, 0));
            rssConf.set(RssClientConf.BLOCKID_TASK_ATTEMPT_ID_BITS, sparkConf.getInt(sparkTaskIdBitsKey, 0));
        } else if (rssKeys.stream().allMatch(rssConf::contains)) {
            sparkConf.set(sparkSeqNoBitsKey, rssConf.getValue(RssClientConf.BLOCKID_SEQUENCE_NO_BITS));
            sparkConf.set(sparkPartIdBitsKey, rssConf.getValue(RssClientConf.BLOCKID_PARTITION_ID_BITS));
            sparkConf.set(sparkTaskIdBitsKey, rssConf.getValue(RssClientConf.BLOCKID_TASK_ATTEMPT_ID_BITS));
        } else {
            sparkConf.set(RssSparkConfig.RSS_MAX_PARTITIONS.key(), RssSparkConfig.RSS_MAX_PARTITIONS.defaultValueString());
            RssShuffleManagerBase.configureBlockIdLayoutFromMaxPartitions(sparkConf, rssConf, maxFailures, speculation);
        }
    }

    public long getTaskAttemptIdForBlockId(int mapIndex, int attemptNo) {
        return RssShuffleManagerBase.getTaskAttemptIdForBlockId(mapIndex, attemptNo, this.maxFailures, this.speculation, this.blockIdLayout.taskAttemptIdBits);
    }

    protected static long getTaskAttemptIdForBlockId(int mapIndex, int attemptNo, int maxFailures, boolean speculation, int maxTaskAttemptIdBits) {
        int maxAttemptNo = ClientUtils.getMaxAttemptNo(maxFailures, speculation);
        int attemptBits = ClientUtils.getNumberOfSignificantBits(maxAttemptNo);
        if (attemptNo > maxAttemptNo) {
            throw new RssException("Observing attempt number " + attemptNo + " while maxFailures is set to " + maxFailures + (speculation ? " with speculation enabled" : "") + ".");
        }
        int mapIndexBits = ClientUtils.getNumberOfSignificantBits(mapIndex);
        if (mapIndexBits + attemptBits > maxTaskAttemptIdBits) {
            throw new RssException("Observing mapIndex[" + mapIndex + "] that would produce a taskAttemptId with " + (mapIndexBits + attemptBits) + " bits which is larger than the allowed " + maxTaskAttemptIdBits + " bits (maxFailures[" + maxFailures + "], speculation[" + speculation + "]). Please consider providing more bits for taskAttemptIds.");
        }
        return (long)mapIndex << attemptBits | (long)attemptNo;
    }

    protected static void fetchAndApplyDynamicConf(SparkConf sparkConf) {
        String user;
        String clientType = (String)sparkConf.get(RssSparkConfig.RSS_CLIENT_TYPE);
        String coordinators = sparkConf.get(RssSparkConfig.RSS_COORDINATOR_QUORUM.key());
        long retryIntervalMs = (Long)sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_INTERVAL_MAX);
        int retryTimes = (Integer)sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_MAX);
        int heartbeatThread = (Integer)sparkConf.get(RssSparkConfig.RSS_CLIENT_HEARTBEAT_THREAD_NUM);
        CoordinatorClientFactory coordinatorClientFactory = CoordinatorClientFactory.getInstance();
        CoordinatorGrpcRetryableClient coordinatorClient = coordinatorClientFactory.createCoordinatorClient(ClientType.valueOf(clientType), coordinators, retryIntervalMs, retryTimes, heartbeatThread);
        int timeoutMs = sparkConf.getInt(RssSparkConfig.RSS_ACCESS_TIMEOUT_MS.key(), ((Integer)RssSparkConfig.RSS_ACCESS_TIMEOUT_MS.defaultValue().get()).intValue());
        try {
            user = UserGroupInformation.getCurrentUser().getShortUserName();
        }
        catch (Exception e) {
            throw new RssException("Errors on getting current user.", e);
        }
        RssFetchClientConfRequest request = new RssFetchClientConfRequest(timeoutMs, user, Collections.emptyMap());
        RssFetchClientConfResponse response = coordinatorClient.fetchClientConf(request);
        if (response.getStatusCode() == StatusCode.SUCCESS) {
            RssSparkShuffleUtils.applyDynamicClientConf(sparkConf, response.getClientConf());
        }
        coordinatorClient.close();
    }

    @Override
    public void unregisterAllMapOutput(int shuffleId) throws SparkException {
        if (!RssSparkShuffleUtils.isStageResubmitSupported()) {
            return;
        }
        MapOutputTrackerMaster tracker = RssShuffleManagerBase.getMapOutputTrackerMaster();
        if (this.isInitialized.compareAndSet(false, true)) {
            this.unregisterAllMapOutputMethod = RssShuffleManagerBase.getUnregisterAllMapOutputMethod(tracker);
            this.registerShuffleMethod = RssShuffleManagerBase.getRegisterShuffleMethod(tracker);
        }
        if (this.unregisterAllMapOutputMethod != null) {
            try {
                this.unregisterAllMapOutputMethod.invoke((Object)tracker, shuffleId);
            }
            catch (IllegalAccessException | InvocationTargetException e) {
                throw new RssException("Invoke unregisterAllMapOutput method failed", e);
            }
        } else {
            int numMaps = this.getNumMaps(shuffleId);
            int numReduces = this.getPartitionNum(shuffleId);
            RssShuffleManagerBase.defaultUnregisterAllMapOutput(tracker, this.registerShuffleMethod, shuffleId, numMaps, numReduces);
        }
    }

    private static void defaultUnregisterAllMapOutput(MapOutputTrackerMaster tracker, Method registerShuffle, int shuffleId, int numMaps, int numReduces) throws SparkException {
        if (tracker != null && registerShuffle != null) {
            tracker.unregisterShuffle(shuffleId);
            try {
                if (SparkVersionUtils.MAJOR_VERSION > 3 || SparkVersionUtils.isSpark3() && SparkVersionUtils.MINOR_VERSION >= 2) {
                    registerShuffle.invoke((Object)tracker, shuffleId, numMaps, numReduces);
                }
                registerShuffle.invoke((Object)tracker, shuffleId, numMaps);
            }
            catch (IllegalAccessException | InvocationTargetException e) {
                throw new RssException("Invoke registerShuffle method failed", e);
            }
        } else {
            throw new SparkException("default unregisterAllMapOutput should only be called on the driver side");
        }
        tracker.incrementEpoch();
    }

    private static Method getUnregisterAllMapOutputMethod(MapOutputTrackerMaster tracker) {
        if (tracker != null) {
            Class<?> klass = tracker.getClass();
            Method m4 = null;
            try {
                if (SparkVersionUtils.isSpark2() && SparkVersionUtils.MINOR_VERSION <= 3) {
                    LOG.warn("Spark version <= 2.3, fallback to default method");
                } else if (SparkVersionUtils.isSpark2()) {
                    m4 = klass.getDeclaredMethod("unregisterAllMapOutput", Integer.TYPE);
                } else if (SparkVersionUtils.isSpark3() && SparkVersionUtils.MINOR_VERSION <= 1) {
                    m4 = klass.getDeclaredMethod("unregisterAllMapOutput", Integer.TYPE);
                } else if (SparkVersionUtils.isSpark3()) {
                    m4 = klass.getDeclaredMethod("unregisterAllMapAndMergeOutput", Integer.TYPE);
                } else {
                    LOG.warn("Unknown spark version({}), fallback to default method", (Object)SparkVersionUtils.SPARK_VERSION);
                }
            }
            catch (NoSuchMethodException e) {
                LOG.warn("Got no such method error when get unregisterAllMapOutput method for spark version({})", (Object)SparkVersionUtils.SPARK_VERSION);
            }
            return m4;
        }
        return null;
    }

    private static Method getRegisterShuffleMethod(MapOutputTrackerMaster tracker) {
        if (tracker != null) {
            Class<?> klass = tracker.getClass();
            Method m4 = null;
            try {
                m4 = SparkVersionUtils.MAJOR_VERSION > 3 || SparkVersionUtils.isSpark3() && SparkVersionUtils.MINOR_VERSION >= 2 ? klass.getDeclaredMethod("registerShuffle", Integer.TYPE, Integer.TYPE, Integer.TYPE) : klass.getDeclaredMethod("registerShuffle", Integer.TYPE, Integer.TYPE);
            }
            catch (NoSuchMethodException e) {
                LOG.warn("Got no such method error when get registerShuffle method for spark version({})", (Object)SparkVersionUtils.SPARK_VERSION);
            }
            return m4;
        }
        return null;
    }

    private static MapOutputTrackerMaster getMapOutputTrackerMaster() {
        MapOutputTracker tracker = Optional.ofNullable(SparkEnv.get()).map(SparkEnv::mapOutputTracker).orElse(null);
        return tracker instanceof MapOutputTrackerMaster ? (MapOutputTrackerMaster)tracker : null;
    }

    private static Map<String, String> parseRemoteStorageConf(Configuration conf) {
        HashMap<String, String> confItems = Maps.newHashMap();
        for (Map.Entry entry : conf) {
            confItems.put((String)entry.getKey(), (String)entry.getValue());
        }
        return confItems;
    }

    protected static RemoteStorageInfo getDefaultRemoteStorageInfo(SparkConf sparkConf) {
        HashMap<String, String> confItems = Maps.newHashMap();
        RssConf rssConf = RssSparkConfig.toRssConf(sparkConf);
        if (rssConf.getBoolean(RssClientConf.RSS_CLIENT_REMOTE_STORAGE_USE_LOCAL_CONF_ENABLED)) {
            confItems = RssShuffleManagerBase.parseRemoteStorageConf(new Configuration(true));
        }
        for (String key : rssConf.getKeySet()) {
            String val;
            if (!key.startsWith("rss.hadoop.") || (val = rssConf.getString(key, null)) == null) continue;
            String extractedKey = key.replaceFirst("rss.hadoop.", "");
            confItems.put(extractedKey, val);
        }
        return new RemoteStorageInfo(sparkConf.get(RssSparkConfig.RSS_REMOTE_STORAGE_PATH.key(), ""), confItems);
    }

    public ShuffleHandleInfo getShuffleHandleInfo(int stageAttemptId, int stageAttemptNumber, RssShuffleHandle<?, ?, ?> rssHandle, boolean isWritePhase) {
        int shuffleId = rssHandle.getShuffleId();
        if (this.shuffleManagerRpcServiceEnabled && this.rssStageRetryForWriteFailureEnabled) {
            return this.getRemoteShuffleHandleInfoWithStageRetry(stageAttemptId, stageAttemptNumber, shuffleId, isWritePhase);
        }
        if (this.shuffleManagerRpcServiceEnabled && this.partitionReassignEnabled) {
            return this.getRemoteShuffleHandleInfoWithBlockRetry(stageAttemptId, stageAttemptNumber, shuffleId, isWritePhase);
        }
        return new SimpleShuffleHandleInfo(shuffleId, rssHandle.getPartitionToServers(), rssHandle.getRemoteStorage());
    }

    protected synchronized StageAttemptShuffleHandleInfo getRemoteShuffleHandleInfoWithStageRetry(int stageAttemptId, int stageAttemptNumber, int shuffleId, boolean isWritePhase) {
        RssPartitionToShuffleServerRequest rssPartitionToShuffleServerRequest = new RssPartitionToShuffleServerRequest(stageAttemptId, stageAttemptNumber, shuffleId, isWritePhase);
        RssReassignOnStageRetryResponse rpcPartitionToShufflerServer = this.getOrCreateShuffleManagerClientSupplier().get().getPartitionToShufflerServerWithStageRetry(rssPartitionToShuffleServerRequest);
        StageAttemptShuffleHandleInfo shuffleHandleInfo = StageAttemptShuffleHandleInfo.fromProto(rpcPartitionToShufflerServer.getShuffleHandleInfoProto());
        return shuffleHandleInfo;
    }

    protected synchronized MutableShuffleHandleInfo getRemoteShuffleHandleInfoWithBlockRetry(int stageAttemptId, int stageAttemptNumber, int shuffleId, boolean isWritePhase) {
        RssPartitionToShuffleServerRequest rssPartitionToShuffleServerRequest = new RssPartitionToShuffleServerRequest(stageAttemptId, stageAttemptNumber, shuffleId, isWritePhase);
        RssReassignOnBlockSendFailureResponse rpcPartitionToShufflerServer = this.getOrCreateShuffleManagerClientSupplier().get().getPartitionToShufflerServerWithBlockRetry(rssPartitionToShuffleServerRequest);
        MutableShuffleHandleInfo shuffleHandleInfo = MutableShuffleHandleInfo.fromProto(rpcPartitionToShufflerServer.getHandle());
        return shuffleHandleInfo;
    }

    protected synchronized Supplier<ShuffleManagerClient> getOrCreateShuffleManagerClientSupplier() {
        if (this.managerClientSupplier == null) {
            RssConf rssConf = RssSparkConfig.toRssConf(this.sparkConf);
            String driver = rssConf.getString("driver.host", "");
            int port = rssConf.get(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT);
            long rpcTimeout = rssConf.getLong(RssClientConf.RPC_TIMEOUT_MS);
            this.managerClientSupplier = ExpiringCloseableSupplier.of(() -> ShuffleManagerClientFactory.getInstance().createShuffleManagerClient(ClientType.GRPC, driver, port, rpcTimeout));
        }
        return this.managerClientSupplier;
    }

    public Supplier<ShuffleManagerClient> getShuffleManagerClientSupplier() {
        return this.managerClientSupplier;
    }

    @Override
    public ShuffleHandleInfo getShuffleHandleInfoByShuffleId(int shuffleId) {
        return this.shuffleHandleInfoManager.get(shuffleId);
    }

    @Override
    public int getMaxFetchFailures() {
        String TASK_MAX_FAILURE = "spark.task.maxFailures";
        return Math.max(0, this.sparkConf.getInt("spark.task.maxFailures", 4) - 1);
    }

    @Override
    public void addFailuresShuffleServerInfos(String shuffleServerId) {
        this.rssStageResubmitManager.recordFailuresShuffleServer(shuffleServerId);
    }

    @Override
    public boolean reassignOnStageResubmit(int shuffleId, int stageAttemptId, int stageAttemptNumber) {
        int requiredShuffleServerNumber = RssSparkShuffleUtils.getRequiredShuffleServerNumber(this.sparkConf);
        int estimateTaskConcurrency = RssSparkShuffleUtils.estimateTaskConcurrency(this.sparkConf);
        Map<Integer, List<ShuffleServerInfo>> partitionToServers = this.requestShuffleAssignment(shuffleId, this.getPartitionNum(shuffleId), 1, requiredShuffleServerNumber, estimateTaskConcurrency, this.rssStageResubmitManager.getServerIdBlackList());
        MutableShuffleHandleInfo shuffleHandleInfo = new MutableShuffleHandleInfo(shuffleId, partitionToServers, this.getRemoteStorageInfo(), this.partitionSplitMode);
        StageAttemptShuffleHandleInfo stageAttemptShuffleHandleInfo = (StageAttemptShuffleHandleInfo)this.shuffleHandleInfoManager.get(shuffleId);
        stageAttemptShuffleHandleInfo.replaceCurrentShuffleHandleInfo(shuffleHandleInfo);
        LOG.info("The stage retry has been triggered successfully for the shuffleId: {}, attemptNumber: {}", (Object)shuffleId, (Object)stageAttemptNumber);
        this.reassignTriggeredOnStageRetry.set(true);
        return true;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public MutableShuffleHandleInfo reassignOnBlockSendFailure(int stageId, int stageAttemptNumber, int shuffleId, Map<Integer, List<ReceivingFailureServer>> partitionToFailureServers, boolean partitionSplit) {
        long startTime = System.currentTimeMillis();
        ShuffleHandleInfo handleInfo = this.shuffleHandleInfoManager.get(shuffleId);
        MutableShuffleHandleInfo internalHandle = null;
        if (handleInfo instanceof MutableShuffleHandleInfo) {
            internalHandle = (MutableShuffleHandleInfo)handleInfo;
        } else if (handleInfo instanceof StageAttemptShuffleHandleInfo) {
            internalHandle = (MutableShuffleHandleInfo)((StageAttemptShuffleHandleInfo)handleInfo).getCurrent();
        }
        if (internalHandle == null) {
            throw new RssException("An unexpected error occurred: internalHandle is null, which should not happen");
        }
        MutableShuffleHandleInfo mutableShuffleHandleInfo = internalHandle;
        synchronized (mutableShuffleHandleInfo) {
            if (!partitionSplit) {
                internalHandle.checkPartitionReassignServerNum(partitionToFailureServers.keySet(), this.partitionReassignMaxServerNum);
            }
            HashMap<ShuffleServerInfo, List<PartitionRange>> newServerToPartitions = new HashMap<ShuffleServerInfo, List<PartitionRange>>();
            HashMap<String, Map> reassignResult = new HashMap<String, Map>();
            for (Map.Entry<Integer, List<ReceivingFailureServer>> entry : partitionToFailureServers.entrySet()) {
                int partitionId = entry.getKey();
                for (ReceivingFailureServer receivingFailureServer : entry.getValue()) {
                    Set<ShuffleServerInfo> updatedReassignServers;
                    StatusCode code = receivingFailureServer.getStatusCode();
                    String serverId = receivingFailureServer.getServerId();
                    boolean serverHasReplaced = false;
                    if (!partitionSplit && !internalHandle.isPartitionSplit(partitionId)) {
                        Set<ShuffleServerInfo> replacements = internalHandle.getReplacements(serverId);
                        if (CollectionUtils.isEmpty(replacements)) {
                            replacements = this.requestReassignServer(stageId, stageAttemptNumber, shuffleId, internalHandle, partitionId, serverId);
                        } else {
                            serverHasReplaced = true;
                        }
                        updatedReassignServers = internalHandle.updateAssignment(partitionId, serverId, replacements);
                    } else {
                        Set<ShuffleServerInfo> replacements;
                        int requireServerNum = 1;
                        if (this.partitionSplitMode == PartitionSplitMode.LOAD_BALANCE) {
                            requireServerNum = this.partitionSplitLoadBalanceServerNum;
                        }
                        if (CollectionUtils.isEmpty(replacements = internalHandle.getReplacementsForPartition(partitionId, serverId))) {
                            replacements = this.requestReassignServer(stageId, stageAttemptNumber, shuffleId, internalHandle, partitionId, serverId, requireServerNum);
                        } else {
                            serverHasReplaced = true;
                        }
                        updatedReassignServers = internalHandle.updateAssignmentOnPartitionSplit(partitionId, serverId, replacements);
                    }
                    if (updatedReassignServers.isEmpty()) continue;
                    reassignResult.computeIfAbsent(serverId, x -> new HashMap()).computeIfAbsent(partitionId, x -> new HashSet()).addAll(updatedReassignServers.stream().map(x -> x.getId()).collect(Collectors.toSet()));
                    if (!serverHasReplaced) continue;
                    for (ShuffleServerInfo serverInfo : updatedReassignServers) {
                        newServerToPartitions.computeIfAbsent(serverInfo, x -> new ArrayList()).add(new PartitionRange(partitionId, partitionId));
                    }
                }
            }
            if (!newServerToPartitions.isEmpty()) {
                LOG.info("Register the new partition->servers assignment on reassign. {}", newServerToPartitions);
                this.registerShuffleServers(this.getAppId(), shuffleId, newServerToPartitions, this.getRemoteStorageInfo());
            }
            LOG.info("Finished reassignOnBlockSendFailure request and cost {}(ms). is partition split:{}. Reassign result: {}", new Object[]{System.currentTimeMillis() - startTime, partitionSplit, reassignResult});
            if (partitionSplit) {
                this.reassignTriggeredOnPartitionSplit.set(true);
            } else {
                this.reassignTriggeredOnBlockSendFailure.set(true);
            }
            return internalHandle;
        }
    }

    private Set<ShuffleServerInfo> requestReassignServer(int stageId, int stageAttemptNumber, int shuffleId, MutableShuffleHandleInfo internalHandle, int partitionId, String serverId, int requiredServerNum) {
        HashSet<String> excludedServers = new HashSet<String>(internalHandle.listExcludedServers());
        excludedServers.addAll(internalHandle.listExcludedServersForPartition(partitionId));
        excludedServers.add(serverId);
        Set<ShuffleServerInfo> replacements = this.reassignServerForTask(stageId, stageAttemptNumber, shuffleId, Sets.newHashSet(partitionId), excludedServers, requiredServerNum, true);
        return replacements;
    }

    private Set<ShuffleServerInfo> requestReassignServer(int stageId, int stageAttemptNumber, int shuffleId, MutableShuffleHandleInfo internalHandle, int partitionId, String serverId) {
        return this.requestReassignServer(stageId, stageAttemptNumber, shuffleId, internalHandle, partitionId, serverId, 1);
    }

    public void stop() {
        if (this.isDriver && this.partitionReassignEnabled) {
            TaskReassignInfoEvent reassignInfoEvent = new TaskReassignInfoEvent(this.reassignTriggeredOnPartitionSplit.get(), this.reassignTriggeredOnBlockSendFailure.get(), this.reassignTriggeredOnStageRetry.get());
            RssSparkShuffleUtils.getActiveSparkContext().listenerBus().post((SparkListenerEvent)reassignInfoEvent);
        }
        if (this.managerClientSupplier != null && this.managerClientSupplier instanceof ExpiringCloseableSupplier) {
            ((ExpiringCloseableSupplier)this.managerClientSupplier).close();
        }
        if (this.heartBeatScheduledExecutorService != null) {
            this.heartBeatScheduledExecutorService.shutdownNow();
        }
        if (this.shuffleWriteClient != null) {
            this.shuffleWriteClient.unregisterShuffle(this.getAppId());
            this.shuffleWriteClient.close();
        }
        if (this.dataPusher != null) {
            try {
                this.dataPusher.close();
            }
            catch (IOException e) {
                LOG.warn("Errors on closing data pusher", (Throwable)e);
            }
        }
        if (this.shuffleManagerServer != null) {
            try {
                this.shuffleManagerServer.stop();
            }
            catch (InterruptedException e) {
                LOG.info("shuffle manager server is interrupted during stop");
            }
        }
    }

    @Override
    public String getAppId() {
        return this.id.get();
    }

    @Override
    public int getPartitionNum(int shuffleId) {
        return this.shuffleIdToPartitionNum.getOrDefault(shuffleId, 0);
    }

    @Override
    public int getNumMaps(int shuffleId) {
        return this.shuffleIdToNumMapTasks.getOrDefault(shuffleId, 0);
    }

    @VisibleForTesting
    public void addSuccessBlockIds(String taskId, Set<Long> blockIds) {
        if (this.taskToSuccessBlockIds.get(taskId) == null) {
            this.taskToSuccessBlockIds.put(taskId, Sets.newHashSet());
        }
        this.taskToSuccessBlockIds.get(taskId).addAll(blockIds);
    }

    @VisibleForTesting
    public void addFailedBlockSendTracker(String taskId, FailedBlockSendTracker failedBlockSendTracker) {
        this.taskToFailedBlockSendTracker.putIfAbsent(taskId, failedBlockSendTracker);
    }

    protected abstract ShuffleWriteClient createShuffleWriteClient();

    protected void checkSupported(SparkConf sparkConf) {
    }

    private ShuffleAssignmentsInfo createShuffleAssignmentsInfo(Set<ShuffleServerInfo> servers, Set<Integer> partitionIds) {
        HashMap<Integer, List<ShuffleServerInfo>> newPartitionToServers = new HashMap<Integer, List<ShuffleServerInfo>>();
        ArrayList<PartitionRange> partitionRanges = new ArrayList<PartitionRange>();
        for (Integer partitionId : partitionIds) {
            newPartitionToServers.put(partitionId, new ArrayList<ShuffleServerInfo>(servers));
            partitionRanges.add(new PartitionRange(partitionId, partitionId));
        }
        HashMap<ShuffleServerInfo, List<PartitionRange>> serverToPartitionRanges = new HashMap<ShuffleServerInfo, List<PartitionRange>>();
        for (ShuffleServerInfo server : servers) {
            serverToPartitionRanges.put(server, partitionRanges);
        }
        return new ShuffleAssignmentsInfo(newPartitionToServers, serverToPartitionRanges);
    }

    private Set<ShuffleServerInfo> reassignServerForTask(int stageId, int stageAttemptNumber, int shuffleId, Set<Integer> partitionIds, Set<String> excludedServers, int requiredServerNum, boolean reassign) {
        AtomicReference replacementsRef = new AtomicReference(new HashSet());
        this.requestShuffleAssignment(shuffleId, requiredServerNum, 1, requiredServerNum, 1, excludedServers, shuffleAssignmentsInfo -> {
            if (shuffleAssignmentsInfo == null) {
                return null;
            }
            Set<ShuffleServerInfo> replacements = shuffleAssignmentsInfo.getPartitionToServers().values().stream().flatMap(x -> x.stream()).collect(Collectors.toSet());
            replacementsRef.set(replacements);
            return this.createShuffleAssignmentsInfo(replacements, partitionIds);
        });
        return replacementsRef.get();
    }

    private Map<Integer, List<ShuffleServerInfo>> requestShuffleAssignment(int shuffleId, int partitionNum, int partitionNumPerRange, int assignmentShuffleServerNumber, int estimateTaskConcurrency, Set<String> faultyServerIds, Function<ShuffleAssignmentsInfo, ShuffleAssignmentsInfo> reassignmentHandler) {
        Set<String> assignmentTags = RssSparkShuffleUtils.getAssignmentTags(this.sparkConf);
        ClientUtils.validateClientType(this.clientType);
        assignmentTags.add(this.clientType);
        long retryInterval = (Long)this.sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL);
        int retryTimes = (Integer)this.sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES);
        faultyServerIds.addAll(this.rssStageResubmitManager.getServerIdBlackList());
        try {
            ShuffleAssignmentsInfo response = this.shuffleWriteClient.getShuffleAssignments(this.getAppId(), shuffleId, partitionNum, partitionNumPerRange, assignmentTags, assignmentShuffleServerNumber, estimateTaskConcurrency, faultyServerIds, retryInterval, retryTimes);
            LOG.info("Finished reassign");
            if (reassignmentHandler != null) {
                response = reassignmentHandler.apply(response);
            }
            this.registerShuffleServers(this.getAppId(), shuffleId, response.getServerToPartitionRanges(), this.getRemoteStorageInfo());
            return response.getPartitionToServers();
        }
        catch (Throwable throwable) {
            throw new RssException("registerShuffle failed!", throwable);
        }
    }

    protected Map<Integer, List<ShuffleServerInfo>> requestShuffleAssignment(int shuffleId, int partitionNum, int partitionNumPerRange, int assignmentShuffleServerNumber, int estimateTaskConcurrency, Set<String> faultyServerIds) {
        Set<String> assignmentTags = RssSparkShuffleUtils.getAssignmentTags(this.sparkConf);
        ClientUtils.validateClientType(this.clientType);
        assignmentTags.add(this.clientType);
        long retryInterval = (Long)this.sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL);
        int retryTimes = (Integer)this.sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES);
        faultyServerIds.addAll(this.rssStageResubmitManager.getServerIdBlackList());
        try {
            return RetryUtils.retry(() -> {
                ShuffleAssignmentsInfo response = this.shuffleWriteClient.getShuffleAssignments(this.appId, shuffleId, partitionNum, partitionNumPerRange, assignmentTags, assignmentShuffleServerNumber, estimateTaskConcurrency, faultyServerIds, 0L, 0);
                this.registerShuffleServers(this.appId, shuffleId, response.getServerToPartitionRanges(), this.getRemoteStorageInfo());
                return response.getPartitionToServers();
            }, retryInterval, retryTimes);
        }
        catch (Throwable throwable) {
            throw new RssException("getShuffleAssignments or registerShuffle failed!", throwable);
        }
    }

    protected void registerShuffleServers(String appId, int shuffleId, Map<ShuffleServerInfo, List<PartitionRange>> serverToPartitionRanges, RemoteStorageInfo remoteStorage) {
        if (serverToPartitionRanges == null || serverToPartitionRanges.isEmpty()) {
            return;
        }
        LOG.info("Start to register shuffleId {}", (Object)shuffleId);
        long start = System.currentTimeMillis();
        Map<String, String> sparkConfMap = this.sparkConfToMap(this.getSparkConf());
        serverToPartitionRanges.entrySet().stream().forEach(entry -> this.shuffleWriteClient.registerShuffle((ShuffleServerInfo)entry.getKey(), appId, shuffleId, (List)entry.getValue(), remoteStorage, ShuffleDataDistributionType.NORMAL, this.maxConcurrencyPerPartitionToWrite, null, sparkConfMap));
        LOG.info("Finish register shuffleId {} with {} ms", (Object)shuffleId, (Object)(System.currentTimeMillis() - start));
    }

    @VisibleForTesting
    public RemoteStorageInfo getRemoteStorageInfo() {
        String storageType = this.sparkConf.get(RssSparkConfig.RSS_STORAGE_TYPE.key());
        RemoteStorageInfo defaultRemoteStorage = RssShuffleManagerBase.getDefaultRemoteStorageInfo(this.sparkConf);
        return ClientUtils.fetchRemoteStorage(this.appId, defaultRemoteStorage, this.dynamicConfEnabled, storageType, this.shuffleWriteClient);
    }

    public boolean isRssStageRetryEnabled() {
        return this.rssStageRetryEnabled;
    }

    public boolean isRssStageRetryForWriteFailureEnabled() {
        return this.rssStageRetryForWriteFailureEnabled;
    }

    public boolean isRssStageRetryForFetchFailureEnabled() {
        return this.rssStageRetryForFetchFailureEnabled;
    }

    @VisibleForTesting
    public SparkConf getSparkConf() {
        return this.sparkConf;
    }

    public Map<String, String> sparkConfToMap(SparkConf sparkConf) {
        HashMap<String, String> map = new HashMap<String, String>();
        for (Tuple2 tuple : sparkConf.getAll()) {
            String key = (String)tuple._1;
            map.put(key, (String)tuple._2);
        }
        return map;
    }

    @Override
    public ShuffleWriteClient getShuffleWriteClient() {
        return this.shuffleWriteClient;
    }

    protected synchronized void startHeartbeat() {
        this.shuffleWriteClient.registerApplicationInfo(this.getAppId(), this.heartbeatTimeout, this.user);
        if (!this.sparkConf.getBoolean(RssSparkConfig.RSS_TEST_FLAG.key(), false) && !this.heartbeatStarted) {
            this.heartBeatScheduledExecutorService.scheduleAtFixedRate(() -> {
                try {
                    String appId = this.getAppId();
                    this.shuffleWriteClient.sendAppHeartbeat(appId, this.heartbeatTimeout);
                    LOG.info("Finish send heartbeat to coordinator and servers");
                }
                catch (Exception e) {
                    LOG.warn("Fail to send heartbeat to coordinator and servers", (Throwable)e);
                }
            }, this.heartbeatInterval / 2L, this.heartbeatInterval, TimeUnit.MILLISECONDS);
            this.heartbeatStarted = true;
        }
    }

    public void clearTaskMeta(String taskId) {
        this.taskToSuccessBlockIds.remove(taskId);
        this.taskToFailedBlockSendTracker.remove(taskId);
    }

    @VisibleForTesting
    protected void registerCoordinator() {
        String coordinators = this.sparkConf.get(RssSparkConfig.RSS_COORDINATOR_QUORUM.key());
        LOG.info("Start Registering coordinators {}", (Object)coordinators);
        this.shuffleWriteClient.registerCoordinators(coordinators, (Long)this.sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_INTERVAL_MAX), (Integer)this.sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_MAX));
    }

    public Set<Long> getFailedBlockIds(String taskId) {
        FailedBlockSendTracker blockIdsFailedSendTracker = this.getBlockIdsFailedSendTracker(taskId);
        if (blockIdsFailedSendTracker == null) {
            return Collections.emptySet();
        }
        return blockIdsFailedSendTracker.getFailedBlockIds();
    }

    public Set<Long> getSuccessBlockIds(String taskId) {
        Set<Long> result = this.taskToSuccessBlockIds.get(taskId);
        if (result == null) {
            result = Collections.emptySet();
        }
        return result;
    }

    public FailedBlockSendTracker getBlockIdsFailedSendTracker(String taskId) {
        return this.taskToFailedBlockSendTracker.get(taskId);
    }

    public boolean markFailedTask(String taskId) {
        LOG.info("Mark the task: {} failed.", (Object)taskId);
        this.failedTaskIds.add(taskId);
        return true;
    }

    public boolean isValidTask(String taskId) {
        return !this.failedTaskIds.contains(taskId);
    }

    @VisibleForTesting
    public void setDataPusher(DataPusher dataPusher) {
        this.dataPusher = dataPusher;
    }

    public DataPusher getDataPusher() {
        return this.dataPusher;
    }

    @VisibleForTesting
    public Map<String, Set<Long>> getTaskToSuccessBlockIds() {
        return this.taskToSuccessBlockIds;
    }

    @VisibleForTesting
    public Map<String, FailedBlockSendTracker> getTaskToFailedBlockSendTracker() {
        return this.taskToFailedBlockSendTracker;
    }

    public CompletableFuture<Long> sendData(AddBlockEvent event) {
        if (this.dataPusher != null && event != null) {
            return this.dataPusher.send(event);
        }
        return new CompletableFuture<Long>();
    }
}

