1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.testsuite.transport.socket;
17
18 import io.netty.bootstrap.Bootstrap;
19 import io.netty.bootstrap.ServerBootstrap;
20 import io.netty.buffer.ByteBuf;
21 import io.netty.buffer.Unpooled;
22 import io.netty.channel.Channel;
23 import io.netty.channel.ChannelFuture;
24 import io.netty.channel.ChannelHandler.Sharable;
25 import io.netty.channel.ChannelHandlerContext;
26 import io.netty.channel.ChannelInboundHandlerAdapter;
27 import io.netty.channel.ChannelInitializer;
28 import io.netty.channel.ChannelOption;
29 import io.netty.channel.SimpleChannelInboundHandler;
30 import io.netty.handler.ssl.OpenSsl;
31 import io.netty.handler.ssl.OpenSslContext;
32 import io.netty.handler.ssl.SslContext;
33 import io.netty.handler.ssl.SslContextBuilder;
34 import io.netty.handler.ssl.SslHandler;
35 import io.netty.handler.ssl.SslHandshakeCompletionEvent;
36 import io.netty.handler.ssl.SslProvider;
37 import io.netty.handler.ssl.util.SelfSignedCertificate;
38 import io.netty.handler.stream.ChunkedWriteHandler;
39 import io.netty.testsuite.util.TestUtils;
40 import io.netty.util.concurrent.Future;
41 import io.netty.util.concurrent.GenericFutureListener;
42 import io.netty.util.internal.logging.InternalLogger;
43 import io.netty.util.internal.logging.InternalLoggerFactory;
44 import org.junit.jupiter.api.AfterAll;
45 import org.junit.jupiter.api.TestInfo;
46 import org.junit.jupiter.api.Timeout;
47 import org.junit.jupiter.params.ParameterizedTest;
48 import org.junit.jupiter.params.provider.MethodSource;
49
50 import javax.net.ssl.SSLEngine;
51 import java.io.File;
52 import java.io.IOException;
53 import java.security.cert.CertificateException;
54 import java.util.ArrayList;
55 import java.util.Collection;
56 import java.util.List;
57 import java.util.Random;
58 import java.util.concurrent.CountDownLatch;
59 import java.util.concurrent.ExecutorService;
60 import java.util.concurrent.Executors;
61 import java.util.concurrent.TimeUnit;
62 import java.util.concurrent.atomic.AtomicInteger;
63 import java.util.concurrent.atomic.AtomicReference;
64
65 import static org.hamcrest.MatcherAssert.assertThat;
66 import static org.hamcrest.Matchers.anyOf;
67 import static org.hamcrest.Matchers.is;
68 import static org.hamcrest.Matchers.not;
69 import static org.hamcrest.Matchers.sameInstance;
70 import static org.junit.jupiter.api.Assertions.assertEquals;
71 import static org.junit.jupiter.api.Assertions.assertSame;
72
73 public class SocketSslEchoTest extends AbstractSocketTest {
74
75 private static final InternalLogger logger = InternalLoggerFactory.getInstance(SocketSslEchoTest.class);
76
77 private static final int FIRST_MESSAGE_SIZE = 16384;
78 private static final Random random = new Random();
79 private static final File CERT_FILE;
80 private static final File KEY_FILE;
81 static final byte[] data = new byte[1048576];
82
83 static {
84 random.nextBytes(data);
85
86 SelfSignedCertificate ssc;
87 try {
88 ssc = new SelfSignedCertificate();
89 } catch (CertificateException e) {
90 throw new Error(e);
91 }
92 CERT_FILE = ssc.certificate();
93 KEY_FILE = ssc.privateKey();
94 }
95
96 protected enum RenegotiationType {
97 NONE,
98 CLIENT_INITIATED,
99 SERVER_INITIATED,
100 }
101
102 protected static class Renegotiation {
103 static final Renegotiation NONE = new Renegotiation(RenegotiationType.NONE, null);
104
105 final RenegotiationType type;
106 final String cipherSuite;
107
108 Renegotiation(RenegotiationType type, String cipherSuite) {
109 this.type = type;
110 this.cipherSuite = cipherSuite;
111 }
112
113 @Override
114 public String toString() {
115 if (type == RenegotiationType.NONE) {
116 return "NONE";
117 }
118
119 return type + "(" + cipherSuite + ')';
120 }
121 }
122
123 public static Collection<Object[]> data() throws Exception {
124 List<SslContext> serverContexts = new ArrayList<SslContext>();
125 serverContexts.add(SslContextBuilder.forServer(CERT_FILE, KEY_FILE)
126 .sslProvider(SslProvider.JDK)
127
128 .protocols("TLSv1.2")
129 .build());
130
131 List<SslContext> clientContexts = new ArrayList<SslContext>();
132 clientContexts.add(SslContextBuilder.forClient()
133 .sslProvider(SslProvider.JDK)
134 .trustManager(CERT_FILE)
135
136 .protocols("TLSv1.2")
137 .build());
138
139 boolean hasOpenSsl = OpenSsl.isAvailable();
140 if (hasOpenSsl) {
141 serverContexts.add(SslContextBuilder.forServer(CERT_FILE, KEY_FILE)
142 .sslProvider(SslProvider.OPENSSL)
143
144 .protocols("TLSv1.2")
145 .build());
146 clientContexts.add(SslContextBuilder.forClient()
147 .sslProvider(SslProvider.OPENSSL)
148 .trustManager(CERT_FILE)
149
150 .protocols("TLSv1.2")
151 .build());
152 } else {
153 logger.warn("OpenSSL is unavailable and thus will not be tested.", OpenSsl.unavailabilityCause());
154 }
155
156 List<Object[]> params = new ArrayList<Object[]>();
157 for (SslContext sc: serverContexts) {
158 for (SslContext cc: clientContexts) {
159 for (RenegotiationType rt: RenegotiationType.values()) {
160 if (rt != RenegotiationType.NONE &&
161 (sc instanceof OpenSslContext || cc instanceof OpenSslContext)) {
162
163 continue;
164 }
165
166 final Renegotiation r;
167 switch (rt) {
168 case NONE:
169 r = Renegotiation.NONE;
170 break;
171 case SERVER_INITIATED:
172 r = new Renegotiation(rt, sc.cipherSuites().get(sc.cipherSuites().size() - 1));
173 break;
174 case CLIENT_INITIATED:
175 r = new Renegotiation(rt, cc.cipherSuites().get(cc.cipherSuites().size() - 1));
176 break;
177 default:
178 throw new Error();
179 }
180
181 for (int i = 0; i < 32; i++) {
182 params.add(new Object[] {
183 sc, cc, r,
184 (i & 16) != 0, (i & 8) != 0, (i & 4) != 0, (i & 2) != 0, (i & 1) != 0 });
185 }
186 }
187 }
188 }
189
190 return params;
191 }
192
193 private final AtomicReference<Throwable> clientException = new AtomicReference<Throwable>();
194 private final AtomicReference<Throwable> serverException = new AtomicReference<Throwable>();
195 private final AtomicInteger clientSendCounter = new AtomicInteger();
196 private final AtomicInteger clientRecvCounter = new AtomicInteger();
197 private final AtomicInteger serverRecvCounter = new AtomicInteger();
198
199 private final AtomicInteger clientNegoCounter = new AtomicInteger();
200 private final AtomicInteger serverNegoCounter = new AtomicInteger();
201
202 private volatile Channel clientChannel;
203 private volatile Channel serverChannel;
204
205 private volatile SslHandler clientSslHandler;
206 private volatile SslHandler serverSslHandler;
207
208 private final EchoClientHandler clientHandler =
209 new EchoClientHandler(clientRecvCounter, clientNegoCounter, clientException);
210
211 private final EchoServerHandler serverHandler =
212 new EchoServerHandler(serverRecvCounter, serverNegoCounter, serverException);
213
214 private SslContext serverCtx;
215 private SslContext clientCtx;
216 private Renegotiation renegotiation;
217 private boolean serverUsesDelegatedTaskExecutor;
218 private boolean clientUsesDelegatedTaskExecutor;
219 private boolean autoRead;
220 private boolean useChunkedWriteHandler;
221 private boolean useCompositeByteBuf;
222
223 @AfterAll
224 public static void compressHeapDumps() throws Exception {
225 TestUtils.compressHeapDumps();
226 }
227
228 @ParameterizedTest(name =
229 "{index}: serverEngine = {0}, clientEngine = {1}, renegotiation = {2}, " +
230 "serverUsesDelegatedTaskExecutor = {3}, clientUsesDelegatedTaskExecutor = {4}, " +
231 "autoRead = {5}, useChunkedWriteHandler = {6}, useCompositeByteBuf = {7}")
232 @MethodSource("data")
233 @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
234 public void testSslEcho(
235 SslContext serverCtx, SslContext clientCtx, Renegotiation renegotiation,
236 boolean serverUsesDelegatedTaskExecutor, boolean clientUsesDelegatedTaskExecutor,
237 boolean autoRead, boolean useChunkedWriteHandler, boolean useCompositeByteBuf,
238 TestInfo testInfo) throws Throwable {
239 this.serverCtx = serverCtx;
240 this.clientCtx = clientCtx;
241 this.serverUsesDelegatedTaskExecutor = serverUsesDelegatedTaskExecutor;
242 this.clientUsesDelegatedTaskExecutor = clientUsesDelegatedTaskExecutor;
243 this.renegotiation = renegotiation;
244 this.autoRead = autoRead;
245 this.useChunkedWriteHandler = useChunkedWriteHandler;
246 this.useCompositeByteBuf = useCompositeByteBuf;
247 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
248 @Override
249 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
250 testSslEcho(serverBootstrap, bootstrap);
251 }
252 });
253 }
254
255 public void testSslEcho(ServerBootstrap sb, Bootstrap cb) throws Throwable {
256 final ExecutorService delegatedTaskExecutor = Executors.newCachedThreadPool();
257 reset();
258
259 sb.childOption(ChannelOption.AUTO_READ, autoRead);
260 cb.option(ChannelOption.AUTO_READ, autoRead);
261
262 sb.childHandler(new ChannelInitializer<Channel>() {
263 @Override
264 public void initChannel(Channel sch) {
265 serverChannel = sch;
266
267 if (serverUsesDelegatedTaskExecutor) {
268 SSLEngine sse = serverCtx.newEngine(sch.alloc());
269 serverSslHandler = new SslHandler(sse, delegatedTaskExecutor);
270 } else {
271 serverSslHandler = serverCtx.newHandler(sch.alloc());
272 }
273 serverSslHandler.setHandshakeTimeoutMillis(0);
274
275 sch.pipeline().addLast("ssl", serverSslHandler);
276 if (useChunkedWriteHandler) {
277 sch.pipeline().addLast(new ChunkedWriteHandler());
278 }
279 sch.pipeline().addLast("serverHandler", serverHandler);
280 }
281 });
282
283 final CountDownLatch clientHandshakeEventLatch = new CountDownLatch(1);
284 cb.handler(new ChannelInitializer<Channel>() {
285 @Override
286 public void initChannel(Channel sch) {
287 clientChannel = sch;
288
289 if (clientUsesDelegatedTaskExecutor) {
290 SSLEngine cse = clientCtx.newEngine(sch.alloc());
291 clientSslHandler = new SslHandler(cse, delegatedTaskExecutor);
292 } else {
293 clientSslHandler = clientCtx.newHandler(sch.alloc());
294 }
295 clientSslHandler.setHandshakeTimeoutMillis(0);
296
297 sch.pipeline().addLast("ssl", clientSslHandler);
298 if (useChunkedWriteHandler) {
299 sch.pipeline().addLast(new ChunkedWriteHandler());
300 }
301 sch.pipeline().addLast("clientHandler", clientHandler);
302 sch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
303 @Override
304 public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
305 if (evt instanceof SslHandshakeCompletionEvent) {
306 clientHandshakeEventLatch.countDown();
307 }
308 ctx.fireUserEventTriggered(evt);
309 }
310 });
311 }
312 });
313
314 final Channel sc = sb.bind().sync().channel();
315 cb.connect(sc.localAddress()).sync();
316
317 final Future<Channel> clientHandshakeFuture = clientSslHandler.handshakeFuture();
318
319
320 clientHandshakeFuture.sync();
321 clientHandshakeEventLatch.await();
322
323 clientChannel.writeAndFlush(Unpooled.wrappedBuffer(data, 0, FIRST_MESSAGE_SIZE));
324 clientSendCounter.set(FIRST_MESSAGE_SIZE);
325
326 boolean needsRenegotiation = renegotiation.type == RenegotiationType.CLIENT_INITIATED;
327 Future<Channel> renegoFuture = null;
328 while (clientSendCounter.get() < data.length) {
329 int clientSendCounterVal = clientSendCounter.get();
330 int length = Math.min(random.nextInt(1024 * 64), data.length - clientSendCounterVal);
331 ByteBuf buf = Unpooled.wrappedBuffer(data, clientSendCounterVal, length);
332 if (useCompositeByteBuf) {
333 buf = Unpooled.compositeBuffer().addComponent(true, buf);
334 }
335
336 ChannelFuture future = clientChannel.writeAndFlush(buf);
337 clientSendCounter.set(clientSendCounterVal += length);
338 future.sync();
339
340 if (needsRenegotiation && clientSendCounterVal >= data.length / 2) {
341 needsRenegotiation = false;
342 clientSslHandler.engine().setEnabledCipherSuites(new String[] { renegotiation.cipherSuite });
343 renegoFuture = clientSslHandler.renegotiate();
344 logStats("CLIENT RENEGOTIATES");
345 assertThat(renegoFuture, is(not(sameInstance(clientHandshakeFuture))));
346 }
347 }
348
349
350 while (clientRecvCounter.get() < data.length) {
351 if (serverException.get() != null) {
352 break;
353 }
354 if (clientException.get() != null) {
355 break;
356 }
357
358 Thread.sleep(50);
359 }
360
361 while (serverRecvCounter.get() < data.length) {
362 if (serverException.get() != null) {
363 break;
364 }
365 if (clientException.get() != null) {
366 break;
367 }
368
369 Thread.sleep(50);
370 }
371
372
373 if (renegoFuture != null) {
374 renegoFuture.sync();
375 }
376 if (serverHandler.renegoFuture != null) {
377 serverHandler.renegoFuture.sync();
378 }
379
380 serverChannel.close().awaitUninterruptibly();
381 clientChannel.close().awaitUninterruptibly();
382 sc.close().awaitUninterruptibly();
383 delegatedTaskExecutor.shutdown();
384
385 if (serverException.get() != null && !(serverException.get() instanceof IOException)) {
386 throw serverException.get();
387 }
388 if (clientException.get() != null && !(clientException.get() instanceof IOException)) {
389 throw clientException.get();
390 }
391 if (serverException.get() != null) {
392 throw serverException.get();
393 }
394 if (clientException.get() != null) {
395 throw clientException.get();
396 }
397
398
399 try {
400 switch (renegotiation.type) {
401 case SERVER_INITIATED:
402 assertThat(serverSslHandler.engine().getSession().getCipherSuite(), is(renegotiation.cipherSuite));
403 assertThat(serverNegoCounter.get(), is(2));
404 assertThat(clientNegoCounter.get(), anyOf(is(1), is(2)));
405 break;
406 case CLIENT_INITIATED:
407 assertThat(serverNegoCounter.get(), anyOf(is(1), is(2)));
408 assertThat(clientSslHandler.engine().getSession().getCipherSuite(), is(renegotiation.cipherSuite));
409 assertThat(clientNegoCounter.get(), is(2));
410 break;
411 case NONE:
412 assertThat(serverNegoCounter.get(), is(1));
413 assertThat(clientNegoCounter.get(), is(1));
414 }
415 } finally {
416 logStats("STATS");
417 }
418 }
419
420 private void reset() {
421 clientException.set(null);
422 serverException.set(null);
423
424 clientSendCounter.set(0);
425 clientRecvCounter.set(0);
426 serverRecvCounter.set(0);
427
428 clientNegoCounter.set(0);
429 serverNegoCounter.set(0);
430
431 clientChannel = null;
432 serverChannel = null;
433
434 clientSslHandler = null;
435 serverSslHandler = null;
436 }
437
438 void logStats(String message) {
439 logger.debug(
440 "{}:\n" +
441 "\tclient { sent: {}, rcvd: {}, nego: {}, cipher: {} },\n" +
442 "\tserver { rcvd: {}, nego: {}, cipher: {} }",
443 message,
444 clientSendCounter, clientRecvCounter, clientNegoCounter,
445 clientSslHandler.engine().getSession().getCipherSuite(),
446 serverRecvCounter, serverNegoCounter,
447 serverSslHandler.engine().getSession().getCipherSuite());
448 }
449
450 @Sharable
451 private abstract class EchoHandler extends SimpleChannelInboundHandler<ByteBuf> {
452
453 protected final AtomicInteger recvCounter;
454 protected final AtomicInteger negoCounter;
455 protected final AtomicReference<Throwable> exception;
456
457 EchoHandler(
458 AtomicInteger recvCounter, AtomicInteger negoCounter,
459 AtomicReference<Throwable> exception) {
460
461 this.recvCounter = recvCounter;
462 this.negoCounter = negoCounter;
463 this.exception = exception;
464 }
465
466 @Override
467 public final void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
468
469
470 if (!autoRead) {
471 ctx.read();
472 }
473 ctx.fireChannelReadComplete();
474 }
475
476 @Override
477 public final void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
478 if (evt instanceof SslHandshakeCompletionEvent) {
479 SslHandshakeCompletionEvent handshakeEvt = (SslHandshakeCompletionEvent) evt;
480 if (handshakeEvt.cause() != null) {
481 logger.warn("Handshake failed:", handshakeEvt.cause());
482 }
483 assertSame(SslHandshakeCompletionEvent.SUCCESS, evt);
484 negoCounter.incrementAndGet();
485 logStats("HANDSHAKEN");
486 }
487 ctx.fireUserEventTriggered(evt);
488 }
489
490 @Override
491 public final void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
492 if (logger.isWarnEnabled()) {
493 logger.warn("Unexpected exception from the client side:", cause);
494 }
495
496 exception.compareAndSet(null, cause);
497 ctx.close();
498 }
499 }
500
501 private class EchoClientHandler extends EchoHandler {
502
503 EchoClientHandler(
504 AtomicInteger recvCounter, AtomicInteger negoCounter,
505 AtomicReference<Throwable> exception) {
506
507 super(recvCounter, negoCounter, exception);
508 }
509
510 @Override
511 public void handlerAdded(final ChannelHandlerContext ctx) {
512 if (!autoRead) {
513 ctx.pipeline().get(SslHandler.class).handshakeFuture().addListener(
514 new GenericFutureListener<Future<? super Channel>>() {
515 @Override
516 public void operationComplete(Future<? super Channel> future) {
517 ctx.read();
518 }
519 });
520 }
521 }
522
523 @Override
524 public void channelRead0(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
525 byte[] actual = new byte[in.readableBytes()];
526 in.readBytes(actual);
527
528 int lastIdx = recvCounter.get();
529 for (int i = 0; i < actual.length; i ++) {
530 assertEquals(data[i + lastIdx], actual[i]);
531 }
532
533 recvCounter.addAndGet(actual.length);
534 }
535 }
536
537 private class EchoServerHandler extends EchoHandler {
538 volatile Future<Channel> renegoFuture;
539
540 EchoServerHandler(
541 AtomicInteger recvCounter, AtomicInteger negoCounter,
542 AtomicReference<Throwable> exception) {
543
544 super(recvCounter, negoCounter, exception);
545 }
546
547 @Override
548 public final void channelRegistered(ChannelHandlerContext ctx) {
549 renegoFuture = null;
550 }
551
552 @Override
553 public void channelActive(final ChannelHandlerContext ctx) throws Exception {
554 if (!autoRead) {
555 ctx.read();
556 }
557 ctx.fireChannelActive();
558 }
559
560 @Override
561 public void channelRead0(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
562 byte[] actual = new byte[in.readableBytes()];
563 in.readBytes(actual);
564
565 int lastIdx = recvCounter.get();
566 for (int i = 0; i < actual.length; i ++) {
567 assertEquals(data[i + lastIdx], actual[i]);
568 }
569
570 ByteBuf buf = Unpooled.wrappedBuffer(actual);
571 if (useCompositeByteBuf) {
572 buf = Unpooled.compositeBuffer().addComponent(true, buf);
573 }
574 ctx.writeAndFlush(buf);
575
576 recvCounter.addAndGet(actual.length);
577
578
579 if (renegotiation.type == RenegotiationType.SERVER_INITIATED &&
580 recvCounter.get() > data.length / 2 && renegoFuture == null) {
581
582 SslHandler sslHandler = ctx.pipeline().get(SslHandler.class);
583
584 Future<Channel> hf = sslHandler.handshakeFuture();
585 assertThat(hf.isDone(), is(true));
586
587 sslHandler.engine().setEnabledCipherSuites(new String[] { renegotiation.cipherSuite });
588 logStats("SERVER RENEGOTIATES");
589 renegoFuture = sslHandler.renegotiate();
590 assertThat(renegoFuture, is(not(sameInstance(hf))));
591 assertThat(renegoFuture, is(sameInstance(sslHandler.handshakeFuture())));
592 assertThat(renegoFuture.isDone(), is(false));
593 }
594 }
595 }
596 }