/*
 * Decompiled with CFR 0.152.
 */
package com.flipkart.krystal.vajramexecutor.krystex.batching;

import com.flipkart.krystal.annos.InvocableOutsideGraph;
import com.flipkart.krystal.core.VajramID;
import com.flipkart.krystal.facets.Dependency;
import com.flipkart.krystal.facets.Facet;
import com.flipkart.krystal.facets.FacetType;
import com.flipkart.krystal.facets.resolution.ResolverDefinition;
import com.flipkart.krystal.krystex.kryon.DefaultDependentChain;
import com.flipkart.krystal.krystex.kryon.DependentChain;
import com.flipkart.krystal.krystex.kryon.DependentChainStart;
import com.flipkart.krystal.krystex.logicdecoration.LogicExecutionContext;
import com.flipkart.krystal.krystex.logicdecoration.OutputLogicDecorator;
import com.flipkart.krystal.krystex.logicdecoration.OutputLogicDecoratorConfig;
import com.flipkart.krystal.vajram.IOVajramDef;
import com.flipkart.krystal.vajram.batching.InputBatcher;
import com.flipkart.krystal.vajram.batching.InputBatcherImpl;
import com.flipkart.krystal.vajram.exec.VajramDefinition;
import com.flipkart.krystal.vajram.facets.resolution.InputResolver;
import com.flipkart.krystal.vajram.facets.specs.DependencySpec;
import com.flipkart.krystal.vajram.facets.specs.FacetSpec;
import com.flipkart.krystal.vajramexecutor.krystex.VajramKryonGraph;
import com.flipkart.krystal.vajramexecutor.krystex.batching.InputBatcherConfig;
import com.flipkart.krystal.vajramexecutor.krystex.batching.InputBatchingDecorator;
import com.google.common.collect.ImmutableCollection;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.Set;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.function.Supplier;
import org.checkerframework.checker.nullness.qual.Nullable;

public record DepChainBatcherConfig(Predicate<LogicExecutionContext> shouldBatch, Function<LogicExecutionContext, String> instanceIdGenerator, Function<OutputLogicDecoratorConfig.OutputLogicDecoratorContext, OutputLogicDecorator> decoratorFactory) {
    public static final DepChainBatcherConfig NO_BATCHING = new DepChainBatcherConfig(_l -> false, _l -> "", _l -> OutputLogicDecorator.NO_OP);

    public static DepChainBatcherConfig simple(Supplier<InputBatcher> inputBatcherSupplier) {
        return new DepChainBatcherConfig(logicExecutionContext -> true, logicExecutionContext -> DepChainBatcherConfig.generateInstanceId(logicExecutionContext.dependants()).toString(), outputLogicDecoratorContext -> new InputBatchingDecorator(outputLogicDecoratorContext.instanceId(), (InputBatcher)inputBatcherSupplier.get(), dependantChain -> outputLogicDecoratorContext.logicExecutionContext().dependants().equals(dependantChain)));
    }

    public static DepChainBatcherConfig sharedBatcher(Supplier<InputBatcher> inputBatcherSupplier, String instanceId, DependentChain ... dependentChains) {
        return DepChainBatcherConfig.sharedBatcher(inputBatcherSupplier, instanceId, (ImmutableSet<DependentChain>)ImmutableSet.copyOf((Object[])dependentChains));
    }

    public static DepChainBatcherConfig sharedBatcher(Supplier<InputBatcher> inputBatcherSupplier, String instanceId, ImmutableSet<DependentChain> dependentChains) {
        return new DepChainBatcherConfig(logicExecutionContext -> dependentChains.contains((Object)logicExecutionContext.dependants()), logicExecutionContext -> instanceId, outputLogicDecoratorContext -> new InputBatchingDecorator(instanceId, (InputBatcher)inputBatcherSupplier.get(), arg_0 -> ((ImmutableSet)dependentChains).contains(arg_0)));
    }

