1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.testsuite.transport.socket;
17
18 import io.netty.bootstrap.Bootstrap;
19 import io.netty.bootstrap.ServerBootstrap;
20 import io.netty.buffer.ByteBuf;
21 import io.netty.buffer.ByteBufAllocator;
22 import io.netty.channel.Channel;
23 import io.netty.channel.ChannelHandlerContext;
24 import io.netty.channel.ChannelInitializer;
25 import io.netty.channel.ChannelPipeline;
26 import io.netty.channel.SimpleChannelInboundHandler;
27 import io.netty.handler.logging.LogLevel;
28 import io.netty.handler.logging.LoggingHandler;
29 import io.netty.handler.ssl.OpenSsl;
30 import io.netty.handler.ssl.SslContext;
31 import io.netty.handler.ssl.SslContextBuilder;
32 import io.netty.handler.ssl.SslHandler;
33 import io.netty.handler.ssl.SslHandshakeCompletionEvent;
34 import io.netty.handler.ssl.SslProvider;
35 import io.netty.handler.ssl.util.SelfSignedCertificate;
36 import io.netty.util.internal.PlatformDependent;
37 import io.netty.util.internal.logging.InternalLogger;
38 import io.netty.util.internal.logging.InternalLoggerFactory;
39 import org.junit.jupiter.api.TestInfo;
40 import org.junit.jupiter.api.Timeout;
41 import org.junit.jupiter.params.ParameterizedTest;
42 import org.junit.jupiter.params.provider.MethodSource;
43
44 import javax.net.ssl.SSLPeerUnverifiedException;
45 import javax.net.ssl.SSLSession;
46 import java.io.File;
47 import java.io.IOException;
48 import java.security.cert.CertificateException;
49 import java.util.ArrayList;
50 import java.util.Collection;
51 import java.util.List;
52 import java.util.concurrent.CountDownLatch;
53 import java.util.concurrent.Executor;
54 import java.util.concurrent.ExecutorService;
55 import java.util.concurrent.Executors;
56 import java.util.concurrent.TimeUnit;
57 import java.util.concurrent.atomic.AtomicReference;
58
59 import static org.junit.jupiter.api.Assertions.assertEquals;
60 import static org.junit.jupiter.api.Assertions.assertFalse;
61 import static org.junit.jupiter.api.Assertions.fail;
62
63 public class SocketSslGreetingTest extends AbstractSocketTest {
64
65 private static final InternalLogger logger = InternalLoggerFactory.getInstance(SocketSslGreetingTest.class);
66
67 private static final LogLevel LOG_LEVEL = LogLevel.TRACE;
68 private static final File CERT_FILE;
69 private static final File KEY_FILE;
70
71 static {
72 SelfSignedCertificate ssc;
73 try {
74 ssc = new SelfSignedCertificate();
75 } catch (CertificateException e) {
76 throw new Error(e);
77 }
78 CERT_FILE = ssc.certificate();
79 KEY_FILE = ssc.privateKey();
80 }
81
82 public static Collection<Object[]> data() throws Exception {
83 List<SslContext> serverContexts = new ArrayList<SslContext>();
84 serverContexts.add(SslContextBuilder.forServer(CERT_FILE, KEY_FILE).sslProvider(SslProvider.JDK).build());
85
86 List<SslContext> clientContexts = new ArrayList<SslContext>();
87 clientContexts.add(SslContextBuilder.forClient().sslProvider(SslProvider.JDK).trustManager(CERT_FILE).build());
88
89 boolean hasOpenSsl = OpenSsl.isAvailable();
90 if (hasOpenSsl) {
91 serverContexts.add(SslContextBuilder.forServer(CERT_FILE, KEY_FILE)
92 .sslProvider(SslProvider.OPENSSL).build());
93 clientContexts.add(SslContextBuilder.forClient().sslProvider(SslProvider.OPENSSL)
94 .trustManager(CERT_FILE).build());
95 } else {
96 logger.warn("OpenSSL is unavailable and thus will not be tested.", OpenSsl.unavailabilityCause());
97 }
98
99 List<Object[]> params = new ArrayList<Object[]>();
100 for (SslContext sc: serverContexts) {
101 for (SslContext cc: clientContexts) {
102 params.add(new Object[] { sc, cc, true });
103 params.add(new Object[] { sc, cc, false });
104 }
105 }
106 return params;
107 }
108
109 private static SslHandler newSslHandler(SslContext sslCtx, ByteBufAllocator allocator, Executor executor) {
110 if (executor == null) {
111 return sslCtx.newHandler(allocator);
112 } else {
113 return sslCtx.newHandler(allocator, executor);
114 }
115 }
116
117
118 @ParameterizedTest(name = "{index}: serverEngine = {0}, clientEngine = {1}, delegate = {2}")
119 @MethodSource("data")
120 @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
121 public void testSslGreeting(final SslContext serverCtx, final SslContext clientCtx, final boolean delegate,
122 TestInfo testInfo) throws Throwable {
123 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
124 @Override
125 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
126 testSslGreeting(sb, cb, serverCtx, clientCtx, delegate);
127 }
128 });
129 }
130
131 public void testSslGreeting(ServerBootstrap sb, Bootstrap cb, final SslContext serverCtx,
132 final SslContext clientCtx, boolean delegate) throws Throwable {
133 final ServerHandler sh = new ServerHandler();
134 final ClientHandler ch = new ClientHandler();
135
136 final ExecutorService executorService = delegate ? Executors.newCachedThreadPool() : null;
137 try {
138 sb.childHandler(new ChannelInitializer<Channel>() {
139 @Override
140 public void initChannel(Channel sch) throws Exception {
141 ChannelPipeline p = sch.pipeline();
142 p.addLast(newSslHandler(serverCtx, sch.alloc(), executorService));
143 p.addLast(new LoggingHandler(LOG_LEVEL));
144 p.addLast(sh);
145 }
146 });
147
148 cb.handler(new ChannelInitializer<Channel>() {
149 @Override
150 public void initChannel(Channel sch) throws Exception {
151 ChannelPipeline p = sch.pipeline();
152 p.addLast(newSslHandler(clientCtx, sch.alloc(), executorService));
153 p.addLast(new LoggingHandler(LOG_LEVEL));
154 p.addLast(ch);
155 }
156 });
157
158 Channel sc = sb.bind().sync().channel();
159 Channel cc = cb.connect(sc.localAddress()).sync().channel();
160
161 ch.latch.await();
162
163 sh.channel.close().awaitUninterruptibly();
164 cc.close().awaitUninterruptibly();
165 sc.close().awaitUninterruptibly();
166
167 if (sh.exception.get() != null && !(sh.exception.get() instanceof IOException)) {
168 throw sh.exception.get();
169 }
170 if (ch.exception.get() != null && !(ch.exception.get() instanceof IOException)) {
171 throw ch.exception.get();
172 }
173 if (sh.exception.get() != null) {
174 throw sh.exception.get();
175 }
176 if (ch.exception.get() != null) {
177 throw ch.exception.get();
178 }
179 } finally {
180 if (executorService != null) {
181 executorService.shutdown();
182 }
183 }
184 }
185
186 private static class ClientHandler extends SimpleChannelInboundHandler<ByteBuf> {
187
188 final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
189 final CountDownLatch latch = new CountDownLatch(1);
190
191 @Override
192 public void channelRead0(ChannelHandlerContext ctx, ByteBuf buf) throws Exception {
193 assertEquals('a', buf.readByte());
194 assertFalse(buf.isReadable());
195 latch.countDown();
196 ctx.close();
197 }
198
199 @Override
200 public void exceptionCaught(ChannelHandlerContext ctx,
201 Throwable cause) throws Exception {
202 if (logger.isWarnEnabled()) {
203 logger.warn("Unexpected exception from the client side", cause);
204 }
205
206 exception.compareAndSet(null, cause);
207 ctx.close();
208 }
209 }
210
211 private static class ServerHandler extends SimpleChannelInboundHandler<String> {
212 volatile Channel channel;
213 final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
214
215 @Override
216 protected void channelRead0(ChannelHandlerContext ctx, String msg) throws Exception {
217
218 }
219
220 @Override
221 public void channelActive(ChannelHandlerContext ctx)
222 throws Exception {
223 channel = ctx.channel();
224 channel.writeAndFlush(ctx.alloc().buffer().writeByte('a'));
225 }
226
227 @Override
228 public void exceptionCaught(ChannelHandlerContext ctx,
229 Throwable cause) throws Exception {
230 if (logger.isWarnEnabled()) {
231 logger.warn("Unexpected exception from the server side", cause);
232 }
233
234 exception.compareAndSet(null, cause);
235 ctx.close();
236 }
237
238 @Override
239 public void userEventTriggered(final ChannelHandlerContext ctx, final Object evt) throws Exception {
240 if (evt instanceof SslHandshakeCompletionEvent) {
241 final SslHandshakeCompletionEvent event = (SslHandshakeCompletionEvent) evt;
242 if (event.isSuccess()) {
243 SSLSession session = ctx.pipeline().get(SslHandler.class).engine().getSession();
244 try {
245 session.getPeerCertificates();
246 fail();
247 } catch (SSLPeerUnverifiedException e) {
248
249 }
250 try {
251 session.getPeerCertificateChain();
252 fail();
253 } catch (SSLPeerUnverifiedException e) {
254
255 } catch (UnsupportedOperationException e) {
256
257
258 if (PlatformDependent.javaVersion() < 15) {
259 throw e;
260 }
261 }
262 try {
263 session.getPeerPrincipal();
264 fail();
265 } catch (SSLPeerUnverifiedException e) {
266
267 }
268 }
269 }
270 ctx.fireUserEventTriggered(evt);
271 }
272 }
273 }