1 /*
2 * Copyright 2014 The Netty Project
3 *
4 * The Netty Project licenses this file to you under the Apache License,
5 * version 2.0 (the "License"); you may not use this file except in compliance
6 * with the License. You may obtain a copy of the License at:
7 *
8 * https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13 * License for the specific language governing permissions and limitations
14 * under the License.
15 */
16 package io.netty.handler.codec.http.websocketx.extensions;
17
18 import static io.netty.util.internal.ObjectUtil.checkNonEmpty;
19
20 import io.netty.buffer.ByteBuf;
21 import io.netty.buffer.Unpooled;
22 import io.netty.channel.ChannelDuplexHandler;
23 import io.netty.channel.ChannelFuture;
24 import io.netty.channel.ChannelFutureListener;
25 import io.netty.channel.ChannelHandlerContext;
26 import io.netty.channel.ChannelPromise;
27 import io.netty.handler.codec.http.DefaultHttpRequest;
28 import io.netty.handler.codec.http.DefaultHttpResponse;
29 import io.netty.handler.codec.http.HttpHeaderNames;
30 import io.netty.handler.codec.http.HttpHeaders;
31 import io.netty.handler.codec.http.HttpRequest;
32 import io.netty.handler.codec.http.HttpResponse;
33 import io.netty.handler.codec.http.HttpResponseStatus;
34 import io.netty.handler.codec.http.LastHttpContent;
35 import io.netty.util.internal.UnstableApi;
36
37 import java.util.ArrayDeque;
38 import java.util.ArrayList;
39 import java.util.Arrays;
40 import java.util.Collections;
41 import java.util.Iterator;
42 import java.util.List;
43 import java.util.Queue;
44
45 /**
46 * This handler negotiates and initializes the WebSocket Extensions.
47 *
48 * It negotiates the extensions based on the client desired order,
49 * ensures that the successfully negotiated extensions are consistent between them,
50 * and initializes the channel pipeline with the extension decoder and encoder.
51 *
52 * Find a basic implementation for compression extensions at
53 * <tt>io.netty.handler.codec.http.websocketx.extensions.compression.WebSocketServerCompressionHandler</tt>.
54 */
55 public class WebSocketServerExtensionHandler extends ChannelDuplexHandler {
56
57 private final List<WebSocketServerExtensionHandshaker> extensionHandshakers;
58
59 private final Queue<List<WebSocketServerExtension>> validExtensions =
60 new ArrayDeque<List<WebSocketServerExtension>>(4);
61
62 /**
63 * Constructor
64 *
65 * @param extensionHandshakers
66 * The extension handshaker in priority order. A handshaker could be repeated many times
67 * with fallback configuration.
68 */
69 public WebSocketServerExtensionHandler(WebSocketServerExtensionHandshaker... extensionHandshakers) {
70 this.extensionHandshakers = Arrays.asList(checkNonEmpty(extensionHandshakers, "extensionHandshakers"));
71 }
72
73 @Override
74 public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
75 // JDK type checks vs non-implemented interfaces costs O(N), where
76 // N is the number of interfaces already implemented by the concrete type that's being tested.
77 // The only requirement for this call is to make HttpRequest(s) implementors to call onHttpRequestChannelRead
78 // and super.channelRead the others, but due to the O(n) cost we perform few fast-path for commonly met
79 // singleton and/or concrete types, to save performing such slow type checks.
80 if (msg != LastHttpContent.EMPTY_LAST_CONTENT) {
81 if (msg instanceof DefaultHttpRequest) {
82 // fast-path
83 onHttpRequestChannelRead(ctx, (DefaultHttpRequest) msg);
84 } else if (msg instanceof HttpRequest) {
85 // slow path
86 onHttpRequestChannelRead(ctx, (HttpRequest) msg);
87 } else {
88 super.channelRead(ctx, msg);
89 }
90 } else {
91 super.channelRead(ctx, msg);
92 }
93 }
94
95 /**
96 * This is a method exposed to perform fail-fast checks of user-defined http types.<p>
97 * eg:<br>
98 * If the user has defined a specific {@link HttpRequest} type i.e.{@code CustomHttpRequest} and
99 * {@link #channelRead} can receive {@link LastHttpContent#EMPTY_LAST_CONTENT} {@code msg}
100 * types too, can override it like this:
101 * <pre>
102 * public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
103 * if (msg != LastHttpContent.EMPTY_LAST_CONTENT) {
104 * if (msg instanceof CustomHttpRequest) {
105 * onHttpRequestChannelRead(ctx, (CustomHttpRequest) msg);
106 * } else {
107 * // if it's handling other HttpRequest types it MUST use onHttpRequestChannelRead again
108 * // or have to delegate it to super.channelRead (that can perform redundant checks).
109 * // If msg is not implementing HttpRequest, it can call ctx.fireChannelRead(msg) on it
110 * // ...
111 * super.channelRead(ctx, msg);
112 * }
113 * } else {
114 * // given that msg isn't a HttpRequest type we can just skip calling super.channelRead
115 * ctx.fireChannelRead(msg);
116 * }
117 * }
118 * </pre>
119 * <strong>IMPORTANT:</strong>
120 * It already call {@code super.channelRead(ctx, request)} before returning.
121 */
122 @UnstableApi
123 protected void onHttpRequestChannelRead(ChannelHandlerContext ctx, HttpRequest request) throws Exception {
124 List<WebSocketServerExtension> validExtensionsList = null;
125
126 if (WebSocketExtensionUtil.isWebsocketUpgrade(request.headers())) {
127 String extensionsHeader = request.headers().getAsString(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS);
128
129 if (extensionsHeader != null) {
130 List<WebSocketExtensionData> extensions =
131 WebSocketExtensionUtil.extractExtensions(extensionsHeader);
132 int rsv = 0;
133
134 for (WebSocketExtensionData extensionData : extensions) {
135 Iterator<WebSocketServerExtensionHandshaker> extensionHandshakersIterator =
136 extensionHandshakers.iterator();
137 WebSocketServerExtension validExtension = null;
138
139 while (validExtension == null && extensionHandshakersIterator.hasNext()) {
140 WebSocketServerExtensionHandshaker extensionHandshaker =
141 extensionHandshakersIterator.next();
142 validExtension = extensionHandshaker.handshakeExtension(extensionData);
143 }
144
145 if (validExtension != null && ((validExtension.rsv() & rsv) == 0)) {
146 if (validExtensionsList == null) {
147 validExtensionsList = new ArrayList<WebSocketServerExtension>(1);
148 }
149 rsv = rsv | validExtension.rsv();
150 validExtensionsList.add(validExtension);
151 }
152 }
153 }
154 }
155
156 if (validExtensionsList == null) {
157 validExtensionsList = Collections.emptyList();
158 }
159 validExtensions.offer(validExtensionsList);
160 super.channelRead(ctx, request);
161 }
162
163 @Override
164 public void write(final ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
165 if (msg != Unpooled.EMPTY_BUFFER && !(msg instanceof ByteBuf)) {
166 if (msg instanceof DefaultHttpResponse) {
167 onHttpResponseWrite(ctx, (DefaultHttpResponse) msg, promise);
168 } else if (msg instanceof HttpResponse) {
169 onHttpResponseWrite(ctx, (HttpResponse) msg, promise);
170 } else {
171 super.write(ctx, msg, promise);
172 }
173 } else {
174 super.write(ctx, msg, promise);
175 }
176 }
177
178 /**
179 * This is a method exposed to perform fail-fast checks of user-defined http types.<p>
180 * eg:<br>
181 * If the user has defined a specific {@link HttpResponse} type i.e.{@code CustomHttpResponse} and
182 * {@link #write} can receive {@link ByteBuf} {@code msg} types too, it can be overridden like this:
183 * <pre>
184 * public void write(final ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
185 * if (msg != Unpooled.EMPTY_BUFFER && !(msg instanceof ByteBuf)) {
186 * if (msg instanceof CustomHttpResponse) {
187 * onHttpResponseWrite(ctx, (CustomHttpResponse) msg, promise);
188 * } else {
189 * // if it's handling other HttpResponse types it MUST use onHttpResponseWrite again
190 * // or have to delegate it to super.write (that can perform redundant checks).
191 * // If msg is not implementing HttpResponse, it can call ctx.write(msg, promise) on it
192 * // ...
193 * super.write(ctx, msg, promise);
194 * }
195 * } else {
196 * // given that msg isn't a HttpResponse type we can just skip calling super.write
197 * ctx.write(msg, promise);
198 * }
199 * }
200 * </pre>
201 * <strong>IMPORTANT:</strong>
202 * It already call {@code super.write(ctx, response, promise)} before returning.
203 */
204 @UnstableApi
205 protected void onHttpResponseWrite(ChannelHandlerContext ctx, HttpResponse response, ChannelPromise promise)
206 throws Exception {
207 List<WebSocketServerExtension> validExtensionsList = validExtensions.poll();
208 // checking the status is faster than looking at headers so we do this first
209 if (HttpResponseStatus.SWITCHING_PROTOCOLS.equals(response.status())) {
210 handlePotentialUpgrade(ctx, promise, response, validExtensionsList);
211 }
212 super.write(ctx, response, promise);
213 }
214
215 private void handlePotentialUpgrade(final ChannelHandlerContext ctx,
216 ChannelPromise promise, HttpResponse httpResponse,
217 final List<WebSocketServerExtension> validExtensionsList) {
218 HttpHeaders headers = httpResponse.headers();
219
220 if (WebSocketExtensionUtil.isWebsocketUpgrade(headers)) {
221 if (validExtensionsList != null && !validExtensionsList.isEmpty()) {
222 String headerValue = headers.getAsString(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS);
223 List<WebSocketExtensionData> extraExtensions =
224 new ArrayList<WebSocketExtensionData>(extensionHandshakers.size());
225 for (WebSocketServerExtension extension : validExtensionsList) {
226 extraExtensions.add(extension.newReponseData());
227 }
228 String newHeaderValue = WebSocketExtensionUtil
229 .computeMergeExtensionsHeaderValue(headerValue, extraExtensions);
230 promise.addListener(new ChannelFutureListener() {
231 @Override
232 public void operationComplete(ChannelFuture future) {
233 if (future.isSuccess()) {
234 for (WebSocketServerExtension extension : validExtensionsList) {
235 WebSocketExtensionDecoder decoder = extension.newExtensionDecoder();
236 WebSocketExtensionEncoder encoder = extension.newExtensionEncoder();
237 String name = ctx.name();
238 ctx.pipeline()
239 .addAfter(name, decoder.getClass().getName(), decoder)
240 .addAfter(name, encoder.getClass().getName(), encoder);
241 }
242 }
243 }
244 });
245
246 if (newHeaderValue != null) {
247 headers.set(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS, newHeaderValue);
248 }
249 }
250
251 promise.addListener(new ChannelFutureListener() {
252 @Override
253 public void operationComplete(ChannelFuture future) {
254 if (future.isSuccess()) {
255 ctx.pipeline().remove(WebSocketServerExtensionHandler.this);
256 }
257 }
258 });
259 }
260 }
261 }