    public static void autoRegisterSharedBatchers(VajramKryonGraph graph, BatchSizeSupplier batchSizeSupplier) {
        DepChainBatcherConfig.autoRegisterSharedBatchers(graph, batchSizeSupplier, (ImmutableSet<DependentChain>)ImmutableSet.of());
    }

    public static void autoRegisterSharedBatchers(VajramKryonGraph graph, BatchSizeSupplier batchSizeSupplier, ImmutableSet<DependentChain> disabledDependentChains) {
        Map<VajramID, Map<Integer, Set<DependentChain>>> ioNodes = DepChainBatcherConfig.getIoVajrams(graph, disabledDependentChains);
        LinkedHashMap depChainBatcherConfigs = new LinkedHashMap();
        ioNodes.forEach((vajramId, ioNodeMap) -> {
            if (DepChainBatcherConfig.isBatchingNeededForIoVajram(graph, vajramId)) {
                ArrayList<DepChainBatcherConfig> inputModulatorConfigs = new ArrayList<DepChainBatcherConfig>(ioNodeMap.size());
                for (Map.Entry entry : ioNodeMap.entrySet()) {
                    Integer depth = (Integer)entry.getKey();
                    Set depChains = (Set)entry.getValue();
                    inputModulatorConfigs.add(DepChainBatcherConfig.sharedBatcher(() -> new InputBatcherImpl(batchSizeSupplier.getBatchSize((VajramID)vajramId)), vajramId.id() + ":depth(" + depth + ")", (DependentChain[])depChains.toArray(DependentChain[]::new)));
                }
                depChainBatcherConfigs.put(vajramId, ImmutableList.copyOf(inputModulatorConfigs));
            }
        });
        graph.registerInputBatchers(new InputBatcherConfig((ImmutableMap<VajramID, ImmutableList<DepChainBatcherConfig>>)ImmutableMap.copyOf(depChainBatcherConfigs)));
    }

    private static StringBuilder generateInstanceId(DependentChain dependentChain) {
        if (dependentChain instanceof DependentChainStart) {
            DependentChainStart dependantChainStart = (DependentChainStart)dependentChain;
            return new StringBuilder(dependantChainStart.toString());
        }
        if (dependentChain instanceof DefaultDependentChain) {
            DefaultDependentChain defaultDependantChain = (DefaultDependentChain)dependentChain;
            if (defaultDependantChain.incomingDependentChain() instanceof DependentChainStart) {
                return DepChainBatcherConfig.generateInstanceId(defaultDependantChain.incomingDependentChain()).append('>').append(defaultDependantChain.kryonId().id()).append(':').append(defaultDependantChain.latestDependency());
            }
            return DepChainBatcherConfig.generateInstanceId(defaultDependantChain.incomingDependentChain()).append('>').append(defaultDependantChain.latestDependency());
        }
        throw new UnsupportedOperationException();
    }

    private static boolean isBatchingNeededForIoVajram(VajramKryonGraph graph, VajramID ioNode) {
        VajramDefinition ioNodeVajram = graph.getVajramDefinition(ioNode);
        for (FacetSpec facetSpec : ioNodeVajram.facetSpecs()) {
            if (!facetSpec.isBatched()) continue;
            return true;
        }
        return false;
    }

    private static Map<VajramID, Map<Integer, Set<DependentChain>>> getIoVajrams(VajramKryonGraph graph, ImmutableSet<DependentChain> disabledDependentChains) {
        HashMap<VajramID, Map<Integer, Set<DependentChain>>> ioNodes = new HashMap<VajramID, Map<Integer, Set<DependentChain>>>();
        for (VajramDefinition rootNode : DepChainBatcherConfig.externallyInvocableVajrams(graph)) {
            DependentChain dependentChain = graph.kryonDefinitionRegistry().getDependentChainsStart();
            HashMap<VajramID, Integer> ioNodeDepths = new HashMap<VajramID, Integer>();
            DepChainBatcherConfig.dfs(rootNode, graph, ioNodes, 0, dependentChain, ioNodeDepths, disabledDependentChains);
        }
        return ioNodes;
    }

