1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.testsuite.transport.socket;
17
18 import io.netty.bootstrap.Bootstrap;
19 import io.netty.bootstrap.ServerBootstrap;
20 import io.netty.buffer.PooledByteBufAllocator;
21 import io.netty.channel.Channel;
22 import io.netty.channel.ChannelHandlerContext;
23 import io.netty.channel.ChannelInitializer;
24 import io.netty.channel.ChannelOption;
25 import io.netty.channel.ChannelPipeline;
26 import io.netty.channel.SimpleChannelInboundHandler;
27 import io.netty.handler.codec.LineBasedFrameDecoder;
28 import io.netty.handler.codec.string.StringDecoder;
29 import io.netty.handler.codec.string.StringEncoder;
30 import io.netty.handler.logging.LogLevel;
31 import io.netty.handler.logging.LoggingHandler;
32 import io.netty.handler.ssl.OpenSsl;
33 import io.netty.handler.ssl.SslContext;
34 import io.netty.handler.ssl.SslContextBuilder;
35 import io.netty.handler.ssl.SslHandler;
36 import io.netty.handler.ssl.SslProvider;
37 import io.netty.handler.ssl.util.SelfSignedCertificate;
38 import io.netty.util.concurrent.DefaultEventExecutorGroup;
39 import io.netty.util.concurrent.EventExecutorGroup;
40 import io.netty.util.concurrent.Future;
41 import io.netty.util.internal.logging.InternalLogger;
42 import io.netty.util.internal.logging.InternalLoggerFactory;
43
44 import org.junit.jupiter.api.AfterAll;
45 import org.junit.jupiter.api.BeforeAll;
46 import org.junit.jupiter.api.TestInfo;
47 import org.junit.jupiter.api.Timeout;
48 import org.junit.jupiter.params.ParameterizedTest;
49 import org.junit.jupiter.params.provider.MethodSource;
50
51 import javax.net.ssl.SSLEngine;
52 import java.io.File;
53 import java.io.IOException;
54 import java.security.cert.CertificateException;
55 import java.util.ArrayList;
56 import java.util.Collection;
57 import java.util.List;
58 import java.util.concurrent.TimeUnit;
59 import java.util.concurrent.atomic.AtomicReference;
60
61 import static org.junit.jupiter.api.Assertions.assertEquals;
62 import static org.junit.jupiter.api.Assertions.assertNotNull;
63 import static org.junit.jupiter.api.Assertions.assertTrue;
64
65 public class SocketStartTlsTest extends AbstractSocketTest {
66 private static final String PARAMETERIZED_NAME = "{index}: serverEngine = {0}, clientEngine = {1}";
67
68 private static final InternalLogger logger = InternalLoggerFactory.getInstance(SocketStartTlsTest.class);
69
70 private static final LogLevel LOG_LEVEL = LogLevel.TRACE;
71 private static final File CERT_FILE;
72 private static final File KEY_FILE;
73 private static EventExecutorGroup executor;
74
75 static {
76 SelfSignedCertificate ssc;
77 try {
78 ssc = new SelfSignedCertificate();
79 } catch (CertificateException e) {
80 throw new Error(e);
81 }
82 CERT_FILE = ssc.certificate();
83 KEY_FILE = ssc.privateKey();
84 }
85
86 public static Collection<Object[]> data() throws Exception {
87 List<SslContext> serverContexts = new ArrayList<SslContext>();
88 serverContexts.add(SslContextBuilder.forServer(CERT_FILE, KEY_FILE).sslProvider(SslProvider.JDK).build());
89
90 List<SslContext> clientContexts = new ArrayList<SslContext>();
91 clientContexts.add(SslContextBuilder.forClient().sslProvider(SslProvider.JDK).trustManager(CERT_FILE).build());
92
93 boolean hasOpenSsl = OpenSsl.isAvailable();
94 if (hasOpenSsl) {
95 serverContexts.add(SslContextBuilder.forServer(CERT_FILE, KEY_FILE)
96 .sslProvider(SslProvider.OPENSSL).build());
97 clientContexts.add(SslContextBuilder.forClient().sslProvider(SslProvider.OPENSSL)
98 .trustManager(CERT_FILE).build());
99 } else {
100 logger.warn("OpenSSL is unavailable and thus will not be tested.", OpenSsl.unavailabilityCause());
101 }
102
103 List<Object[]> params = new ArrayList<Object[]>();
104 for (SslContext sc: serverContexts) {
105 for (SslContext cc: clientContexts) {
106 params.add(new Object[] { sc, cc });
107 }
108 }
109 return params;
110 }
111
112 @BeforeAll
113 public static void createExecutor() {
114 executor = new DefaultEventExecutorGroup(2);
115 }
116
117 @AfterAll
118 public static void shutdownExecutor() throws Exception {
119 executor.shutdownGracefully().sync();
120 }
121
122 @ParameterizedTest(name = PARAMETERIZED_NAME)
123 @MethodSource("data")
124 @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
125 public void testStartTls(final SslContext serverCtx, final SslContext clientCtx, TestInfo testInfo)
126 throws Throwable {
127 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
128 @Override
129 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
130 testStartTls(sb, cb, serverCtx, clientCtx);
131 }
132 });
133 }
134
135 public void testStartTls(ServerBootstrap sb, Bootstrap cb,
136 SslContext serverCtx, SslContext clientCtx) throws Throwable {
137 testStartTls(sb, cb, serverCtx, clientCtx, true);
138 }
139
140 @ParameterizedTest(name = PARAMETERIZED_NAME)
141 @MethodSource("data")
142 @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
143 public void testStartTlsNotAutoRead(final SslContext serverCtx, final SslContext clientCtx,
144 TestInfo testInfo) throws Throwable {
145 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
146 @Override
147 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
148 testStartTlsNotAutoRead(sb, cb, serverCtx, clientCtx);
149 }
150 });
151 }
152
153 public void testStartTlsNotAutoRead(ServerBootstrap sb, Bootstrap cb,
154 SslContext serverCtx, SslContext clientCtx) throws Throwable {
155 testStartTls(sb, cb, serverCtx, clientCtx, false);
156 }
157
158 private void testStartTls(ServerBootstrap sb, Bootstrap cb,
159 SslContext serverCtx, SslContext clientCtx, boolean autoRead) throws Throwable {
160 sb.childOption(ChannelOption.AUTO_READ, autoRead);
161 cb.option(ChannelOption.AUTO_READ, autoRead);
162
163 final EventExecutorGroup executor = SocketStartTlsTest.executor;
164 SSLEngine sse = serverCtx.newEngine(PooledByteBufAllocator.DEFAULT);
165 SSLEngine cse = clientCtx.newEngine(PooledByteBufAllocator.DEFAULT);
166
167 final StartTlsServerHandler sh = new StartTlsServerHandler(sse, autoRead);
168 final StartTlsClientHandler ch = new StartTlsClientHandler(cse, autoRead);
169
170 sb.childHandler(new ChannelInitializer<Channel>() {
171 @Override
172 public void initChannel(Channel sch) throws Exception {
173 ChannelPipeline p = sch.pipeline();
174 p.addLast("logger", new LoggingHandler(LOG_LEVEL));
175 p.addLast(new LineBasedFrameDecoder(64), new StringDecoder(), new StringEncoder());
176 p.addLast(executor, sh);
177 }
178 });
179
180 cb.handler(new ChannelInitializer<Channel>() {
181 @Override
182 public void initChannel(Channel sch) throws Exception {
183 ChannelPipeline p = sch.pipeline();
184 p.addLast("logger", new LoggingHandler(LOG_LEVEL));
185 p.addLast(new LineBasedFrameDecoder(64), new StringDecoder(), new StringEncoder());
186 p.addLast(executor, ch);
187 }
188 });
189
190 Channel sc = sb.bind().sync().channel();
191 Channel cc = cb.connect(sc.localAddress()).sync().channel();
192
193 while (cc.isActive()) {
194 if (sh.exception.get() != null) {
195 break;
196 }
197 if (ch.exception.get() != null) {
198 break;
199 }
200
201 Thread.sleep(50);
202 }
203
204 while (sh.channel.isActive()) {
205 if (sh.exception.get() != null) {
206 break;
207 }
208 if (ch.exception.get() != null) {
209 break;
210 }
211
212 Thread.sleep(50);
213 }
214
215 sh.channel.close().awaitUninterruptibly();
216 cc.close().awaitUninterruptibly();
217 sc.close().awaitUninterruptibly();
218
219 if (sh.exception.get() != null && !(sh.exception.get() instanceof IOException)) {
220 throw sh.exception.get();
221 }
222 if (ch.exception.get() != null && !(ch.exception.get() instanceof IOException)) {
223 throw ch.exception.get();
224 }
225 if (sh.exception.get() != null) {
226 throw sh.exception.get();
227 }
228 if (ch.exception.get() != null) {
229 throw ch.exception.get();
230 }
231 }
232
233 private static class StartTlsClientHandler extends SimpleChannelInboundHandler<String> {
234 private final SslHandler sslHandler;
235 private final boolean autoRead;
236 private Future<Channel> handshakeFuture;
237 final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
238
239 StartTlsClientHandler(SSLEngine engine, boolean autoRead) {
240 engine.setUseClientMode(true);
241 sslHandler = new SslHandler(engine);
242 this.autoRead = autoRead;
243 }
244
245 @Override
246 public void channelActive(ChannelHandlerContext ctx)
247 throws Exception {
248 if (!autoRead) {
249 ctx.read();
250 }
251 ctx.writeAndFlush("StartTlsRequest\n");
252 }
253
254 @Override
255 public void channelRead0(ChannelHandlerContext ctx, String msg) throws Exception {
256 if ("StartTlsResponse".equals(msg)) {
257 ctx.pipeline().addAfter("logger", "ssl", sslHandler);
258 handshakeFuture = sslHandler.handshakeFuture();
259 ctx.writeAndFlush("EncryptedRequest\n");
260 return;
261 }
262
263 assertEquals("EncryptedResponse", msg);
264 assertNotNull(handshakeFuture);
265 assertTrue(handshakeFuture.isSuccess());
266 ctx.close();
267 }
268
269 @Override
270 public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
271 if (!autoRead) {
272 ctx.read();
273 }
274 }
275
276 @Override
277 public void exceptionCaught(ChannelHandlerContext ctx,
278 Throwable cause) throws Exception {
279 if (logger.isWarnEnabled()) {
280 logger.warn("Unexpected exception from the client side", cause);
281 }
282
283 exception.compareAndSet(null, cause);
284 ctx.close();
285 }
286 }
287
288 private static class StartTlsServerHandler extends SimpleChannelInboundHandler<String> {
289 private final SslHandler sslHandler;
290 private final boolean autoRead;
291 volatile Channel channel;
292 final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
293
294 StartTlsServerHandler(SSLEngine engine, boolean autoRead) {
295 engine.setUseClientMode(false);
296 sslHandler = new SslHandler(engine, true);
297 this.autoRead = autoRead;
298 }
299
300 @Override
301 public void channelActive(ChannelHandlerContext ctx) throws Exception {
302 channel = ctx.channel();
303 if (!autoRead) {
304 ctx.read();
305 }
306 }
307
308 @Override
309 public void channelRead0(ChannelHandlerContext ctx, String msg) throws Exception {
310 if ("StartTlsRequest".equals(msg)) {
311 ctx.pipeline().addAfter("logger", "ssl", sslHandler);
312 ctx.writeAndFlush("StartTlsResponse\n");
313 return;
314 }
315
316 assertEquals("EncryptedRequest", msg);
317 ctx.writeAndFlush("EncryptedResponse\n");
318 }
319
320 @Override
321 public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
322 if (!autoRead) {
323 ctx.read();
324 }
325 }
326
327 @Override
328 public void exceptionCaught(ChannelHandlerContext ctx,
329 Throwable cause) throws Exception {
330 if (logger.isWarnEnabled()) {
331 logger.warn("Unexpected exception from the server side", cause);
332 }
333
334 exception.compareAndSet(null, cause);
335 ctx.close();
336 }
337 }
338 }