1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.channel.embedded;
17
18 import java.net.SocketAddress;
19 import java.nio.channels.ClosedChannelException;
20 import java.util.ArrayDeque;
21 import java.util.Queue;
22 import java.util.concurrent.TimeUnit;
23
24 import io.netty.channel.AbstractChannel;
25 import io.netty.channel.Channel;
26 import io.netty.channel.ChannelConfig;
27 import io.netty.channel.ChannelFuture;
28 import io.netty.channel.ChannelFutureListener;
29 import io.netty.channel.ChannelHandler;
30 import io.netty.channel.ChannelHandlerContext;
31 import io.netty.channel.ChannelId;
32 import io.netty.channel.ChannelInitializer;
33 import io.netty.channel.ChannelMetadata;
34 import io.netty.channel.ChannelOutboundBuffer;
35 import io.netty.channel.ChannelPipeline;
36 import io.netty.channel.ChannelPromise;
37 import io.netty.channel.DefaultChannelConfig;
38 import io.netty.channel.DefaultChannelPipeline;
39 import io.netty.channel.EventLoop;
40 import io.netty.channel.RecvByteBufAllocator;
41 import io.netty.util.ReferenceCountUtil;
42 import io.netty.util.internal.ObjectUtil;
43 import io.netty.util.internal.PlatformDependent;
44 import io.netty.util.internal.RecyclableArrayList;
45 import io.netty.util.internal.logging.InternalLogger;
46 import io.netty.util.internal.logging.InternalLoggerFactory;
47
48
49
50
51 public class EmbeddedChannel extends AbstractChannel {
52
53 private static final SocketAddress LOCAL_ADDRESS = new EmbeddedSocketAddress();
54 private static final SocketAddress REMOTE_ADDRESS = new EmbeddedSocketAddress();
55
56 private static final ChannelHandler[] EMPTY_HANDLERS = new ChannelHandler[0];
57 private enum State { OPEN, ACTIVE, CLOSED }
58
59 private static final InternalLogger logger = InternalLoggerFactory.getInstance(EmbeddedChannel.class);
60
61 private static final ChannelMetadata METADATA_NO_DISCONNECT = new ChannelMetadata(false);
62 private static final ChannelMetadata METADATA_DISCONNECT = new ChannelMetadata(true);
63
64 private final EmbeddedEventLoop loop = new EmbeddedEventLoop();
65 private final ChannelFutureListener recordExceptionListener = new ChannelFutureListener() {
66 @Override
67 public void operationComplete(ChannelFuture future) throws Exception {
68 recordException(future);
69 }
70 };
71
72 private final ChannelMetadata metadata;
73 private final ChannelConfig config;
74
75 private Queue<Object> inboundMessages;
76 private Queue<Object> outboundMessages;
77 private Throwable lastException;
78 private State state;
79
80
81
82
83 public EmbeddedChannel() {
84 this(EMPTY_HANDLERS);
85 }
86
87
88
89
90
91
92 public EmbeddedChannel(ChannelId channelId) {
93 this(channelId, EMPTY_HANDLERS);
94 }
95
96
97
98
99
100
101 public EmbeddedChannel(ChannelHandler... handlers) {
102 this(EmbeddedChannelId.INSTANCE, handlers);
103 }
104
105
106
107
108
109
110
111
112 public EmbeddedChannel(boolean hasDisconnect, ChannelHandler... handlers) {
113 this(EmbeddedChannelId.INSTANCE, hasDisconnect, handlers);
114 }
115
116
117
118
119
120
121
122
123
124
125 public EmbeddedChannel(boolean register, boolean hasDisconnect, ChannelHandler... handlers) {
126 this(EmbeddedChannelId.INSTANCE, register, hasDisconnect, handlers);
127 }
128
129
130
131
132
133
134
135
136 public EmbeddedChannel(ChannelId channelId, ChannelHandler... handlers) {
137 this(channelId, false, handlers);
138 }
139
140
141
142
143
144
145
146
147
148
149 public EmbeddedChannel(ChannelId channelId, boolean hasDisconnect, ChannelHandler... handlers) {
150 this(channelId, true, hasDisconnect, handlers);
151 }
152
153
154
155
156
157
158
159
160
161
162
163
164 public EmbeddedChannel(ChannelId channelId, boolean register, boolean hasDisconnect,
165 ChannelHandler... handlers) {
166 this(null, channelId, register, hasDisconnect, handlers);
167 }
168
169
170
171
172
173
174
175
176
177
178
179
180
181 public EmbeddedChannel(Channel parent, ChannelId channelId, boolean register, boolean hasDisconnect,
182 final ChannelHandler... handlers) {
183 super(parent, channelId);
184 metadata = metadata(hasDisconnect);
185 config = new DefaultChannelConfig(this);
186 setup(register, handlers);
187 }
188
189
190
191
192
193
194
195
196
197
198
199 public EmbeddedChannel(ChannelId channelId, boolean hasDisconnect, final ChannelConfig config,
200 final ChannelHandler... handlers) {
201 super(null, channelId);
202 metadata = metadata(hasDisconnect);
203 this.config = ObjectUtil.checkNotNull(config, "config");
204 setup(true, handlers);
205 }
206
207 private static ChannelMetadata metadata(boolean hasDisconnect) {
208 return hasDisconnect ? METADATA_DISCONNECT : METADATA_NO_DISCONNECT;
209 }
210
211 private void setup(boolean register, final ChannelHandler... handlers) {
212 ObjectUtil.checkNotNull(handlers, "handlers");
213 ChannelPipeline p = pipeline();
214 p.addLast(new ChannelInitializer<Channel>() {
215 @Override
216 protected void initChannel(Channel ch) throws Exception {
217 ChannelPipeline pipeline = ch.pipeline();
218 for (ChannelHandler h: handlers) {
219 if (h == null) {
220 break;
221 }
222 pipeline.addLast(h);
223 }
224 }
225 });
226 if (register) {
227 ChannelFuture future = loop.register(this);
228 assert future.isDone();
229 }
230 }
231
232
233
234
235 public void register() throws Exception {
236 ChannelFuture future = loop.register(this);
237 assert future.isDone();
238 Throwable cause = future.cause();
239 if (cause != null) {
240 PlatformDependent.throwException(cause);
241 }
242 }
243
244 @Override
245 protected final DefaultChannelPipeline newChannelPipeline() {
246 return new EmbeddedChannelPipeline(this);
247 }
248
249 @Override
250 public ChannelMetadata metadata() {
251 return metadata;
252 }
253
254 @Override
255 public ChannelConfig config() {
256 return config;
257 }
258
259 @Override
260 public boolean isOpen() {
261 return state != State.CLOSED;
262 }
263
264 @Override
265 public boolean isActive() {
266 return state == State.ACTIVE;
267 }
268
269
270
271
272 public Queue<Object> inboundMessages() {
273 if (inboundMessages == null) {
274 inboundMessages = new ArrayDeque<Object>();
275 }
276 return inboundMessages;
277 }
278
279
280
281
282 @Deprecated
283 public Queue<Object> lastInboundBuffer() {
284 return inboundMessages();
285 }
286
287
288
289
290 public Queue<Object> outboundMessages() {
291 if (outboundMessages == null) {
292 outboundMessages = new ArrayDeque<Object>();
293 }
294 return outboundMessages;
295 }
296
297
298
299
300 @Deprecated
301 public Queue<Object> lastOutboundBuffer() {
302 return outboundMessages();
303 }
304
305
306
307
308 @SuppressWarnings("unchecked")
309 public <T> T readInbound() {
310 T message = (T) poll(inboundMessages);
311 if (message != null) {
312 ReferenceCountUtil.touch(message, "Caller of readInbound() will handle the message from this point");
313 }
314 return message;
315 }
316
317
318
319
320 @SuppressWarnings("unchecked")
321 public <T> T readOutbound() {
322 T message = (T) poll(outboundMessages);
323 if (message != null) {
324 ReferenceCountUtil.touch(message, "Caller of readOutbound() will handle the message from this point.");
325 }
326 return message;
327 }
328
329
330
331
332
333
334
335
336 public boolean writeInbound(Object... msgs) {
337 ensureOpen();
338 if (msgs.length == 0) {
339 return isNotEmpty(inboundMessages);
340 }
341
342 ChannelPipeline p = pipeline();
343 for (Object m: msgs) {
344 p.fireChannelRead(m);
345 }
346
347 flushInbound(false, voidPromise());
348 return isNotEmpty(inboundMessages);
349 }
350
351
352
353
354
355
356
357 public ChannelFuture writeOneInbound(Object msg) {
358 return writeOneInbound(msg, newPromise());
359 }
360
361
362
363
364
365
366
367 public ChannelFuture writeOneInbound(Object msg, ChannelPromise promise) {
368 if (checkOpen(true)) {
369 pipeline().fireChannelRead(msg);
370 }
371 return checkException(promise);
372 }
373
374
375
376
377
378
379 public EmbeddedChannel flushInbound() {
380 flushInbound(true, voidPromise());
381 return this;
382 }
383
384 private ChannelFuture flushInbound(boolean recordException, ChannelPromise promise) {
385 if (checkOpen(recordException)) {
386 pipeline().fireChannelReadComplete();
387 runPendingTasks();
388 }
389
390 return checkException(promise);
391 }
392
393
394
395
396
397
398
399 public boolean writeOutbound(Object... msgs) {
400 ensureOpen();
401 if (msgs.length == 0) {
402 return isNotEmpty(outboundMessages);
403 }
404
405 RecyclableArrayList futures = RecyclableArrayList.newInstance(msgs.length);
406 try {
407 for (Object m: msgs) {
408 if (m == null) {
409 break;
410 }
411 futures.add(write(m));
412 }
413
414 flushOutbound0();
415
416 int size = futures.size();
417 for (int i = 0; i < size; i++) {
418 ChannelFuture future = (ChannelFuture) futures.get(i);
419 if (future.isDone()) {
420 recordException(future);
421 } else {
422
423 future.addListener(recordExceptionListener);
424 }
425 }
426
427 checkException();
428 return isNotEmpty(outboundMessages);
429 } finally {
430 futures.recycle();
431 }
432 }
433
434
435
436
437
438
439
440 public ChannelFuture writeOneOutbound(Object msg) {
441 return writeOneOutbound(msg, newPromise());
442 }
443
444
445
446
447
448
449
450 public ChannelFuture writeOneOutbound(Object msg, ChannelPromise promise) {
451 if (checkOpen(true)) {
452 return write(msg, promise);
453 }
454 return checkException(promise);
455 }
456
457
458
459
460
461
462 public EmbeddedChannel flushOutbound() {
463 if (checkOpen(true)) {
464 flushOutbound0();
465 }
466 checkException(voidPromise());
467 return this;
468 }
469
470 private void flushOutbound0() {
471
472
473 runPendingTasks();
474
475 flush();
476 }
477
478
479
480
481
482
483 public boolean finish() {
484 return finish(false);
485 }
486
487
488
489
490
491
492
493 public boolean finishAndReleaseAll() {
494 return finish(true);
495 }
496
497
498
499
500
501
502
503 private boolean finish(boolean releaseAll) {
504 close();
505 try {
506 checkException();
507 return isNotEmpty(inboundMessages) || isNotEmpty(outboundMessages);
508 } finally {
509 if (releaseAll) {
510 releaseAll(inboundMessages);
511 releaseAll(outboundMessages);
512 }
513 }
514 }
515
516
517
518
519
520 public boolean releaseInbound() {
521 return releaseAll(inboundMessages);
522 }
523
524
525
526
527
528 public boolean releaseOutbound() {
529 return releaseAll(outboundMessages);
530 }
531
532 private static boolean releaseAll(Queue<Object> queue) {
533 if (isNotEmpty(queue)) {
534 for (;;) {
535 Object msg = queue.poll();
536 if (msg == null) {
537 break;
538 }
539 ReferenceCountUtil.release(msg);
540 }
541 return true;
542 }
543 return false;
544 }
545
546 private void finishPendingTasks(boolean cancel) {
547 runPendingTasks();
548 if (cancel) {
549
550 embeddedEventLoop().cancelScheduledTasks();
551 }
552 }
553
554 @Override
555 public final ChannelFuture close() {
556 return close(newPromise());
557 }
558
559 @Override
560 public final ChannelFuture disconnect() {
561 return disconnect(newPromise());
562 }
563
564 @Override
565 public final ChannelFuture close(ChannelPromise promise) {
566
567
568 runPendingTasks();
569 ChannelFuture future = super.close(promise);
570
571
572 finishPendingTasks(true);
573 return future;
574 }
575
576 @Override
577 public final ChannelFuture disconnect(ChannelPromise promise) {
578 ChannelFuture future = super.disconnect(promise);
579 finishPendingTasks(!metadata.hasDisconnect());
580 return future;
581 }
582
583 private static boolean isNotEmpty(Queue<Object> queue) {
584 return queue != null && !queue.isEmpty();
585 }
586
587 private static Object poll(Queue<Object> queue) {
588 return queue != null ? queue.poll() : null;
589 }
590
591
592
593
594
595 public void runPendingTasks() {
596 try {
597 embeddedEventLoop().runTasks();
598 } catch (Exception e) {
599 recordException(e);
600 }
601
602 try {
603 embeddedEventLoop().runScheduledTasks();
604 } catch (Exception e) {
605 recordException(e);
606 }
607 }
608
609
610
611
612
613
614
615
616 public boolean hasPendingTasks() {
617 return embeddedEventLoop().hasPendingNormalTasks() ||
618 embeddedEventLoop().nextScheduledTask() == 0;
619 }
620
621
622
623
624
625
626 public long runScheduledPendingTasks() {
627 try {
628 return embeddedEventLoop().runScheduledTasks();
629 } catch (Exception e) {
630 recordException(e);
631 return embeddedEventLoop().nextScheduledTask();
632 }
633 }
634
635 private void recordException(ChannelFuture future) {
636 if (!future.isSuccess()) {
637 recordException(future.cause());
638 }
639 }
640
641 private void recordException(Throwable cause) {
642 if (lastException == null) {
643 lastException = cause;
644 } else {
645 logger.warn(
646 "More than one exception was raised. " +
647 "Will report only the first one and log others.", cause);
648 }
649 }
650
651
652
653
654
655 public void advanceTimeBy(long duration, TimeUnit unit) {
656 embeddedEventLoop().advanceTimeBy(unit.toNanos(duration));
657 }
658
659
660
661
662
663
664 public void freezeTime() {
665 embeddedEventLoop().freezeTime();
666 }
667
668
669
670
671
672
673
674
675 public void unfreezeTime() {
676 embeddedEventLoop().unfreezeTime();
677 }
678
679
680
681
682 private ChannelFuture checkException(ChannelPromise promise) {
683 Throwable t = lastException;
684 if (t != null) {
685 lastException = null;
686
687 if (promise.isVoid()) {
688 PlatformDependent.throwException(t);
689 }
690
691 return promise.setFailure(t);
692 }
693
694 return promise.setSuccess();
695 }
696
697
698
699
700 public void checkException() {
701 checkException(voidPromise());
702 }
703
704
705
706
707
708 private boolean checkOpen(boolean recordException) {
709 if (!isOpen()) {
710 if (recordException) {
711 recordException(new ClosedChannelException());
712 }
713 return false;
714 }
715
716 return true;
717 }
718
719 private EmbeddedEventLoop embeddedEventLoop() {
720 if (isRegistered()) {
721 return (EmbeddedEventLoop) super.eventLoop();
722 }
723
724 return loop;
725 }
726
727
728
729
730 protected final void ensureOpen() {
731 if (!checkOpen(true)) {
732 checkException();
733 }
734 }
735
736 @Override
737 protected boolean isCompatible(EventLoop loop) {
738 return loop instanceof EmbeddedEventLoop;
739 }
740
741 @Override
742 protected SocketAddress localAddress0() {
743 return isActive()? LOCAL_ADDRESS : null;
744 }
745
746 @Override
747 protected SocketAddress remoteAddress0() {
748 return isActive()? REMOTE_ADDRESS : null;
749 }
750
751 @Override
752 protected void doRegister() throws Exception {
753 state = State.ACTIVE;
754 }
755
756 @Override
757 protected void doBind(SocketAddress localAddress) throws Exception {
758
759 }
760
761 @Override
762 protected void doDisconnect() throws Exception {
763 if (!metadata.hasDisconnect()) {
764 doClose();
765 }
766 }
767
768 @Override
769 protected void doClose() throws Exception {
770 state = State.CLOSED;
771 }
772
773 @Override
774 protected void doBeginRead() throws Exception {
775
776 }
777
778 @Override
779 protected AbstractUnsafe newUnsafe() {
780 return new EmbeddedUnsafe();
781 }
782
783 @Override
784 public Unsafe unsafe() {
785 return ((EmbeddedUnsafe) super.unsafe()).wrapped;
786 }
787
788 @Override
789 protected void doWrite(ChannelOutboundBuffer in) throws Exception {
790 for (;;) {
791 Object msg = in.current();
792 if (msg == null) {
793 break;
794 }
795
796 ReferenceCountUtil.retain(msg);
797 handleOutboundMessage(msg);
798 in.remove();
799 }
800 }
801
802
803
804
805
806
807 protected void handleOutboundMessage(Object msg) {
808 outboundMessages().add(msg);
809 }
810
811
812
813
814 protected void handleInboundMessage(Object msg) {
815 inboundMessages().add(msg);
816 }
817
818 private final class EmbeddedUnsafe extends AbstractUnsafe {
819
820
821
822 final Unsafe wrapped = new Unsafe() {
823 @Override
824 public RecvByteBufAllocator.Handle recvBufAllocHandle() {
825 return EmbeddedUnsafe.this.recvBufAllocHandle();
826 }
827
828 @Override
829 public SocketAddress localAddress() {
830 return EmbeddedUnsafe.this.localAddress();
831 }
832
833 @Override
834 public SocketAddress remoteAddress() {
835 return EmbeddedUnsafe.this.remoteAddress();
836 }
837
838 @Override
839 public void register(EventLoop eventLoop, ChannelPromise promise) {
840 EmbeddedUnsafe.this.register(eventLoop, promise);
841 runPendingTasks();
842 }
843
844 @Override
845 public void bind(SocketAddress localAddress, ChannelPromise promise) {
846 EmbeddedUnsafe.this.bind(localAddress, promise);
847 runPendingTasks();
848 }
849
850 @Override
851 public void connect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) {
852 EmbeddedUnsafe.this.connect(remoteAddress, localAddress, promise);
853 runPendingTasks();
854 }
855
856 @Override
857 public void disconnect(ChannelPromise promise) {
858 EmbeddedUnsafe.this.disconnect(promise);
859 runPendingTasks();
860 }
861
862 @Override
863 public void close(ChannelPromise promise) {
864 EmbeddedUnsafe.this.close(promise);
865 runPendingTasks();
866 }
867
868 @Override
869 public void closeForcibly() {
870 EmbeddedUnsafe.this.closeForcibly();
871 runPendingTasks();
872 }
873
874 @Override
875 public void deregister(ChannelPromise promise) {
876 EmbeddedUnsafe.this.deregister(promise);
877 runPendingTasks();
878 }
879
880 @Override
881 public void beginRead() {
882 EmbeddedUnsafe.this.beginRead();
883 runPendingTasks();
884 }
885
886 @Override
887 public void write(Object msg, ChannelPromise promise) {
888 EmbeddedUnsafe.this.write(msg, promise);
889 runPendingTasks();
890 }
891
892 @Override
893 public void flush() {
894 EmbeddedUnsafe.this.flush();
895 runPendingTasks();
896 }
897
898 @Override
899 public ChannelPromise voidPromise() {
900 return EmbeddedUnsafe.this.voidPromise();
901 }
902
903 @Override
904 public ChannelOutboundBuffer outboundBuffer() {
905 return EmbeddedUnsafe.this.outboundBuffer();
906 }
907 };
908
909 @Override
910 public void connect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) {
911 safeSetSuccess(promise);
912 }
913 }
914
915 private final class EmbeddedChannelPipeline extends DefaultChannelPipeline {
916 EmbeddedChannelPipeline(EmbeddedChannel channel) {
917 super(channel);
918 }
919
920 @Override
921 protected void onUnhandledInboundException(Throwable cause) {
922 recordException(cause);
923 }
924
925 @Override
926 protected void onUnhandledInboundMessage(ChannelHandlerContext ctx, Object msg) {
927 handleInboundMessage(msg);
928 }
929 }
930 }