1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.handler.codec.http.websocketx.extensions;
17
18 import static io.netty.util.internal.ObjectUtil.checkNonEmpty;
19
20 import io.netty.channel.ChannelDuplexHandler;
21 import io.netty.channel.ChannelHandlerContext;
22 import io.netty.channel.ChannelPromise;
23 import io.netty.handler.codec.CodecException;
24 import io.netty.handler.codec.http.HttpHeaderNames;
25 import io.netty.handler.codec.http.HttpRequest;
26 import io.netty.handler.codec.http.HttpResponse;
27
28 import java.util.ArrayList;
29 import java.util.Arrays;
30 import java.util.Iterator;
31 import java.util.List;
32
33
34
35
36
37
38
39
40
41
42
43 public class WebSocketClientExtensionHandler extends ChannelDuplexHandler {
44
45 private final List<WebSocketClientExtensionHandshaker> extensionHandshakers;
46
47
48
49
50
51
52
53
54 public WebSocketClientExtensionHandler(WebSocketClientExtensionHandshaker... extensionHandshakers) {
55 this.extensionHandshakers = Arrays.asList(checkNonEmpty(extensionHandshakers, "extensionHandshakers"));
56 }
57
58 @Override
59 public void write(final ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
60 if (msg instanceof HttpRequest && WebSocketExtensionUtil.isWebsocketUpgrade(((HttpRequest) msg).headers())) {
61 HttpRequest request = (HttpRequest) msg;
62 String headerValue = request.headers().getAsString(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS);
63 List<WebSocketExtensionData> extraExtensions =
64 new ArrayList<WebSocketExtensionData>(extensionHandshakers.size());
65 for (WebSocketClientExtensionHandshaker extensionHandshaker : extensionHandshakers) {
66 extraExtensions.add(extensionHandshaker.newRequestData());
67 }
68 String newHeaderValue = WebSocketExtensionUtil
69 .computeMergeExtensionsHeaderValue(headerValue, extraExtensions);
70
71 request.headers().set(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS, newHeaderValue);
72 }
73
74 super.write(ctx, msg, promise);
75 }
76
77 @Override
78 public void channelRead(ChannelHandlerContext ctx, Object msg)
79 throws Exception {
80 if (msg instanceof HttpResponse) {
81 HttpResponse response = (HttpResponse) msg;
82
83 if (WebSocketExtensionUtil.isWebsocketUpgrade(response.headers())) {
84 String extensionsHeader = response.headers().getAsString(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS);
85
86 if (extensionsHeader != null) {
87 List<WebSocketExtensionData> extensions =
88 WebSocketExtensionUtil.extractExtensions(extensionsHeader);
89 List<WebSocketClientExtension> validExtensions =
90 new ArrayList<WebSocketClientExtension>(extensions.size());
91 int rsv = 0;
92
93 for (WebSocketExtensionData extensionData : extensions) {
94 Iterator<WebSocketClientExtensionHandshaker> extensionHandshakersIterator =
95 extensionHandshakers.iterator();
96 WebSocketClientExtension validExtension = null;
97
98 while (validExtension == null && extensionHandshakersIterator.hasNext()) {
99 WebSocketClientExtensionHandshaker extensionHandshaker =
100 extensionHandshakersIterator.next();
101 validExtension = extensionHandshaker.handshakeExtension(extensionData);
102 }
103
104 if (validExtension != null && ((validExtension.rsv() & rsv) == 0)) {
105 rsv = rsv | validExtension.rsv();
106 validExtensions.add(validExtension);
107 } else {
108 throw new CodecException(
109 "invalid WebSocket Extension handshake for \"" + extensionsHeader + '"');
110 }
111 }
112
113 for (WebSocketClientExtension validExtension : validExtensions) {
114 WebSocketExtensionDecoder decoder = validExtension.newExtensionDecoder();
115 WebSocketExtensionEncoder encoder = validExtension.newExtensionEncoder();
116 ctx.pipeline().addAfter(ctx.name(), decoder.getClass().getName(), decoder);
117 ctx.pipeline().addAfter(ctx.name(), encoder.getClass().getName(), encoder);
118 }
119 }
120
121 ctx.pipeline().remove(ctx.name());
122 }
123 }
124
125 super.channelRead(ctx, msg);
126 }
127 }
128