1 /*
2 * Copyright 2017 The Netty Project
3 *
4 * The Netty Project licenses this file to you under the Apache License,
5 * version 2.0 (the "License"); you may not use this file except in compliance
6 * with the License. You may obtain a copy of the License at:
7 *
8 * https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13 * License for the specific language governing permissions and limitations
14 * under the License.
15 */
16 package io.netty.handler.ssl;
17
18 import io.netty.buffer.ByteBuf;
19 import io.netty.buffer.ByteBufUtil;
20 import io.netty.channel.ChannelHandlerContext;
21 import io.netty.channel.ChannelOutboundHandler;
22 import io.netty.channel.ChannelPromise;
23 import io.netty.handler.codec.ByteToMessageDecoder;
24 import io.netty.handler.codec.DecoderException;
25 import io.netty.handler.codec.TooLongFrameException;
26 import io.netty.util.concurrent.Future;
27 import io.netty.util.concurrent.FutureListener;
28 import io.netty.util.internal.ObjectUtil;
29 import io.netty.util.internal.PlatformDependent;
30 import io.netty.util.internal.logging.InternalLogger;
31 import io.netty.util.internal.logging.InternalLoggerFactory;
32
33 import java.net.SocketAddress;
34 import java.util.List;
35
36 /**
37 * {@link ByteToMessageDecoder} which allows to be notified once a full {@code ClientHello} was received.
38 */
39 public abstract class SslClientHelloHandler<T> extends ByteToMessageDecoder implements ChannelOutboundHandler {
40
41 /**
42 * The maximum length of client hello message as defined by
43 * <a href="https://www.rfc-editor.org/rfc/rfc5246#section-6.2.1">RFC5246</a>.
44 */
45 public static final int MAX_CLIENT_HELLO_LENGTH = 0xFFFFFF;
46
47 private static final InternalLogger logger =
48 InternalLoggerFactory.getInstance(SslClientHelloHandler.class);
49
50 private final int maxClientHelloLength;
51 private boolean handshakeFailed;
52 private boolean suppressRead;
53 private boolean readPending;
54 private ByteBuf handshakeBuffer;
55
56 public SslClientHelloHandler() {
57 this(MAX_CLIENT_HELLO_LENGTH);
58 }
59
60 protected SslClientHelloHandler(int maxClientHelloLength) {
61 // 16MB is the maximum as per RFC:
62 // See https://www.rfc-editor.org/rfc/rfc5246#section-6.2.1
63 this.maxClientHelloLength =
64 ObjectUtil.checkInRange(maxClientHelloLength, 0, MAX_CLIENT_HELLO_LENGTH, "maxClientHelloLength");
65 }
66
67 @Override
68 protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
69 if (!suppressRead && !handshakeFailed) {
70 try {
71 int readerIndex = in.readerIndex();
72 int readableBytes = in.readableBytes();
73 int handshakeLength = -1;
74
75 // Check if we have enough data to determine the record type and length.
76 while (readableBytes >= SslUtils.SSL_RECORD_HEADER_LENGTH) {
77 final int contentType = in.getUnsignedByte(readerIndex);
78 switch (contentType) {
79 case SslUtils.SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC:
80 // fall-through
81 case SslUtils.SSL_CONTENT_TYPE_ALERT:
82 final int len = SslUtils.getEncryptedPacketLength(in, readerIndex);
83
84 // Not an SSL/TLS packet
85 if (len == SslUtils.NOT_ENCRYPTED) {
86 handshakeFailed = true;
87 NotSslRecordException e = new NotSslRecordException(
88 "not an SSL/TLS record: " + ByteBufUtil.hexDump(in));
89 in.skipBytes(in.readableBytes());
90 ctx.fireUserEventTriggered(new SniCompletionEvent(e));
91 SslUtils.handleHandshakeFailure(ctx, e, true);
92 throw e;
93 }
94 if (len == SslUtils.NOT_ENOUGH_DATA) {
95 // Not enough data
96 return;
97 }
98 // No ClientHello
99 select(ctx, null);
100 return;
101 case SslUtils.SSL_CONTENT_TYPE_HANDSHAKE:
102 final int majorVersion = in.getUnsignedByte(readerIndex + 1);
103 // SSLv3 or TLS
104 if (majorVersion == 3) {
105 int packetLength = in.getUnsignedShort(readerIndex + 3) +
106 SslUtils.SSL_RECORD_HEADER_LENGTH;
107
108 if (readableBytes < packetLength) {
109 // client hello incomplete; try again to decode once more data is ready.
110 return;
111 } else if (packetLength == SslUtils.SSL_RECORD_HEADER_LENGTH) {
112 select(ctx, null);
113 return;
114 }
115
116 final int endOffset = readerIndex + packetLength;
117
118 // Let's check if we already parsed the handshake length or not.
119 if (handshakeLength == -1) {
120 if (readerIndex + 4 > endOffset) {
121 // Need more data to read HandshakeType and handshakeLength (4 bytes)
122 return;
123 }
124
125 final int handshakeType = in.getUnsignedByte(readerIndex +
126 SslUtils.SSL_RECORD_HEADER_LENGTH);
127
128 // Check if this is a clientHello(1)
129 // See https://tools.ietf.org/html/rfc5246#section-7.4
130 if (handshakeType != 1) {
131 select(ctx, null);
132 return;
133 }
134
135 // Read the length of the handshake as it may arrive in fragments
136 // See https://tools.ietf.org/html/rfc5246#section-7.4
137 handshakeLength = in.getUnsignedMedium(readerIndex +
138 SslUtils.SSL_RECORD_HEADER_LENGTH + 1);
139
140 if (handshakeLength > maxClientHelloLength && maxClientHelloLength != 0) {
141 TooLongFrameException e = new TooLongFrameException(
142 "ClientHello length exceeds " + maxClientHelloLength +
143 ": " + handshakeLength);
144 in.skipBytes(in.readableBytes());
145 ctx.fireUserEventTriggered(new SniCompletionEvent(e));
146 SslUtils.handleHandshakeFailure(ctx, e, true);
147 throw e;
148 }
149 // Consume handshakeType and handshakeLength (this sums up as 4 bytes)
150 readerIndex += 4;
151 packetLength -= 4;
152
153 if (handshakeLength + 4 + SslUtils.SSL_RECORD_HEADER_LENGTH <= packetLength) {
154 // We have everything we need in one packet.
155 // Skip the record header
156 readerIndex += SslUtils.SSL_RECORD_HEADER_LENGTH;
157 select(ctx, in.retainedSlice(readerIndex, handshakeLength));
158 return;
159 } else {
160 if (handshakeBuffer == null) {
161 handshakeBuffer = ctx.alloc().buffer(handshakeLength);
162 } else {
163 // Clear the buffer so we can aggregate into it again.
164 handshakeBuffer.clear();
165 }
166 }
167 }
168
169 // Combine the encapsulated data in one buffer but not include the SSL_RECORD_HEADER
170 handshakeBuffer.writeBytes(in, readerIndex + SslUtils.SSL_RECORD_HEADER_LENGTH,
171 packetLength - SslUtils.SSL_RECORD_HEADER_LENGTH);
172 readerIndex += packetLength;
173 readableBytes -= packetLength;
174 if (handshakeLength <= handshakeBuffer.readableBytes()) {
175 ByteBuf clientHello = handshakeBuffer.setIndex(0, handshakeLength);
176 handshakeBuffer = null;
177
178 select(ctx, clientHello);
179 return;
180 }
181 break;
182 }
183 // fall-through
184 default:
185 // not tls, ssl or application data
186 select(ctx, null);
187 return;
188 }
189 }
190 } catch (NotSslRecordException e) {
191 // Just rethrow as in this case we also closed the channel and this is consistent with SslHandler.
192 throw e;
193 } catch (TooLongFrameException e) {
194 // Just rethrow as in this case we also closed the channel
195 throw e;
196 } catch (Exception e) {
197 // unexpected encoding, ignore sni and use default
198 if (logger.isDebugEnabled()) {
199 logger.debug("Unexpected client hello packet: " + ByteBufUtil.hexDump(in), e);
200 }
201 select(ctx, null);
202 }
203 }
204 }
205
206 private void releaseHandshakeBuffer() {
207 releaseIfNotNull(handshakeBuffer);
208 handshakeBuffer = null;
209 }
210
211 private static void releaseIfNotNull(ByteBuf buffer) {
212 if (buffer != null) {
213 buffer.release();
214 }
215 }
216
217 private void select(final ChannelHandlerContext ctx, ByteBuf clientHello) throws Exception {
218 final Future<T> future;
219 try {
220 future = lookup(ctx, clientHello);
221 if (future.isDone()) {
222 onLookupComplete(ctx, future);
223 } else {
224 suppressRead = true;
225 final ByteBuf finalClientHello = clientHello;
226 future.addListener(new FutureListener<T>() {
227 @Override
228 public void operationComplete(Future<T> future) {
229 releaseIfNotNull(finalClientHello);
230 try {
231 suppressRead = false;
232 try {
233 onLookupComplete(ctx, future);
234 } catch (DecoderException err) {
235 ctx.fireExceptionCaught(err);
236 } catch (Exception cause) {
237 ctx.fireExceptionCaught(new DecoderException(cause));
238 } catch (Throwable cause) {
239 ctx.fireExceptionCaught(cause);
240 }
241 } finally {
242 if (readPending) {
243 readPending = false;
244 ctx.read();
245 }
246 }
247 }
248 });
249
250 // Ownership was transferred to the FutureListener.
251 clientHello = null;
252 }
253 } catch (Throwable cause) {
254 PlatformDependent.throwException(cause);
255 } finally {
256 releaseIfNotNull(clientHello);
257 }
258 }
259
260 @Override
261 protected void handlerRemoved0(ChannelHandlerContext ctx) throws Exception {
262 releaseHandshakeBuffer();
263
264 super.handlerRemoved0(ctx);
265 }
266
267 /**
268 * Kicks off a lookup for the given {@code ClientHello} and returns a {@link Future} which in turn will
269 * notify the {@link #onLookupComplete(ChannelHandlerContext, Future)} on completion.
270 *
271 * See https://tools.ietf.org/html/rfc5246#section-7.4.1.2
272 *
273 * <pre>
274 * struct {
275 * ProtocolVersion client_version;
276 * Random random;
277 * SessionID session_id;
278 * CipherSuite cipher_suites<2..2^16-2>;
279 * CompressionMethod compression_methods<1..2^8-1>;
280 * select (extensions_present) {
281 * case false:
282 * struct {};
283 * case true:
284 * Extension extensions<0..2^16-1>;
285 * };
286 * } ClientHello;
287 * </pre>
288 *
289 * @see #onLookupComplete(ChannelHandlerContext, Future)
290 */
291 protected abstract Future<T> lookup(ChannelHandlerContext ctx, ByteBuf clientHello) throws Exception;
292
293 /**
294 * Called upon completion of the {@link #lookup(ChannelHandlerContext, ByteBuf)} {@link Future}.
295 *
296 * @see #lookup(ChannelHandlerContext, ByteBuf)
297 */
298 protected abstract void onLookupComplete(ChannelHandlerContext ctx, Future<T> future) throws Exception;
299
300 @Override
301 public void read(ChannelHandlerContext ctx) throws Exception {
302 if (suppressRead) {
303 readPending = true;
304 } else {
305 ctx.read();
306 }
307 }
308
309 @Override
310 public void bind(ChannelHandlerContext ctx, SocketAddress localAddress, ChannelPromise promise) throws Exception {
311 ctx.bind(localAddress, promise);
312 }
313
314 @Override
315 public void connect(ChannelHandlerContext ctx, SocketAddress remoteAddress, SocketAddress localAddress,
316 ChannelPromise promise) throws Exception {
317 ctx.connect(remoteAddress, localAddress, promise);
318 }
319
320 @Override
321 public void disconnect(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
322 ctx.disconnect(promise);
323 }
324
325 @Override
326 public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
327 ctx.close(promise);
328 }
329
330 @Override
331 public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
332 ctx.deregister(promise);
333 }
334
335 @Override
336 public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
337 ctx.write(msg, promise);
338 }
339
340 @Override
341 public void flush(ChannelHandlerContext ctx) throws Exception {
342 ctx.flush();
343 }
344 }