1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.handler.ssl;
17
18 import io.netty.buffer.ByteBufAllocator;
19 import io.netty.channel.ChannelHandlerContext;
20 import io.netty.handler.codec.DecoderException;
21 import io.netty.util.AsyncMapping;
22 import io.netty.util.DomainNameMapping;
23 import io.netty.util.Mapping;
24 import io.netty.util.ReferenceCountUtil;
25 import io.netty.util.concurrent.Future;
26 import io.netty.util.concurrent.Promise;
27 import io.netty.util.internal.ObjectUtil;
28 import io.netty.util.internal.PlatformDependent;
29
30
31
32
33
34
35
36
37 public class SniHandler extends AbstractSniHandler<SslContext> {
38 private static final Selection EMPTY_SELECTION = new Selection(null, null);
39
40 protected final AsyncMapping<String, SslContext> mapping;
41
42 private volatile Selection selection = EMPTY_SELECTION;
43
44
45
46
47
48
49
50 public SniHandler(Mapping<? super String, ? extends SslContext> mapping) {
51 this(new AsyncMappingAdapter(mapping));
52 }
53
54
55
56
57
58
59
60
61
62 public SniHandler(Mapping<? super String, ? extends SslContext> mapping,
63 int maxClientHelloLength, long handshakeTimeoutMillis) {
64 this(new AsyncMappingAdapter(mapping), maxClientHelloLength, handshakeTimeoutMillis);
65 }
66
67
68
69
70
71
72
73 public SniHandler(DomainNameMapping<? extends SslContext> mapping) {
74 this((Mapping<String, ? extends SslContext>) mapping);
75 }
76
77
78
79
80
81
82
83 @SuppressWarnings("unchecked")
84 public SniHandler(AsyncMapping<? super String, ? extends SslContext> mapping) {
85 this(mapping, 0, 0L);
86 }
87
88
89
90
91
92
93
94
95
96 @SuppressWarnings("unchecked")
97 public SniHandler(AsyncMapping<? super String, ? extends SslContext> mapping,
98 int maxClientHelloLength, long handshakeTimeoutMillis) {
99 super(maxClientHelloLength, handshakeTimeoutMillis);
100 this.mapping = (AsyncMapping<String, SslContext>) ObjectUtil.checkNotNull(mapping, "mapping");
101 }
102
103
104
105
106
107
108
109
110 public SniHandler(Mapping<? super String, ? extends SslContext> mapping, long handshakeTimeoutMillis) {
111 this(new AsyncMappingAdapter(mapping), handshakeTimeoutMillis);
112 }
113
114
115
116
117
118
119
120
121 public SniHandler(AsyncMapping<? super String, ? extends SslContext> mapping, long handshakeTimeoutMillis) {
122 this(mapping, 0, handshakeTimeoutMillis);
123 }
124
125
126
127
128 public String hostname() {
129 return selection.hostname;
130 }
131
132
133
134
135 public SslContext sslContext() {
136 return selection.context;
137 }
138
139
140
141
142
143
144
145 @Override
146 protected Future<SslContext> lookup(ChannelHandlerContext ctx, String hostname) throws Exception {
147 return mapping.map(hostname, ctx.executor().<SslContext>newPromise());
148 }
149
150 @Override
151 protected final void onLookupComplete(ChannelHandlerContext ctx,
152 String hostname, Future<SslContext> future) throws Exception {
153 if (!future.isSuccess()) {
154 final Throwable cause = future.cause();
155 if (cause instanceof Error) {
156 throw (Error) cause;
157 }
158 throw new DecoderException("failed to get the SslContext for " + hostname, cause);
159 }
160
161 SslContext sslContext = future.getNow();
162 selection = new Selection(sslContext, hostname);
163 try {
164 replaceHandler(ctx, hostname, sslContext);
165 } catch (Throwable cause) {
166 selection = EMPTY_SELECTION;
167 PlatformDependent.throwException(cause);
168 }
169 }
170
171
172
173
174
175
176
177
178
179
180 protected void replaceHandler(ChannelHandlerContext ctx, String hostname, SslContext sslContext) throws Exception {
181 SslHandler sslHandler = null;
182 try {
183 sslHandler = newSslHandler(sslContext, ctx.alloc());
184 ctx.pipeline().replace(this, SslHandler.class.getName(), sslHandler);
185 sslHandler = null;
186 } finally {
187
188
189
190 if (sslHandler != null) {
191 ReferenceCountUtil.safeRelease(sslHandler.engine());
192 }
193 }
194 }
195
196
197
198
199
200 protected SslHandler newSslHandler(SslContext context, ByteBufAllocator allocator) {
201 SslHandler sslHandler = context.newHandler(allocator);
202 sslHandler.setHandshakeTimeoutMillis(handshakeTimeoutMillis);
203 return sslHandler;
204 }
205
206 private static final class AsyncMappingAdapter implements AsyncMapping<String, SslContext> {
207 private final Mapping<? super String, ? extends SslContext> mapping;
208
209 private AsyncMappingAdapter(Mapping<? super String, ? extends SslContext> mapping) {
210 this.mapping = ObjectUtil.checkNotNull(mapping, "mapping");
211 }
212
213 @Override
214 public Future<SslContext> map(String input, Promise<SslContext> promise) {
215 final SslContext context;
216 try {
217 context = mapping.map(input);
218 } catch (Throwable cause) {
219 return promise.setFailure(cause);
220 }
221 return promise.setSuccess(context);
222 }
223 }
224
225 private static final class Selection {
226 final SslContext context;
227 final String hostname;
228
229 Selection(SslContext context, String hostname) {
230 this.context = context;
231 this.hostname = hostname;
232 }
233 }
234 }