/*
 * Decompiled with CFR 0.152.
 */
package com.flipkart.gjex.grpc.interceptor;

import com.flipkart.gjex.core.context.GJEXContext;
import com.flipkart.gjex.core.filter.RequestParams;
import com.flipkart.gjex.core.filter.grpc.GrpcFilter;
import com.flipkart.gjex.core.filter.grpc.GrpcFilterConfig;
import com.flipkart.gjex.core.filter.grpc.MethodFilters;
import com.flipkart.gjex.core.logging.Logging;
import com.flipkart.gjex.core.util.NetworkUtils;
import com.flipkart.gjex.core.util.Pair;
import com.flipkart.gjex.grpc.utils.AnnotationUtils;
import io.grpc.BindableService;
import io.grpc.Context;
import io.grpc.ForwardingServerCall;
import io.grpc.ForwardingServerCallListener;
import io.grpc.Grpc;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.Status;
import io.grpc.StatusException;
import io.grpc.StatusRuntimeException;
import java.lang.reflect.Method;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import javax.inject.Named;
import javax.inject.Singleton;
import org.apache.commons.collections4.CollectionUtils;

@Singleton
@Named(value="FilterInterceptor")
public class FilterInterceptor
implements ServerInterceptor,
Logging {
    private final Map<String, List<GrpcFilter>> filtersMap = new HashMap<String, List<GrpcFilter>>();

    public void registerFilters(List<GrpcFilter> grpcFilters, List<BindableService> services, GrpcFilterConfig grpcFilterConfig) {
        Map<Class, GrpcFilter> classToInstanceMap = grpcFilters.stream().collect(Collectors.toMap(Object::getClass, filter -> filter, (existing, replacement) -> existing));
        services.forEach(service -> {
            List<Pair<?, Method>> annotatedMethods = AnnotationUtils.getAnnotatedMethods(service.getClass(), MethodFilters.class);
            if (annotatedMethods != null) {
                annotatedMethods.forEach(pair -> {
                    ArrayList filtersForMethod = new ArrayList();
                    try {
                        filtersForMethod.addAll(this.addAllStaticFilters(grpcFilterConfig, classToInstanceMap));
                    }
                    catch (ClassNotFoundException e) {
                        throw new RuntimeException("Failed to load filter class: " + e.getMessage(), e);
                    }
                    Arrays.asList(((Method)pair.getValue()).getAnnotation(MethodFilters.class).value()).forEach(filterClass -> {
                        if (!classToInstanceMap.containsKey(filterClass)) {
                            throw new RuntimeException("Filter instance not bound for Filter class :" + filterClass.getName());
                        }
                        GrpcFilter grpcFilter = ((GrpcFilter)classToInstanceMap.get(filterClass)).configure(grpcFilterConfig);
                        if (grpcFilter != null) {
                            filtersForMethod.add(grpcFilter);
                        }
                    });
                    String methodSignature = (service.bindService().getServiceDescriptor().getName() + "/" + ((Method)pair.getValue()).getName()).toLowerCase();
                    this.filtersMap.put(methodSignature, filtersForMethod);
                });
            }
        });
    }

    public <Req, Res> ServerCall.Listener<Req> interceptCall(final ServerCall<Req, Res> call, Metadata headers, ServerCallHandler<Req, Res> next) {
        List<GrpcFilter> grpcFilterReferences = this.filtersMap.get(call.getMethodDescriptor().getFullMethodName().toLowerCase());
        Metadata forwardHeaders = new Metadata();
        if (grpcFilterReferences == null || grpcFilterReferences.isEmpty()) {
            return new ForwardingServerCallListener.SimpleForwardingServerCallListener<Req>(next.startCall((ServerCall)new ForwardingServerCall.SimpleForwardingServerCall<Req, Res>(call){}, headers)){};
        }
        final List grpcFilters = grpcFilterReferences.stream().map(GrpcFilter::getInstance).collect(Collectors.toList());
        for (GrpcFilter filter : grpcFilters) {
            for (Metadata.Key key : filter.getForwardHeaderKeys()) {
                Object value = headers.get(key);
                if (value == null) continue;
                forwardHeaders.put(key, value);
            }
        }
        final Context contextWithHeaders = forwardHeaders.keys().isEmpty() ? null : Context.current().withValue(GJEXContext.getHeadersKey(), (Object)forwardHeaders);
        ServerCall.Listener listener = next.startCall((ServerCall)new ForwardingServerCall.SimpleForwardingServerCall<Req, Res>(call){

            public void sendMessage(Res response) {
                Context previous = FilterInterceptor.this.attachContext(contextWithHeaders);
                try {
                    grpcFilters.forEach(filter -> filter.doProcessResponse(response));
                    super.sendMessage(response);
                }
                finally {
                    FilterInterceptor.this.detachContext(contextWithHeaders, previous);
                }
            }

            public void sendHeaders(Metadata responseHeaders) {
                Context previous = FilterInterceptor.this.attachContext(contextWithHeaders);
                try {
                    grpcFilters.forEach(filter -> filter.doProcessResponseHeaders((Object)responseHeaders));
                    super.sendHeaders(responseHeaders);
                }
                finally {
                    FilterInterceptor.this.detachContext(contextWithHeaders, previous);
                }
            }
        }, headers);
        final RequestParams requestParams = RequestParams.builder().clientIp(FilterInterceptor.getClientIp((SocketAddress)call.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR))).resourcePath(call.getMethodDescriptor().getFullMethodName().toLowerCase()).method(call.getMethodDescriptor().getType().name()).metadata((Object)headers).build();
        return new ForwardingServerCallListener.SimpleForwardingServerCallListener<Req>(listener){

            public void onHalfClose() {
                Context previous = FilterInterceptor.this.attachContext(contextWithHeaders);
                try {
                    super.onHalfClose();
                }
                catch (RuntimeException ex) {
                    FilterInterceptor.this.handleException(call, ex);
                    grpcFilters.forEach(filter -> filter.doHandleException((Exception)ex));
                }
                finally {
                    FilterInterceptor.this.detachContext(contextWithHeaders, previous);
                }
            }

            /*
             * WARNING - Removed try catching itself - possible behaviour change.
             */
            public void onMessage(Req request) {
                Context previous = FilterInterceptor.this.attachContext(contextWithHeaders);
                try {
                    for (GrpcFilter filter2 : grpcFilters) {
                        filter2.doProcessRequest(request, requestParams);
                    }
                    super.onMessage(request);
                }
                catch (StatusException ex) {
                    FilterInterceptor.this.handleException(call, (Exception)((Object)ex));
                    grpcFilters.forEach(filter -> filter.doHandleException((Exception)((Object)ex)));
                }
                finally {
                    FilterInterceptor.this.detachContext(contextWithHeaders, previous);
                }
            }

            public void onCancel() {
                Context previous = FilterInterceptor.this.attachContext(contextWithHeaders);
                try {
                    super.onCancel();
                }
                catch (RuntimeException ex) {
                    FilterInterceptor.this.handleException(call, ex);
                    grpcFilters.forEach(filter -> filter.doHandleException((Exception)ex));
                }
                finally {
                    FilterInterceptor.this.detachContext(contextWithHeaders, previous);
                }
            }
        };
    }

    private <Req, Res> void handleException(ServerCall<Req, Res> call, Exception e) {
        this.error("Closing gRPC call due to RuntimeException.", e);
        Status returnStatus = Status.INTERNAL;
        Metadata metadata = new Metadata();
        if (e instanceof StatusRuntimeException) {
            StatusRuntimeException statusRuntimeException = (StatusRuntimeException)e;
            returnStatus = statusRuntimeException.getStatus();
            if (statusRuntimeException.getTrailers() != null) {
                metadata = statusRuntimeException.getTrailers();
            }
        } else if (e instanceof StatusException) {
            StatusException statusException = (StatusException)((Object)e);
            returnStatus = statusException.getStatus();
            if (statusException.getTrailers() != null) {
                metadata = statusException.getTrailers();
            }
        }
        try {
            call.close(returnStatus.withDescription(e.getMessage()), metadata);
        }
        catch (IllegalStateException ie) {
            this.warn("Exception while attempting to close ServerCall stream: " + ie.getMessage());
        }
    }

    private Context attachContext(Context context) {
        return context == null ? null : context.attach();
    }

    private void detachContext(Context currentContext, Context previousContext) {
        if (currentContext != null) {
            currentContext.detach(previousContext);
        }
    }

    private List<GrpcFilter<?, ?>> addAllStaticFilters(GrpcFilterConfig grpcFilterConfig, Map<Class<?>, GrpcFilter<?, ?>> classToInstanceMap) throws ClassNotFoundException {
        List filterClasses = grpcFilterConfig.getGlobalFilterClasses();
        ArrayList filtersForMethod = new ArrayList();
        if (CollectionUtils.isEmpty((Collection)filterClasses)) {
            return filtersForMethod;
        }
        for (String filterClass : filterClasses) {
            GrpcFilter filter;
            Class<?> clazz = Class.forName(filterClass);
            if (!classToInstanceMap.containsKey(clazz) || (filter = classToInstanceMap.get(clazz).configure(grpcFilterConfig)) == null || !filtersForMethod.stream().noneMatch(existing -> existing.getClass().equals(filter.getClass()))) continue;
            filtersForMethod.add(filter);
        }
        return filtersForMethod;
    }

    protected static String getClientIp(SocketAddress socketAddress) {
        if (socketAddress != null) {
            if (socketAddress instanceof InetSocketAddress) {
                return ((InetSocketAddress)socketAddress).getHostName();
            }
            String socketAddressString = socketAddress.toString();
            return NetworkUtils.extractIPAddress((String)socketAddressString);
        }
        return "0.0.0.0";
    }
}

