1   
2   
3   
4   
5   
6   
7   
8   
9   
10  
11  
12  
13  
14  
15  
16  package io.netty.handler.codec.http.websocketx;
17  
18  import java.nio.channels.ClosedChannelException;
19  import java.util.Collections;
20  import java.util.LinkedHashSet;
21  import java.util.Set;
22  
23  import io.netty.buffer.Unpooled;
24  import io.netty.channel.Channel;
25  import io.netty.channel.ChannelFuture;
26  import io.netty.channel.ChannelFutureListener;
27  import io.netty.channel.ChannelHandler;
28  import io.netty.channel.ChannelHandlerContext;
29  import io.netty.channel.ChannelInboundHandlerAdapter;
30  import io.netty.channel.ChannelOutboundInvoker;
31  import io.netty.channel.ChannelPipeline;
32  import io.netty.channel.ChannelPromise;
33  import io.netty.handler.codec.http.DefaultFullHttpRequest;
34  import io.netty.handler.codec.http.EmptyHttpHeaders;
35  import io.netty.handler.codec.http.FullHttpRequest;
36  import io.netty.handler.codec.http.FullHttpResponse;
37  import io.netty.handler.codec.http.HttpContentCompressor;
38  import io.netty.handler.codec.http.HttpHeaders;
39  import io.netty.handler.codec.http.HttpObject;
40  import io.netty.handler.codec.http.HttpObjectAggregator;
41  import io.netty.handler.codec.http.HttpRequest;
42  import io.netty.handler.codec.http.HttpRequestDecoder;
43  import io.netty.handler.codec.http.HttpResponseEncoder;
44  import io.netty.handler.codec.http.HttpServerCodec;
45  import io.netty.handler.codec.http.HttpUtil;
46  import io.netty.handler.codec.http.LastHttpContent;
47  import io.netty.util.ReferenceCountUtil;
48  import io.netty.util.internal.EmptyArrays;
49  import io.netty.util.internal.ObjectUtil;
50  import io.netty.util.internal.logging.InternalLogger;
51  import io.netty.util.internal.logging.InternalLoggerFactory;
52  
53  
54  
55  
56  public abstract class WebSocketServerHandshaker {
57      protected static final InternalLogger logger = InternalLoggerFactory.getInstance(WebSocketServerHandshaker.class);
58  
59      private final String uri;
60  
61      private final String[] subprotocols;
62  
63      private final WebSocketVersion version;
64  
65      private final WebSocketDecoderConfig decoderConfig;
66  
67      private String selectedSubprotocol;
68  
69      
70  
71  
72      public static final String SUB_PROTOCOL_WILDCARD = "*";
73  
74      
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87      protected WebSocketServerHandshaker(
88              WebSocketVersion version, String uri, String subprotocols,
89              int maxFramePayloadLength) {
90          this(version, uri, subprotocols, WebSocketDecoderConfig.newBuilder()
91              .maxFramePayloadLength(maxFramePayloadLength)
92              .build());
93      }
94  
95      
96  
97  
98  
99  
100 
101 
102 
103 
104 
105 
106 
107 
108     protected WebSocketServerHandshaker(
109             WebSocketVersion version, String uri, String subprotocols, WebSocketDecoderConfig decoderConfig) {
110         this.version = version;
111         this.uri = uri;
112         if (subprotocols != null) {
113             String[] subprotocolArray = subprotocols.split(",");
114             for (int i = 0; i < subprotocolArray.length; i++) {
115                 subprotocolArray[i] = subprotocolArray[i].trim();
116             }
117             this.subprotocols = subprotocolArray;
118         } else {
119             this.subprotocols = EmptyArrays.EMPTY_STRINGS;
120         }
121         this.decoderConfig = ObjectUtil.checkNotNull(decoderConfig, "decoderConfig");
122     }
123 
124     
125 
126 
127     public String uri() {
128         return uri;
129     }
130 
131     
132 
133 
134     public Set<String> subprotocols() {
135         Set<String> ret = new LinkedHashSet<String>();
136         Collections.addAll(ret, subprotocols);
137         return ret;
138     }
139 
140     
141 
142 
143     public WebSocketVersion version() {
144         return version;
145     }
146 
147     
148 
149 
150 
151 
152     public int maxFramePayloadLength() {
153         return decoderConfig.maxFramePayloadLength();
154     }
155 
156     
157 
158 
159 
160 
161     public WebSocketDecoderConfig decoderConfig() {
162         return decoderConfig;
163     }
164 
165     
166 
167 
168 
169 
170 
171 
172 
173 
174 
175 
176     public ChannelFuture handshake(Channel channel, FullHttpRequest req) {
177         return handshake(channel, req, null, channel.newPromise());
178     }
179 
180     
181 
182 
183 
184 
185 
186 
187 
188 
189 
190 
191 
192 
193 
194 
195 
196     public final ChannelFuture handshake(Channel channel, FullHttpRequest req,
197                                             HttpHeaders responseHeaders, final ChannelPromise promise) {
198 
199         if (logger.isDebugEnabled()) {
200             logger.debug("{} WebSocket version {} server handshake", channel, version());
201         }
202         FullHttpResponse response = newHandshakeResponse(req, responseHeaders);
203         ChannelPipeline p = channel.pipeline();
204         if (p.get(HttpObjectAggregator.class) != null) {
205             p.remove(HttpObjectAggregator.class);
206         }
207         if (p.get(HttpContentCompressor.class) != null) {
208             p.remove(HttpContentCompressor.class);
209         }
210         ChannelHandlerContext ctx = p.context(HttpRequestDecoder.class);
211         final String encoderName;
212         if (ctx == null) {
213             
214             ctx = p.context(HttpServerCodec.class);
215             if (ctx == null) {
216                 promise.setFailure(
217                         new IllegalStateException("No HttpDecoder and no HttpServerCodec in the pipeline"));
218                 response.release();
219                 return promise;
220             }
221             p.addBefore(ctx.name(), "wsencoder", newWebSocketEncoder());
222             p.addBefore(ctx.name(), "wsdecoder", newWebsocketDecoder());
223             encoderName = ctx.name();
224         } else {
225             p.replace(ctx.name(), "wsdecoder", newWebsocketDecoder());
226 
227             encoderName = p.context(HttpResponseEncoder.class).name();
228             p.addBefore(encoderName, "wsencoder", newWebSocketEncoder());
229         }
230         channel.writeAndFlush(response).addListener(new ChannelFutureListener() {
231             @Override
232             public void operationComplete(ChannelFuture future) throws Exception {
233                 if (future.isSuccess()) {
234                     ChannelPipeline p = future.channel().pipeline();
235                     p.remove(encoderName);
236                     promise.setSuccess();
237                 } else {
238                     promise.setFailure(future.cause());
239                 }
240             }
241         });
242         return promise;
243     }
244 
245     
246 
247 
248 
249 
250 
251 
252 
253 
254 
255 
256     public ChannelFuture handshake(Channel channel, HttpRequest req) {
257         return handshake(channel, req, null, channel.newPromise());
258     }
259 
260     
261 
262 
263 
264 
265 
266 
267 
268 
269 
270 
271 
272 
273 
274 
275 
276     public final ChannelFuture handshake(final Channel channel, HttpRequest req,
277                                          final HttpHeaders responseHeaders, final ChannelPromise promise) {
278         if (req instanceof FullHttpRequest) {
279             return handshake(channel, (FullHttpRequest) req, responseHeaders, promise);
280         }
281 
282         if (logger.isDebugEnabled()) {
283             logger.debug("{} WebSocket version {} server handshake", channel, version());
284         }
285 
286         ChannelPipeline p = channel.pipeline();
287         ChannelHandlerContext ctx = p.context(HttpRequestDecoder.class);
288         if (ctx == null) {
289             
290             ctx = p.context(HttpServerCodec.class);
291             if (ctx == null) {
292                 promise.setFailure(
293                         new IllegalStateException("No HttpDecoder and no HttpServerCodec in the pipeline"));
294                 return promise;
295             }
296         }
297 
298         String aggregatorCtx = ctx.name();
299         if (HttpUtil.isContentLengthSet(req) || HttpUtil.isTransferEncodingChunked(req) ||
300             version == WebSocketVersion.V00) {
301             
302             
303             aggregatorCtx = "httpAggregator";
304             p.addAfter(ctx.name(), aggregatorCtx, new HttpObjectAggregator(8192));
305         }
306 
307         p.addAfter(aggregatorCtx, "handshaker", new ChannelInboundHandlerAdapter() {
308 
309             private FullHttpRequest fullHttpRequest;
310 
311             @Override
312             public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
313                 if (msg instanceof HttpObject) {
314                     try {
315                         handleHandshakeRequest(ctx, (HttpObject) msg);
316                     } finally {
317                         ReferenceCountUtil.release(msg);
318                     }
319                 } else {
320                     super.channelRead(ctx, msg);
321                 }
322             }
323 
324             @Override
325             public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
326                 
327                 ctx.pipeline().remove(this);
328                 promise.tryFailure(cause);
329                 ctx.fireExceptionCaught(cause);
330             }
331 
332             @Override
333             public void channelInactive(ChannelHandlerContext ctx) throws Exception {
334                 try {
335                     
336                     if (!promise.isDone()) {
337                         promise.tryFailure(new ClosedChannelException());
338                     }
339                     ctx.fireChannelInactive();
340                 } finally {
341                     releaseFullHttpRequest();
342                 }
343             }
344 
345             @Override
346             public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
347                 releaseFullHttpRequest();
348             }
349 
350             private void handleHandshakeRequest(ChannelHandlerContext ctx, HttpObject httpObject) {
351                 if (httpObject instanceof FullHttpRequest) {
352                     ctx.pipeline().remove(this);
353                     handshake(channel, (FullHttpRequest) httpObject, responseHeaders, promise);
354                     return;
355                 }
356 
357                 if (httpObject instanceof LastHttpContent) {
358                     assert fullHttpRequest != null;
359                     FullHttpRequest handshakeRequest = fullHttpRequest;
360                     fullHttpRequest = null;
361                     try {
362                         ctx.pipeline().remove(this);
363                         handshake(channel, handshakeRequest, responseHeaders, promise);
364                     } finally {
365                         handshakeRequest.release();
366                     }
367                     return;
368                 }
369 
370                 if (httpObject instanceof HttpRequest) {
371                     HttpRequest httpRequest = (HttpRequest) httpObject;
372                     fullHttpRequest = new DefaultFullHttpRequest(httpRequest.protocolVersion(), httpRequest.method(),
373                         httpRequest.uri(), Unpooled.EMPTY_BUFFER, httpRequest.headers(), EmptyHttpHeaders.INSTANCE);
374                     if (httpRequest.decoderResult().isFailure()) {
375                         fullHttpRequest.setDecoderResult(httpRequest.decoderResult());
376                     }
377                 }
378             }
379 
380             private void releaseFullHttpRequest() {
381                 if (fullHttpRequest != null) {
382                     fullHttpRequest.release();
383                     fullHttpRequest = null;
384                 }
385             }
386         });
387         try {
388             ctx.fireChannelRead(ReferenceCountUtil.retain(req));
389         } catch (Throwable cause) {
390             promise.setFailure(cause);
391         }
392         return promise;
393     }
394 
395     
396 
397 
398     protected abstract FullHttpResponse newHandshakeResponse(FullHttpRequest req,
399                                          HttpHeaders responseHeaders);
400     
401 
402 
403 
404 
405 
406 
407 
408 
409 
410 
411     public ChannelFuture close(Channel channel, CloseWebSocketFrame frame) {
412         ObjectUtil.checkNotNull(channel, "channel");
413         return close(channel, frame, channel.newPromise());
414     }
415 
416     
417 
418 
419 
420 
421 
422 
423 
424 
425 
426 
427 
428 
429     public ChannelFuture close(Channel channel, CloseWebSocketFrame frame, ChannelPromise promise) {
430         return close0(channel, frame, promise);
431     }
432 
433     
434 
435 
436 
437 
438 
439 
440 
441     public ChannelFuture close(ChannelHandlerContext ctx, CloseWebSocketFrame frame) {
442         ObjectUtil.checkNotNull(ctx, "ctx");
443         return close(ctx, frame, ctx.newPromise());
444     }
445 
446     
447 
448 
449 
450 
451 
452 
453 
454 
455 
456     public ChannelFuture close(ChannelHandlerContext ctx, CloseWebSocketFrame frame, ChannelPromise promise) {
457         ObjectUtil.checkNotNull(ctx, "ctx");
458         return close0(ctx, frame, promise).addListener(ChannelFutureListener.CLOSE);
459     }
460 
461     private ChannelFuture close0(ChannelOutboundInvoker invoker, CloseWebSocketFrame frame, ChannelPromise promise) {
462         return invoker.writeAndFlush(frame, promise).addListener(ChannelFutureListener.CLOSE);
463     }
464 
465     
466 
467 
468 
469 
470 
471 
472     protected String selectSubprotocol(String requestedSubprotocols) {
473         if (requestedSubprotocols == null || subprotocols.length == 0) {
474             return null;
475         }
476 
477         String[] requestedSubprotocolArray = requestedSubprotocols.split(",");
478         for (String p: requestedSubprotocolArray) {
479             String requestedSubprotocol = p.trim();
480 
481             for (String supportedSubprotocol: subprotocols) {
482                 if (SUB_PROTOCOL_WILDCARD.equals(supportedSubprotocol)
483                         || requestedSubprotocol.equals(supportedSubprotocol)) {
484                     selectedSubprotocol = requestedSubprotocol;
485                     return requestedSubprotocol;
486                 }
487             }
488         }
489 
490         
491         return null;
492     }
493 
494     
495 
496 
497 
498 
499 
500     public String selectedSubprotocol() {
501         return selectedSubprotocol;
502     }
503 
504     
505 
506 
507     protected abstract WebSocketFrameDecoder newWebsocketDecoder();
508 
509     
510 
511 
512     protected abstract WebSocketFrameEncoder newWebSocketEncoder();
513 
514 }