1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.testsuite.transport.socket;
17
18 import io.netty.bootstrap.Bootstrap;
19 import io.netty.bootstrap.ServerBootstrap;
20 import io.netty.buffer.ByteBuf;
21 import io.netty.buffer.ByteBufAllocator;
22 import io.netty.buffer.CompositeByteBuf;
23 import io.netty.channel.Channel;
24 import io.netty.channel.ChannelConfig;
25 import io.netty.channel.ChannelFutureListener;
26 import io.netty.channel.ChannelHandlerContext;
27 import io.netty.channel.ChannelInboundHandlerAdapter;
28 import io.netty.channel.ChannelInitializer;
29 import io.netty.channel.ChannelOption;
30 import io.netty.util.ReferenceCountUtil;
31 import org.junit.jupiter.api.Test;
32 import org.junit.jupiter.api.TestInfo;
33 import org.junit.jupiter.api.Timeout;
34
35 import java.io.IOException;
36 import java.util.Random;
37 import java.util.concurrent.CountDownLatch;
38 import java.util.concurrent.TimeUnit;
39 import java.util.concurrent.atomic.AtomicReference;
40
41 import static org.junit.jupiter.api.Assertions.assertEquals;
42
43 public class CompositeBufferGatheringWriteTest extends AbstractSocketTest {
44 private static final int EXPECTED_BYTES = 20;
45
46 @Test
47 @Timeout(value = 10000, unit = TimeUnit.MILLISECONDS)
48 public void testSingleCompositeBufferWrite(TestInfo testInfo) throws Throwable {
49 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
50 @Override
51 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
52 testSingleCompositeBufferWrite(serverBootstrap, bootstrap);
53 }
54 });
55 }
56
57 public void testSingleCompositeBufferWrite(ServerBootstrap sb, Bootstrap cb) throws Throwable {
58 Channel serverChannel = null;
59 Channel clientChannel = null;
60 try {
61 final CountDownLatch latch = new CountDownLatch(1);
62 final AtomicReference<Object> clientReceived = new AtomicReference<Object>();
63 sb.childHandler(new ChannelInitializer<Channel>() {
64 @Override
65 protected void initChannel(Channel ch) throws Exception {
66 ch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
67 @Override
68 public void channelActive(ChannelHandlerContext ctx) throws Exception {
69 ctx.writeAndFlush(newCompositeBuffer(ctx.alloc()))
70 .addListener(ChannelFutureListener.CLOSE);
71 }
72 });
73 }
74 });
75 cb.handler(new ChannelInitializer<Channel>() {
76 @Override
77 protected void initChannel(Channel ch) throws Exception {
78 ch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
79 private ByteBuf aggregator;
80 @Override
81 public void handlerAdded(ChannelHandlerContext ctx) {
82 aggregator = ctx.alloc().buffer(EXPECTED_BYTES);
83 }
84
85 @Override
86 public void channelRead(ChannelHandlerContext ctx, Object msg) {
87 try {
88 if (msg instanceof ByteBuf) {
89 aggregator.writeBytes((ByteBuf) msg);
90 }
91 } finally {
92 ReferenceCountUtil.release(msg);
93 }
94 }
95
96 @Override
97 public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
98
99 if (!(cause instanceof IOException)) {
100 clientReceived.set(cause);
101 latch.countDown();
102 }
103 }
104
105 @Override
106 public void channelInactive(ChannelHandlerContext ctx) throws Exception {
107 if (clientReceived.compareAndSet(null, aggregator)) {
108 try {
109 assertEquals(EXPECTED_BYTES, aggregator.readableBytes());
110 } catch (Throwable cause) {
111 aggregator.release();
112 aggregator = null;
113 clientReceived.set(cause);
114 } finally {
115 latch.countDown();
116 }
117 }
118 }
119 });
120 }
121 });
122
123 serverChannel = sb.bind().syncUninterruptibly().channel();
124 clientChannel = cb.connect(serverChannel.localAddress()).syncUninterruptibly().channel();
125
126 ByteBuf expected = newCompositeBuffer(clientChannel.alloc());
127 latch.await();
128 Object received = clientReceived.get();
129 if (received instanceof ByteBuf) {
130 ByteBuf actual = (ByteBuf) received;
131 assertEquals(expected, actual);
132 expected.release();
133 actual.release();
134 } else {
135 expected.release();
136 throw (Throwable) received;
137 }
138 } finally {
139 if (clientChannel != null) {
140 clientChannel.close().sync();
141 }
142 if (serverChannel != null) {
143 serverChannel.close().sync();
144 }
145 }
146 }
147
148 @Test
149 @Timeout(value = 10000, unit = TimeUnit.MILLISECONDS)
150 public void testCompositeBufferPartialWriteDoesNotCorruptData(TestInfo testInfo) throws Throwable {
151 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
152 @Override
153 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
154 testCompositeBufferPartialWriteDoesNotCorruptData(serverBootstrap, bootstrap);
155 }
156 });
157 }
158
159 protected void compositeBufferPartialWriteDoesNotCorruptDataInitServerConfig(ChannelConfig config,
160 int soSndBuf) {
161 }
162
163 public void testCompositeBufferPartialWriteDoesNotCorruptData(ServerBootstrap sb, Bootstrap cb) throws Throwable {
164
165
166
167
168 Channel serverChannel = null;
169 Channel clientChannel = null;
170 try {
171 Random r = new Random();
172 final int soSndBuf = 1024;
173 ByteBufAllocator alloc = ByteBufAllocator.DEFAULT;
174 final ByteBuf expectedContent = alloc.buffer(soSndBuf * 2);
175 expectedContent.writeBytes(newRandomBytes(expectedContent.writableBytes(), r));
176 final CountDownLatch latch = new CountDownLatch(1);
177 final AtomicReference<Object> clientReceived = new AtomicReference<Object>();
178 sb.childOption(ChannelOption.SO_SNDBUF, soSndBuf)
179 .childHandler(new ChannelInitializer<Channel>() {
180 @Override
181 protected void initChannel(Channel ch) throws Exception {
182 ch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
183 @Override
184 public void channelActive(ChannelHandlerContext ctx) throws Exception {
185 compositeBufferPartialWriteDoesNotCorruptDataInitServerConfig(ctx.channel().config(),
186 soSndBuf);
187
188 int offset = soSndBuf - 100;
189 ctx.write(expectedContent.retainedSlice(expectedContent.readerIndex(), offset));
190
191
192 CompositeByteBuf compositeByteBuf = ctx.alloc().compositeBuffer();
193 compositeByteBuf.addComponent(true,
194 expectedContent.retainedSlice(expectedContent.readerIndex() + offset, 50));
195 offset += 50;
196 compositeByteBuf.addComponent(true,
197 expectedContent.retainedSlice(expectedContent.readerIndex() + offset, 200));
198 offset += 200;
199 ctx.write(compositeByteBuf);
200
201
202
203 ctx.write(expectedContent.retainedSlice(expectedContent.readerIndex() + offset, 50));
204 offset += 50;
205
206
207 ctx.writeAndFlush(expectedContent.retainedSlice(expectedContent.readerIndex() + offset,
208 expectedContent.readableBytes() - expectedContent.readerIndex() - offset))
209 .addListener(ChannelFutureListener.CLOSE);
210 }
211
212 @Override
213 public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
214
215 if (!(cause instanceof IOException)) {
216 clientReceived.set(cause);
217 latch.countDown();
218 }
219 }
220 });
221 }
222 });
223 cb.handler(new ChannelInitializer<Channel>() {
224 @Override
225 protected void initChannel(Channel ch) throws Exception {
226 ch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
227 private ByteBuf aggregator;
228 @Override
229 public void handlerAdded(ChannelHandlerContext ctx) {
230 aggregator = ctx.alloc().buffer(expectedContent.readableBytes());
231 }
232
233 @Override
234 public void channelRead(ChannelHandlerContext ctx, Object msg) {
235 try {
236 if (msg instanceof ByteBuf) {
237 aggregator.writeBytes((ByteBuf) msg);
238 }
239 } finally {
240 ReferenceCountUtil.release(msg);
241 }
242 }
243
244 @Override
245 public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
246
247 if (!(cause instanceof IOException)) {
248 clientReceived.set(cause);
249 latch.countDown();
250 }
251 }
252
253 @Override
254 public void channelInactive(ChannelHandlerContext ctx) throws Exception {
255 if (clientReceived.compareAndSet(null, aggregator)) {
256 try {
257 assertEquals(expectedContent.readableBytes(), aggregator.readableBytes());
258 } catch (Throwable cause) {
259 aggregator.release();
260 aggregator = null;
261 clientReceived.set(cause);
262 } finally {
263 latch.countDown();
264 }
265 }
266 }
267 });
268 }
269 });
270
271 serverChannel = sb.bind().syncUninterruptibly().channel();
272 clientChannel = cb.connect(serverChannel.localAddress()).syncUninterruptibly().channel();
273
274 latch.await();
275 Object received = clientReceived.get();
276 if (received instanceof ByteBuf) {
277 ByteBuf actual = (ByteBuf) received;
278 assertEquals(expectedContent, actual);
279 expectedContent.release();
280 actual.release();
281 } else {
282 expectedContent.release();
283 throw (Throwable) received;
284 }
285 } finally {
286 if (clientChannel != null) {
287 clientChannel.close().sync();
288 }
289 if (serverChannel != null) {
290 serverChannel.close().sync();
291 }
292 }
293 }
294
295 private static ByteBuf newCompositeBuffer(ByteBufAllocator alloc) {
296 CompositeByteBuf compositeByteBuf = alloc.compositeBuffer();
297 compositeByteBuf.addComponent(true, alloc.directBuffer(4).writeInt(100));
298 compositeByteBuf.addComponent(true, alloc.directBuffer(8).writeLong(123));
299 compositeByteBuf.addComponent(true, alloc.directBuffer(8).writeLong(456));
300 assertEquals(EXPECTED_BYTES, compositeByteBuf.readableBytes());
301 return compositeByteBuf;
302 }
303
304 private static byte[] newRandomBytes(int size, Random r) {
305 byte[] bytes = new byte[size];
306 r.nextBytes(bytes);
307 return bytes;
308 }
309 }