1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54 package io.netty.handler.codec.http.websocketx;
55
56 import io.netty.buffer.ByteBuf;
57 import io.netty.buffer.Unpooled;
58 import io.netty.channel.ChannelFutureListener;
59 import io.netty.channel.ChannelHandlerContext;
60 import io.netty.handler.codec.ByteToMessageDecoder;
61 import io.netty.handler.codec.TooLongFrameException;
62 import io.netty.util.internal.ObjectUtil;
63 import io.netty.util.internal.logging.InternalLogger;
64 import io.netty.util.internal.logging.InternalLoggerFactory;
65
66 import java.nio.ByteOrder;
67 import java.util.List;
68
69 import static io.netty.buffer.ByteBufUtil.readBytes;
70
71
72
73
74
75 public class WebSocket08FrameDecoder extends ByteToMessageDecoder
76 implements WebSocketFrameDecoder {
77
78 enum State {
79 READING_FIRST,
80 READING_SECOND,
81 READING_SIZE,
82 MASKING_KEY,
83 PAYLOAD,
84 CORRUPT
85 }
86
87 private static final InternalLogger logger = InternalLoggerFactory.getInstance(WebSocket08FrameDecoder.class);
88
89 private static final byte OPCODE_CONT = 0x0;
90 private static final byte OPCODE_TEXT = 0x1;
91 private static final byte OPCODE_BINARY = 0x2;
92 private static final byte OPCODE_CLOSE = 0x8;
93 private static final byte OPCODE_PING = 0x9;
94 private static final byte OPCODE_PONG = 0xA;
95
96 private final WebSocketDecoderConfig config;
97
98 private int fragmentedFramesCount;
99 private boolean frameFinalFlag;
100 private boolean frameMasked;
101 private int frameRsv;
102 private int frameOpcode;
103 private long framePayloadLength;
104 private int mask;
105 private int framePayloadLen1;
106 private boolean receivedClosingHandshake;
107 private State state = State.READING_FIRST;
108
109
110
111
112
113
114
115
116
117
118
119
120
121 public WebSocket08FrameDecoder(boolean expectMaskedFrames, boolean allowExtensions, int maxFramePayloadLength) {
122 this(expectMaskedFrames, allowExtensions, maxFramePayloadLength, false);
123 }
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140 public WebSocket08FrameDecoder(boolean expectMaskedFrames, boolean allowExtensions, int maxFramePayloadLength,
141 boolean allowMaskMismatch) {
142 this(WebSocketDecoderConfig.newBuilder()
143 .expectMaskedFrames(expectMaskedFrames)
144 .allowExtensions(allowExtensions)
145 .maxFramePayloadLength(maxFramePayloadLength)
146 .allowMaskMismatch(allowMaskMismatch)
147 .build());
148 }
149
150
151
152
153
154
155
156 public WebSocket08FrameDecoder(WebSocketDecoderConfig decoderConfig) {
157 this.config = ObjectUtil.checkNotNull(decoderConfig, "decoderConfig");
158 }
159
160 @Override
161 protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
162
163 if (receivedClosingHandshake) {
164 in.skipBytes(actualReadableBytes());
165 return;
166 }
167
168 switch (state) {
169 case READING_FIRST:
170 if (!in.isReadable()) {
171 return;
172 }
173
174 framePayloadLength = 0;
175
176
177 byte b = in.readByte();
178 frameFinalFlag = (b & 0x80) != 0;
179 frameRsv = (b & 0x70) >> 4;
180 frameOpcode = b & 0x0F;
181
182 if (logger.isTraceEnabled()) {
183 logger.trace("Decoding WebSocket Frame opCode={}", frameOpcode);
184 }
185
186 state = State.READING_SECOND;
187 case READING_SECOND:
188 if (!in.isReadable()) {
189 return;
190 }
191
192 b = in.readByte();
193 frameMasked = (b & 0x80) != 0;
194 framePayloadLen1 = b & 0x7F;
195
196 if (frameRsv != 0 && !config.allowExtensions()) {
197 protocolViolation(ctx, in, "RSV != 0 and no extension negotiated, RSV:" + frameRsv);
198 return;
199 }
200
201 if (!config.allowMaskMismatch() && config.expectMaskedFrames() != frameMasked) {
202 protocolViolation(ctx, in, "received a frame that is not masked as expected");
203 return;
204 }
205
206 if (frameOpcode > 7) {
207
208
209 if (!frameFinalFlag) {
210 protocolViolation(ctx, in, "fragmented control frame");
211 return;
212 }
213
214
215 if (framePayloadLen1 > 125) {
216 protocolViolation(ctx, in, "control frame with payload length > 125 octets");
217 return;
218 }
219
220
221 if (!(frameOpcode == OPCODE_CLOSE || frameOpcode == OPCODE_PING
222 || frameOpcode == OPCODE_PONG)) {
223 protocolViolation(ctx, in, "control frame using reserved opcode " + frameOpcode);
224 return;
225 }
226
227
228
229
230 if (frameOpcode == 8 && framePayloadLen1 == 1) {
231 protocolViolation(ctx, in, "received close control frame with payload len 1");
232 return;
233 }
234 } else {
235
236 if (!(frameOpcode == OPCODE_CONT || frameOpcode == OPCODE_TEXT
237 || frameOpcode == OPCODE_BINARY)) {
238 protocolViolation(ctx, in, "data frame using reserved opcode " + frameOpcode);
239 return;
240 }
241
242
243 if (fragmentedFramesCount == 0 && frameOpcode == OPCODE_CONT) {
244 protocolViolation(ctx, in, "received continuation data frame outside fragmented message");
245 return;
246 }
247
248
249 if (fragmentedFramesCount != 0 && frameOpcode != OPCODE_CONT) {
250 protocolViolation(ctx, in,
251 "received non-continuation data frame while inside fragmented message");
252 return;
253 }
254 }
255
256 state = State.READING_SIZE;
257 case READING_SIZE:
258
259
260 if (framePayloadLen1 == 126) {
261 if (in.readableBytes() < 2) {
262 return;
263 }
264 framePayloadLength = in.readUnsignedShort();
265 if (framePayloadLength < 126) {
266 protocolViolation(ctx, in, "invalid data frame length (not using minimal length encoding)");
267 return;
268 }
269 } else if (framePayloadLen1 == 127) {
270 if (in.readableBytes() < 8) {
271 return;
272 }
273 framePayloadLength = in.readLong();
274 if (framePayloadLength < 0) {
275 protocolViolation(ctx, in, "invalid data frame length (negative length)");
276 return;
277 }
278
279 if (framePayloadLength < 65536) {
280 protocolViolation(ctx, in, "invalid data frame length (not using minimal length encoding)");
281 return;
282 }
283 } else {
284 framePayloadLength = framePayloadLen1;
285 }
286
287 if (framePayloadLength > config.maxFramePayloadLength()) {
288 protocolViolation(ctx, in, WebSocketCloseStatus.MESSAGE_TOO_BIG,
289 "Max frame length of " + config.maxFramePayloadLength() + " has been exceeded.");
290 return;
291 }
292
293 if (logger.isTraceEnabled()) {
294 logger.trace("Decoding WebSocket Frame length={}", framePayloadLength);
295 }
296
297 state = State.MASKING_KEY;
298 case MASKING_KEY:
299 if (frameMasked) {
300 if (in.readableBytes() < 4) {
301 return;
302 }
303 mask = in.readInt();
304 }
305 state = State.PAYLOAD;
306 case PAYLOAD:
307 if (in.readableBytes() < framePayloadLength) {
308 return;
309 }
310
311 ByteBuf payloadBuffer = Unpooled.EMPTY_BUFFER;
312 try {
313 if (framePayloadLength > 0) {
314 payloadBuffer = readBytes(ctx.alloc(), in, toFrameLength(framePayloadLength));
315 }
316
317
318
319 state = State.READING_FIRST;
320
321
322 if (frameMasked & framePayloadLength > 0) {
323 unmask(payloadBuffer);
324 }
325
326
327
328 if (frameOpcode == OPCODE_PING) {
329 out.add(new PingWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer));
330 payloadBuffer = null;
331 return;
332 }
333 if (frameOpcode == OPCODE_PONG) {
334 out.add(new PongWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer));
335 payloadBuffer = null;
336 return;
337 }
338 if (frameOpcode == OPCODE_CLOSE) {
339 receivedClosingHandshake = true;
340 checkCloseFrameBody(ctx, payloadBuffer);
341 out.add(new CloseWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer));
342 payloadBuffer = null;
343 return;
344 }
345
346
347
348 if (frameFinalFlag) {
349
350
351 fragmentedFramesCount = 0;
352 } else {
353
354 fragmentedFramesCount++;
355 }
356
357
358 if (frameOpcode == OPCODE_TEXT) {
359 out.add(new TextWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer));
360 payloadBuffer = null;
361 return;
362 } else if (frameOpcode == OPCODE_BINARY) {
363 out.add(new BinaryWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer));
364 payloadBuffer = null;
365 return;
366 } else if (frameOpcode == OPCODE_CONT) {
367 out.add(new ContinuationWebSocketFrame(frameFinalFlag, frameRsv,
368 payloadBuffer));
369 payloadBuffer = null;
370 return;
371 } else {
372 throw new UnsupportedOperationException("Cannot decode web socket frame with opcode: "
373 + frameOpcode);
374 }
375 } finally {
376 if (payloadBuffer != null) {
377 payloadBuffer.release();
378 }
379 }
380 case CORRUPT:
381 if (in.isReadable()) {
382
383
384 in.readByte();
385 }
386 return;
387 default:
388 throw new Error("Shouldn't reach here.");
389 }
390 }
391
392 private void unmask(ByteBuf frame) {
393 int i = frame.readerIndex();
394 int end = frame.writerIndex();
395
396 ByteOrder order = frame.order();
397
398 int intMask = mask;
399
400 long longMask = intMask & 0xFFFFFFFFL;
401 longMask |= longMask << 32;
402
403 for (int lim = end - 7; i < lim; i += 8) {
404 frame.setLong(i, frame.getLong(i) ^ longMask);
405 }
406
407 if (i < end - 3) {
408 frame.setInt(i, frame.getInt(i) ^ (int) longMask);
409 i += 4;
410 }
411
412 if (order == ByteOrder.LITTLE_ENDIAN) {
413 intMask = Integer.reverseBytes(intMask);
414 }
415
416 int maskOffset = 0;
417 for (; i < end; i++) {
418 frame.setByte(i, frame.getByte(i) ^ WebSocketUtil.byteAtIndex(intMask, maskOffset++ & 3));
419 }
420 }
421
422 private void protocolViolation(ChannelHandlerContext ctx, ByteBuf in, String reason) {
423 protocolViolation(ctx, in, WebSocketCloseStatus.PROTOCOL_ERROR, reason);
424 }
425
426 private void protocolViolation(ChannelHandlerContext ctx, ByteBuf in, WebSocketCloseStatus status, String reason) {
427 protocolViolation(ctx, in, new CorruptedWebSocketFrameException(status, reason));
428 }
429
430 private void protocolViolation(ChannelHandlerContext ctx, ByteBuf in, CorruptedWebSocketFrameException ex) {
431 state = State.CORRUPT;
432 int readableBytes = in.readableBytes();
433 if (readableBytes > 0) {
434
435
436 in.skipBytes(readableBytes);
437 }
438 if (ctx.channel().isActive() && config.closeOnProtocolViolation()) {
439 Object closeMessage;
440 if (receivedClosingHandshake) {
441 closeMessage = Unpooled.EMPTY_BUFFER;
442 } else {
443 WebSocketCloseStatus closeStatus = ex.closeStatus();
444 String reasonText = ex.getMessage();
445 if (reasonText == null) {
446 reasonText = closeStatus.reasonText();
447 }
448 closeMessage = new CloseWebSocketFrame(closeStatus, reasonText);
449 }
450 ctx.writeAndFlush(closeMessage).addListener(ChannelFutureListener.CLOSE);
451 }
452 throw ex;
453 }
454
455 private static int toFrameLength(long l) {
456 if (l > Integer.MAX_VALUE) {
457 throw new TooLongFrameException("Length:" + l);
458 } else {
459 return (int) l;
460 }
461 }
462
463
464 protected void checkCloseFrameBody(
465 ChannelHandlerContext ctx, ByteBuf buffer) {
466 if (buffer == null || !buffer.isReadable()) {
467 return;
468 }
469 if (buffer.readableBytes() < 2) {
470 protocolViolation(ctx, buffer, WebSocketCloseStatus.INVALID_PAYLOAD_DATA, "Invalid close frame body");
471 }
472
473
474 int statusCode = buffer.getShort(buffer.readerIndex());
475 if (!WebSocketCloseStatus.isValidStatusCode(statusCode)) {
476 protocolViolation(ctx, buffer, "Invalid close frame getStatus code: " + statusCode);
477 }
478
479
480 if (buffer.readableBytes() > 2) {
481 try {
482 new Utf8Validator().check(buffer, buffer.readerIndex() + 2, buffer.readableBytes() - 2);
483 } catch (CorruptedWebSocketFrameException ex) {
484 protocolViolation(ctx, buffer, ex);
485 }
486 }
487 }
488 }