package com.flipkart.krystal.vajramexecutor.krystex.inputinjection;

import static com.flipkart.krystal.core.VajramID.vajramID;
import static com.flipkart.krystal.facets.FacetType.INJECTION;

import com.flipkart.krystal.core.VajramID;
import com.flipkart.krystal.data.Errable;
import com.flipkart.krystal.data.ExecutionItem;
import com.flipkart.krystal.data.FacetValues;
import com.flipkart.krystal.data.FacetValuesBuilder;
import com.flipkart.krystal.data.Failure;
import com.flipkart.krystal.except.StackTracelessException;
import com.flipkart.krystal.krystex.commands.DirectForwardReceive;
import com.flipkart.krystal.krystex.commands.ForwardReceiveBatch;
import com.flipkart.krystal.krystex.commands.KryonCommand;
import com.flipkart.krystal.krystex.kryon.Kryon;
import com.flipkart.krystal.krystex.kryon.KryonCommandResponse;
import com.flipkart.krystal.krystex.kryon.VajramKryonDefinition;
import com.flipkart.krystal.krystex.request.InvocationId;
import com.flipkart.krystal.vajram.VajramDef;
import com.flipkart.krystal.vajram.exec.VajramDefinition;
import com.flipkart.krystal.vajram.facets.specs.DefaultFacetSpec;
import com.flipkart.krystal.vajram.facets.specs.FacetSpec;
import com.flipkart.krystal.vajram.inputinjection.VajramInjectionProvider;
import com.flipkart.krystal.vajramexecutor.krystex.VajramKryonGraph;
import com.google.common.collect.ImmutableMap;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import lombok.extern.slf4j.Slf4j;
import org.checkerframework.checker.nullness.qual.Nullable;

@Slf4j
class InjectingDecoratedKryon implements Kryon<KryonCommand, KryonCommandResponse> {

  private final Kryon<KryonCommand, KryonCommandResponse> kryon;
  private final VajramKryonGraph vajramKryonGraph;
  private final @Nullable VajramInjectionProvider injectionProvider;

  InjectingDecoratedKryon(
      Kryon<KryonCommand, KryonCommandResponse> kryon,
      VajramKryonGraph vajramKryonGraph,
      @Nullable VajramInjectionProvider injectionProvider) {
    this.kryon = kryon;
    this.vajramKryonGraph = vajramKryonGraph;
    this.injectionProvider = injectionProvider;
  }

  @Override
  public VajramKryonDefinition getKryonDefinition() {
    return kryon.getKryonDefinition();
  }

  @Override
  public CompletableFuture<KryonCommandResponse> executeCommand(KryonCommand kryonCommand) {
    VajramDefinition vajramDefinition =
        vajramKryonGraph.getVajramDefinition(vajramID(kryonCommand.vajramID().id()));
    if (vajramDefinition.metadata().isInputInjectionNeeded()
        && vajramDefinition.def() instanceof VajramDef<?> vajramDef) {
      if (kryonCommand instanceof ForwardReceiveBatch forwardBatch) {
        return injectFacets(forwardBatch, vajramDefinition);
      } else if (kryonCommand instanceof DirectForwardReceive forwardReceive) {
        return injectFacets(forwardReceive, vajramDefinition);
      }
    }
    return kryon.executeCommand(kryonCommand);
  }

  private CompletableFuture<KryonCommandResponse> injectFacets(
      DirectForwardReceive forwardReceive, VajramDefinition vajramDefinition) {

    Set<FacetSpec<?, ?>> injectableFacets = new LinkedHashSet<>();
    vajramDefinition
        .facetSpecs()
        .forEach(
            facetSpec -> {
              if (INJECTION.equals(facetSpec.facetType())) {
                injectableFacets.add(facetSpec);
              }
            });

    for (ExecutionItem executionItem : forwardReceive.executionItems()) {
      FacetValuesBuilder facetsBuilder = executionItem.facetValues();
      injectFacetsOfVajram(vajramDefinition, injectableFacets, facetsBuilder);
    }
    return kryon.executeCommand(forwardReceive);
  }

  private CompletableFuture<KryonCommandResponse> injectFacets(
      ForwardReceiveBatch forwardBatch, VajramDefinition vajramDefinition) {
    Map<InvocationId, ? extends FacetValues> requestIdToFacets =
        forwardBatch.executableInvocations();

    ImmutableMap.Builder<InvocationId, FacetValues> newRequests = ImmutableMap.builder();
    Set<FacetSpec<?, ?>> injectableFacets = new LinkedHashSet<>();
    vajramDefinition
        .facetSpecs()
        .forEach(
            facetSpec -> {
              if (INJECTION.equals(facetSpec.facetType())) {
                injectableFacets.add(facetSpec);
              }
            });

    for (Entry<InvocationId, ? extends FacetValues> entry : requestIdToFacets.entrySet()) {
      InvocationId invocationId = entry.getKey();
      FacetValuesBuilder facetsBuilder;
      facetsBuilder = entry.getValue()._asBuilder();
      newRequests.put(
          invocationId, injectFacetsOfVajram(vajramDefinition, injectableFacets, facetsBuilder));
    }
    return kryon.executeCommand(
        new ForwardReceiveBatch(
            forwardBatch.vajramID(),
            newRequests.build(),
            forwardBatch.dependentChain(),
            forwardBatch.invocationsToSkip()));
  }

  private FacetValuesBuilder injectFacetsOfVajram(
      VajramDefinition vajramDefinition,
      Set<FacetSpec<?, ?>> injectableFacets,
      FacetValuesBuilder facetsBuilder) {
    for (FacetSpec facetSpec : injectableFacets) {
      if (!(facetSpec instanceof DefaultFacetSpec defaultFacetSpec)) {
        continue;
      }
      Errable<?> facetValue = defaultFacetSpec.getFacetValue(facetsBuilder);
      if (facetValue.valueOpt().isPresent()) {
        continue;
      }
      // Input was not resolved by calling vajram.
      Errable<Object> injectedValue = getInjectedValue(vajramDefinition.vajramId(), facetSpec);
      if (injectedValue instanceof Failure<Object> f) {
        defaultFacetSpec.setFacetValue(facetsBuilder, f);
        log.error(
            "Could not inject input {} of vajram {}",
            facetSpec,
            kryon.getKryonDefinition().vajramID().id(),
            f.error());
      }
      defaultFacetSpec.setFacetValue(facetsBuilder, injectedValue);
    }
    return facetsBuilder;
  }

  @SuppressWarnings("unchecked")
  private Errable<Object> getInjectedValue(VajramID vajramId, FacetSpec facetDef) {
    VajramInjectionProvider inputInjector = this.injectionProvider;
    if (inputInjector == null) {
      var exception = new StackTracelessException("Dependency injector is null");
      log.error(
          "Cannot inject input {} of vajram {}",
          facetDef,
          kryon.getKryonDefinition().vajramID().id(),
          exception);
      return Errable.withError(exception);
    }
    try {
      return (Errable<Object>) inputInjector.get(vajramId, facetDef);
    } catch (Throwable e) {
      String message =
          "Could not inject input %s of vajram %s"
              .formatted(facetDef, kryon.getKryonDefinition().vajramID().id());
      log.error(message, e);
      return Errable.withError(e);
    }
  }
}
