1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.testsuite.transport.socket;
17
18 import io.netty.bootstrap.Bootstrap;
19 import io.netty.bootstrap.ServerBootstrap;
20 import io.netty.buffer.ByteBuf;
21 import io.netty.buffer.Unpooled;
22 import io.netty.channel.Channel;
23 import io.netty.channel.ChannelFuture;
24 import io.netty.channel.ChannelHandlerContext;
25 import io.netty.channel.SimpleChannelInboundHandler;
26 import java.util.concurrent.TimeUnit;
27 import org.junit.jupiter.api.Test;
28 import org.junit.jupiter.api.TestInfo;
29 import org.junit.jupiter.api.Timeout;
30
31 import java.io.IOException;
32 import java.util.concurrent.atomic.AtomicReference;
33
34 import static org.junit.jupiter.api.Assertions.assertEquals;
35 import static org.junit.jupiter.api.Assertions.assertTrue;
36
37 public class SocketCancelWriteTest extends AbstractSocketTest {
38
39 @Test
40 @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
41 public void testCancelWrite(TestInfo testInfo) throws Throwable {
42 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
43 @Override
44 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
45 testCancelWrite(serverBootstrap, bootstrap);
46 }
47 });
48 }
49
50 public void testCancelWrite(ServerBootstrap sb, Bootstrap cb) throws Throwable {
51 final TestHandler sh = new TestHandler();
52 final TestHandler ch = new TestHandler();
53 final ByteBuf a = Unpooled.buffer().writeByte('a');
54 final ByteBuf b = Unpooled.buffer().writeByte('b');
55 final ByteBuf c = Unpooled.buffer().writeByte('c');
56 final ByteBuf d = Unpooled.buffer().writeByte('d');
57 final ByteBuf e = Unpooled.buffer().writeByte('e');
58
59 cb.handler(ch);
60 sb.childHandler(sh);
61
62 Channel sc = sb.bind().sync().channel();
63 Channel cc = cb.connect(sc.localAddress()).sync().channel();
64
65 ChannelFuture f = cc.write(a);
66 assertTrue(f.cancel(false));
67 cc.writeAndFlush(b);
68 cc.write(c);
69 ChannelFuture f2 = cc.write(d);
70 assertTrue(f2.cancel(false));
71 cc.writeAndFlush(e);
72
73 while (sh.counter < 3) {
74 if (sh.exception.get() != null) {
75 break;
76 }
77 if (ch.exception.get() != null) {
78 break;
79 }
80 Thread.sleep(50);
81 }
82 sh.channel.close().sync();
83 ch.channel.close().sync();
84 sc.close().sync();
85
86 if (sh.exception.get() != null && !(sh.exception.get() instanceof IOException)) {
87 throw sh.exception.get();
88 }
89 if (sh.exception.get() != null) {
90 throw sh.exception.get();
91 }
92 if (ch.exception.get() != null && !(ch.exception.get() instanceof IOException)) {
93 throw ch.exception.get();
94 }
95 if (ch.exception.get() != null) {
96 throw ch.exception.get();
97 }
98 assertEquals(0, ch.counter);
99 assertEquals(Unpooled.wrappedBuffer(new byte[]{'b', 'c', 'e'}), sh.received);
100 }
101
102 private static class TestHandler extends SimpleChannelInboundHandler<ByteBuf> {
103 volatile Channel channel;
104 final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
105 volatile int counter;
106 final ByteBuf received = Unpooled.buffer();
107 @Override
108 public void channelActive(ChannelHandlerContext ctx)
109 throws Exception {
110 channel = ctx.channel();
111 }
112
113 @Override
114 public void channelRead0(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
115 counter += in.readableBytes();
116 received.writeBytes(in);
117 }
118
119 @Override
120 public void exceptionCaught(ChannelHandlerContext ctx,
121 Throwable cause) throws Exception {
122 if (exception.compareAndSet(null, cause)) {
123 ctx.close();
124 }
125 }
126 }
127 }