    private static Iterable<VajramDefinition> externallyInvocableVajrams(VajramKryonGraph graph) {
        return graph.vajramDefinitions().values().stream().filter(v -> v.vajramTags().getAnnotationByType(InvocableOutsideGraph.class).isPresent()).toList();
    }

    private static void dfs(VajramDefinition rootNode, VajramKryonGraph graph, Map<VajramID, Map<Integer, Set<DependentChain>>> ioNodes, int depth, DependentChain incomingDepChain, Map<VajramID, Integer> ioNodeDepths, ImmutableSet<DependentChain> disabledDependentChains) {
        HashMap<Facet, List<Facet>> inputDefGraph = new HashMap<Facet, List<Facet>>();
        VajramID vajramId = rootNode.vajramId();
        for (Facet inputDef : DepChainBatcherConfig.getOrderedInputDef(rootNode, inputDefGraph)) {
            DependentChain dependentChain;
            if (!(inputDef instanceof DependencySpec)) continue;
            DependencySpec dependency = (DependencySpec)inputDef;
            List<ResolverDefinition> resolverDefinition = DepChainBatcherConfig.getInputResolverDefinition(rootNode, dependency);
            VajramDefinition childNode = graph.getVajramDefinition(dependency.onVajramId());
            if (inputDefGraph.get(inputDef) != null) {
                for (Facet inputDef1 : (List)inputDefGraph.get(inputDef)) {
                    VajramID prerequisiteVajramId = DepChainBatcherConfig.dependencyInputInChildNode(resolverDefinition, inputDef1);
                    if (prerequisiteVajramId == null) continue;
                    DepChainBatcherConfig.incrementTheLeafIONodeOfTheVajram(graph.getVajramDefinition(prerequisiteVajramId), graph, ioNodeDepths);
                }
            }
            if (disabledDependentChains.contains((Object)(dependentChain = incomingDepChain.extend(vajramId, (Dependency)dependency)))) continue;
            if (childNode.def() instanceof IOVajramDef) {
                depth = ioNodeDepths.computeIfAbsent(childNode.vajramId(), _v -> 0);
                ioNodes.computeIfAbsent(childNode.vajramId(), k -> new HashMap()).computeIfAbsent(depth, k -> new LinkedHashSet()).add(dependentChain);
            }
            DepChainBatcherConfig.dfs(childNode, graph, ioNodes, depth, dependentChain, ioNodeDepths, disabledDependentChains);
            if (inputDefGraph.get(inputDef) == null) continue;
            for (Facet inputDef1 : (List)inputDefGraph.get(inputDef)) {
                VajramID prerequisiteVajramId = DepChainBatcherConfig.dependencyInputInChildNode(resolverDefinition, inputDef1);
                if (prerequisiteVajramId == null) continue;
                graph.getVajramDefinition(prerequisiteVajramId);
                DepChainBatcherConfig.decrementTheLeafIONodeOfTheVajram(graph.getVajramDefinition(prerequisiteVajramId), graph, ioNodeDepths);
            }
        }
    }

    private static void incrementTheLeafIONodeOfTheVajram(VajramDefinition node, VajramKryonGraph graph, Map<VajramID, Integer> ioNodeDepth) {
        if (node.def() instanceof IOVajramDef) {
            ioNodeDepth.compute(node.vajramId(), (_vid, depth) -> depth == null ? 0 : depth + 1);
        }
        for (Facet inputDef : node.facetSpecs()) {
            if (!(inputDef instanceof DependencySpec)) continue;
            DependencySpec dependency = (DependencySpec)inputDef;
            DepChainBatcherConfig.incrementTheLeafIONodeOfTheVajram(graph.getVajramDefinition(dependency.onVajramId()), graph, ioNodeDepth);
        }
    }

