1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.testsuite.autobahn;
17
18 import io.netty.buffer.ByteBuf;
19 import io.netty.buffer.Unpooled;
20 import io.netty.channel.ChannelFuture;
21 import io.netty.channel.ChannelFutureListener;
22 import io.netty.channel.ChannelHandlerContext;
23 import io.netty.channel.ChannelInboundHandlerAdapter;
24 import io.netty.handler.codec.http.DefaultFullHttpResponse;
25 import io.netty.handler.codec.http.FullHttpResponse;
26 import io.netty.handler.codec.http.HttpHeaderNames;
27 import io.netty.handler.codec.http.HttpRequest;
28 import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
29 import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
30 import io.netty.handler.codec.http.websocketx.ContinuationWebSocketFrame;
31 import io.netty.handler.codec.http.websocketx.PingWebSocketFrame;
32 import io.netty.handler.codec.http.websocketx.PongWebSocketFrame;
33 import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
34 import io.netty.handler.codec.http.websocketx.WebSocketFrame;
35 import io.netty.handler.codec.http.websocketx.WebSocketServerHandshaker;
36 import io.netty.handler.codec.http.websocketx.WebSocketServerHandshakerFactory;
37 import io.netty.util.CharsetUtil;
38 import io.netty.util.internal.StringUtil;
39
40 import java.util.logging.Level;
41 import java.util.logging.Logger;
42
43 import static io.netty.handler.codec.http.HttpUtil.*;
44 import static io.netty.handler.codec.http.HttpMethod.*;
45 import static io.netty.handler.codec.http.HttpResponseStatus.*;
46 import static io.netty.handler.codec.http.HttpVersion.*;
47
48
49
50
51 public class AutobahnServerHandler extends ChannelInboundHandlerAdapter {
52 private static final Logger logger = Logger.getLogger(AutobahnServerHandler.class.getName());
53
54 private WebSocketServerHandshaker handshaker;
55
56 @Override
57 public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
58 if (msg instanceof HttpRequest) {
59 handleHttpRequest(ctx, (HttpRequest) msg);
60 } else if (msg instanceof WebSocketFrame) {
61 handleWebSocketFrame(ctx, (WebSocketFrame) msg);
62 } else {
63 throw new IllegalStateException("unknown message: " + msg);
64 }
65 }
66
67 @Override
68 public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
69 ctx.flush();
70 }
71
72 private void handleHttpRequest(ChannelHandlerContext ctx, HttpRequest req)
73 throws Exception {
74
75 if (!req.decoderResult().isSuccess()) {
76 sendHttpResponse(ctx, req, new DefaultFullHttpResponse(HTTP_1_1, BAD_REQUEST, ctx.alloc().buffer(0)));
77 return;
78 }
79
80
81 if (!GET.equals(req.method())) {
82 sendHttpResponse(ctx, req, new DefaultFullHttpResponse(HTTP_1_1, FORBIDDEN, ctx.alloc().buffer(0)));
83 return;
84 }
85
86
87 WebSocketServerHandshakerFactory wsFactory = new WebSocketServerHandshakerFactory(
88 getWebSocketLocation(req), null, false, Integer.MAX_VALUE);
89 handshaker = wsFactory.newHandshaker(req);
90 if (handshaker == null) {
91 WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel());
92 } else {
93 handshaker.handshake(ctx.channel(), req);
94 }
95 }
96
97 private void handleWebSocketFrame(ChannelHandlerContext ctx, WebSocketFrame frame) {
98 if (logger.isLoggable(Level.FINE)) {
99 logger.fine(String.format(
100 "Channel %s received %s", ctx.channel().hashCode(), StringUtil.simpleClassName(frame)));
101 }
102
103 if (frame instanceof CloseWebSocketFrame) {
104 handshaker.close(ctx, (CloseWebSocketFrame) frame);
105 } else if (frame instanceof PingWebSocketFrame) {
106 ctx.write(new PongWebSocketFrame(frame.isFinalFragment(), frame.rsv(), frame.content()));
107 } else if (frame instanceof TextWebSocketFrame ||
108 frame instanceof BinaryWebSocketFrame ||
109 frame instanceof ContinuationWebSocketFrame) {
110 ctx.write(frame);
111 } else if (frame instanceof PongWebSocketFrame) {
112 frame.release();
113
114 } else {
115 throw new UnsupportedOperationException(String.format("%s frame types not supported", frame.getClass()
116 .getName()));
117 }
118 }
119
120 private static void sendHttpResponse(
121 ChannelHandlerContext ctx, HttpRequest req, FullHttpResponse res) {
122
123 if (res.status().code() != 200) {
124 ByteBuf buf = Unpooled.copiedBuffer(res.status().toString(), CharsetUtil.UTF_8);
125 res.content().writeBytes(buf);
126 buf.release();
127 setContentLength(res, res.content().readableBytes());
128 }
129
130
131 ChannelFuture f = ctx.channel().writeAndFlush(res);
132 if (!isKeepAlive(req) || res.status().code() != 200) {
133 f.addListener(ChannelFutureListener.CLOSE);
134 }
135 }
136
137 @Override
138 public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
139 ctx.close();
140 }
141
142 private static String getWebSocketLocation(HttpRequest req) {
143 return "ws://" + req.headers().get(HttpHeaderNames.HOST);
144 }
145 }