1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.testsuite.transport.socket;
17
18 import io.netty.bootstrap.Bootstrap;
19 import io.netty.bootstrap.ServerBootstrap;
20 import io.netty.buffer.ByteBuf;
21 import io.netty.buffer.ByteBufAllocator;
22 import io.netty.channel.Channel;
23 import io.netty.channel.ChannelHandler.Sharable;
24 import io.netty.channel.ChannelHandlerContext;
25 import io.netty.channel.ChannelInitializer;
26 import io.netty.channel.SimpleChannelInboundHandler;
27 import io.netty.handler.codec.DecoderException;
28 import io.netty.handler.ssl.JdkSslClientContext;
29 import io.netty.handler.ssl.OpenSsl;
30 import io.netty.handler.ssl.OpenSslServerContext;
31 import io.netty.handler.ssl.SslContext;
32 import io.netty.handler.ssl.SslHandler;
33 import io.netty.handler.ssl.SslHandshakeCompletionEvent;
34 import io.netty.handler.ssl.util.SelfSignedCertificate;
35 import io.netty.util.concurrent.Future;
36 import io.netty.util.internal.logging.InternalLogger;
37 import io.netty.util.internal.logging.InternalLoggerFactory;
38 import org.junit.jupiter.api.TestInfo;
39 import org.junit.jupiter.api.Timeout;
40 import org.junit.jupiter.api.condition.DisabledIf;
41 import org.junit.jupiter.params.ParameterizedTest;
42 import org.junit.jupiter.params.provider.MethodSource;
43
44 import java.io.File;
45 import java.nio.channels.ClosedChannelException;
46 import java.security.cert.CertificateException;
47 import java.util.ArrayList;
48 import java.util.Collection;
49 import java.util.List;
50 import java.util.concurrent.Executor;
51 import java.util.concurrent.ExecutorService;
52 import java.util.concurrent.Executors;
53 import java.util.concurrent.TimeUnit;
54 import java.util.concurrent.atomic.AtomicReference;
55
56 import javax.net.ssl.SSLHandshakeException;
57
58 import static org.junit.jupiter.api.Assertions.assertSame;
59 import static org.junit.jupiter.api.Assertions.assertTrue;
60 import static org.junit.jupiter.api.Assertions.fail;
61 import static org.junit.jupiter.api.Assumptions.assumeFalse;
62 import static org.junit.jupiter.api.Assumptions.assumeTrue;
63
64 public class SocketSslClientRenegotiateTest extends AbstractSocketTest {
65 private static final InternalLogger logger = InternalLoggerFactory.getInstance(
66 SocketSslClientRenegotiateTest.class);
67 private static final File CERT_FILE;
68 private static final File KEY_FILE;
69
70 static {
71 SelfSignedCertificate ssc;
72 try {
73 ssc = new SelfSignedCertificate();
74 } catch (CertificateException e) {
75 throw new Error(e);
76 }
77 CERT_FILE = ssc.certificate();
78 KEY_FILE = ssc.privateKey();
79 }
80
81 private static boolean openSslNotAvailable() {
82 return !OpenSsl.isAvailable();
83 }
84
85 public static Collection<Object[]> data() throws Exception {
86 List<SslContext> serverContexts = new ArrayList<SslContext>();
87 List<SslContext> clientContexts = new ArrayList<SslContext>();
88 clientContexts.add(new JdkSslClientContext(CERT_FILE));
89
90 boolean hasOpenSsl = OpenSsl.isAvailable();
91 if (hasOpenSsl) {
92 OpenSslServerContext context = new OpenSslServerContext(CERT_FILE, KEY_FILE);
93 serverContexts.add(context);
94 } else {
95 logger.warn("OpenSSL is unavailable and thus will not be tested.", OpenSsl.unavailabilityCause());
96 }
97
98 List<Object[]> params = new ArrayList<Object[]>();
99 for (SslContext sc: serverContexts) {
100 for (SslContext cc: clientContexts) {
101 for (int i = 0; i < 32; i++) {
102 params.add(new Object[] { sc, cc, true});
103 params.add(new Object[] { sc, cc, false});
104 }
105 }
106 }
107
108 return params;
109 }
110
111 private final AtomicReference<Throwable> clientException = new AtomicReference<Throwable>();
112 private final AtomicReference<Throwable> serverException = new AtomicReference<Throwable>();
113
114 private volatile Channel clientChannel;
115 private volatile Channel serverChannel;
116
117 private volatile SslHandler clientSslHandler;
118 private volatile SslHandler serverSslHandler;
119
120 private final TestHandler clientHandler = new TestHandler(clientException);
121
122 private final TestHandler serverHandler = new TestHandler(serverException);
123
124 @DisabledIf("openSslNotAvailable")
125 @ParameterizedTest(name = "{index}: serverEngine = {0}, clientEngine = {1}, delegate = {2}")
126 @MethodSource("data")
127 @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
128 public void testSslRenegotiationRejected(final SslContext serverCtx, final SslContext clientCtx,
129 final boolean delegate, TestInfo testInfo) throws Throwable {
130
131 assumeFalse("BoringSSL".equals(OpenSsl.versionString()));
132 assumeTrue(OpenSsl.isAvailable());
133 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
134 @Override
135 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
136 testSslRenegotiationRejected(sb, cb, serverCtx, clientCtx, delegate);
137 }
138 });
139 }
140
141 private static SslHandler newSslHandler(SslContext sslCtx, ByteBufAllocator allocator, Executor executor) {
142 if (executor == null) {
143 return sslCtx.newHandler(allocator);
144 } else {
145 return sslCtx.newHandler(allocator, executor);
146 }
147 }
148
149 public void testSslRenegotiationRejected(ServerBootstrap sb, Bootstrap cb, final SslContext serverCtx,
150 final SslContext clientCtx, boolean delegate) throws Throwable {
151 reset();
152
153 final ExecutorService executorService = delegate ? Executors.newCachedThreadPool() : null;
154
155 try {
156 sb.childHandler(new ChannelInitializer<Channel>() {
157 @Override
158 @SuppressWarnings("deprecation")
159 public void initChannel(Channel sch) throws Exception {
160 serverChannel = sch;
161 serverSslHandler = newSslHandler(serverCtx, sch.alloc(), executorService);
162
163 serverSslHandler.engine().setEnabledProtocols(new String[]{"TLSv1.2"});
164 sch.pipeline().addLast("ssl", serverSslHandler);
165 sch.pipeline().addLast("handler", serverHandler);
166 }
167 });
168
169 cb.handler(new ChannelInitializer<Channel>() {
170 @Override
171 @SuppressWarnings("deprecation")
172 public void initChannel(Channel sch) throws Exception {
173 clientChannel = sch;
174 clientSslHandler = newSslHandler(clientCtx, sch.alloc(), executorService);
175
176 clientSslHandler.engine().setEnabledProtocols(new String[]{"TLSv1.2"});
177 sch.pipeline().addLast("ssl", clientSslHandler);
178 sch.pipeline().addLast("handler", clientHandler);
179 }
180 });
181
182 Channel sc = sb.bind().sync().channel();
183 cb.connect(sc.localAddress()).sync();
184
185 Future<Channel> clientHandshakeFuture = clientSslHandler.handshakeFuture();
186 clientHandshakeFuture.sync();
187
188 String renegotiation = clientSslHandler.engine().getEnabledCipherSuites()[0];
189
190 clientSslHandler.engine().setEnabledCipherSuites(new String[]{renegotiation});
191 clientSslHandler.renegotiate().await();
192 serverChannel.close().awaitUninterruptibly();
193 clientChannel.close().awaitUninterruptibly();
194 sc.close().awaitUninterruptibly();
195 try {
196 if (serverException.get() != null) {
197 throw serverException.get();
198 }
199 fail();
200 } catch (DecoderException e) {
201 assertTrue(e.getCause() instanceof SSLHandshakeException);
202 }
203 if (clientException.get() != null) {
204 throw clientException.get();
205 }
206 } finally {
207 if (executorService != null) {
208 executorService.shutdown();
209 }
210 }
211 }
212
213 private void reset() {
214 clientException.set(null);
215 serverException.set(null);
216 clientHandler.handshakeCounter = 0;
217 serverHandler.handshakeCounter = 0;
218 clientChannel = null;
219 serverChannel = null;
220
221 clientSslHandler = null;
222 serverSslHandler = null;
223 }
224
225 @Sharable
226 private static final class TestHandler extends SimpleChannelInboundHandler<ByteBuf> {
227
228 protected final AtomicReference<Throwable> exception;
229 private int handshakeCounter;
230
231 TestHandler(AtomicReference<Throwable> exception) {
232 this.exception = exception;
233 }
234
235 @Override
236 public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
237 ctx.flush();
238 }
239
240 @Override
241 public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
242 exception.compareAndSet(null, cause);
243 ctx.close();
244 }
245
246 @Override
247 public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
248 if (evt instanceof SslHandshakeCompletionEvent) {
249 SslHandshakeCompletionEvent handshakeEvt = (SslHandshakeCompletionEvent) evt;
250 if (handshakeCounter == 0) {
251 handshakeCounter++;
252 if (handshakeEvt.cause() != null) {
253 logger.warn("Handshake failed:", handshakeEvt.cause());
254 }
255 assertSame(SslHandshakeCompletionEvent.SUCCESS, evt);
256 } else {
257 if (ctx.channel().parent() == null) {
258 assertTrue(handshakeEvt.cause() instanceof ClosedChannelException);
259 }
260 }
261 }
262 }
263
264 @Override
265 public void channelRead0(ChannelHandlerContext ctx, ByteBuf in) throws Exception { }
266 }
267 }