1   
2   
3   
4   
5   
6   
7   
8   
9   
10  
11  
12  
13  
14  
15  package io.netty.handler.codec.http2;
16  
17  import io.netty.util.collection.IntCollections;
18  import io.netty.util.collection.IntObjectHashMap;
19  import io.netty.util.collection.IntObjectMap;
20  import io.netty.util.internal.DefaultPriorityQueue;
21  import io.netty.util.internal.EmptyPriorityQueue;
22  import io.netty.util.internal.MathUtil;
23  import io.netty.util.internal.PriorityQueue;
24  import io.netty.util.internal.PriorityQueueNode;
25  import io.netty.util.internal.SystemPropertyUtil;
26  import io.netty.util.internal.UnstableApi;
27  
28  import java.io.Serializable;
29  import java.util.ArrayList;
30  import java.util.Comparator;
31  import java.util.Iterator;
32  import java.util.List;
33  
34  import static io.netty.handler.codec.http2.Http2CodecUtil.CONNECTION_STREAM_ID;
35  import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_MIN_ALLOCATION_CHUNK;
36  import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_PRIORITY_WEIGHT;
37  import static io.netty.handler.codec.http2.Http2CodecUtil.streamableBytes;
38  import static io.netty.handler.codec.http2.Http2Error.INTERNAL_ERROR;
39  import static io.netty.handler.codec.http2.Http2Exception.connectionError;
40  import static io.netty.util.internal.ObjectUtil.checkPositive;
41  import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero;
42  import static java.lang.Integer.MAX_VALUE;
43  import static java.lang.Math.max;
44  import static java.lang.Math.min;
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  @UnstableApi
60  public final class WeightedFairQueueByteDistributor implements StreamByteDistributor {
61      
62  
63  
64  
65  
66  
67  
68  
69      static final int INITIAL_CHILDREN_MAP_SIZE =
70              max(1, SystemPropertyUtil.getInt("io.netty.http2.childrenMapSize", 2));
71      
72  
73  
74      private static final int DEFAULT_MAX_STATE_ONLY_SIZE = 5;
75  
76      private final Http2Connection.PropertyKey stateKey;
77      
78  
79  
80  
81      private final IntObjectMap<State> stateOnlyMap;
82      
83  
84  
85  
86      private final PriorityQueue<State> stateOnlyRemovalQueue;
87      private final Http2Connection connection;
88      private final State connectionState;
89      
90  
91  
92  
93      private int allocationQuantum = DEFAULT_MIN_ALLOCATION_CHUNK;
94      private final int maxStateOnlySize;
95  
96      public WeightedFairQueueByteDistributor(Http2Connection connection) {
97          this(connection, DEFAULT_MAX_STATE_ONLY_SIZE);
98      }
99  
100     public WeightedFairQueueByteDistributor(Http2Connection connection, int maxStateOnlySize) {
101         checkPositiveOrZero(maxStateOnlySize, "maxStateOnlySize");
102         if (maxStateOnlySize == 0) {
103             stateOnlyMap = IntCollections.emptyMap();
104             stateOnlyRemovalQueue = EmptyPriorityQueue.instance();
105         } else {
106             stateOnlyMap = new IntObjectHashMap<State>(maxStateOnlySize);
107             
108             
109             stateOnlyRemovalQueue = new DefaultPriorityQueue<State>(StateOnlyComparator.INSTANCE, maxStateOnlySize + 2);
110         }
111         this.maxStateOnlySize = maxStateOnlySize;
112 
113         this.connection = connection;
114         stateKey = connection.newKey();
115         final Http2Stream connectionStream = connection.connectionStream();
116         connectionStream.setProperty(stateKey, connectionState = new State(connectionStream, 16));
117 
118         
119         connection.addListener(new Http2ConnectionAdapter() {
120             @Override
121             public void onStreamAdded(Http2Stream stream) {
122                 State state = stateOnlyMap.remove(stream.id());
123                 if (state == null) {
124                     state = new State(stream);
125                     
126                     List<ParentChangedEvent> events = new ArrayList<ParentChangedEvent>(1);
127                     connectionState.takeChild(state, false, events);
128                     notifyParentChanged(events);
129                 } else {
130                     stateOnlyRemovalQueue.removeTyped(state);
131                     state.stream = stream;
132                 }
133                 switch (stream.state()) {
134                     case RESERVED_REMOTE:
135                     case RESERVED_LOCAL:
136                         state.setStreamReservedOrActivated();
137                         
138                         
139                         break;
140                     default:
141                         break;
142                 }
143                 stream.setProperty(stateKey, state);
144             }
145 
146             @Override
147             public void onStreamActive(Http2Stream stream) {
148                 state(stream).setStreamReservedOrActivated();
149                 
150                 
151             }
152 
153             @Override
154             public void onStreamClosed(Http2Stream stream) {
155                 state(stream).close();
156             }
157 
158             @Override
159             public void onStreamRemoved(Http2Stream stream) {
160                 
161                 
162                 
163                 State state = state(stream);
164 
165                 
166                 
167                 
168                 state.stream = null;
169 
170                 if (WeightedFairQueueByteDistributor.this.maxStateOnlySize == 0) {
171                     state.parent.removeChild(state);
172                     return;
173                 }
174                 if (stateOnlyRemovalQueue.size() == WeightedFairQueueByteDistributor.this.maxStateOnlySize) {
175                     State stateToRemove = stateOnlyRemovalQueue.peek();
176                     if (StateOnlyComparator.INSTANCE.compare(stateToRemove, state) >= 0) {
177                         
178                         
179                         state.parent.removeChild(state);
180                         return;
181                     }
182                     stateOnlyRemovalQueue.poll();
183                     stateToRemove.parent.removeChild(stateToRemove);
184                     stateOnlyMap.remove(stateToRemove.streamId);
185                 }
186                 stateOnlyRemovalQueue.add(state);
187                 stateOnlyMap.put(state.streamId, state);
188             }
189         });
190     }
191 
192     @Override
193     public void updateStreamableBytes(StreamState state) {
194         state(state.stream()).updateStreamableBytes(streamableBytes(state),
195                                                     state.hasFrame() && state.windowSize() >= 0);
196     }
197 
198     @Override
199     public void updateDependencyTree(int childStreamId, int parentStreamId, short weight, boolean exclusive) {
200         State state = state(childStreamId);
201         if (state == null) {
202             
203             
204             
205             if (maxStateOnlySize == 0) {
206                 return;
207             }
208             state = new State(childStreamId);
209             stateOnlyRemovalQueue.add(state);
210             stateOnlyMap.put(childStreamId, state);
211         }
212 
213         State newParent = state(parentStreamId);
214         if (newParent == null) {
215             
216             
217             
218             if (maxStateOnlySize == 0) {
219                 return;
220             }
221             newParent = new State(parentStreamId);
222             stateOnlyRemovalQueue.add(newParent);
223             stateOnlyMap.put(parentStreamId, newParent);
224             
225             List<ParentChangedEvent> events = new ArrayList<ParentChangedEvent>(1);
226             connectionState.takeChild(newParent, false, events);
227             notifyParentChanged(events);
228         }
229 
230         
231         
232         if (state.activeCountForTree != 0 && state.parent != null) {
233             state.parent.totalQueuedWeights += weight - state.weight;
234         }
235         state.weight = weight;
236 
237         if (newParent != state.parent || exclusive && newParent.children.size() != 1) {
238             final List<ParentChangedEvent> events;
239             if (newParent.isDescendantOf(state)) {
240                 events = new ArrayList<ParentChangedEvent>(2 + (exclusive ? newParent.children.size() : 0));
241                 state.parent.takeChild(newParent, false, events);
242             } else {
243                 events = new ArrayList<ParentChangedEvent>(1 + (exclusive ? newParent.children.size() : 0));
244             }
245             newParent.takeChild(state, exclusive, events);
246             notifyParentChanged(events);
247         }
248 
249         
250         
251         
252         while (stateOnlyRemovalQueue.size() > maxStateOnlySize) {
253             State stateToRemove = stateOnlyRemovalQueue.poll();
254             stateToRemove.parent.removeChild(stateToRemove);
255             stateOnlyMap.remove(stateToRemove.streamId);
256         }
257     }
258 
259     @Override
260     public boolean distribute(int maxBytes, Writer writer) throws Http2Exception {
261         
262         if (connectionState.activeCountForTree == 0) {
263             return false;
264         }
265 
266         
267         
268         
269         int oldIsActiveCountForTree;
270         do {
271             oldIsActiveCountForTree = connectionState.activeCountForTree;
272             
273             maxBytes -= distributeToChildren(maxBytes, writer, connectionState);
274         } while (connectionState.activeCountForTree != 0 &&
275                 (maxBytes > 0 || oldIsActiveCountForTree != connectionState.activeCountForTree));
276 
277         return connectionState.activeCountForTree != 0;
278     }
279 
280     
281 
282 
283 
284     public void allocationQuantum(int allocationQuantum) {
285         checkPositive(allocationQuantum, "allocationQuantum");
286         this.allocationQuantum = allocationQuantum;
287     }
288 
289     private int distribute(int maxBytes, Writer writer, State state) throws Http2Exception {
290         if (state.isActive()) {
291             int nsent = min(maxBytes, state.streamableBytes);
292             state.write(nsent, writer);
293             if (nsent == 0 && maxBytes != 0) {
294                 
295                 
296                 
297                 
298                 state.updateStreamableBytes(state.streamableBytes, false);
299             }
300             return nsent;
301         }
302 
303         return distributeToChildren(maxBytes, writer, state);
304     }
305 
306     
307 
308 
309 
310 
311 
312 
313 
314 
315 
316     private int distributeToChildren(int maxBytes, Writer writer, State state) throws Http2Exception {
317         long oldTotalQueuedWeights = state.totalQueuedWeights;
318         State childState = state.pollPseudoTimeQueue();
319         State nextChildState = state.peekPseudoTimeQueue();
320         childState.setDistributing();
321         try {
322             assert nextChildState == null || nextChildState.pseudoTimeToWrite >= childState.pseudoTimeToWrite :
323                 "nextChildState[" + nextChildState.streamId + "].pseudoTime(" + nextChildState.pseudoTimeToWrite +
324                 ") < " + " childState[" + childState.streamId + "].pseudoTime(" + childState.pseudoTimeToWrite + ')';
325             int nsent = distribute(nextChildState == null ? maxBytes :
326                             min(maxBytes, (int) min((nextChildState.pseudoTimeToWrite - childState.pseudoTimeToWrite) *
327                                                childState.weight / oldTotalQueuedWeights + allocationQuantum, MAX_VALUE)
328                                ),
329                                writer,
330                                childState);
331             state.pseudoTime += nsent;
332             childState.updatePseudoTime(state, nsent, oldTotalQueuedWeights);
333             return nsent;
334         } finally {
335             childState.unsetDistributing();
336             
337             
338             
339             if (childState.activeCountForTree != 0) {
340                 state.offerPseudoTimeQueue(childState);
341             }
342         }
343     }
344 
345     private State state(Http2Stream stream) {
346         return stream.getProperty(stateKey);
347     }
348 
349     private State state(int streamId) {
350         Http2Stream stream = connection.stream(streamId);
351         return stream != null ? state(stream) : stateOnlyMap.get(streamId);
352     }
353 
354     
355 
356 
357     boolean isChild(int childId, int parentId, short weight) {
358         State parent = state(parentId);
359         State child;
360         return parent.children.containsKey(childId) &&
361                 (child = state(childId)).parent == parent && child.weight == weight;
362     }
363 
364     
365 
366 
367     int numChildren(int streamId) {
368         State state = state(streamId);
369         return state == null ? 0 : state.children.size();
370     }
371 
372     
373 
374 
375 
376     void notifyParentChanged(List<ParentChangedEvent> events) {
377         for (int i = 0; i < events.size(); ++i) {
378             ParentChangedEvent event = events.get(i);
379             stateOnlyRemovalQueue.priorityChanged(event.state);
380             if (event.state.parent != null && event.state.activeCountForTree != 0) {
381                 event.state.parent.offerAndInitializePseudoTime(event.state);
382                 event.state.parent.activeCountChangeForTree(event.state.activeCountForTree);
383             }
384         }
385     }
386 
387     
388 
389 
390 
391 
392 
393 
394 
395     private static final class StateOnlyComparator implements Comparator<State>, Serializable {
396         private static final long serialVersionUID = -4806936913002105966L;
397 
398         static final StateOnlyComparator INSTANCE = new StateOnlyComparator();
399 
400         @Override
401         public int compare(State o1, State o2) {
402             
403             boolean o1Actived = o1.wasStreamReservedOrActivated();
404             if (o1Actived != o2.wasStreamReservedOrActivated()) {
405                 return o1Actived ? -1 : 1;
406             }
407             
408             int x = o2.dependencyTreeDepth - o1.dependencyTreeDepth;
409 
410             
411             
412             
413             
414             
415 
416             
417             return x != 0 ? x : o1.streamId - o2.streamId;
418         }
419     }
420 
421     private static final class StatePseudoTimeComparator implements Comparator<State>, Serializable {
422         private static final long serialVersionUID = -1437548640227161828L;
423 
424         static final StatePseudoTimeComparator INSTANCE = new StatePseudoTimeComparator();
425 
426         @Override
427         public int compare(State o1, State o2) {
428             return MathUtil.compare(o1.pseudoTimeToWrite, o2.pseudoTimeToWrite);
429         }
430     }
431 
432     
433 
434 
435     private final class State implements PriorityQueueNode {
436         private static final byte STATE_IS_ACTIVE = 0x1;
437         private static final byte STATE_IS_DISTRIBUTING = 0x2;
438         private static final byte STATE_STREAM_ACTIVATED = 0x4;
439 
440         
441 
442 
443         Http2Stream stream;
444         State parent;
445         IntObjectMap<State> children = IntCollections.emptyMap();
446         private final PriorityQueue<State> pseudoTimeQueue;
447         final int streamId;
448         int streamableBytes;
449         int dependencyTreeDepth;
450         
451 
452 
453         int activeCountForTree;
454         private int pseudoTimeQueueIndex = INDEX_NOT_IN_QUEUE;
455         private int stateOnlyQueueIndex = INDEX_NOT_IN_QUEUE;
456         
457 
458 
459         long pseudoTimeToWrite;
460         
461 
462 
463         long pseudoTime;
464         long totalQueuedWeights;
465         private byte flags;
466         short weight = DEFAULT_PRIORITY_WEIGHT;
467 
468         State(int streamId) {
469             this(streamId, null, 0);
470         }
471 
472         State(Http2Stream stream) {
473             this(stream, 0);
474         }
475 
476         State(Http2Stream stream, int initialSize) {
477             this(stream.id(), stream, initialSize);
478         }
479 
480         State(int streamId, Http2Stream stream, int initialSize) {
481             this.stream = stream;
482             this.streamId = streamId;
483             pseudoTimeQueue = new DefaultPriorityQueue<State>(StatePseudoTimeComparator.INSTANCE, initialSize);
484         }
485 
486         boolean isDescendantOf(State state) {
487             State next = parent;
488             while (next != null) {
489                 if (next == state) {
490                     return true;
491                 }
492                 next = next.parent;
493             }
494             return false;
495         }
496 
497         void takeChild(State child, boolean exclusive, List<ParentChangedEvent> events) {
498             takeChild(null, child, exclusive, events);
499         }
500 
501         
502 
503 
504 
505         void takeChild(Iterator<IntObjectMap.PrimitiveEntry<State>> childItr, State child, boolean exclusive,
506                        List<ParentChangedEvent> events) {
507             State oldParent = child.parent;
508 
509             if (oldParent != this) {
510                 events.add(new ParentChangedEvent(child, oldParent));
511                 child.setParent(this);
512                 
513                 
514                 
515                 if (childItr != null) {
516                     childItr.remove();
517                 } else if (oldParent != null) {
518                     oldParent.children.remove(child.streamId);
519                 }
520 
521                 
522                 initChildrenIfEmpty();
523 
524                 final State oldChild = children.put(child.streamId, child);
525                 assert oldChild == null : "A stream with the same stream ID was already in the child map.";
526             }
527 
528             if (exclusive && !children.isEmpty()) {
529                 
530                 
531                 Iterator<IntObjectMap.PrimitiveEntry<State>> itr = removeAllChildrenExcept(child).entries().iterator();
532                 while (itr.hasNext()) {
533                     child.takeChild(itr, itr.next().value(), false, events);
534                 }
535             }
536         }
537 
538         
539 
540 
541         void removeChild(State child) {
542             if (children.remove(child.streamId) != null) {
543                 List<ParentChangedEvent> events = new ArrayList<ParentChangedEvent>(1 + child.children.size());
544                 events.add(new ParentChangedEvent(child, child.parent));
545                 child.setParent(null);
546 
547                 if (!child.children.isEmpty()) {
548                     
549                     Iterator<IntObjectMap.PrimitiveEntry<State>> itr = child.children.entries().iterator();
550                     long totalWeight = child.getTotalWeight();
551                     do {
552                         
553                         State dependency = itr.next().value();
554                         dependency.weight = (short) max(1, dependency.weight * child.weight / totalWeight);
555                         takeChild(itr, dependency, false, events);
556                     } while (itr.hasNext());
557                 }
558 
559                 notifyParentChanged(events);
560             }
561         }
562 
563         private long getTotalWeight() {
564             long totalWeight = 0L;
565             for (State state : children.values()) {
566                 totalWeight += state.weight;
567             }
568             return totalWeight;
569         }
570 
571         
572 
573 
574 
575 
576         private IntObjectMap<State> removeAllChildrenExcept(State stateToRetain) {
577             stateToRetain = children.remove(stateToRetain.streamId);
578             IntObjectMap<State> prevChildren = children;
579             
580             
581             initChildren();
582             if (stateToRetain != null) {
583                 children.put(stateToRetain.streamId, stateToRetain);
584             }
585             return prevChildren;
586         }
587 
588         private void setParent(State newParent) {
589             
590             if (activeCountForTree != 0 && parent != null) {
591                 parent.removePseudoTimeQueue(this);
592                 parent.activeCountChangeForTree(-activeCountForTree);
593             }
594             parent = newParent;
595             
596             dependencyTreeDepth = newParent == null ? MAX_VALUE : newParent.dependencyTreeDepth + 1;
597         }
598 
599         private void initChildrenIfEmpty() {
600             if (children == IntCollections.<State>emptyMap()) {
601                 initChildren();
602             }
603         }
604 
605         private void initChildren() {
606             children = new IntObjectHashMap<State>(INITIAL_CHILDREN_MAP_SIZE);
607         }
608 
609         void write(int numBytes, Writer writer) throws Http2Exception {
610             assert stream != null;
611             try {
612                 writer.write(stream, numBytes);
613             } catch (Throwable t) {
614                 throw connectionError(INTERNAL_ERROR, t, "byte distribution write error");
615             }
616         }
617 
618         void activeCountChangeForTree(int increment) {
619             assert activeCountForTree + increment >= 0;
620             activeCountForTree += increment;
621             if (parent != null) {
622                 assert activeCountForTree != increment ||
623                        pseudoTimeQueueIndex == INDEX_NOT_IN_QUEUE ||
624                        parent.pseudoTimeQueue.containsTyped(this) :
625                      "State[" + streamId + "].activeCountForTree changed from 0 to " + increment + " is in a " +
626                      "pseudoTimeQueue, but not in parent[ " + parent.streamId + "]'s pseudoTimeQueue";
627                 if (activeCountForTree == 0) {
628                     parent.removePseudoTimeQueue(this);
629                 } else if (activeCountForTree == increment && !isDistributing()) {
630                     
631                     
632                     
633                     
634                     
635                     
636                     
637                     
638                     parent.offerAndInitializePseudoTime(this);
639                 }
640                 parent.activeCountChangeForTree(increment);
641             }
642         }
643 
644         void updateStreamableBytes(int newStreamableBytes, boolean isActive) {
645             if (isActive() != isActive) {
646                 if (isActive) {
647                     activeCountChangeForTree(1);
648                     setActive();
649                 } else {
650                     activeCountChangeForTree(-1);
651                     unsetActive();
652                 }
653             }
654 
655             streamableBytes = newStreamableBytes;
656         }
657 
658         
659 
660 
661         void updatePseudoTime(State parentState, int nsent, long totalQueuedWeights) {
662             assert streamId != CONNECTION_STREAM_ID && nsent >= 0;
663             
664             
665             pseudoTimeToWrite = min(pseudoTimeToWrite, parentState.pseudoTime) + nsent * totalQueuedWeights / weight;
666         }
667 
668         
669 
670 
671 
672 
673         void offerAndInitializePseudoTime(State state) {
674             state.pseudoTimeToWrite = pseudoTime;
675             offerPseudoTimeQueue(state);
676         }
677 
678         void offerPseudoTimeQueue(State state) {
679             pseudoTimeQueue.offer(state);
680             totalQueuedWeights += state.weight;
681         }
682 
683         
684 
685 
686         State pollPseudoTimeQueue() {
687             State state = pseudoTimeQueue.poll();
688             
689             totalQueuedWeights -= state.weight;
690             return state;
691         }
692 
693         void removePseudoTimeQueue(State state) {
694             if (pseudoTimeQueue.removeTyped(state)) {
695                 totalQueuedWeights -= state.weight;
696             }
697         }
698 
699         State peekPseudoTimeQueue() {
700             return pseudoTimeQueue.peek();
701         }
702 
703         void close() {
704             updateStreamableBytes(0, false);
705             stream = null;
706         }
707 
708         boolean wasStreamReservedOrActivated() {
709             return (flags & STATE_STREAM_ACTIVATED) != 0;
710         }
711 
712         void setStreamReservedOrActivated() {
713             flags |= STATE_STREAM_ACTIVATED;
714         }
715 
716         boolean isActive() {
717             return (flags & STATE_IS_ACTIVE) != 0;
718         }
719 
720         private void setActive() {
721             flags |= STATE_IS_ACTIVE;
722         }
723 
724         private void unsetActive() {
725             flags &= ~STATE_IS_ACTIVE;
726         }
727 
728         boolean isDistributing() {
729             return (flags & STATE_IS_DISTRIBUTING) != 0;
730         }
731 
732         void setDistributing() {
733             flags |= STATE_IS_DISTRIBUTING;
734         }
735 
736         void unsetDistributing() {
737             flags &= ~STATE_IS_DISTRIBUTING;
738         }
739 
740         @Override
741         public int priorityQueueIndex(DefaultPriorityQueue<?> queue) {
742             return queue == stateOnlyRemovalQueue ? stateOnlyQueueIndex : pseudoTimeQueueIndex;
743         }
744 
745         @Override
746         public void priorityQueueIndex(DefaultPriorityQueue<?> queue, int i) {
747             if (queue == stateOnlyRemovalQueue) {
748                 stateOnlyQueueIndex = i;
749             } else {
750                 pseudoTimeQueueIndex = i;
751             }
752         }
753 
754         @Override
755         public String toString() {
756             
757             StringBuilder sb = new StringBuilder(256 * (activeCountForTree > 0 ? activeCountForTree : 1));
758             toString(sb);
759             return sb.toString();
760         }
761 
762         private void toString(StringBuilder sb) {
763             sb.append("{streamId ").append(streamId)
764                     .append(" streamableBytes ").append(streamableBytes)
765                     .append(" activeCountForTree ").append(activeCountForTree)
766                     .append(" pseudoTimeQueueIndex ").append(pseudoTimeQueueIndex)
767                     .append(" pseudoTimeToWrite ").append(pseudoTimeToWrite)
768                     .append(" pseudoTime ").append(pseudoTime)
769                     .append(" flags ").append(flags)
770                     .append(" pseudoTimeQueue.size() ").append(pseudoTimeQueue.size())
771                     .append(" stateOnlyQueueIndex ").append(stateOnlyQueueIndex)
772                     .append(" parent.streamId ").append(parent == null ? -1 : parent.streamId).append("} [");
773 
774             if (!pseudoTimeQueue.isEmpty()) {
775                 for (State s : pseudoTimeQueue) {
776                     s.toString(sb);
777                     sb.append(", ");
778                 }
779                 
780                 sb.setLength(sb.length() - 2);
781             }
782             sb.append(']');
783         }
784     }
785 
786     
787 
788 
789     private static final class ParentChangedEvent {
790         final State state;
791         final State oldParent;
792 
793         
794 
795 
796 
797 
798         ParentChangedEvent(State state, State oldParent) {
799             this.state = state;
800             this.oldParent = oldParent;
801         }
802     }
803 }