    private static void decrementTheLeafIONodeOfTheVajram(VajramDefinition node, VajramKryonGraph graph, Map<VajramID, Integer> ioNodeDepth) {
        if (node.def() instanceof IOVajramDef) {
            ioNodeDepth.compute(node.vajramId(), (_vid, depth) -> depth == null ? 0 : depth - 1);
        }
        for (Facet inputDef : node.facetSpecs()) {
            if (!(inputDef instanceof DependencySpec)) continue;
            DependencySpec dependency = (DependencySpec)inputDef;
            DepChainBatcherConfig.decrementTheLeafIONodeOfTheVajram(graph.getVajramDefinition(dependency.onVajramId()), graph, ioNodeDepth);
        }
    }

    private static @Nullable VajramID dependencyInputInChildNode(List<ResolverDefinition> depInputs, Facet inputDefinition) {
        for (ResolverDefinition depInput : depInputs) {
            if (!(inputDefinition instanceof DependencySpec)) continue;
            DependencySpec dependency = (DependencySpec)inputDefinition;
            if (!depInput.sources().contains((Object)inputDefinition)) continue;
            return dependency.onVajramId();
        }
        return null;
    }

    private static List<ResolverDefinition> getInputResolverDefinition(VajramDefinition rootNode, DependencySpec<?, ?, ?> dependency) {
        return rootNode.inputResolvers().values().stream().map(InputResolver::definition).filter(definition -> definition.target().dependency().id() == dependency.id()).collect(ArrayList::new, ArrayList::add, ArrayList::addAll);
    }

    private static Collection<Facet> getOrderedInputDef(VajramDefinition rootNode, Map<Facet, List<Facet>> graph) {
        ImmutableCollection resolvers = rootNode.inputResolvers().values();
        ImmutableSet inputDefinitions = rootNode.facetSpecs();
        for (InputResolver resolver : resolvers) {
            ResolverDefinition resolverDefinition = resolver.definition();
            for (Facet facet : inputDefinitions) {
                Facet dependingVID;
                if (!FacetType.DEPENDENCY.equals((Object)facet.facetType()) || !resolverDefinition.sources().contains((Object)facet) || (dependingVID = DepChainBatcherConfig.getInputDefinitionDep((Facet)resolverDefinition.target().dependency(), (ImmutableCollection<? extends Facet>)inputDefinitions)) == null) continue;
                graph.putIfAbsent(dependingVID, new ArrayList());
                graph.get(dependingVID).add(facet);
            }
        }
        HashSet<Facet> visited = new HashSet<Facet>();
        ArrayDeque<Facet> queue = new ArrayDeque<Facet>();
        for (Facet vid : inputDefinitions) {
            if (!FacetType.DEPENDENCY.equals((Object)vid.facetType()) || visited.contains(vid)) continue;
            DepChainBatcherConfig.topologicalSortUtil(vid, visited, graph, queue);
        }
        return queue;
    }

    private static @Nullable Facet getInputDefinitionDep(Facet dep, ImmutableCollection<? extends Facet> inputDefinitions) {
        for (Facet facet : inputDefinitions) {
            if (!FacetType.DEPENDENCY.equals((Object)facet.facetType()) || facet.id() != dep.id()) continue;
            return facet;
        }
        return null;
    }

    static void topologicalSortUtil(Facet vid, Set<Facet> visited, Map<Facet, List<Facet>> graph, Queue<Facet> stack) {
        visited.add(vid);
        for (Facet i : (List)graph.getOrDefault(vid, new ArrayList())) {
            if (visited.contains(i)) continue;
            DepChainBatcherConfig.topologicalSortUtil(i, visited, graph, stack);
        }
        if (vid.facetType().equals((Object)FacetType.DEPENDENCY)) {
            stack.add(vid);
        }
    }

    @FunctionalInterface
    public static interface BatchSizeSupplier {
        public int getBatchSize(VajramID var1);
    }
}

