1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.handler.ssl;
17
18 import io.netty.buffer.ByteBuf;
19 import io.netty.channel.ChannelHandlerContext;
20 import io.netty.util.CharsetUtil;
21 import io.netty.util.concurrent.Future;
22 import io.netty.util.concurrent.ScheduledFuture;
23
24 import java.util.Locale;
25 import java.util.concurrent.TimeUnit;
26
27 import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero;
28
29
30
31
32
33
34
35
36 public abstract class AbstractSniHandler<T> extends SslClientHelloHandler<T> {
37
38 private static String extractSniHostname(ByteBuf in) {
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59 int offset = in.readerIndex();
60 int endOffset = in.writerIndex();
61 offset += 34;
62
63 if (endOffset - offset >= 6) {
64 final int sessionIdLength = in.getUnsignedByte(offset);
65 offset += sessionIdLength + 1;
66
67 final int cipherSuitesLength = in.getUnsignedShort(offset);
68 offset += cipherSuitesLength + 2;
69
70 final int compressionMethodLength = in.getUnsignedByte(offset);
71 offset += compressionMethodLength + 1;
72
73 final int extensionsLength = in.getUnsignedShort(offset);
74 offset += 2;
75 final int extensionsLimit = offset + extensionsLength;
76
77
78 if (extensionsLimit <= endOffset) {
79 while (extensionsLimit - offset >= 4) {
80 final int extensionType = in.getUnsignedShort(offset);
81 offset += 2;
82
83 final int extensionLength = in.getUnsignedShort(offset);
84 offset += 2;
85
86 if (extensionsLimit - offset < extensionLength) {
87 break;
88 }
89
90
91
92 if (extensionType == 0) {
93 offset += 2;
94 if (extensionsLimit - offset < 3) {
95 break;
96 }
97
98 final int serverNameType = in.getUnsignedByte(offset);
99 offset++;
100
101 if (serverNameType == 0) {
102 final int serverNameLength = in.getUnsignedShort(offset);
103 offset += 2;
104
105 if (extensionsLimit - offset < serverNameLength) {
106 break;
107 }
108
109 final String hostname = in.toString(offset, serverNameLength, CharsetUtil.US_ASCII);
110 return hostname.toLowerCase(Locale.US);
111 } else {
112
113 break;
114 }
115 }
116
117 offset += extensionLength;
118 }
119 }
120 }
121 return null;
122 }
123
124 protected final long handshakeTimeoutMillis;
125 private ScheduledFuture<?> timeoutFuture;
126 private String hostname;
127
128
129
130
131 protected AbstractSniHandler(long handshakeTimeoutMillis) {
132 this(0, handshakeTimeoutMillis);
133 }
134
135
136
137
138
139 protected AbstractSniHandler(int maxClientHelloLength, long handshakeTimeoutMillis) {
140 super(maxClientHelloLength);
141 this.handshakeTimeoutMillis = checkPositiveOrZero(handshakeTimeoutMillis, "handshakeTimeoutMillis");
142 }
143
144 public AbstractSniHandler() {
145 this(0, 0L);
146 }
147
148 @Override
149 public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
150 if (ctx.channel().isActive()) {
151 checkStartTimeout(ctx);
152 }
153 }
154
155 @Override
156 public void channelActive(ChannelHandlerContext ctx) throws Exception {
157 ctx.fireChannelActive();
158 checkStartTimeout(ctx);
159 }
160
161 private void checkStartTimeout(final ChannelHandlerContext ctx) {
162 if (handshakeTimeoutMillis <= 0 || timeoutFuture != null) {
163 return;
164 }
165 timeoutFuture = ctx.executor().schedule(new Runnable() {
166 @Override
167 public void run() {
168 if (ctx.channel().isActive()) {
169 SslHandshakeTimeoutException exception = new SslHandshakeTimeoutException(
170 "handshake timed out after " + handshakeTimeoutMillis + "ms");
171 ctx.fireUserEventTriggered(new SniCompletionEvent(exception));
172 ctx.close();
173 }
174 }
175 }, handshakeTimeoutMillis, TimeUnit.MILLISECONDS);
176 }
177
178 @Override
179 protected Future<T> lookup(ChannelHandlerContext ctx, ByteBuf clientHello) throws Exception {
180 hostname = clientHello == null ? null : extractSniHostname(clientHello);
181
182 return lookup(ctx, hostname);
183 }
184
185 @Override
186 protected void onLookupComplete(ChannelHandlerContext ctx, Future<T> future) throws Exception {
187 if (timeoutFuture != null) {
188 timeoutFuture.cancel(false);
189 }
190 try {
191 onLookupComplete(ctx, hostname, future);
192 } finally {
193 fireSniCompletionEvent(ctx, hostname, future);
194 }
195 }
196
197
198
199
200
201
202
203 protected abstract Future<T> lookup(ChannelHandlerContext ctx, String hostname) throws Exception;
204
205
206
207
208
209
210 protected abstract void onLookupComplete(ChannelHandlerContext ctx,
211 String hostname, Future<T> future) throws Exception;
212
213 private static void fireSniCompletionEvent(ChannelHandlerContext ctx, String hostname, Future<?> future) {
214 Throwable cause = future.cause();
215 if (cause == null) {
216 ctx.fireUserEventTriggered(new SniCompletionEvent(hostname));
217 } else {
218 ctx.fireUserEventTriggered(new SniCompletionEvent(hostname, cause));
219 }
220 }
221 }