1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.testsuite.transport.socket;
17
18 import io.netty.bootstrap.Bootstrap;
19 import io.netty.bootstrap.ServerBootstrap;
20 import io.netty.buffer.ByteBuf;
21 import io.netty.buffer.ByteBufUtil;
22 import io.netty.buffer.Unpooled;
23 import io.netty.channel.Channel;
24 import io.netty.channel.ChannelHandlerContext;
25 import io.netty.channel.ChannelHandler.Sharable;
26 import io.netty.channel.ChannelInitializer;
27 import io.netty.channel.SimpleChannelInboundHandler;
28 import io.netty.channel.socket.SocketChannel;
29 import io.netty.handler.ssl.JdkSslClientContext;
30 import io.netty.handler.ssl.JdkSslServerContext;
31 import io.netty.handler.ssl.SslContext;
32 import io.netty.handler.ssl.SslHandler;
33 import io.netty.handler.ssl.util.SelfSignedCertificate;
34 import io.netty.util.internal.logging.InternalLogger;
35 import io.netty.util.internal.logging.InternalLoggerFactory;
36
37 import org.junit.jupiter.api.TestInfo;
38 import org.junit.jupiter.api.Timeout;
39 import org.junit.jupiter.params.ParameterizedTest;
40 import org.junit.jupiter.params.provider.MethodSource;
41
42 import javax.net.ssl.SSLEngine;
43 import javax.net.ssl.SSLSessionContext;
44
45 import java.io.File;
46 import java.io.IOException;
47 import java.net.InetSocketAddress;
48 import java.security.cert.CertificateException;
49 import java.util.Collection;
50 import java.util.Collections;
51 import java.util.Enumeration;
52 import java.util.HashSet;
53 import java.util.Set;
54 import java.util.concurrent.TimeUnit;
55 import java.util.concurrent.atomic.AtomicReference;
56
57 import static org.junit.jupiter.api.Assertions.assertEquals;
58
59 public class SocketSslSessionReuseTest extends AbstractSocketTest {
60
61 private static final InternalLogger logger = InternalLoggerFactory.getInstance(SocketSslSessionReuseTest.class);
62
63 private static final File CERT_FILE;
64 private static final File KEY_FILE;
65
66 static {
67 SelfSignedCertificate ssc;
68 try {
69 ssc = new SelfSignedCertificate();
70 } catch (CertificateException e) {
71 throw new Error(e);
72 }
73 CERT_FILE = ssc.certificate();
74 KEY_FILE = ssc.privateKey();
75 }
76
77 public static Collection<Object[]> data() throws Exception {
78 return Collections.singletonList(new Object[] {
79 new JdkSslServerContext(CERT_FILE, KEY_FILE),
80 new JdkSslClientContext(CERT_FILE)
81 });
82 }
83
84 @ParameterizedTest(name = "{index}: serverEngine = {0}, clientEngine = {1}")
85 @MethodSource("data")
86 @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
87 public void testSslSessionReuse(final SslContext serverCtx, final SslContext clientCtx, TestInfo testInfo)
88 throws Throwable {
89 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
90 @Override
91 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
92 testSslSessionReuse(sb, cb, serverCtx, clientCtx);
93 }
94 });
95 }
96
97 public void testSslSessionReuse(ServerBootstrap sb, Bootstrap cb,
98 final SslContext serverCtx, final SslContext clientCtx) throws Throwable {
99 final ReadAndDiscardHandler sh = new ReadAndDiscardHandler(true, true);
100 final ReadAndDiscardHandler ch = new ReadAndDiscardHandler(false, true);
101 final String[] protocols = { "TLSv1", "TLSv1.1", "TLSv1.2" };
102
103 sb.childHandler(new ChannelInitializer<SocketChannel>() {
104 @Override
105 protected void initChannel(SocketChannel sch) throws Exception {
106 SSLEngine engine = serverCtx.newEngine(sch.alloc());
107 engine.setUseClientMode(false);
108 engine.setEnabledProtocols(protocols);
109
110 sch.pipeline().addLast(new SslHandler(engine));
111 sch.pipeline().addLast(sh);
112 }
113 });
114 final Channel sc = sb.bind().sync().channel();
115
116 cb.handler(new ChannelInitializer<SocketChannel>() {
117 @Override
118 protected void initChannel(SocketChannel sch) throws Exception {
119 InetSocketAddress serverAddr = (InetSocketAddress) sc.localAddress();
120 SSLEngine engine = clientCtx.newEngine(sch.alloc(), serverAddr.getHostString(), serverAddr.getPort());
121 engine.setUseClientMode(true);
122 engine.setEnabledProtocols(protocols);
123
124 sch.pipeline().addLast(new SslHandler(engine));
125 sch.pipeline().addLast(ch);
126 }
127 });
128
129 try {
130 SSLSessionContext clientSessionCtx = clientCtx.sessionContext();
131 ByteBuf msg = Unpooled.wrappedBuffer(new byte[] { 0xa, 0xb, 0xc, 0xd }, 0, 4);
132 Channel cc = cb.connect(sc.localAddress()).sync().channel();
133 cc.writeAndFlush(msg).sync();
134 cc.closeFuture().sync();
135 rethrowHandlerExceptions(sh, ch);
136 Set<String> sessions = sessionIdSet(clientSessionCtx.getIds());
137
138 msg = Unpooled.wrappedBuffer(new byte[] { 0xa, 0xb, 0xc, 0xd }, 0, 4);
139 cc = cb.connect(sc.localAddress()).sync().channel();
140 cc.writeAndFlush(msg).sync();
141 cc.closeFuture().sync();
142 assertEquals(sessions, sessionIdSet(clientSessionCtx.getIds()), "Expected no new sessions");
143 rethrowHandlerExceptions(sh, ch);
144 } finally {
145 sc.close().awaitUninterruptibly();
146 }
147 }
148
149 private static void rethrowHandlerExceptions(ReadAndDiscardHandler sh, ReadAndDiscardHandler ch) throws Throwable {
150 if (sh.exception.get() != null && !(sh.exception.get() instanceof IOException)) {
151 throw sh.exception.get();
152 }
153 if (ch.exception.get() != null && !(ch.exception.get() instanceof IOException)) {
154 throw ch.exception.get();
155 }
156 if (sh.exception.get() != null) {
157 throw sh.exception.get();
158 }
159 if (ch.exception.get() != null) {
160 throw ch.exception.get();
161 }
162 }
163
164 private static Set<String> sessionIdSet(Enumeration<byte[]> sessionIds) {
165 Set<String> idSet = new HashSet<String>();
166 byte[] id;
167 while (sessionIds.hasMoreElements()) {
168 id = sessionIds.nextElement();
169 idSet.add(ByteBufUtil.hexDump(Unpooled.wrappedBuffer(id)));
170 }
171 return idSet;
172 }
173
174 @Sharable
175 private static class ReadAndDiscardHandler extends SimpleChannelInboundHandler<ByteBuf> {
176 final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
177 private final boolean server;
178 private final boolean autoRead;
179
180 ReadAndDiscardHandler(boolean server, boolean autoRead) {
181 this.server = server;
182 this.autoRead = autoRead;
183 }
184
185 @Override
186 public void channelRead0(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
187 byte[] actual = new byte[in.readableBytes()];
188 in.readBytes(actual);
189 ctx.close();
190 }
191
192 @Override
193 public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
194 try {
195 ctx.flush();
196 } finally {
197 if (!autoRead) {
198 ctx.read();
199 }
200 }
201 }
202
203 @Override
204 public void exceptionCaught(ChannelHandlerContext ctx,
205 Throwable cause) throws Exception {
206 if (logger.isWarnEnabled()) {
207 logger.warn(
208 "Unexpected exception from the " +
209 (server? "server" : "client") + " side", cause);
210 }
211
212 exception.compareAndSet(null, cause);
213 ctx.close();
214 }
215 }
216 }