/*
 * Decompiled with CFR 0.152.
 */
package io.netty.handler.ssl;

import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.Unpooled;
import io.netty.buffer.UnpooledByteBufAllocator;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.DefaultEventLoopGroup;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel;
import io.netty.channel.local.LocalServerChannel;
import io.netty.handler.ssl.OpenSsl;
import io.netty.handler.ssl.OpenSslContext;
import io.netty.handler.ssl.OpenSslContextOption;
import io.netty.handler.ssl.OpenSslPrivateKeyMethod;
import io.netty.handler.ssl.OpenSslTestUtils;
import io.netty.handler.ssl.OpenSslX509KeyManagerFactory;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.handler.ssl.SslContextOption;
import io.netty.handler.ssl.SslHandler;
import io.netty.handler.ssl.SslProvider;
import io.netty.handler.ssl.util.InsecureTrustManagerFactory;
import io.netty.handler.ssl.util.SelfSignedCertificate;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.concurrent.Promise;
import java.net.SocketAddress;
import java.security.NoSuchAlgorithmException;
import java.security.Signature;
import java.security.SignatureException;
import java.security.cert.X509Certificate;
import java.security.spec.MGF1ParameterSpec;
import java.security.spec.PSSParameterSpec;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLHandshakeException;
import org.hamcrest.Matcher;
import org.hamcrest.MatcherAssert;
import org.hamcrest.Matchers;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.Assume;
import org.junit.BeforeClass;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;

@RunWith(value=Parameterized.class)
public class OpenSslPrivateKeyMethodTest {
    private static final String RFC_CIPHER_NAME = "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256";
    private static EventLoopGroup GROUP;
    private static SelfSignedCertificate CERT;
    private static ExecutorService EXECUTOR;
    private final boolean delegate;

    @Parameterized.Parameters(name="{index}: delegate = {0}")
    public static Collection<Object[]> parameters() {
        ArrayList<Object[]> dst = new ArrayList<Object[]>();
        dst.add(new Object[]{true});
        dst.add(new Object[]{false});
        return dst;
    }

    @BeforeClass
    public static void init() throws Exception {
        OpenSslTestUtils.checkShouldUseKeyManagerFactory();
        Assume.assumeTrue((boolean)OpenSsl.isBoringSSL());
        OpenSslPrivateKeyMethodTest.assumeCipherAvailable(SslProvider.OPENSSL);
        OpenSslPrivateKeyMethodTest.assumeCipherAvailable(SslProvider.JDK);
        GROUP = new DefaultEventLoopGroup();
        CERT = new SelfSignedCertificate();
        EXECUTOR = Executors.newCachedThreadPool(new ThreadFactory(){

            @Override
            public Thread newThread(Runnable r) {
                return new DelegateThread(r);
            }
        });
    }

    @AfterClass
    public static void destroy() {
        if (OpenSsl.isBoringSSL()) {
            GROUP.shutdownGracefully();
            CERT.delete();
            EXECUTOR.shutdown();
        }
    }

    public OpenSslPrivateKeyMethodTest(boolean delegate) {
        this.delegate = delegate;
    }

    private static void assumeCipherAvailable(SslProvider provider) throws NoSuchAlgorithmException {
        boolean cipherSupported = false;
        if (provider == SslProvider.JDK) {
            SSLEngine engine = SSLContext.getDefault().createSSLEngine();
            for (String c : engine.getSupportedCipherSuites()) {
                if (!RFC_CIPHER_NAME.equals(c)) continue;
                cipherSupported = true;
                break;
            }
        } else {
            cipherSupported = OpenSsl.isCipherSuiteAvailable((String)RFC_CIPHER_NAME);
        }
        Assume.assumeTrue((String)"Unsupported cipher: TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", (boolean)cipherSupported);
    }

    private static SslHandler newSslHandler(SslContext sslCtx, ByteBufAllocator allocator, Executor executor) {
        if (executor == null) {
            return sslCtx.newHandler(allocator);
        }
        return sslCtx.newHandler(allocator, executor);
    }

    private SslContext buildServerContext(OpenSslPrivateKeyMethod method) throws Exception {
        List<String> ciphers = Collections.singletonList(RFC_CIPHER_NAME);
        OpenSslX509KeyManagerFactory kmf = OpenSslX509KeyManagerFactory.newKeyless((X509Certificate[])new X509Certificate[]{CERT.cert()});
        return SslContextBuilder.forServer((KeyManagerFactory)kmf).sslProvider(SslProvider.OPENSSL).ciphers(ciphers).protocols(new String[]{"TLSv1.2"}).option((SslContextOption)OpenSslContextOption.PRIVATE_KEY_METHOD, (Object)method).build();
    }

