/*
 * Decompiled with CFR 0.152.
 */
package com.twosigma.waiter.courier;

import com.twosigma.waiter.courier.CourierGrpc;
import com.twosigma.waiter.courier.CourierReply;
import com.twosigma.waiter.courier.CourierRequest;
import com.twosigma.waiter.courier.CourierSummary;
import com.twosigma.waiter.courier.StateReply;
import com.twosigma.waiter.courier.StateRequest;
import com.twosigma.waiter.courier.Variant;
import io.grpc.BindableService;
import io.grpc.Context;
import io.grpc.Contexts;
import io.grpc.ForwardingServerCall;
import io.grpc.Metadata;
import io.grpc.Server;
import io.grpc.ServerBuilder;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.Status;
import io.grpc.StatusRuntimeException;
import io.grpc.stub.ServerCallStreamObserver;
import io.grpc.stub.StreamObserver;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.logging.Logger;

public class GrpcServer {
    private static final Logger LOGGER = Logger.getLogger(GrpcServer.class.getName());
    private static final Context.Key<CidTimestamp> CID_TIMESTAMP = Context.key((String)"CID.TIMESTAMP");
    private static final Map<String, Map<Long, List<String>>> requestCidToStateList = new HashMap<String, Map<Long, List<String>>>();
    private Server server;

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private static List<String> trackState(String correlationId, long timestamp, String state) {
        if (correlationId != null) {
            Map<String, Map<Long, List<String>>> map = requestCidToStateList;
            synchronized (map) {
                List<Object> stateList;
                Map<Long, List<String>> stateEntriesCurrent = requestCidToStateList.get(correlationId);
                if (stateEntriesCurrent != null) {
                    List<String> stateListCurrent = stateEntriesCurrent.get(timestamp);
                    if (stateListCurrent == null) {
                        stateList = new ArrayList();
                        stateEntriesCurrent.put(timestamp, stateList);
                    } else {
                        stateList = stateListCurrent;
                    }
                } else {
                    stateList = new ArrayList();
                    HashMap<Long, List<Object>> stateEntriesNew = new HashMap<Long, List<Object>>();
                    stateEntriesNew.put(timestamp, stateList);
                    requestCidToStateList.put(correlationId, stateEntriesNew);
                }
                stateList.add(state);
                return stateList;
            }
        }
        return new ArrayList<String>();
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private static List<String> trackState(String correlationId) {
        Map<String, Map<Long, List<String>>> map = requestCidToStateList;
        synchronized (map) {
            Map<Long, List<String>> stateEntries = requestCidToStateList.get(correlationId);
            if (stateEntries != null && !stateEntries.isEmpty()) {
                ArrayList<Long> timestamps = new ArrayList<Long>(stateEntries.keySet());
                Collections.sort(timestamps);
                Long requestTime = (Long)timestamps.get(0);
                return stateEntries.get(requestTime);
            }
            return new ArrayList<String>();
        }
    }

    private static void sleep(long durationMillis) {
        if (durationMillis > 0L) {
            try {
                Thread.sleep(durationMillis);
            }
            catch (Exception ex) {
                ex.printStackTrace();
            }
        }
    }

    private static void registerOnCancelHandler(StreamObserver<?> responseObserver) {
        if (responseObserver instanceof ServerCallStreamObserver) {
            CidTimestamp cidTimestamp = (CidTimestamp)CID_TIMESTAMP.get();
            String correlationId = cidTimestamp.correlationId;
            long timestamp = cidTimestamp.timestamp;
            ((ServerCallStreamObserver)responseObserver).setOnCancelHandler(() -> {
                LOGGER.info("CancelHandler for " + cidTimestamp + " was triggered");
                GrpcServer.trackState(correlationId, timestamp, "CANCEL_HANDLER");
            });
        }
    }

    void start(int port) throws IOException {
        LOGGER.info("starting gRPC server on port " + port);
        this.server = ServerBuilder.forPort((int)port).addService((BindableService)new CourierImpl()).intercept((ServerInterceptor)new GrpcServerInterceptor()).intercept((ServerInterceptor)new CorrelationIdInterceptor()).build().start();
        LOGGER.info("gRPC server started, listening on " + port);
        Runtime.getRuntime().addShutdownHook(new Thread(() -> {
            System.err.println("*** shutting down gRPC server since JVM is shutting down");
            this.stop();
            System.err.println("*** server shut down");
        }));
    }

    private void stop() {
        if (this.server != null) {
            this.server.shutdown();
        }
    }

    void blockUntilShutdown() throws InterruptedException {
        if (this.server != null) {
            this.server.awaitTermination();
        }
    }

    private static class GrpcServerInterceptor
    implements ServerInterceptor {
        private GrpcServerInterceptor() {
        }

        public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(ServerCall<ReqT, RespT> serverCall, final Metadata requestMetadata, ServerCallHandler<ReqT, RespT> serverCallHandler) {
            this.logMetadata(requestMetadata, "request");
            CidTimestamp cidTimestamp = (CidTimestamp)CID_TIMESTAMP.get();
            final String correlationId = cidTimestamp.correlationId;
            final long timestamp = cidTimestamp.timestamp;
            GrpcServer.trackState(correlationId, timestamp, "INIT");
            ForwardingServerCall.SimpleForwardingServerCall wrapperCall = new ForwardingServerCall.SimpleForwardingServerCall<ReqT, RespT>(serverCall){

                public void sendHeaders(Metadata responseHeaders) {
                    LOGGER.info("GrpcServerInterceptor.sendHeaders[cid=" + correlationId + "]");
                    this.logMetadata(requestMetadata, "response");
                    if (correlationId != null) {
                        LOGGER.info("response linked to cid: " + correlationId);
                        responseHeaders.put(Metadata.Key.of((String)"x-cid", (Metadata.AsciiMarshaller)Metadata.ASCII_STRING_MARSHALLER), (Object)correlationId);
                        GrpcServer.trackState(correlationId, timestamp, "SEND_HEADERS");
                    }
                    super.sendHeaders(responseHeaders);
                }

                public void sendMessage(RespT response) {
                    LOGGER.info("GrpcServerInterceptor.sendMessage[cid=" + correlationId + "]");
                    GrpcServer.trackState(correlationId, timestamp, "SEND_MESSAGE");
                    super.sendMessage(response);
                }

                public void close(Status status, Metadata trailers) {
                    LOGGER.info("GrpcServerInterceptor.close[cid=" + correlationId + "] " + status + ", " + trailers);
                    GrpcServer.trackState(correlationId, timestamp, "CLOSE");
                    super.close(status, trailers);
                }
            };
            final ServerCall.Listener listener = serverCallHandler.startCall((ServerCall)wrapperCall, requestMetadata);
            return new ServerCall.Listener<ReqT>(){

                public void onMessage(ReqT message) {
                    LOGGER.info("GrpcServerInterceptor.onMessage[cid=" + correlationId + "]");
                    GrpcServer.trackState(correlationId, timestamp, "RECEIVE_MESSAGE");
                    listener.onMessage(message);
                }

                public void onHalfClose() {
                    LOGGER.info("GrpcServerInterceptor.onHalfClose[cid=" + correlationId + "]");
                    GrpcServer.trackState(correlationId, timestamp, "HALF_CLOSE");
                    listener.onHalfClose();
                }

                public void onCancel() {
                    LOGGER.info("GrpcServerInterceptor.onCancel[cid=" + correlationId + "]");
                    List stateList = GrpcServer.trackState(correlationId, timestamp, "CANCEL");
                    LOGGER.info(correlationId + " states: " + stateList);
                    listener.onCancel();
                }

                public void onComplete() {
                    LOGGER.info("GrpcServerInterceptor.onComplete[cid=" + correlationId + "]");
                    List stateList = GrpcServer.trackState(correlationId, timestamp, "COMPLETE");
                    LOGGER.info(correlationId + " states: " + stateList);
                    listener.onComplete();
                }

                public void onReady() {
                    LOGGER.info("GrpcServerInterceptor.onReady[cid=" + correlationId + "]");
                    GrpcServer.trackState(correlationId, timestamp, "READY");
                    listener.onReady();
                }
            };
        }

        private void logMetadata(Metadata metadata, String label) {
            Set metadataKeys = metadata.keys();
            LOGGER.info(label + "@" + metadata.hashCode() + " metadata keys = " + metadataKeys);
            for (String key : metadataKeys) {
                String value = (String)metadata.get(Metadata.Key.of((String)key, (Metadata.AsciiMarshaller)Metadata.ASCII_STRING_MARSHALLER));
                LOGGER.info(label + " metadata " + key + " = " + value);
            }
        }
    }

    public class CorrelationIdInterceptor
    implements ServerInterceptor {
        public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(ServerCall<ReqT, RespT> serverCall, Metadata requestMetadata, ServerCallHandler<ReqT, RespT> serverCallHandler) {
            Metadata.Key xCidKey = Metadata.Key.of((String)"x-cid", (Metadata.AsciiMarshaller)Metadata.ASCII_STRING_MARSHALLER);
            String correlationId = (String)requestMetadata.get(xCidKey);
            if (correlationId == null) {
                correlationId = "courier-" + System.nanoTime();
            }
            long timestamp = System.nanoTime();
            Context currentContext = Context.current();
            Context newContext = currentContext.withValue(CID_TIMESTAMP, (Object)new CidTimestamp(correlationId, timestamp));
            return Contexts.interceptCall((Context)newContext, serverCall, (Metadata)requestMetadata, serverCallHandler);
        }
    }

    private static class CourierImpl
    extends CourierGrpc.CourierImplBase {
        private CourierImpl() {
        }

        @Override
        public void retrieveState(StateRequest request, StreamObserver<StateReply> responseObserver) {
            String correlationId = request.getCid();
            LOGGER.info("received StateRequest{cid=" + correlationId + "}");
            GrpcServer.registerOnCancelHandler(responseObserver);
            StateReply.Builder builder = StateReply.newBuilder().setCid(correlationId);
            List stateList = GrpcServer.trackState(correlationId);
            LOGGER.info("cid " + correlationId + " has states: " + stateList);
            if (stateList != null) {
                for (String state : stateList) {
                    builder.addState(state);
                }
            }
            StateReply reply = builder.build();
            LOGGER.info("Sending StateReply for cid=" + reply.getCid());
            responseObserver.onNext((Object)reply);
            responseObserver.onCompleted();
        }

        @Override
        public void sendPackage(CourierRequest request, StreamObserver<CourierReply> responseObserver) {
            LOGGER.info("received CourierRequest{id=" + request.getId() + ", from=" + request.getFrom() + ", message.length=" + request.getMessage().length() + "}");
            GrpcServer.registerOnCancelHandler(responseObserver);
            GrpcServer.sleep(request.getSleepDurationMillis());
            if (Variant.SEND_ERROR.equals((Object)request.getVariant())) {
                StatusRuntimeException error = Status.CANCELLED.withCause((Throwable)new RuntimeException(request.getId())).withDescription("Cancelled by server").asRuntimeException();
                LOGGER.info("Sending cancelled by server error");
                responseObserver.onError((Throwable)error);
            } else if (Variant.EXIT_PRE_RESPONSE.equals((Object)request.getVariant())) {
                GrpcServer.sleep(1000L);
                LOGGER.info("Exiting server abruptly");
                System.exit(1);
            } else {
                CourierReply reply = CourierReply.newBuilder().setId(request.getId()).setMessage(request.getMessage()).setResponse("received").build();
                LOGGER.info("Sending CourierReply for id=" + reply.getId());
                responseObserver.onNext((Object)reply);
                responseObserver.onCompleted();
            }
        }

        @Override
        public StreamObserver<CourierRequest> collectPackages(final StreamObserver<CourierSummary> responseObserver) {
            GrpcServer.registerOnCancelHandler(responseObserver);
            return new StreamObserver<CourierRequest>(){
                private long numMessages = 0L;
                private long totalLength = 0L;

                public void onNext(CourierRequest request) {
                    LOGGER.info("Received CourierRequest id=" + request.getId());
                    ++this.numMessages;
                    this.totalLength += (long)request.getMessage().length();
                    LOGGER.info("Summary of collected packages: numMessages=" + this.numMessages + " with totalLength=" + this.totalLength);
                    GrpcServer.sleep(request.getSleepDurationMillis());
                    if (Variant.EXIT_PRE_RESPONSE.equals((Object)request.getVariant())) {
                        GrpcServer.sleep(1000L);
                        LOGGER.info("Exiting server abruptly");
                        System.exit(1);
                    } else if (Variant.SEND_ERROR.equals((Object)request.getVariant())) {
                        StatusRuntimeException error = Status.CANCELLED.withCause((Throwable)new RuntimeException(request.getId())).withDescription("Cancelled by server").asRuntimeException();
                        LOGGER.info("Sending cancelled by server error");
                        responseObserver.onError((Throwable)error);
                    } else {
                        CourierSummary courierSummary = CourierSummary.newBuilder().setNumMessages(this.numMessages).setTotalLength(this.totalLength).build();
                        LOGGER.info("Sending CourierSummary for id=" + request.getId());
                        responseObserver.onNext((Object)courierSummary);
                    }
                    if (Variant.EXIT_POST_RESPONSE.equals((Object)request.getVariant())) {
                        GrpcServer.sleep(1000L);
                        LOGGER.info("Exiting server abruptly");
                        System.exit(1);
                    }
                }

                public void onError(Throwable th) {
                    LOGGER.severe("Error in collecting packages: " + th.getMessage());
                    responseObserver.onError(th);
                }

                public void onCompleted() {
                    LOGGER.severe("Completed collecting packages");
                    responseObserver.onCompleted();
                }
            };
        }

        @Override
        public StreamObserver<CourierRequest> aggregatePackages(final StreamObserver<CourierSummary> responseObserver) {
            GrpcServer.registerOnCancelHandler(responseObserver);
            return new StreamObserver<CourierRequest>(){
                private long numMessages = 0L;
                private long totalLength = 0L;

                public void onNext(CourierRequest request) {
                    LOGGER.info("Received CourierRequest id=" + request.getId());
                    ++this.numMessages;
                    this.totalLength += (long)request.getMessage().length();
                    LOGGER.info("Summary of collected packages: numMessages=" + this.numMessages + " with totalLength=" + this.totalLength);
                    GrpcServer.sleep(request.getSleepDurationMillis());
                    if (Variant.EXIT_PRE_RESPONSE.equals((Object)request.getVariant()) || Variant.EXIT_POST_RESPONSE.equals((Object)request.getVariant())) {
                        GrpcServer.sleep(1000L);
                        LOGGER.info("Exiting server abruptly");
                        System.exit(1);
                    } else if (Variant.SEND_ERROR.equals((Object)request.getVariant())) {
                        StatusRuntimeException error = Status.CANCELLED.withCause((Throwable)new RuntimeException(request.getId())).withDescription("Cancelled by server").asRuntimeException();
                        LOGGER.info("Sending cancelled by server error");
                        responseObserver.onError((Throwable)error);
                    }
                }

                public void onError(Throwable th) {
                    LOGGER.severe("Error in aggregating packages: " + th.getMessage());
                    responseObserver.onError(th);
                }

                public void onCompleted() {
                    LOGGER.severe("Completed aggregating packages");
                    CourierSummary courierSummary = CourierSummary.newBuilder().setNumMessages(this.numMessages).setTotalLength(this.totalLength).build();
                    LOGGER.info("Sending aggregated CourierSummary");
                    responseObserver.onNext((Object)courierSummary);
                    responseObserver.onCompleted();
                }
            };
        }
    }

    private static class CidTimestamp {
        private final String correlationId;
        private final long timestamp;

        private CidTimestamp(String correlationId, long timestamp) {
            this.correlationId = correlationId;
            this.timestamp = timestamp;
        }

        public String toString() {
            return "CidTimestamp{correlationId='" + this.correlationId + '\'' + ", timestamp=" + this.timestamp + '}';
        }
    }
}

