1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.example.stomp.websocket;
17
18 import io.netty.channel.ChannelFuture;
19 import io.netty.channel.ChannelFutureListener;
20 import io.netty.channel.ChannelHandler.Sharable;
21 import io.netty.channel.ChannelHandlerContext;
22 import io.netty.channel.SimpleChannelInboundHandler;
23 import io.netty.handler.codec.DecoderResult;
24 import io.netty.handler.codec.stomp.DefaultStompFrame;
25 import io.netty.handler.codec.stomp.StompCommand;
26 import io.netty.handler.codec.stomp.StompFrame;
27 import io.netty.util.CharsetUtil;
28
29 import java.util.HashSet;
30 import java.util.Iterator;
31 import java.util.Map.Entry;
32 import java.util.Set;
33 import java.util.UUID;
34 import java.util.concurrent.ConcurrentHashMap;
35 import java.util.concurrent.ConcurrentMap;
36
37 import static io.netty.handler.codec.stomp.StompHeaders.*;
38
39 @Sharable
40 public class StompChatHandler extends SimpleChannelInboundHandler<StompFrame> {
41
42 private final ConcurrentMap<String, Set<StompSubscription>> chatDestinations =
43 new ConcurrentHashMap<String, Set<StompSubscription>>();
44
45 @Override
46 protected void channelRead0(ChannelHandlerContext ctx, StompFrame inboundFrame) throws Exception {
47 DecoderResult decoderResult = inboundFrame.decoderResult();
48 if (decoderResult.isFailure()) {
49 sendErrorFrame("rejected frame", decoderResult.toString(), ctx);
50 return;
51 }
52
53 switch (inboundFrame.command()) {
54 case STOMP:
55 case CONNECT:
56 onConnect(ctx, inboundFrame);
57 break;
58 case SUBSCRIBE:
59 onSubscribe(ctx, inboundFrame);
60 break;
61 case SEND:
62 onSend(ctx, inboundFrame);
63 break;
64 case UNSUBSCRIBE:
65 onUnsubscribe(ctx, inboundFrame);
66 break;
67 case DISCONNECT:
68 onDisconnect(ctx, inboundFrame);
69 break;
70 default:
71 sendErrorFrame("unsupported command",
72 "Received unsupported command " + inboundFrame.command(), ctx);
73 }
74 }
75
76 private void onSubscribe(ChannelHandlerContext ctx, StompFrame inboundFrame) {
77 String destination = inboundFrame.headers().getAsString(DESTINATION);
78 String subscriptionId = inboundFrame.headers().getAsString(ID);
79
80 if (destination == null || subscriptionId == null) {
81 sendErrorFrame("missed header", "Required 'destination' or 'id' header missed", ctx);
82 return;
83 }
84
85 Set<StompSubscription> subscriptions = chatDestinations.get(destination);
86 if (subscriptions == null) {
87 subscriptions = new HashSet<StompSubscription>();
88 Set<StompSubscription> previousSubscriptions = chatDestinations.putIfAbsent(destination, subscriptions);
89 if (previousSubscriptions != null) {
90 subscriptions = previousSubscriptions;
91 }
92 }
93
94 final StompSubscription subscription = new StompSubscription(subscriptionId, destination, ctx.channel());
95 if (subscriptions.contains(subscription)) {
96 sendErrorFrame("duplicate subscription",
97 "Received duplicate subscription id=" + subscriptionId, ctx);
98 return;
99 }
100
101 subscriptions.add(subscription);
102 ctx.channel().closeFuture().addListener(new ChannelFutureListener() {
103 @Override
104 public void operationComplete(ChannelFuture future) {
105 chatDestinations.get(subscription.destination()).remove(subscription);
106 }
107 });
108
109 String receiptId = inboundFrame.headers().getAsString(RECEIPT);
110 if (receiptId != null) {
111 StompFrame receiptFrame = new DefaultStompFrame(StompCommand.RECEIPT);
112 receiptFrame.headers().set(RECEIPT_ID, receiptId);
113 ctx.writeAndFlush(receiptFrame);
114 }
115 }
116
117 private void onSend(ChannelHandlerContext ctx, StompFrame inboundFrame) {
118 String destination = inboundFrame.headers().getAsString(DESTINATION);
119 if (destination == null) {
120 sendErrorFrame("missed header", "required 'destination' header missed", ctx);
121 return;
122 }
123
124 Set<StompSubscription> subscriptions = chatDestinations.get(destination);
125 for (StompSubscription subscription : subscriptions) {
126 subscription.channel().writeAndFlush(transformToMessage(inboundFrame, subscription));
127 }
128 }
129
130 private void onUnsubscribe(ChannelHandlerContext ctx, StompFrame inboundFrame) {
131 String subscriptionId = inboundFrame.headers().getAsString(SUBSCRIPTION);
132 for (Entry<String, Set<StompSubscription>> entry : chatDestinations.entrySet()) {
133 Iterator<StompSubscription> iterator = entry.getValue().iterator();
134 while (iterator.hasNext()) {
135 StompSubscription subscription = iterator.next();
136 if (subscription.id().equals(subscriptionId) && subscription.channel().equals(ctx.channel())) {
137 iterator.remove();
138 return;
139 }
140 }
141 }
142 }
143
144 private static void onConnect(ChannelHandlerContext ctx, StompFrame inboundFrame) {
145 String acceptVersions = inboundFrame.headers().getAsString(ACCEPT_VERSION);
146 StompVersion handshakeAcceptVersion = ctx.channel().attr(StompVersion.CHANNEL_ATTRIBUTE_KEY).get();
147 if (acceptVersions == null || !acceptVersions.contains(handshakeAcceptVersion.version())) {
148 sendErrorFrame("invalid version",
149 "Received invalid version, expected " + handshakeAcceptVersion.version(), ctx);
150 return;
151 }
152
153 StompFrame connectedFrame = new DefaultStompFrame(StompCommand.CONNECTED);
154 connectedFrame.headers()
155 .set(VERSION, handshakeAcceptVersion.version())
156 .set(SERVER, "Netty-Server")
157 .set(HEART_BEAT, "0,0");
158 ctx.writeAndFlush(connectedFrame);
159 }
160
161 private static void onDisconnect(ChannelHandlerContext ctx, StompFrame inboundFrame) {
162 String receiptId = inboundFrame.headers().getAsString(RECEIPT);
163 if (receiptId == null) {
164 ctx.close();
165 return;
166 }
167
168 StompFrame receiptFrame = new DefaultStompFrame(StompCommand.RECEIPT);
169 receiptFrame.headers().set(RECEIPT_ID, receiptId);
170 ctx.writeAndFlush(receiptFrame).addListener(ChannelFutureListener.CLOSE);
171 }
172
173 private static void sendErrorFrame(String message, String description, ChannelHandlerContext ctx) {
174 StompFrame errorFrame = new DefaultStompFrame(StompCommand.ERROR);
175 errorFrame.headers().set(MESSAGE, message);
176
177 if (description != null) {
178 errorFrame.content().writeCharSequence(description, CharsetUtil.UTF_8);
179 }
180
181 ctx.writeAndFlush(errorFrame).addListener(ChannelFutureListener.CLOSE);
182 }
183
184 private static StompFrame transformToMessage(StompFrame sendFrame, StompSubscription subscription) {
185 StompFrame messageFrame = new DefaultStompFrame(StompCommand.MESSAGE, sendFrame.content().retainedDuplicate());
186 String id = UUID.randomUUID().toString();
187 messageFrame.headers()
188 .set(MESSAGE_ID, id)
189 .set(SUBSCRIPTION, subscription.id())
190 .set(CONTENT_LENGTH, Integer.toString(messageFrame.content().readableBytes()));
191
192 CharSequence contentType = sendFrame.headers().get(CONTENT_TYPE);
193 if (contentType != null) {
194 messageFrame.headers().set(CONTENT_TYPE, contentType);
195 }
196
197 return messageFrame;
198 }
199 }