    private SslContext buildClientContext() throws Exception {
        return SslContextBuilder.forClient().sslProvider(SslProvider.JDK).ciphers(Collections.singletonList(RFC_CIPHER_NAME)).protocols(new String[]{"TLSv1.2"}).trustManager(InsecureTrustManagerFactory.INSTANCE).build();
    }

    private Executor delegateExecutor() {
        return this.delegate ? EXECUTOR : null;
    }

    private void assertThread() {
        if (this.delegate && OpenSslContext.USE_TASKS) {
            Assert.assertEquals(DelegateThread.class, Thread.currentThread().getClass());
        } else {
            Assert.assertNotEquals(DelegateThread.class, Thread.currentThread().getClass());
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testPrivateKeyMethod() throws Exception {
        final AtomicBoolean signCalled = new AtomicBoolean();
        final SslContext sslServerContext = this.buildServerContext(new OpenSslPrivateKeyMethod(){

            public byte[] sign(SSLEngine engine, int signatureAlgorithm, byte[] input) throws Exception {
                Signature signature;
                signCalled.set(true);
                OpenSslPrivateKeyMethodTest.this.assertThread();
                Assert.assertEquals((Object)CERT.cert().getPublicKey(), (Object)engine.getSession().getLocalCertificates()[0].getPublicKey());
                if (signatureAlgorithm == OpenSslPrivateKeyMethod.SSL_SIGN_RSA_PKCS1_SHA256) {
                    signature = Signature.getInstance("SHA256withRSA");
                } else if (signatureAlgorithm == OpenSslPrivateKeyMethod.SSL_SIGN_RSA_PSS_RSAE_SHA256) {
                    signature = Signature.getInstance("RSASSA-PSS");
                    signature.setParameter(new PSSParameterSpec("SHA-256", "MGF1", MGF1ParameterSpec.SHA256, 32, 1));
                } else {
                    throw new AssertionError((Object)("Unexpected signature algorithm " + signatureAlgorithm));
                }
                signature.initSign(CERT.key());
                signature.update(input);
                return signature.sign();
            }

            public byte[] decrypt(SSLEngine engine, byte[] input) {
                throw new UnsupportedOperationException();
            }
        });
        final SslContext sslClientContext = this.buildClientContext();
        try {
            try {
                final Promise serverPromise = GROUP.next().newPromise();
                final Promise clientPromise = GROUP.next().newPromise();
                ChannelInitializer<Channel> serverHandler = new ChannelInitializer<Channel>(){

                    protected void initChannel(Channel ch) {
                        ChannelPipeline pipeline = ch.pipeline();
                        pipeline.addLast(new ChannelHandler[]{OpenSslPrivateKeyMethodTest.newSslHandler(sslServerContext, ch.alloc(), OpenSslPrivateKeyMethodTest.this.delegateExecutor())});
                        pipeline.addLast(new ChannelHandler[]{new SimpleChannelInboundHandler<Object>(){

                            public void channelInactive(ChannelHandlerContext ctx) {
                                serverPromise.cancel(true);
                                ctx.fireChannelInactive();
                            }

                            public void channelRead0(ChannelHandlerContext ctx, Object msg) {
                                if (serverPromise.trySuccess(null)) {
                                    ctx.writeAndFlush((Object)Unpooled.wrappedBuffer((byte[])new byte[]{80, 79, 78, 71}));
                                }
                                ctx.close();
                            }

                            public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
                                if (!serverPromise.tryFailure(cause)) {
                                    ctx.fireExceptionCaught(cause);
                                }
                            }
                        }});
                    }
                };
                LocalAddress address = new LocalAddress("test-" + SslProvider.OPENSSL + '-' + SslProvider.JDK + '-' + RFC_CIPHER_NAME + '-' + this.delegate);
                Channel server = OpenSslPrivateKeyMethodTest.server(address, (ChannelHandler)serverHandler);
                try {
                    ChannelInitializer<Channel> clientHandler = new ChannelInitializer<Channel>(){

                        protected void initChannel(Channel ch) {
                            ChannelPipeline pipeline = ch.pipeline();
                            pipeline.addLast(new ChannelHandler[]{OpenSslPrivateKeyMethodTest.newSslHandler(sslClientContext, ch.alloc(), OpenSslPrivateKeyMethodTest.this.delegateExecutor())});
                            pipeline.addLast(new ChannelHandler[]{new SimpleChannelInboundHandler<Object>(){

                                public void channelInactive(ChannelHandlerContext ctx) {
                                    clientPromise.cancel(true);
                                    ctx.fireChannelInactive();
                                }

                                public void channelRead0(ChannelHandlerContext ctx, Object msg) {
                                    clientPromise.trySuccess(null);
                                    ctx.close();
                                }

                                public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
                                    if (!clientPromise.tryFailure(cause)) {
                                        ctx.fireExceptionCaught(cause);
                                    }
                                }
                            }});
                        }
                    };
                    Channel client = OpenSslPrivateKeyMethodTest.client(server, (ChannelHandler)clientHandler);
                    try {
                        client.writeAndFlush((Object)Unpooled.wrappedBuffer((byte[])new byte[]{80, 73, 78, 71})).syncUninterruptibly();
                        Assert.assertTrue((String)"client timeout", (boolean)clientPromise.await(5L, TimeUnit.SECONDS));
                        Assert.assertTrue((String)"server timeout", (boolean)serverPromise.await(5L, TimeUnit.SECONDS));
                        clientPromise.sync();
                        serverPromise.sync();
                        Assert.assertTrue((boolean)signCalled.get());
                    }
                    finally {
                        client.close().sync();
                    }
                }
                finally {
                    server.close().sync();
                }
            }
            finally {
                ReferenceCountUtil.release((Object)sslClientContext);
            }
        }
        finally {
            ReferenceCountUtil.release((Object)sslServerContext);
        }
    }

