package com.flipkart.krystal.vajramexecutor.krystex;

import static com.flipkart.krystal.utils.Futures.linkFutures;
import static com.google.common.collect.ImmutableList.toImmutableList;

import com.flipkart.krystal.config.ConfigProvider;
import com.flipkart.krystal.config.NestedConfig;
import com.flipkart.krystal.data.Facets;
import com.flipkart.krystal.krystex.OutputLogic;
import com.flipkart.krystal.krystex.OutputLogicDefinition;
import com.flipkart.krystal.krystex.kryon.DependantChain;
import com.flipkart.krystal.krystex.logicdecoration.FlushCommand;
import com.flipkart.krystal.krystex.logicdecoration.InitiateActiveDepChains;
import com.flipkart.krystal.krystex.logicdecoration.LogicDecoratorCommand;
import com.flipkart.krystal.krystex.logicdecoration.OutputLogicDecorator;
import com.flipkart.krystal.vajram.facets.FacetValuesAdaptor;
import com.flipkart.krystal.vajram.modulation.FacetsConverter;
import com.flipkart.krystal.vajram.modulation.InputModulator;
import com.flipkart.krystal.vajram.modulation.ModulatedFacets;
import com.flipkart.krystal.vajram.modulation.UnmodulatedFacets;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import java.util.Collection;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.function.Function;
import java.util.function.Predicate;
import org.checkerframework.checker.nullness.qual.Nullable;

public final class InputModulationDecorator<
        I /*InputsNeedingModulation*/ extends FacetValuesAdaptor,
        C /*CommonInputs*/ extends FacetValuesAdaptor>
    implements OutputLogicDecorator {

  public static final String DECORATOR_TYPE = InputModulationDecorator.class.getName();
  private final String instanceId;
  private final InputModulator<I, C> inputModulator;
  private final FacetsConverter<I, C> facetsConverter;
  private final Predicate<DependantChain> isApplicableToDependantChain;
  private final Map<Facets, CompletableFuture<@Nullable Object>> futureCache = new HashMap<>();
  private ImmutableSet<DependantChain> activeDependantChains = ImmutableSet.of();
  private final Set<DependantChain> flushedDependantChains = new LinkedHashSet<>();

  public InputModulationDecorator(
      String instanceId,
      InputModulator<I, C> inputModulator,
      FacetsConverter<I, C> facetsConverter,
      Predicate<DependantChain> isApplicableToDependantChain) {
    this.instanceId = instanceId;
    this.inputModulator = inputModulator;
    this.facetsConverter = facetsConverter;
    this.isApplicableToDependantChain = isApplicableToDependantChain;
  }

  @Override
  public OutputLogic<Object> decorateLogic(
      OutputLogic<Object> logicToDecorate, OutputLogicDefinition<Object> originalLogicDefinition) {
    inputModulator.onModulation(
        requests -> requests.forEach(request -> modulateInputsList(logicToDecorate, request)));
    return inputsList -> {
      List<UnmodulatedFacets<I, C>> requests = inputsList.stream().map(facetsConverter).toList();
      List<ModulatedFacets<I, C>> modulatedFacets =
          requests.stream()
              .map(
                  unmodulatedInput ->
                      inputModulator.add(
                          unmodulatedInput.modulatedInputs(), unmodulatedInput.commonFacets()))
              .flatMap(Collection::stream)
              .toList();
      requests.forEach(
          request ->
              futureCache.computeIfAbsent(
                  request.toFacetValues(), e -> new CompletableFuture<@Nullable Object>()));
      for (ModulatedFacets<I, C> modulatedFacet : modulatedFacets) {
        modulateInputsList(logicToDecorate, modulatedFacet);
      }
      return requests.stream()
          .map(UnmodulatedFacets::toFacetValues)
          .collect(
              ImmutableMap.<Facets, Facets, CompletableFuture<@Nullable Object>>toImmutableMap(
                  Function.identity(),
                  key ->
                      Optional.ofNullable(futureCache.get(key))
                          .orElseThrow(
                              () ->
                                  new AssertionError(
                                      "Future cache has been primed with values. This should never happen"))));
    };
  }

  @Override
  public void executeCommand(LogicDecoratorCommand logicDecoratorCommand) {
    if (logicDecoratorCommand instanceof InitiateActiveDepChains initiateActiveDepChains) {
      LinkedHashSet<DependantChain> allActiveDepChains =
          new LinkedHashSet<>(initiateActiveDepChains.dependantsChains());
      // Retain only the ones which are applicable for this input modulation decorator
      allActiveDepChains.removeIf(isApplicableToDependantChain.negate());
      this.activeDependantChains = ImmutableSet.copyOf(allActiveDepChains);
    } else if (logicDecoratorCommand instanceof FlushCommand flushCommand) {
      flushedDependantChains.add(flushCommand.dependantsChain());
      if (flushedDependantChains.containsAll(activeDependantChains)) {
        inputModulator.modulate();
        flushedDependantChains.clear();
      }
    }
  }

  private void modulateInputsList(
      OutputLogic<Object> logicToDecorate, ModulatedFacets<I, C> modulatedFacets) {
    ImmutableList<UnmodulatedFacets<I, C>> requests =
        modulatedFacets.modInputs().stream()
            .map(each -> new UnmodulatedFacets<>(each, modulatedFacets.commonFacets()))
            .collect(toImmutableList());
    logicToDecorate
        .execute(requests.stream().map(UnmodulatedFacets::toFacetValues).collect(toImmutableList()))
        .forEach(
            (inputs, resultFuture) -> {
              //noinspection RedundantTypeArguments: To Handle nullChecker errors
              linkFutures(
                  resultFuture,
                  futureCache.<CompletableFuture<@Nullable Object>>computeIfAbsent(
                      inputs, request -> new CompletableFuture<@Nullable Object>()));
            });
  }

  @Override
  public void onConfigUpdate(ConfigProvider configProvider) {
    inputModulator.onConfigUpdate(
        new NestedConfig(String.format("input_modulation.%s.", instanceId), configProvider));
  }

  @Override
  public String getId() {
    return instanceId;
  }
}
