1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.handler.codec.http.websocketx;
17
18 import io.netty.buffer.Unpooled;
19 import io.netty.channel.Channel;
20 import io.netty.channel.ChannelFuture;
21 import io.netty.channel.ChannelFutureListener;
22 import io.netty.channel.ChannelHandler;
23 import io.netty.channel.ChannelHandlerContext;
24 import io.netty.channel.ChannelInboundHandlerAdapter;
25 import io.netty.channel.ChannelOutboundInvoker;
26 import io.netty.channel.ChannelPipeline;
27 import io.netty.channel.ChannelPromise;
28 import io.netty.handler.codec.http.DefaultFullHttpResponse;
29 import io.netty.handler.codec.http.EmptyHttpHeaders;
30 import io.netty.handler.codec.http.FullHttpRequest;
31 import io.netty.handler.codec.http.FullHttpResponse;
32 import io.netty.handler.codec.http.HttpClientCodec;
33 import io.netty.handler.codec.http.HttpContentDecompressor;
34 import io.netty.handler.codec.http.HttpHeaderNames;
35 import io.netty.handler.codec.http.HttpHeaders;
36 import io.netty.handler.codec.http.HttpObject;
37 import io.netty.handler.codec.http.HttpObjectAggregator;
38 import io.netty.handler.codec.http.HttpRequestEncoder;
39 import io.netty.handler.codec.http.HttpResponse;
40 import io.netty.handler.codec.http.HttpResponseDecoder;
41 import io.netty.handler.codec.http.HttpScheme;
42 import io.netty.handler.codec.http.LastHttpContent;
43 import io.netty.util.NetUtil;
44 import io.netty.util.ReferenceCountUtil;
45 import io.netty.util.internal.ObjectUtil;
46
47 import java.net.URI;
48 import java.nio.channels.ClosedChannelException;
49 import java.util.Locale;
50 import java.util.concurrent.Future;
51 import java.util.concurrent.TimeUnit;
52 import java.util.concurrent.atomic.AtomicIntegerFieldUpdater;
53
54
55
56
57 public abstract class WebSocketClientHandshaker {
58
59 private static final String HTTP_SCHEME_PREFIX = HttpScheme.HTTP + "://";
60 private static final String HTTPS_SCHEME_PREFIX = HttpScheme.HTTPS + "://";
61 protected static final int DEFAULT_FORCE_CLOSE_TIMEOUT_MILLIS = 10000;
62
63 private final URI uri;
64
65 private final WebSocketVersion version;
66
67 private volatile boolean handshakeComplete;
68
69 private volatile long forceCloseTimeoutMillis = DEFAULT_FORCE_CLOSE_TIMEOUT_MILLIS;
70
71 private volatile int forceCloseInit;
72
73 private static final AtomicIntegerFieldUpdater<WebSocketClientHandshaker> FORCE_CLOSE_INIT_UPDATER =
74 AtomicIntegerFieldUpdater.newUpdater(WebSocketClientHandshaker.class, "forceCloseInit");
75
76 private volatile boolean forceCloseComplete;
77
78 private final String expectedSubprotocol;
79
80 private volatile String actualSubprotocol;
81
82 protected final HttpHeaders customHeaders;
83
84 private final int maxFramePayloadLength;
85
86 private final boolean absoluteUpgradeUrl;
87
88 protected final boolean generateOriginHeader;
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105 protected WebSocketClientHandshaker(URI uri, WebSocketVersion version, String subprotocol,
106 HttpHeaders customHeaders, int maxFramePayloadLength) {
107 this(uri, version, subprotocol, customHeaders, maxFramePayloadLength, DEFAULT_FORCE_CLOSE_TIMEOUT_MILLIS);
108 }
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127 protected WebSocketClientHandshaker(URI uri, WebSocketVersion version, String subprotocol,
128 HttpHeaders customHeaders, int maxFramePayloadLength,
129 long forceCloseTimeoutMillis) {
130 this(uri, version, subprotocol, customHeaders, maxFramePayloadLength, forceCloseTimeoutMillis, false);
131 }
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153 protected WebSocketClientHandshaker(URI uri, WebSocketVersion version, String subprotocol,
154 HttpHeaders customHeaders, int maxFramePayloadLength,
155 long forceCloseTimeoutMillis, boolean absoluteUpgradeUrl) {
156 this(uri, version, subprotocol, customHeaders, maxFramePayloadLength, forceCloseTimeoutMillis,
157 absoluteUpgradeUrl, true);
158 }
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183 protected WebSocketClientHandshaker(URI uri, WebSocketVersion version, String subprotocol,
184 HttpHeaders customHeaders, int maxFramePayloadLength,
185 long forceCloseTimeoutMillis, boolean absoluteUpgradeUrl, boolean generateOriginHeader) {
186 this.uri = uri;
187 this.version = version;
188 expectedSubprotocol = subprotocol;
189 this.customHeaders = customHeaders;
190 this.maxFramePayloadLength = maxFramePayloadLength;
191 this.forceCloseTimeoutMillis = forceCloseTimeoutMillis;
192 this.absoluteUpgradeUrl = absoluteUpgradeUrl;
193 this.generateOriginHeader = generateOriginHeader;
194 }
195
196
197
198
199 public URI uri() {
200 return uri;
201 }
202
203
204
205
206 public WebSocketVersion version() {
207 return version;
208 }
209
210
211
212
213 public int maxFramePayloadLength() {
214 return maxFramePayloadLength;
215 }
216
217
218
219
220 public boolean isHandshakeComplete() {
221 return handshakeComplete;
222 }
223
224 private void setHandshakeComplete() {
225 handshakeComplete = true;
226 }
227
228
229
230
231 public String expectedSubprotocol() {
232 return expectedSubprotocol;
233 }
234
235
236
237
238
239 public String actualSubprotocol() {
240 return actualSubprotocol;
241 }
242
243 private void setActualSubprotocol(String actualSubprotocol) {
244 this.actualSubprotocol = actualSubprotocol;
245 }
246
247 public long forceCloseTimeoutMillis() {
248 return forceCloseTimeoutMillis;
249 }
250
251
252
253
254
255 protected boolean isForceCloseComplete() {
256 return forceCloseComplete;
257 }
258
259
260
261
262
263
264
265 public WebSocketClientHandshaker setForceCloseTimeoutMillis(long forceCloseTimeoutMillis) {
266 this.forceCloseTimeoutMillis = forceCloseTimeoutMillis;
267 return this;
268 }
269
270
271
272
273
274
275
276 public ChannelFuture handshake(Channel channel) {
277 ObjectUtil.checkNotNull(channel, "channel");
278 return handshake(channel, channel.newPromise());
279 }
280
281
282
283
284
285
286
287
288
289 public final ChannelFuture handshake(Channel channel, final ChannelPromise promise) {
290 ChannelPipeline pipeline = channel.pipeline();
291 HttpResponseDecoder decoder = pipeline.get(HttpResponseDecoder.class);
292 if (decoder == null) {
293 HttpClientCodec codec = pipeline.get(HttpClientCodec.class);
294 if (codec == null) {
295 promise.setFailure(new IllegalStateException("ChannelPipeline does not contain " +
296 "an HttpResponseDecoder or HttpClientCodec"));
297 return promise;
298 }
299 }
300
301 if (uri.getHost() == null) {
302 if (customHeaders == null || !customHeaders.contains(HttpHeaderNames.HOST)) {
303 promise.setFailure(new IllegalArgumentException("Cannot generate the 'host' header value," +
304 " webSocketURI should contain host or passed through customHeaders"));
305 return promise;
306 }
307
308 if (generateOriginHeader && !customHeaders.contains(HttpHeaderNames.ORIGIN)) {
309 final String originName;
310 if (version == WebSocketVersion.V07 || version == WebSocketVersion.V08) {
311 originName = HttpHeaderNames.SEC_WEBSOCKET_ORIGIN.toString();
312 } else {
313 originName = HttpHeaderNames.ORIGIN.toString();
314 }
315
316 promise.setFailure(new IllegalArgumentException("Cannot generate the '" + originName + "' header" +
317 " value, webSocketURI should contain host or disable generateOriginHeader or pass value" +
318 " through customHeaders"));
319 return promise;
320 }
321 }
322
323 FullHttpRequest request = newHandshakeRequest();
324
325 channel.writeAndFlush(request).addListener(new ChannelFutureListener() {
326 @Override
327 public void operationComplete(ChannelFuture future) {
328 if (future.isSuccess()) {
329 ChannelPipeline p = future.channel().pipeline();
330 ChannelHandlerContext ctx = p.context(HttpRequestEncoder.class);
331 if (ctx == null) {
332 ctx = p.context(HttpClientCodec.class);
333 }
334 if (ctx == null) {
335 promise.setFailure(new IllegalStateException("ChannelPipeline does not contain " +
336 "an HttpRequestEncoder or HttpClientCodec"));
337 return;
338 }
339 p.addAfter(ctx.name(), "ws-encoder", newWebSocketEncoder());
340
341 promise.setSuccess();
342 } else {
343 promise.setFailure(future.cause());
344 }
345 }
346 });
347 return promise;
348 }
349
350
351
352
353 protected abstract FullHttpRequest newHandshakeRequest();
354
355
356
357
358
359
360
361
362
363 public final void finishHandshake(Channel channel, FullHttpResponse response) {
364 verify(response);
365
366
367
368 String receivedProtocol = response.headers().get(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL);
369 receivedProtocol = receivedProtocol != null ? receivedProtocol.trim() : null;
370 String expectedProtocol = expectedSubprotocol != null ? expectedSubprotocol : "";
371 boolean protocolValid = false;
372
373 if (expectedProtocol.isEmpty() && receivedProtocol == null) {
374
375 protocolValid = true;
376 setActualSubprotocol(expectedSubprotocol);
377 } else if (!expectedProtocol.isEmpty() && receivedProtocol != null && !receivedProtocol.isEmpty()) {
378
379 for (String protocol : expectedProtocol.split(",")) {
380 if (protocol.trim().equals(receivedProtocol)) {
381 protocolValid = true;
382 setActualSubprotocol(receivedProtocol);
383 break;
384 }
385 }
386 }
387
388 if (!protocolValid) {
389 throw new WebSocketClientHandshakeException(String.format(
390 "Invalid subprotocol. Actual: %s. Expected one of: %s",
391 receivedProtocol, expectedSubprotocol), response);
392 }
393
394 setHandshakeComplete();
395
396 final ChannelPipeline p = channel.pipeline();
397
398 HttpContentDecompressor decompressor = p.get(HttpContentDecompressor.class);
399 if (decompressor != null) {
400 p.remove(decompressor);
401 }
402
403
404 HttpObjectAggregator aggregator = p.get(HttpObjectAggregator.class);
405 if (aggregator != null) {
406 p.remove(aggregator);
407 }
408
409 ChannelHandlerContext ctx = p.context(HttpResponseDecoder.class);
410 if (ctx == null) {
411 ctx = p.context(HttpClientCodec.class);
412 if (ctx == null) {
413 throw new IllegalStateException("ChannelPipeline does not contain " +
414 "an HttpRequestEncoder or HttpClientCodec");
415 }
416 final HttpClientCodec codec = (HttpClientCodec) ctx.handler();
417
418 codec.removeOutboundHandler();
419
420 p.addAfter(ctx.name(), "ws-decoder", newWebsocketDecoder());
421
422
423
424
425 channel.eventLoop().execute(new Runnable() {
426 @Override
427 public void run() {
428 p.remove(codec);
429 }
430 });
431 } else {
432 if (p.get(HttpRequestEncoder.class) != null) {
433
434 p.remove(HttpRequestEncoder.class);
435 }
436 final ChannelHandlerContext context = ctx;
437 p.addAfter(context.name(), "ws-decoder", newWebsocketDecoder());
438
439
440
441
442 channel.eventLoop().execute(new Runnable() {
443 @Override
444 public void run() {
445 p.remove(context.handler());
446 }
447 });
448 }
449 }
450
451
452
453
454
455
456
457
458
459
460
461 public final ChannelFuture processHandshake(final Channel channel, HttpResponse response) {
462 return processHandshake(channel, response, channel.newPromise());
463 }
464
465
466
467
468
469
470
471
472
473
474
475
476
477 public final ChannelFuture processHandshake(final Channel channel, HttpResponse response,
478 final ChannelPromise promise) {
479 if (response instanceof FullHttpResponse) {
480 try {
481 finishHandshake(channel, (FullHttpResponse) response);
482 promise.setSuccess();
483 } catch (Throwable cause) {
484 promise.setFailure(cause);
485 }
486 } else {
487 ChannelPipeline p = channel.pipeline();
488 ChannelHandlerContext ctx = p.context(HttpResponseDecoder.class);
489 if (ctx == null) {
490 ctx = p.context(HttpClientCodec.class);
491 if (ctx == null) {
492 return promise.setFailure(new IllegalStateException("ChannelPipeline does not contain " +
493 "an HttpResponseDecoder or HttpClientCodec"));
494 }
495 }
496
497 String aggregatorCtx = ctx.name();
498
499 if (version == WebSocketVersion.V00) {
500
501
502 aggregatorCtx = "httpAggregator";
503 p.addAfter(ctx.name(), aggregatorCtx, new HttpObjectAggregator(8192));
504 }
505
506 p.addAfter(aggregatorCtx, "handshaker", new ChannelInboundHandlerAdapter() {
507
508 private FullHttpResponse fullHttpResponse;
509
510 @Override
511 public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
512 if (msg instanceof HttpObject) {
513 try {
514 handleHandshakeResponse(ctx, (HttpObject) msg);
515 } finally {
516 ReferenceCountUtil.release(msg);
517 }
518 } else {
519 super.channelRead(ctx, msg);
520 }
521 }
522
523 @Override
524 public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
525
526 ctx.pipeline().remove(this);
527 promise.setFailure(cause);
528 }
529
530 @Override
531 public void channelInactive(ChannelHandlerContext ctx) throws Exception {
532 try {
533
534 if (!promise.isDone()) {
535 promise.tryFailure(new ClosedChannelException());
536 }
537 ctx.fireChannelInactive();
538 } finally {
539 releaseFullHttpResponse();
540 }
541 }
542
543 @Override
544 public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
545 releaseFullHttpResponse();
546 }
547
548 private void handleHandshakeResponse(ChannelHandlerContext ctx, HttpObject response) {
549 if (response instanceof FullHttpResponse) {
550 ctx.pipeline().remove(this);
551 tryFinishHandshake((FullHttpResponse) response);
552 return;
553 }
554
555 if (response instanceof LastHttpContent) {
556 assert fullHttpResponse != null;
557 FullHttpResponse handshakeResponse = fullHttpResponse;
558 fullHttpResponse = null;
559 try {
560 ctx.pipeline().remove(this);
561 tryFinishHandshake(handshakeResponse);
562 } finally {
563 handshakeResponse.release();
564 }
565 return;
566 }
567
568 if (response instanceof HttpResponse) {
569 HttpResponse httpResponse = (HttpResponse) response;
570 fullHttpResponse = new DefaultFullHttpResponse(httpResponse.protocolVersion(),
571 httpResponse.status(), Unpooled.EMPTY_BUFFER, httpResponse.headers(),
572 EmptyHttpHeaders.INSTANCE);
573 if (httpResponse.decoderResult().isFailure()) {
574 fullHttpResponse.setDecoderResult(httpResponse.decoderResult());
575 }
576 }
577 }
578
579 private void tryFinishHandshake(FullHttpResponse fullHttpResponse) {
580 try {
581 finishHandshake(channel, fullHttpResponse);
582 promise.setSuccess();
583 } catch (Throwable cause) {
584 promise.setFailure(cause);
585 }
586 }
587
588 private void releaseFullHttpResponse() {
589 if (fullHttpResponse != null) {
590 fullHttpResponse.release();
591 fullHttpResponse = null;
592 }
593 }
594 });
595 try {
596 ctx.fireChannelRead(ReferenceCountUtil.retain(response));
597 } catch (Throwable cause) {
598 promise.setFailure(cause);
599 }
600 }
601 return promise;
602 }
603
604
605
606
607 protected abstract void verify(FullHttpResponse response);
608
609
610
611
612 protected abstract WebSocketFrameDecoder newWebsocketDecoder();
613
614
615
616
617 protected abstract WebSocketFrameEncoder newWebSocketEncoder();
618
619
620
621
622
623
624
625
626
627
628
629
630 public ChannelFuture close(Channel channel, CloseWebSocketFrame frame) {
631 ObjectUtil.checkNotNull(channel, "channel");
632 return close(channel, frame, channel.newPromise());
633 }
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648 public ChannelFuture close(Channel channel, CloseWebSocketFrame frame, ChannelPromise promise) {
649 ObjectUtil.checkNotNull(channel, "channel");
650 return close0(channel, channel, frame, promise);
651 }
652
653
654
655
656
657
658
659
660
661 public ChannelFuture close(ChannelHandlerContext ctx, CloseWebSocketFrame frame) {
662 ObjectUtil.checkNotNull(ctx, "ctx");
663 return close(ctx, frame, ctx.newPromise());
664 }
665
666
667
668
669
670
671
672
673
674
675
676 public ChannelFuture close(ChannelHandlerContext ctx, CloseWebSocketFrame frame, ChannelPromise promise) {
677 ObjectUtil.checkNotNull(ctx, "ctx");
678 return close0(ctx, ctx.channel(), frame, promise);
679 }
680
681 private ChannelFuture close0(final ChannelOutboundInvoker invoker, final Channel channel,
682 CloseWebSocketFrame frame, ChannelPromise promise) {
683 invoker.writeAndFlush(frame, promise);
684 final long forceCloseTimeoutMillis = this.forceCloseTimeoutMillis;
685 final WebSocketClientHandshaker handshaker = this;
686 if (forceCloseTimeoutMillis <= 0 || !channel.isActive() || forceCloseInit != 0) {
687 return promise;
688 }
689
690 promise.addListener(new ChannelFutureListener() {
691 @Override
692 public void operationComplete(ChannelFuture future) {
693
694
695
696
697 if (future.isSuccess() && channel.isActive() &&
698 FORCE_CLOSE_INIT_UPDATER.compareAndSet(handshaker, 0, 1)) {
699 final Future<?> forceCloseFuture = channel.eventLoop().schedule(new Runnable() {
700 @Override
701 public void run() {
702 if (channel.isActive()) {
703 invoker.close();
704 forceCloseComplete = true;
705 }
706 }
707 }, forceCloseTimeoutMillis, TimeUnit.MILLISECONDS);
708
709 channel.closeFuture().addListener(new ChannelFutureListener() {
710 @Override
711 public void operationComplete(ChannelFuture future) throws Exception {
712 forceCloseFuture.cancel(false);
713 }
714 });
715 }
716 }
717 });
718 return promise;
719 }
720
721
722
723
724 protected String upgradeUrl(URI wsURL) {
725 if (absoluteUpgradeUrl) {
726 return wsURL.toString();
727 }
728
729 String path = wsURL.getRawPath();
730 path = path == null || path.isEmpty() ? "/" : path;
731 String query = wsURL.getRawQuery();
732 return query != null && !query.isEmpty() ? path + '?' + query : path;
733 }
734
735 static CharSequence websocketHostValue(URI wsURL) {
736 int port = wsURL.getPort();
737 if (port == -1) {
738 return wsURL.getHost();
739 }
740 String host = wsURL.getHost();
741 String scheme = wsURL.getScheme();
742 if (port == HttpScheme.HTTP.port()) {
743 return HttpScheme.HTTP.name().contentEquals(scheme)
744 || WebSocketScheme.WS.name().contentEquals(scheme) ?
745 host : NetUtil.toSocketAddressString(host, port);
746 }
747 if (port == HttpScheme.HTTPS.port()) {
748 return HttpScheme.HTTPS.name().contentEquals(scheme)
749 || WebSocketScheme.WSS.name().contentEquals(scheme) ?
750 host : NetUtil.toSocketAddressString(host, port);
751 }
752
753
754
755 return NetUtil.toSocketAddressString(host, port);
756 }
757
758 static CharSequence websocketOriginValue(URI wsURL) {
759 String scheme = wsURL.getScheme();
760 final String schemePrefix;
761 int port = wsURL.getPort();
762 final int defaultPort;
763 if (WebSocketScheme.WSS.name().contentEquals(scheme)
764 || HttpScheme.HTTPS.name().contentEquals(scheme)
765 || (scheme == null && port == WebSocketScheme.WSS.port())) {
766
767 schemePrefix = HTTPS_SCHEME_PREFIX;
768 defaultPort = WebSocketScheme.WSS.port();
769 } else {
770 schemePrefix = HTTP_SCHEME_PREFIX;
771 defaultPort = WebSocketScheme.WS.port();
772 }
773
774
775 String host = wsURL.getHost().toLowerCase(Locale.US);
776
777 if (port != defaultPort && port != -1) {
778
779
780 return schemePrefix + NetUtil.toSocketAddressString(host, port);
781 }
782 return schemePrefix + host;
783 }
784 }