    @Test
    public void testPrivateKeyMethodFailsBecauseOfException() throws Exception {
        this.testPrivateKeyMethodFails(false);
    }

    @Test
    public void testPrivateKeyMethodFailsBecauseOfNull() throws Exception {
        this.testPrivateKeyMethodFails(true);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void testPrivateKeyMethodFails(final boolean returnNull) throws Exception {
        SslContext sslServerContext = this.buildServerContext(new OpenSslPrivateKeyMethod(){

            public byte[] sign(SSLEngine engine, int signatureAlgorithm, byte[] input) throws Exception {
                OpenSslPrivateKeyMethodTest.this.assertThread();
                if (returnNull) {
                    return null;
                }
                throw new SignatureException();
            }

            public byte[] decrypt(SSLEngine engine, byte[] input) {
                throw new UnsupportedOperationException();
            }
        });
        SslContext sslClientContext = this.buildClientContext();
        SslHandler serverSslHandler = OpenSslPrivateKeyMethodTest.newSslHandler(sslServerContext, (ByteBufAllocator)UnpooledByteBufAllocator.DEFAULT, this.delegateExecutor());
        SslHandler clientSslHandler = OpenSslPrivateKeyMethodTest.newSslHandler(sslClientContext, (ByteBufAllocator)UnpooledByteBufAllocator.DEFAULT, this.delegateExecutor());
        try {
            try {
                LocalAddress address = new LocalAddress("test-" + SslProvider.OPENSSL + '-' + SslProvider.JDK + '-' + RFC_CIPHER_NAME + '-' + this.delegate);
                Channel server = OpenSslPrivateKeyMethodTest.server(address, (ChannelHandler)serverSslHandler);
                try {
                    Channel client = OpenSslPrivateKeyMethodTest.client(server, (ChannelHandler)clientSslHandler);
                    try {
                        Throwable clientCause = clientSslHandler.handshakeFuture().await().cause();
                        Throwable serverCause = serverSslHandler.handshakeFuture().await().cause();
                        Assert.assertNotNull((Object)clientCause);
                        MatcherAssert.assertThat((Object)serverCause, (Matcher)Matchers.instanceOf(SSLHandshakeException.class));
                    }
                    finally {
                        client.close().sync();
                    }
                }
                finally {
                    server.close().sync();
                }
            }
            finally {
                ReferenceCountUtil.release((Object)sslClientContext);
            }
        }
        finally {
            ReferenceCountUtil.release((Object)sslServerContext);
        }
    }

    private static Channel server(LocalAddress address, ChannelHandler handler) throws Exception {
        ServerBootstrap bootstrap = ((ServerBootstrap)new ServerBootstrap().channel(LocalServerChannel.class)).group(GROUP).childHandler(handler);
        return bootstrap.bind((SocketAddress)address).sync().channel();
    }

    private static Channel client(Channel server, ChannelHandler handler) throws Exception {
        SocketAddress remoteAddress = server.localAddress();
        Bootstrap bootstrap = (Bootstrap)((Bootstrap)((Bootstrap)new Bootstrap().channel(LocalChannel.class)).group(GROUP)).handler(handler);
        return bootstrap.connect(remoteAddress).sync().channel();
    }

    private static final class DelegateThread
    extends Thread {
        DelegateThread(Runnable target) {
            super(target);
        }
    }
}

