1
2
3
4
5
6
7
8
9
10
11
12
13 package io.netty.testsuite.transport.socket;
14
15 import io.netty.bootstrap.Bootstrap;
16 import io.netty.bootstrap.ServerBootstrap;
17 import io.netty.buffer.ByteBuf;
18 import io.netty.buffer.Unpooled;
19 import io.netty.channel.Channel;
20 import io.netty.channel.ChannelHandlerContext;
21 import io.netty.channel.ChannelInitializer;
22 import io.netty.channel.SimpleChannelInboundHandler;
23 import io.netty.channel.socket.SocketChannel;
24 import io.netty.handler.traffic.AbstractTrafficShapingHandler;
25 import io.netty.handler.traffic.ChannelTrafficShapingHandler;
26 import io.netty.handler.traffic.GlobalTrafficShapingHandler;
27 import io.netty.handler.traffic.TrafficCounter;
28 import io.netty.util.concurrent.DefaultEventExecutorGroup;
29 import io.netty.util.concurrent.EventExecutorGroup;
30 import io.netty.util.concurrent.Promise;
31 import io.netty.util.internal.logging.InternalLogger;
32 import io.netty.util.internal.logging.InternalLoggerFactory;
33 import org.junit.jupiter.api.AfterAll;
34 import org.junit.jupiter.api.BeforeAll;
35 import org.junit.jupiter.api.Test;
36 import org.junit.jupiter.api.TestInfo;
37 import org.junit.jupiter.api.Timeout;
38
39 import java.io.IOException;
40 import java.util.Arrays;
41 import java.util.Random;
42 import java.util.concurrent.Executors;
43 import java.util.concurrent.ScheduledExecutorService;
44 import java.util.concurrent.TimeUnit;
45 import java.util.concurrent.atomic.AtomicReference;
46
47 import static org.junit.jupiter.api.Assertions.assertTrue;
48
49 public class TrafficShapingHandlerTest extends AbstractSocketTest {
50 private static final InternalLogger logger = InternalLoggerFactory.getInstance(TrafficShapingHandlerTest.class);
51 private static final InternalLogger loggerServer = InternalLoggerFactory.getInstance("ServerTSH");
52 private static final InternalLogger loggerClient = InternalLoggerFactory.getInstance("ClientTSH");
53
54 static final int messageSize = 1024;
55 static final int bandwidthFactor = 12;
56 static final int minfactor = 3;
57 static final int maxfactor = bandwidthFactor + bandwidthFactor / 2;
58 static final long stepms = (1000 / bandwidthFactor - 10) / 10 * 10;
59 static final long minimalms = Math.max(stepms / 2, 20) / 10 * 10;
60 static final long check = 10;
61 private static final Random random = new Random();
62 static final byte[] data = new byte[messageSize];
63
64 private static final String TRAFFIC = "traffic";
65 private static String currentTestName;
66 private static int currentTestRun;
67
68 private static EventExecutorGroup group;
69 private static EventExecutorGroup groupForGlobal;
70 private static final ScheduledExecutorService executor = Executors.newScheduledThreadPool(10);
71 static {
72 random.nextBytes(data);
73 }
74
75 @BeforeAll
76 public static void createGroup() {
77 logger.info("Bandwidth: " + minfactor + " <= " + bandwidthFactor + " <= " + maxfactor +
78 " StepMs: " + stepms + " MinMs: " + minimalms + " CheckMs: " + check);
79 group = new DefaultEventExecutorGroup(8);
80 groupForGlobal = new DefaultEventExecutorGroup(8);
81 }
82
83 @AfterAll
84 public static void destroyGroup() throws Exception {
85 group.shutdownGracefully().sync();
86 groupForGlobal.shutdownGracefully().sync();
87 executor.shutdown();
88 }
89
90 private static long[] computeWaitRead(int[] multipleMessage) {
91 long[] minimalWaitBetween = new long[multipleMessage.length + 1];
92 minimalWaitBetween[0] = 0;
93 for (int i = 0; i < multipleMessage.length; i++) {
94 if (multipleMessage[i] > 1) {
95 minimalWaitBetween[i + 1] = (multipleMessage[i] - 1) * stepms + minimalms;
96 } else {
97 minimalWaitBetween[i + 1] = 10;
98 }
99 }
100 return minimalWaitBetween;
101 }
102
103 private static long[] computeWaitWrite(int[] multipleMessage) {
104 long[] minimalWaitBetween = new long[multipleMessage.length + 1];
105 for (int i = 0; i < multipleMessage.length; i++) {
106 if (multipleMessage[i] > 1) {
107 minimalWaitBetween[i] = (multipleMessage[i] - 1) * stepms + minimalms;
108 } else {
109 minimalWaitBetween[i] = 10;
110 }
111 }
112 return minimalWaitBetween;
113 }
114
115 private static long[] computeWaitAutoRead(int []autoRead) {
116 long [] minimalWaitBetween = new long[autoRead.length + 1];
117 minimalWaitBetween[0] = 0;
118 for (int i = 0; i < autoRead.length; i++) {
119 if (autoRead[i] != 0) {
120 if (autoRead[i] > 0) {
121 minimalWaitBetween[i + 1] = -1;
122 } else {
123 minimalWaitBetween[i + 1] = check;
124 }
125 } else {
126 minimalWaitBetween[i + 1] = 0;
127 }
128 }
129 return minimalWaitBetween;
130 }
131
132 @Test
133 @Timeout(value = 10000, unit = TimeUnit.MILLISECONDS)
134 public void testNoTrafficShapping(TestInfo testInfo) throws Throwable {
135 currentTestName = "TEST NO TRAFFIC";
136 currentTestRun = 0;
137 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
138 @Override
139 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
140 testNoTrafficShapping(serverBootstrap, bootstrap);
141 }
142 });
143 }
144
145 public void testNoTrafficShapping(ServerBootstrap sb, Bootstrap cb) throws Throwable {
146 int[] autoRead = null;
147 int[] multipleMessage = { 1, 2, 1 };
148 long[] minimalWaitBetween = null;
149 testTrafficShapping0(sb, cb, false, false, false, false, autoRead, minimalWaitBetween, multipleMessage);
150 }
151
152 @Test
153 @Timeout(value = 10000, unit = TimeUnit.MILLISECONDS)
154 public void testWriteTrafficShapping(TestInfo testInfo) throws Throwable {
155 currentTestName = "TEST WRITE";
156 currentTestRun = 0;
157 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
158 @Override
159 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
160 testWriteTrafficShapping(serverBootstrap, bootstrap);
161 }
162 });
163 }
164
165 public void testWriteTrafficShapping(ServerBootstrap sb, Bootstrap cb) throws Throwable {
166 int[] autoRead = null;
167 int[] multipleMessage = { 1, 2, 1, 1 };
168 long[] minimalWaitBetween = computeWaitWrite(multipleMessage);
169 testTrafficShapping0(sb, cb, false, false, true, false, autoRead, minimalWaitBetween, multipleMessage);
170 }
171
172 @Test
173 @Timeout(value = 10000, unit = TimeUnit.MILLISECONDS)
174 public void testReadTrafficShapping(TestInfo testInfo) throws Throwable {
175 currentTestName = "TEST READ";
176 currentTestRun = 0;
177 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
178 @Override
179 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
180 testReadTrafficShapping(serverBootstrap, bootstrap);
181 }
182 });
183 }
184
185 public void testReadTrafficShapping(ServerBootstrap sb, Bootstrap cb) throws Throwable {
186 int[] autoRead = null;
187 int[] multipleMessage = { 1, 2, 1, 1 };
188 long[] minimalWaitBetween = computeWaitRead(multipleMessage);
189 testTrafficShapping0(sb, cb, false, true, false, false, autoRead, minimalWaitBetween, multipleMessage);
190 }
191
192 @Test
193 @Timeout(value = 10000, unit = TimeUnit.MILLISECONDS)
194 public void testWrite1TrafficShapping(TestInfo testInfo) throws Throwable {
195 currentTestName = "TEST WRITE";
196 currentTestRun = 0;
197 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
198 @Override
199 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
200 testWrite1TrafficShapping(serverBootstrap, bootstrap);
201 }
202 });
203 }
204
205 public void testWrite1TrafficShapping(ServerBootstrap sb, Bootstrap cb) throws Throwable {
206 int[] autoRead = null;
207 int[] multipleMessage = { 1, 1, 1 };
208 long[] minimalWaitBetween = computeWaitWrite(multipleMessage);
209 testTrafficShapping0(sb, cb, false, false, true, false, autoRead, minimalWaitBetween, multipleMessage);
210 }
211
212 @Test
213 @Timeout(value = 10000, unit = TimeUnit.MILLISECONDS)
214 public void testRead1TrafficShapping(TestInfo testInfo) throws Throwable {
215 currentTestName = "TEST READ";
216 currentTestRun = 0;
217 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
218 @Override
219 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
220 testRead1TrafficShapping(serverBootstrap, bootstrap);
221 }
222 });
223 }
224
225 public void testRead1TrafficShapping(ServerBootstrap sb, Bootstrap cb) throws Throwable {
226 int[] autoRead = null;
227 int[] multipleMessage = { 1, 1, 1 };
228 long[] minimalWaitBetween = computeWaitRead(multipleMessage);
229 testTrafficShapping0(sb, cb, false, true, false, false, autoRead, minimalWaitBetween, multipleMessage);
230 }
231
232 @Test
233 @Timeout(value = 10000, unit = TimeUnit.MILLISECONDS)
234 public void testWriteGlobalTrafficShapping(TestInfo testInfo) throws Throwable {
235 currentTestName = "TEST GLOBAL WRITE";
236 currentTestRun = 0;
237 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
238 @Override
239 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
240 testWriteGlobalTrafficShapping(serverBootstrap, bootstrap);
241 }
242 });
243 }
244
245 public void testWriteGlobalTrafficShapping(ServerBootstrap sb, Bootstrap cb) throws Throwable {
246 int[] autoRead = null;
247 int[] multipleMessage = { 1, 2, 1, 1 };
248 long[] minimalWaitBetween = computeWaitWrite(multipleMessage);
249 testTrafficShapping0(sb, cb, false, false, true, true, autoRead, minimalWaitBetween, multipleMessage);
250 }
251
252 @Test
253 @Timeout(value = 10000, unit = TimeUnit.MILLISECONDS)
254 public void testReadGlobalTrafficShapping(TestInfo testInfo) throws Throwable {
255 currentTestName = "TEST GLOBAL READ";
256 currentTestRun = 0;
257 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
258 @Override
259 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
260 testReadGlobalTrafficShapping(serverBootstrap, bootstrap);
261 }
262 });
263 }
264
265 public void testReadGlobalTrafficShapping(ServerBootstrap sb, Bootstrap cb) throws Throwable {
266 int[] autoRead = null;
267 int[] multipleMessage = { 1, 2, 1, 1 };
268 long[] minimalWaitBetween = computeWaitRead(multipleMessage);
269 testTrafficShapping0(sb, cb, false, true, false, true, autoRead, minimalWaitBetween, multipleMessage);
270 }
271
272 @Test
273 @Timeout(value = 10000, unit = TimeUnit.MILLISECONDS)
274 public void testAutoReadTrafficShapping(TestInfo testInfo) throws Throwable {
275 currentTestName = "TEST AUTO READ";
276 currentTestRun = 0;
277 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
278 @Override
279 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
280 testAutoReadTrafficShapping(serverBootstrap, bootstrap);
281 }
282 });
283 }
284
285 public void testAutoReadTrafficShapping(ServerBootstrap sb, Bootstrap cb) throws Throwable {
286 int[] autoRead = { 1, -1, -1, 1, -2, 0, 1, 0, -3, 0, 1, 2, 0 };
287 int[] multipleMessage = new int[autoRead.length];
288 Arrays.fill(multipleMessage, 1);
289 long[] minimalWaitBetween = computeWaitAutoRead(autoRead);
290 testTrafficShapping0(sb, cb, false, true, false, false, autoRead, minimalWaitBetween, multipleMessage);
291 }
292
293 @Test
294 @Timeout(value = 10000, unit = TimeUnit.MILLISECONDS)
295 public void testAutoReadGlobalTrafficShapping(TestInfo testInfo) throws Throwable {
296 currentTestName = "TEST AUTO READ GLOBAL";
297 currentTestRun = 0;
298 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
299 @Override
300 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
301 testAutoReadGlobalTrafficShapping(serverBootstrap, bootstrap);
302 }
303 });
304 }
305
306 public void testAutoReadGlobalTrafficShapping(ServerBootstrap sb, Bootstrap cb) throws Throwable {
307 int[] autoRead = { 1, -1, -1, 1, -2, 0, 1, 0, -3, 0, 1, 2, 0 };
308 int[] multipleMessage = new int[autoRead.length];
309 Arrays.fill(multipleMessage, 1);
310 long[] minimalWaitBetween = computeWaitAutoRead(autoRead);
311 testTrafficShapping0(sb, cb, false, true, false, true, autoRead, minimalWaitBetween, multipleMessage);
312 }
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332 private static void testTrafficShapping0(
333 ServerBootstrap sb, Bootstrap cb, final boolean additionalExecutor,
334 final boolean limitRead, final boolean limitWrite, final boolean globalLimit, int[] autoRead,
335 long[] minimalWaitBetween, int[] multipleMessage) throws Throwable {
336
337 currentTestRun++;
338 logger.info("TEST: " + currentTestName + " RUN: " + currentTestRun +
339 " Exec: " + additionalExecutor + " Read: " + limitRead + " Write: " + limitWrite + " Global: "
340 + globalLimit);
341 final ServerHandler sh = new ServerHandler(autoRead, multipleMessage);
342 Promise<Boolean> promise = group.next().newPromise();
343 final ClientHandler ch = new ClientHandler(promise, minimalWaitBetween, multipleMessage,
344 autoRead);
345
346 final AbstractTrafficShapingHandler handler;
347 if (limitRead) {
348 if (globalLimit) {
349 handler = new GlobalTrafficShapingHandler(groupForGlobal, 0, bandwidthFactor * messageSize, check);
350 } else {
351 handler = new ChannelTrafficShapingHandler(0, bandwidthFactor * messageSize, check);
352 }
353 } else if (limitWrite) {
354 if (globalLimit) {
355 handler = new GlobalTrafficShapingHandler(groupForGlobal, bandwidthFactor * messageSize, 0, check);
356 } else {
357 handler = new ChannelTrafficShapingHandler(bandwidthFactor * messageSize, 0, check);
358 }
359 } else {
360 handler = null;
361 }
362
363 sb.childHandler(new ChannelInitializer<SocketChannel>() {
364 @Override
365 protected void initChannel(SocketChannel c) throws Exception {
366 if (limitRead) {
367 c.pipeline().addLast(TRAFFIC, handler);
368 }
369 c.pipeline().addLast(sh);
370 }
371 });
372 cb.handler(new ChannelInitializer<SocketChannel>() {
373 @Override
374 protected void initChannel(SocketChannel c) throws Exception {
375 if (limitWrite) {
376 c.pipeline().addLast(TRAFFIC, handler);
377 }
378 c.pipeline().addLast(ch);
379 }
380 });
381
382 Channel sc = sb.bind().sync().channel();
383 Channel cc = cb.connect(sc.localAddress()).sync().channel();
384
385 int totalNb = 0;
386 for (int i = 1; i < multipleMessage.length; i++) {
387 totalNb += multipleMessage[i];
388 }
389 Long start = TrafficCounter.milliSecondFromNano();
390 int nb = multipleMessage[0];
391 for (int i = 0; i < nb; i++) {
392 cc.write(cc.alloc().buffer().writeBytes(data));
393 }
394 cc.flush();
395
396 promise.await();
397 Long stop = TrafficCounter.milliSecondFromNano();
398 assertTrue(promise.isSuccess(), "Error during execution of TrafficShapping: " + promise.cause());
399
400 float average = (totalNb * messageSize) / (float) (stop - start);
401 logger.info("TEST: " + currentTestName + " RUN: " + currentTestRun +
402 " Average of traffic: " + average + " compare to " + bandwidthFactor);
403 sh.channel.close().sync();
404 ch.channel.close().sync();
405 sc.close().sync();
406 if (autoRead != null) {
407
408 Thread.sleep(minimalms);
409 }
410
411 if (autoRead == null && minimalWaitBetween != null) {
412 assertTrue(average <= maxfactor,
413 "Overall Traffic not ok since > " + maxfactor + ": " + average);
414 if (additionalExecutor) {
415
416 assertTrue(average >= 0.25, "Overall Traffic not ok since < 0.25: " + average);
417 } else {
418 assertTrue(average >= minfactor,
419 "Overall Traffic not ok since < " + minfactor + ": " + average);
420 }
421 }
422 if (handler != null && globalLimit) {
423 ((GlobalTrafficShapingHandler) handler).release();
424 }
425
426 if (sh.exception.get() != null && !(sh.exception.get() instanceof IOException)) {
427 throw sh.exception.get();
428 }
429 if (ch.exception.get() != null && !(ch.exception.get() instanceof IOException)) {
430 throw ch.exception.get();
431 }
432 if (sh.exception.get() != null) {
433 throw sh.exception.get();
434 }
435 if (ch.exception.get() != null) {
436 throw ch.exception.get();
437 }
438 }
439
440 private static class ClientHandler extends SimpleChannelInboundHandler<ByteBuf> {
441 volatile Channel channel;
442 final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
443 volatile int step;
444
445 private long currentLastTime = TrafficCounter.milliSecondFromNano();
446 private final long[] minimalWaitBetween;
447 private final int[] multipleMessage;
448 private final int[] autoRead;
449 final Promise<Boolean> promise;
450
451 ClientHandler(Promise<Boolean> promise, long[] minimalWaitBetween, int[] multipleMessage,
452 int[] autoRead) {
453 this.minimalWaitBetween = minimalWaitBetween;
454 this.multipleMessage = Arrays.copyOf(multipleMessage, multipleMessage.length);
455 this.promise = promise;
456 this.autoRead = autoRead;
457 }
458
459 @Override
460 public void channelActive(ChannelHandlerContext ctx) throws Exception {
461 channel = ctx.channel();
462 }
463
464 @Override
465 public void channelRead0(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
466 long lastTimestamp = 0;
467 loggerClient.debug("Step: " + step + " Read: " + in.readableBytes() / 8 + " blocks");
468 while (in.isReadable()) {
469 lastTimestamp = in.readLong();
470 multipleMessage[step]--;
471 }
472 if (multipleMessage[step] > 0) {
473
474 return;
475 }
476 long minimalWait = minimalWaitBetween != null? minimalWaitBetween[step] : 0;
477 int ar = 0;
478 if (autoRead != null) {
479 if (step > 0 && autoRead[step - 1] != 0) {
480 ar = autoRead[step - 1];
481 }
482 }
483 loggerClient.info("Step: " + step + " Interval: " + (lastTimestamp - currentLastTime) + " compareTo "
484 + minimalWait + " (" + ar + ')');
485 assertTrue(lastTimestamp - currentLastTime >= minimalWait,
486 "The interval of time is incorrect:" + (lastTimestamp - currentLastTime) + " not> " + minimalWait);
487 currentLastTime = lastTimestamp;
488 step++;
489 if (multipleMessage.length > step) {
490 int nb = multipleMessage[step];
491 for (int i = 0; i < nb; i++) {
492 channel.write(channel.alloc().buffer().writeBytes(data));
493 }
494 channel.flush();
495 } else {
496 promise.setSuccess(true);
497 }
498 }
499
500 @Override
501 public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
502 if (exception.compareAndSet(null, cause)) {
503 cause.printStackTrace();
504 promise.setFailure(cause);
505 ctx.close();
506 }
507 }
508 }
509
510 private static class ServerHandler extends SimpleChannelInboundHandler<ByteBuf> {
511 private final int[] autoRead;
512 private final int[] multipleMessage;
513 volatile Channel channel;
514 volatile int step;
515 final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
516
517 ServerHandler(int[] autoRead, int[] multipleMessage) {
518 this.autoRead = autoRead;
519 this.multipleMessage = Arrays.copyOf(multipleMessage, multipleMessage.length);
520 }
521
522 @Override
523 public void channelActive(ChannelHandlerContext ctx) throws Exception {
524 channel = ctx.channel();
525 }
526
527 @Override
528 public void channelRead0(final ChannelHandlerContext ctx, ByteBuf in) throws Exception {
529 byte[] actual = new byte[in.readableBytes()];
530 int nb = actual.length / messageSize;
531 loggerServer.info("Step: " + step + " Read: " + nb + " blocks");
532 in.readBytes(actual);
533 long timestamp = TrafficCounter.milliSecondFromNano();
534 int isAutoRead = 0;
535 int laststep = step;
536 for (int i = 0; i < nb; i++) {
537 multipleMessage[step]--;
538 if (multipleMessage[step] == 0) {
539
540 if (autoRead != null) {
541 isAutoRead = autoRead[step];
542 }
543 step++;
544 }
545 }
546 if (laststep != step) {
547
548 if (autoRead != null && isAutoRead != 2) {
549 if (isAutoRead != 0) {
550 loggerServer.info("Step: " + step + " Set AutoRead: " + (isAutoRead > 0));
551 channel.config().setAutoRead(isAutoRead > 0);
552 } else {
553 loggerServer.info("Step: " + step + " AutoRead: NO");
554 }
555 }
556 }
557 Thread.sleep(10);
558 loggerServer.debug("Step: " + step + " Write: " + nb);
559 for (int i = 0; i < nb; i++) {
560 channel.write(Unpooled.copyLong(timestamp));
561 }
562 channel.flush();
563 if (laststep != step) {
564
565 if (isAutoRead != 0) {
566 if (isAutoRead < 0) {
567 final int exactStep = step;
568 long wait = isAutoRead == -1? minimalms : stepms + minimalms;
569 if (isAutoRead == -3) {
570 wait = stepms * 3;
571 }
572 executor.schedule(new Runnable() {
573 @Override
574 public void run() {
575 loggerServer.info("Step: " + exactStep + " Reset AutoRead");
576 channel.config().setAutoRead(true);
577 }
578 }, wait, TimeUnit.MILLISECONDS);
579 } else {
580 if (isAutoRead > 1) {
581 loggerServer.debug("Step: " + step + " Will Set AutoRead: True");
582 final int exactStep = step;
583 executor.schedule(new Runnable() {
584 @Override
585 public void run() {
586 loggerServer.info("Step: " + exactStep + " Set AutoRead: True");
587 channel.config().setAutoRead(true);
588 }
589 }, stepms + minimalms, TimeUnit.MILLISECONDS);
590 }
591 }
592 }
593 }
594 }
595
596 @Override
597 public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
598 if (exception.compareAndSet(null, cause)) {
599 cause.printStackTrace();
600 ctx.close();
601 }
602 }
603 }
604 }