/*
 * Decompiled with CFR 0.152.
 */
package com.flipkart.krystal.krystex.node;

import com.flipkart.krystal.data.Inputs;
import com.flipkart.krystal.krystex.KrystalExecutor;
import com.flipkart.krystal.krystex.MainLogicDefinition;
import com.flipkart.krystal.krystex.commands.ExecuteWithInputs;
import com.flipkart.krystal.krystex.commands.Flush;
import com.flipkart.krystal.krystex.commands.NodeCommand;
import com.flipkart.krystal.krystex.commands.NodeRequestCommand;
import com.flipkart.krystal.krystex.commands.SkipNode;
import com.flipkart.krystal.krystex.decoration.InitiateActiveDepChains;
import com.flipkart.krystal.krystex.decoration.LogicDecorationOrdering;
import com.flipkart.krystal.krystex.decoration.LogicExecutionContext;
import com.flipkart.krystal.krystex.decoration.MainLogicDecorator;
import com.flipkart.krystal.krystex.decoration.MainLogicDecoratorConfig;
import com.flipkart.krystal.krystex.node.DependantChain;
import com.flipkart.krystal.krystex.node.DependantChainStart;
import com.flipkart.krystal.krystex.node.DisabledDependantChainException;
import com.flipkart.krystal.krystex.node.KrystalNodeExecutorConfig;
import com.flipkart.krystal.krystex.node.KrystalNodeExecutorMetrics;
import com.flipkart.krystal.krystex.node.Node;
import com.flipkart.krystal.krystex.node.NodeDefinition;
import com.flipkart.krystal.krystex.node.NodeDefinitionRegistry;
import com.flipkart.krystal.krystex.node.NodeExecutionConfig;
import com.flipkart.krystal.krystex.node.NodeId;
import com.flipkart.krystal.krystex.node.NodeRegistry;
import com.flipkart.krystal.krystex.node.NodeResponse;
import com.flipkart.krystal.krystex.request.RequestId;
import com.flipkart.krystal.utils.Futures;
import com.flipkart.krystal.utils.MultiLeasePool;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.RejectedExecutionException;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Stream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public final class KrystalNodeExecutor
implements KrystalExecutor {
    private static final Logger log = LoggerFactory.getLogger(KrystalNodeExecutor.class);
    private final NodeDefinitionRegistry nodeDefinitionRegistry;
    private final LogicDecorationOrdering logicDecorationOrdering;
    private final MultiLeasePool.Lease<? extends ExecutorService> commandQueueLease;
    private final String instanceId;
    private final ImmutableMap<String, List<MainLogicDecoratorConfig>> requestScopedLogicDecoratorConfigs;
    private final ImmutableSet<DependantChain> disabledDependantChains;
    private final Map<String, Map<String, MainLogicDecorator>> requestScopedMainDecorators = new LinkedHashMap<String, Map<String, MainLogicDecorator>>();
    private final NodeRegistry nodeRegistry = new NodeRegistry();
    private final KrystalNodeExecutorMetrics krystalNodeMetrics;
    private volatile boolean closed;
    private final Map<RequestId, NodeResult> allRequests = new LinkedHashMap<RequestId, NodeResult>();
    private final Set<RequestId> unFlushedRequests = new LinkedHashSet<RequestId>();
    private final Map<NodeId, Set<DependantChain>> dependantChainsPerNode = new LinkedHashMap<NodeId, Set<DependantChain>>();

    public KrystalNodeExecutor(NodeDefinitionRegistry nodeDefinitionRegistry, MultiLeasePool<? extends ExecutorService> commandQueuePool, KrystalNodeExecutorConfig config, String instanceId) {
        this(nodeDefinitionRegistry, commandQueuePool, config.logicDecorationOrdering(), config.requestScopedLogicDecoratorConfigs(), config.disabledDependantChains(), instanceId);
    }

    public KrystalNodeExecutor(NodeDefinitionRegistry nodeDefinitionRegistry, MultiLeasePool<? extends ExecutorService> commandQueuePool, LogicDecorationOrdering logicDecorationOrdering, Map<String, List<MainLogicDecoratorConfig>> requestScopedLogicDecoratorConfigs, ImmutableSet<DependantChain> disabledDependantChains, String instanceId) {
        this.nodeDefinitionRegistry = nodeDefinitionRegistry;
        this.logicDecorationOrdering = logicDecorationOrdering;
        this.commandQueueLease = commandQueuePool.lease();
        this.instanceId = instanceId;
        this.requestScopedLogicDecoratorConfigs = ImmutableMap.copyOf(requestScopedLogicDecoratorConfigs);
        this.disabledDependantChains = disabledDependantChains;
        this.krystalNodeMetrics = new KrystalNodeExecutorMetrics();
    }

    private ImmutableMap<String, MainLogicDecorator> getRequestScopedDecorators(LogicExecutionContext logicExecutionContext) {
        NodeId nodeId = logicExecutionContext.nodeId();
        NodeDefinition nodeDefinition = this.nodeDefinitionRegistry.get(nodeId);
        MainLogicDefinition mainLogicDefinition = nodeDefinition.getMainLogicDefinition();
        LinkedHashMap decorators = new LinkedHashMap();
        Stream.concat(mainLogicDefinition.getRequestScopedLogicDecoratorConfigs().entrySet().stream(), this.requestScopedLogicDecoratorConfigs.entrySet().stream()).forEach(entry -> {
            String decoratorType = (String)entry.getKey();
            ArrayList decoratorConfigList = new ArrayList((Collection)entry.getValue());
            decoratorConfigList.forEach(decoratorConfig -> {
                String instanceId = decoratorConfig.instanceIdGenerator().apply(logicExecutionContext);
                if (decoratorConfig.shouldDecorate().test(logicExecutionContext)) {
                    MainLogicDecorator mainLogicDecorator = this.requestScopedMainDecorators.computeIfAbsent(decoratorType, t -> new LinkedHashMap()).computeIfAbsent(instanceId, _i -> decoratorConfig.factory().apply(new MainLogicDecoratorConfig.DecoratorContext(instanceId, logicExecutionContext)));
                    mainLogicDecorator.executeCommand(new InitiateActiveDepChains(nodeId, (ImmutableSet<DependantChain>)ImmutableSet.copyOf((Collection)this.dependantChainsPerNode.get(nodeId))));
                    decorators.putIfAbsent(decoratorType, mainLogicDecorator);
                }
            });
        });
        return ImmutableMap.copyOf(decorators);
    }

    @Override
    public <T> CompletableFuture<T> executeNode(NodeId nodeId, Inputs inputs, NodeExecutionConfig executionConfig) {
        if (this.closed) {
            throw new RejectedExecutionException("KrystalNodeExecutor is already closed");
        }
        Preconditions.checkArgument((executionConfig != null ? 1 : 0) != 0, (Object)"executionConfig can not be null");
        String executionId = executionConfig.executionId();
        Preconditions.checkArgument((executionId != null ? 1 : 0) != 0, (Object)"executionConfig.executionId can not be null");
        RequestId requestId = new RequestId("%s:%s".formatted(this.instanceId, executionId));
        return this.enqueueCommand(() -> {
            this.createDependencyNodes(nodeId, DependantChainStart.instance(), executionConfig);
            CompletableFuture<Object> future = new CompletableFuture<Object>();
            if (this.allRequests.containsKey(requestId)) {
                future.completeExceptionally(new IllegalArgumentException("Received duplicate requests for same instanceId '%s' and execution Id '%s'".formatted(this.instanceId, executionId)));
            } else {
                this.allRequests.put(requestId, new NodeResult(nodeId, inputs, executionConfig, future));
                this.unFlushedRequests.add(requestId);
            }
            return future;
        }).thenCompose(Function.identity());
    }

    private void createDependencyNodes(NodeId nodeId, DependantChain dependantChain, NodeExecutionConfig executionConfig) {
        NodeDefinition nodeDefinition = this.nodeDefinitionRegistry.get(nodeId);
        if (!Sets.union(this.disabledDependantChains, executionConfig.disabledDependantChains()).contains((Object)dependantChain)) {
            this.nodeRegistry.createIfAbsent(nodeId, _n -> new Node(nodeDefinition, this, this::getRequestScopedDecorators, this.logicDecorationOrdering));
            ImmutableMap<String, NodeId> dependencyNodes = nodeDefinition.dependencyNodes();
            dependencyNodes.forEach((dependencyName, depNodeId) -> this.createDependencyNodes((NodeId)depNodeId, DependantChain.extend(dependantChain, nodeId, dependencyName), executionConfig));
            this.dependantChainsPerNode.computeIfAbsent(nodeId, _n -> new LinkedHashSet()).add(dependantChain);
        }
    }

    CompletableFuture<NodeResponse> enqueueNodeCommand(Supplier<NodeRequestCommand> nodeCommand) {
        return this.enqueueCommand(() -> this._executeCommand((NodeCommand)nodeCommand.get())).thenCompose(Function.identity());
    }

    CompletableFuture<NodeResponse> executeCommand(NodeCommand nodeCommand) {
        this.krystalNodeMetrics.commandQueueBypassed();
        return this._executeCommand(nodeCommand);
    }

    private CompletableFuture<NodeResponse> _executeCommand(NodeCommand nodeCommand) {
        try {
            this.validate(nodeCommand);
        }
        catch (Throwable e) {
            return CompletableFuture.failedFuture(e);
        }
        if (nodeCommand instanceof NodeRequestCommand) {
            NodeRequestCommand nodeRequestCommand = (NodeRequestCommand)nodeCommand;
            return this.nodeRegistry.get(nodeCommand.nodeId()).executeRequestCommand(nodeRequestCommand);
        }
        if (nodeCommand instanceof Flush) {
            Flush flush = (Flush)nodeCommand;
            this.nodeRegistry.get(flush.nodeId()).executeCommand(flush);
            return CompletableFuture.failedFuture(new UnsupportedOperationException("No data returned for flush command"));
        }
        throw new UnsupportedOperationException("Unknown NodeCommand type %s".formatted(nodeCommand.getClass()));
    }

    private void validate(NodeCommand nodeCommand) {
        DependantChain dependantChain = null;
        if (nodeCommand instanceof NodeRequestCommand) {
            NodeRequestCommand nodeRequestCommand = (NodeRequestCommand)nodeCommand;
            RequestId requestId = nodeRequestCommand.requestId();
            if (nodeCommand instanceof ExecuteWithInputs) {
                ExecuteWithInputs executeWithInputs = (ExecuteWithInputs)nodeCommand;
                dependantChain = executeWithInputs.dependantChain();
            } else if (nodeCommand instanceof SkipNode) {
                SkipNode skipNode = (SkipNode)nodeCommand;
                dependantChain = skipNode.dependantChain();
            }
            if (Sets.union(this.disabledDependantChains, this.allRequests.get(requestId.originatedFrom()).executionConfig().disabledDependantChains()).contains((Object)dependantChain)) {
                throw new DisabledDependantChainException(dependantChain);
            }
        }
    }

    @Override
    public void flush() {
        this.enqueueCommand(() -> {
            this.unFlushedRequests.forEach(requestId -> {
                NodeResult nodeResult = this.allRequests.get(requestId);
                NodeId nodeId = nodeResult.nodeId();
                if (nodeResult.future().isDone()) {
                    return;
                }
                NodeDefinition nodeDefinition = this.nodeDefinitionRegistry.get(nodeId);
                CompletionStage submissionResult = ((CompletableFuture)this.executeCommand(new ExecuteWithInputs(nodeId, (ImmutableSet<String>)((ImmutableSet)nodeDefinition.getMainLogicDefinition().inputNames().stream().filter(s -> !nodeDefinition.dependencyNodes().containsKey(s)).collect(ImmutableSet.toImmutableSet())), nodeResult.inputs(), DependantChainStart.instance(), (RequestId)requestId)).thenApply(NodeResponse::response)).thenApply(valueOrError -> {
                    if (valueOrError.error().isPresent()) {
                        throw new RuntimeException((Throwable)valueOrError.error().get());
                    }
                    return valueOrError.value().orElse(null);
                });
                Futures.linkFutures((CompletableFuture)submissionResult, nodeResult.future());
            });
            ArrayList futures = new ArrayList();
            this.unFlushedRequests.stream().map(this.allRequests::get).map(NodeResult::future).forEach(futures::add);
            this.unFlushedRequests.stream().map(requestId -> this.allRequests.get(requestId).nodeId()).distinct().forEach(nodeId -> futures.add(this.executeCommand(new Flush((NodeId)nodeId))));
            return CompletableFuture.allOf((CompletableFuture[])futures.toArray(CompletableFuture[]::new)).whenComplete((_v, _t) -> {
                this.dependantChainsPerNode.clear();
                this.unFlushedRequests.clear();
            });
        });
    }

    public KrystalNodeExecutorMetrics getKrystalNodeMetrics() {
        return this.krystalNodeMetrics;
    }

    @Override
    public void close() {
        if (this.closed) {
            return;
        }
        this.closed = true;
        this.flush();
        this.enqueueCommand(() -> CompletableFuture.allOf((CompletableFuture[])this.allRequests.values().stream().map(NodeResult::future).toArray(CompletableFuture[]::new)).whenComplete((unused, throwable) -> this.commandQueueLease.close()));
    }

    private <T> CompletableFuture<T> enqueueCommand(Supplier<T> command) {
        return CompletableFuture.supplyAsync(() -> {
            this.krystalNodeMetrics.commandQueued();
            return command.get();
        }, (Executor)this.commandQueueLease.get());
    }

    private record NodeResult(NodeId nodeId, Inputs inputs, NodeExecutionConfig executionConfig, CompletableFuture<Object> future) {
    }
}

