1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.example.http.websocketx.server;
17
18 import io.netty.buffer.ByteBuf;
19 import io.netty.buffer.ByteBufUtil;
20 import io.netty.channel.ChannelFuture;
21 import io.netty.channel.ChannelFutureListener;
22 import io.netty.channel.ChannelHandlerContext;
23 import io.netty.channel.ChannelPipeline;
24 import io.netty.channel.SimpleChannelInboundHandler;
25 import io.netty.handler.codec.http.DefaultFullHttpResponse;
26 import io.netty.handler.codec.http.FullHttpRequest;
27 import io.netty.handler.codec.http.FullHttpResponse;
28 import io.netty.handler.codec.http.HttpHeaderNames;
29 import io.netty.handler.codec.http.HttpHeaderValues;
30 import io.netty.handler.codec.http.HttpRequest;
31 import io.netty.handler.codec.http.HttpResponseStatus;
32 import io.netty.handler.codec.http.HttpUtil;
33 import io.netty.handler.ssl.SslHandler;
34
35 import static io.netty.handler.codec.http.HttpHeaderNames.*;
36 import static io.netty.handler.codec.http.HttpMethod.*;
37 import static io.netty.handler.codec.http.HttpResponseStatus.*;
38
39
40
41
42 public class WebSocketIndexPageHandler extends SimpleChannelInboundHandler<FullHttpRequest> {
43
44 private final String websocketPath;
45
46 public WebSocketIndexPageHandler(String websocketPath) {
47 this.websocketPath = websocketPath;
48 }
49
50 @Override
51 protected void channelRead0(ChannelHandlerContext ctx, FullHttpRequest req) throws Exception {
52
53 if (!req.decoderResult().isSuccess()) {
54 sendHttpResponse(ctx, req, new DefaultFullHttpResponse(req.protocolVersion(), BAD_REQUEST,
55 ctx.alloc().buffer(0)));
56 return;
57 }
58
59
60 if (req.headers().contains(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET, true)) {
61 ctx.fireChannelRead(req.retain());
62 return;
63 }
64
65
66 if (!GET.equals(req.method())) {
67 sendHttpResponse(ctx, req, new DefaultFullHttpResponse(req.protocolVersion(), FORBIDDEN,
68 ctx.alloc().buffer(0)));
69 return;
70 }
71
72
73 if ("/".equals(req.uri()) || "/index.html".equals(req.uri())) {
74 String webSocketLocation = getWebSocketLocation(ctx.pipeline(), req, websocketPath);
75 ByteBuf content = WebSocketServerIndexPage.getContent(webSocketLocation);
76 FullHttpResponse res = new DefaultFullHttpResponse(req.protocolVersion(), OK, content);
77
78 res.headers().set(CONTENT_TYPE, "text/html; charset=UTF-8");
79 HttpUtil.setContentLength(res, content.readableBytes());
80
81 sendHttpResponse(ctx, req, res);
82 } else {
83 sendHttpResponse(ctx, req, new DefaultFullHttpResponse(req.protocolVersion(), NOT_FOUND,
84 ctx.alloc().buffer(0)));
85 }
86 }
87
88 @Override
89 public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
90 cause.printStackTrace();
91 ctx.close();
92 }
93
94 private static void sendHttpResponse(ChannelHandlerContext ctx, FullHttpRequest req, FullHttpResponse res) {
95
96 HttpResponseStatus responseStatus = res.status();
97 if (responseStatus.code() != 200) {
98 ByteBufUtil.writeUtf8(res.content(), responseStatus.toString());
99 HttpUtil.setContentLength(res, res.content().readableBytes());
100 }
101
102 boolean keepAlive = HttpUtil.isKeepAlive(req) && responseStatus.code() == 200;
103 HttpUtil.setKeepAlive(res, keepAlive);
104 ChannelFuture future = ctx.writeAndFlush(res);
105 if (!keepAlive) {
106 future.addListener(ChannelFutureListener.CLOSE);
107 }
108 }
109
110 private static String getWebSocketLocation(ChannelPipeline cp, HttpRequest req, String path) {
111 String protocol = "ws";
112 if (cp.get(SslHandler.class) != null) {
113
114 protocol = "wss";
115 }
116 return protocol + "://" + req.headers().get(HttpHeaderNames.HOST) + path;
117 }
118 }