1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.handler.codec.http2;
17
18 import io.netty.buffer.ByteBuf;
19 import io.netty.buffer.ByteBufUtil;
20 import io.netty.channel.ChannelFuture;
21 import io.netty.channel.ChannelHandlerContext;
22 import io.netty.channel.ChannelPromise;
23 import io.netty.util.ReferenceCountUtil;
24 import io.netty.util.internal.UnstableApi;
25
26 import java.util.ArrayDeque;
27 import java.util.Iterator;
28 import java.util.Map;
29 import java.util.Queue;
30 import java.util.TreeMap;
31
32 import static io.netty.handler.codec.http2.Http2CodecUtil.SMALLEST_MAX_CONCURRENT_STREAMS;
33 import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR;
34 import static io.netty.handler.codec.http2.Http2Exception.connectionError;
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58 @UnstableApi
59 public class StreamBufferingEncoder extends DecoratingHttp2ConnectionEncoder {
60
61
62
63
64 public static final class Http2ChannelClosedException extends Http2Exception {
65 private static final long serialVersionUID = 4768543442094476971L;
66
67 public Http2ChannelClosedException() {
68 super(Http2Error.REFUSED_STREAM, "Connection closed");
69 }
70 }
71
72 private static final class GoAwayDetail {
73 private final int lastStreamId;
74 private final long errorCode;
75 private final byte[] debugData;
76
77 GoAwayDetail(int lastStreamId, long errorCode, byte[] debugData) {
78 this.lastStreamId = lastStreamId;
79 this.errorCode = errorCode;
80 this.debugData = debugData.clone();
81 }
82 }
83
84
85
86
87
88 public static final class Http2GoAwayException extends Http2Exception {
89 private static final long serialVersionUID = 1326785622777291198L;
90 private final GoAwayDetail goAwayDetail;
91
92 public Http2GoAwayException(int lastStreamId, long errorCode, byte[] debugData) {
93 this(new GoAwayDetail(lastStreamId, errorCode, debugData));
94 }
95
96 Http2GoAwayException(GoAwayDetail goAwayDetail) {
97 super(Http2Error.STREAM_CLOSED);
98 this.goAwayDetail = goAwayDetail;
99 }
100
101 public int lastStreamId() {
102 return goAwayDetail.lastStreamId;
103 }
104
105 public long errorCode() {
106 return goAwayDetail.errorCode;
107 }
108
109 public byte[] debugData() {
110 return goAwayDetail.debugData.clone();
111 }
112 }
113
114
115
116
117
118 private final TreeMap<Integer, PendingStream> pendingStreams = new TreeMap<Integer, PendingStream>();
119 private int maxConcurrentStreams;
120 private boolean closed;
121 private GoAwayDetail goAwayDetail;
122
123 public StreamBufferingEncoder(Http2ConnectionEncoder delegate) {
124 this(delegate, SMALLEST_MAX_CONCURRENT_STREAMS);
125 }
126
127 public StreamBufferingEncoder(Http2ConnectionEncoder delegate, int initialMaxConcurrentStreams) {
128 super(delegate);
129 maxConcurrentStreams = initialMaxConcurrentStreams;
130 connection().addListener(new Http2ConnectionAdapter() {
131
132 @Override
133 public void onGoAwayReceived(int lastStreamId, long errorCode, ByteBuf debugData) {
134 goAwayDetail = new GoAwayDetail(
135
136 lastStreamId, errorCode,
137 ByteBufUtil.getBytes(debugData, debugData.readerIndex(), debugData.readableBytes(), false));
138 cancelGoAwayStreams(goAwayDetail);
139 }
140
141 @Override
142 public void onStreamClosed(Http2Stream stream) {
143 tryCreatePendingStreams();
144 }
145 });
146 }
147
148
149
150
151 public int numBufferedStreams() {
152 return pendingStreams.size();
153 }
154
155 @Override
156 public ChannelFuture writeHeaders(ChannelHandlerContext ctx, int streamId, Http2Headers headers,
157 int padding, boolean endStream, ChannelPromise promise) {
158 return writeHeaders(ctx, streamId, headers, 0, Http2CodecUtil.DEFAULT_PRIORITY_WEIGHT,
159 false, padding, endStream, promise);
160 }
161
162 @Override
163 public ChannelFuture writeHeaders(ChannelHandlerContext ctx, int streamId, Http2Headers headers,
164 int streamDependency, short weight, boolean exclusive,
165 int padding, boolean endOfStream, ChannelPromise promise) {
166 if (closed) {
167 return promise.setFailure(new Http2ChannelClosedException());
168 }
169 if (isExistingStream(streamId) || canCreateStream()) {
170 return super.writeHeaders(ctx, streamId, headers, streamDependency, weight,
171 exclusive, padding, endOfStream, promise);
172 }
173 if (goAwayDetail != null) {
174 return promise.setFailure(new Http2GoAwayException(goAwayDetail));
175 }
176 PendingStream pendingStream = pendingStreams.get(streamId);
177 if (pendingStream == null) {
178 pendingStream = new PendingStream(ctx, streamId);
179 pendingStreams.put(streamId, pendingStream);
180 }
181 pendingStream.frames.add(new HeadersFrame(headers, streamDependency, weight, exclusive,
182 padding, endOfStream, promise));
183 return promise;
184 }
185
186 @Override
187 public ChannelFuture writeRstStream(ChannelHandlerContext ctx, int streamId, long errorCode,
188 ChannelPromise promise) {
189 if (isExistingStream(streamId)) {
190 return super.writeRstStream(ctx, streamId, errorCode, promise);
191 }
192
193
194 PendingStream stream = pendingStreams.remove(streamId);
195 if (stream != null) {
196
197
198
199
200 stream.close(null);
201 promise.setSuccess();
202 } else {
203 promise.setFailure(connectionError(PROTOCOL_ERROR, "Stream does not exist %d", streamId));
204 }
205 return promise;
206 }
207
208 @Override
209 public ChannelFuture writeData(ChannelHandlerContext ctx, int streamId, ByteBuf data,
210 int padding, boolean endOfStream, ChannelPromise promise) {
211 if (isExistingStream(streamId)) {
212 return super.writeData(ctx, streamId, data, padding, endOfStream, promise);
213 }
214 PendingStream pendingStream = pendingStreams.get(streamId);
215 if (pendingStream != null) {
216 pendingStream.frames.add(new DataFrame(data, padding, endOfStream, promise));
217 } else {
218 ReferenceCountUtil.safeRelease(data);
219 promise.setFailure(connectionError(PROTOCOL_ERROR, "Stream does not exist %d", streamId));
220 }
221 return promise;
222 }
223
224 @Override
225 public void remoteSettings(Http2Settings settings) throws Http2Exception {
226
227
228 super.remoteSettings(settings);
229
230
231 maxConcurrentStreams = connection().local().maxActiveStreams();
232
233
234 tryCreatePendingStreams();
235 }
236
237 @Override
238 public void close() {
239 try {
240 if (!closed) {
241 closed = true;
242
243
244 Http2ChannelClosedException e = new Http2ChannelClosedException();
245 while (!pendingStreams.isEmpty()) {
246 PendingStream stream = pendingStreams.pollFirstEntry().getValue();
247 stream.close(e);
248 }
249 }
250 } finally {
251 super.close();
252 }
253 }
254
255 private void tryCreatePendingStreams() {
256 while (!pendingStreams.isEmpty() && canCreateStream()) {
257 Map.Entry<Integer, PendingStream> entry = pendingStreams.pollFirstEntry();
258 PendingStream pendingStream = entry.getValue();
259 try {
260 pendingStream.sendFrames();
261 } catch (Throwable t) {
262 pendingStream.close(t);
263 }
264 }
265 }
266
267 private void cancelGoAwayStreams(GoAwayDetail goAwayDetail) {
268 Iterator<PendingStream> iter = pendingStreams.values().iterator();
269 Exception e = new Http2GoAwayException(goAwayDetail);
270 while (iter.hasNext()) {
271 PendingStream stream = iter.next();
272 if (stream.streamId > goAwayDetail.lastStreamId) {
273 iter.remove();
274 stream.close(e);
275 }
276 }
277 }
278
279
280
281
282 private boolean canCreateStream() {
283 return connection().local().numActiveStreams() < maxConcurrentStreams;
284 }
285
286 private boolean isExistingStream(int streamId) {
287 return streamId <= connection().local().lastStreamCreated();
288 }
289
290 private static final class PendingStream {
291 final ChannelHandlerContext ctx;
292 final int streamId;
293 final Queue<Frame> frames = new ArrayDeque<Frame>(2);
294
295 PendingStream(ChannelHandlerContext ctx, int streamId) {
296 this.ctx = ctx;
297 this.streamId = streamId;
298 }
299
300 void sendFrames() {
301 for (Frame frame : frames) {
302 frame.send(ctx, streamId);
303 }
304 }
305
306 void close(Throwable t) {
307 for (Frame frame : frames) {
308 frame.release(t);
309 }
310 }
311 }
312
313 private abstract static class Frame {
314 final ChannelPromise promise;
315
316 Frame(ChannelPromise promise) {
317 this.promise = promise;
318 }
319
320
321
322
323 void release(Throwable t) {
324 if (t == null) {
325 promise.setSuccess();
326 } else {
327 promise.setFailure(t);
328 }
329 }
330
331 abstract void send(ChannelHandlerContext ctx, int streamId);
332 }
333
334 private final class HeadersFrame extends Frame {
335 final Http2Headers headers;
336 final int streamDependency;
337 final short weight;
338 final boolean exclusive;
339 final int padding;
340 final boolean endOfStream;
341
342 HeadersFrame(Http2Headers headers, int streamDependency, short weight, boolean exclusive,
343 int padding, boolean endOfStream, ChannelPromise promise) {
344 super(promise);
345 this.headers = headers;
346 this.streamDependency = streamDependency;
347 this.weight = weight;
348 this.exclusive = exclusive;
349 this.padding = padding;
350 this.endOfStream = endOfStream;
351 }
352
353 @Override
354 void send(ChannelHandlerContext ctx, int streamId) {
355 writeHeaders(ctx, streamId, headers, streamDependency, weight, exclusive, padding, endOfStream, promise);
356 }
357 }
358
359 private final class DataFrame extends Frame {
360 final ByteBuf data;
361 final int padding;
362 final boolean endOfStream;
363
364 DataFrame(ByteBuf data, int padding, boolean endOfStream, ChannelPromise promise) {
365 super(promise);
366 this.data = data;
367 this.padding = padding;
368 this.endOfStream = endOfStream;
369 }
370
371 @Override
372 void release(Throwable t) {
373 super.release(t);
374 ReferenceCountUtil.safeRelease(data);
375 }
376
377 @Override
378 void send(ChannelHandlerContext ctx, int streamId) {
379 writeData(ctx, streamId, data, padding, endOfStream, promise);
380 }
381 }
382 }