1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.example.http.websocketx.benchmarkserver;
17
18 import io.netty.buffer.ByteBuf;
19 import io.netty.buffer.ByteBufUtil;
20 import io.netty.channel.ChannelFuture;
21 import io.netty.channel.ChannelFutureListener;
22 import io.netty.channel.ChannelHandlerContext;
23 import io.netty.channel.SimpleChannelInboundHandler;
24 import io.netty.handler.codec.http.DefaultFullHttpResponse;
25 import io.netty.handler.codec.http.FullHttpRequest;
26 import io.netty.handler.codec.http.FullHttpResponse;
27 import io.netty.handler.codec.http.HttpHeaderNames;
28 import io.netty.handler.codec.http.HttpResponseStatus;
29 import io.netty.handler.codec.http.HttpUtil;
30 import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
31 import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
32 import io.netty.handler.codec.http.websocketx.PingWebSocketFrame;
33 import io.netty.handler.codec.http.websocketx.PongWebSocketFrame;
34 import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
35 import io.netty.handler.codec.http.websocketx.WebSocketFrame;
36 import io.netty.handler.codec.http.websocketx.WebSocketServerHandshaker;
37 import io.netty.handler.codec.http.websocketx.WebSocketServerHandshakerFactory;
38
39 import static io.netty.handler.codec.http.HttpMethod.*;
40 import static io.netty.handler.codec.http.HttpResponseStatus.*;
41
42
43
44
45 public class WebSocketServerHandler extends SimpleChannelInboundHandler<Object> {
46
47 private static final String WEBSOCKET_PATH = "/websocket";
48
49 private WebSocketServerHandshaker handshaker;
50
51 @Override
52 public void channelRead0(ChannelHandlerContext ctx, Object msg) {
53 if (msg instanceof FullHttpRequest) {
54 handleHttpRequest(ctx, (FullHttpRequest) msg);
55 } else if (msg instanceof WebSocketFrame) {
56 handleWebSocketFrame(ctx, (WebSocketFrame) msg);
57 }
58 }
59
60 @Override
61 public void channelReadComplete(ChannelHandlerContext ctx) {
62 ctx.flush();
63 }
64
65 private void handleHttpRequest(ChannelHandlerContext ctx, FullHttpRequest req) {
66
67 if (!req.decoderResult().isSuccess()) {
68 sendHttpResponse(ctx, req, new DefaultFullHttpResponse(req.protocolVersion(), BAD_REQUEST,
69 ctx.alloc().buffer(0)));
70 return;
71 }
72
73
74 if (!GET.equals(req.method())) {
75 sendHttpResponse(ctx, req, new DefaultFullHttpResponse(req.protocolVersion(), FORBIDDEN,
76 ctx.alloc().buffer(0)));
77 return;
78 }
79
80
81 if ("/".equals(req.uri())) {
82 ByteBuf content = WebSocketServerBenchmarkPage.getContent(getWebSocketLocation(req));
83 FullHttpResponse res = new DefaultFullHttpResponse(req.protocolVersion(), OK, content);
84
85 res.headers().set(HttpHeaderNames.CONTENT_TYPE, "text/html; charset=UTF-8");
86 HttpUtil.setContentLength(res, content.readableBytes());
87
88 sendHttpResponse(ctx, req, res);
89 return;
90 }
91
92 if ("/favicon.ico".equals(req.uri())) {
93 FullHttpResponse res = new DefaultFullHttpResponse(req.protocolVersion(), NOT_FOUND,
94 ctx.alloc().buffer(0));
95 sendHttpResponse(ctx, req, res);
96 return;
97 }
98
99
100 WebSocketServerHandshakerFactory wsFactory = new WebSocketServerHandshakerFactory(
101 getWebSocketLocation(req), null, true, 5 * 1024 * 1024);
102 handshaker = wsFactory.newHandshaker(req);
103 if (handshaker == null) {
104 WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel());
105 } else {
106 handshaker.handshake(ctx.channel(), req);
107 }
108 }
109
110 private void handleWebSocketFrame(ChannelHandlerContext ctx, WebSocketFrame frame) {
111
112
113 if (frame instanceof CloseWebSocketFrame) {
114 handshaker.close(ctx, (CloseWebSocketFrame) frame.retain());
115 return;
116 }
117 if (frame instanceof PingWebSocketFrame) {
118 ctx.write(new PongWebSocketFrame(frame.content().retain()));
119 return;
120 }
121 if (frame instanceof TextWebSocketFrame) {
122
123 ctx.write(frame.retain());
124 return;
125 }
126 if (frame instanceof BinaryWebSocketFrame) {
127
128 ctx.write(frame.retain());
129 }
130 }
131
132 private static void sendHttpResponse(ChannelHandlerContext ctx, FullHttpRequest req, FullHttpResponse res) {
133
134 HttpResponseStatus responseStatus = res.status();
135 if (responseStatus.code() != 200) {
136 ByteBufUtil.writeUtf8(res.content(), responseStatus.toString());
137 HttpUtil.setContentLength(res, res.content().readableBytes());
138 }
139
140 boolean keepAlive = HttpUtil.isKeepAlive(req) && responseStatus.code() == 200;
141 HttpUtil.setKeepAlive(res, keepAlive);
142 ChannelFuture future = ctx.write(res);
143 if (!keepAlive) {
144 future.addListener(ChannelFutureListener.CLOSE);
145 }
146 }
147
148 @Override
149 public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
150 cause.printStackTrace();
151 ctx.close();
152 }
153
154 private static String getWebSocketLocation(FullHttpRequest req) {
155 String location = req.headers().get(HttpHeaderNames.HOST) + WEBSOCKET_PATH;
156 if (WebSocketServer.SSL) {
157 return "wss://" + location;
158 } else {
159 return "ws://" + location;
160 }
161 }
162 }