查看本类的 API文档回源码主页即时通讯网 - 即时通讯开发者社区!
1   /*
2    * Copyright 2020 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.handler.codec.haproxy;
17  
18  import io.netty.buffer.ByteBuf;
19  import io.netty.channel.ChannelHandler.Sharable;
20  import io.netty.channel.ChannelHandlerContext;
21  import io.netty.handler.codec.MessageToByteEncoder;
22  import io.netty.util.CharsetUtil;
23  import io.netty.util.NetUtil;
24  
25  import java.util.List;
26  
27  import static io.netty.handler.codec.haproxy.HAProxyConstants.*;
28  
29  /**
30   * Encodes an HAProxy proxy protocol message
31   *
32   * @see <a href="https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt">Proxy Protocol Specification</a>
33   */
34  @Sharable
35  public final class HAProxyMessageEncoder extends MessageToByteEncoder<HAProxyMessage> {
36  
37      private static final int V2_VERSION_BITMASK = 0x02 << 4;
38  
39      // Length for source/destination addresses for the UNIX family must be 108 bytes each.
40      static final int UNIX_ADDRESS_BYTES_LENGTH = 108;
41      static final int TOTAL_UNIX_ADDRESS_BYTES_LENGTH = UNIX_ADDRESS_BYTES_LENGTH * 2;
42  
43      public static final HAProxyMessageEncoder INSTANCE = new HAProxyMessageEncoder();
44  
45      private HAProxyMessageEncoder() {
46      }
47  
48      @Override
49      protected void encode(ChannelHandlerContext ctx, HAProxyMessage msg, ByteBuf out) throws Exception {
50          switch (msg.protocolVersion()) {
51              case V1:
52                  encodeV1(msg, out);
53                  break;
54              case V2:
55                  encodeV2(msg, out);
56                  break;
57              default:
58                  throw new HAProxyProtocolException("Unsupported version: " + msg.protocolVersion());
59          }
60      }
61  
62      private static void encodeV1(HAProxyMessage msg, ByteBuf out) {
63          out.writeBytes(TEXT_PREFIX);
64          out.writeByte((byte) ' ');
65          out.writeCharSequence(msg.proxiedProtocol().name(), CharsetUtil.US_ASCII);
66          out.writeByte((byte) ' ');
67          out.writeCharSequence(msg.sourceAddress(), CharsetUtil.US_ASCII);
68          out.writeByte((byte) ' ');
69          out.writeCharSequence(msg.destinationAddress(), CharsetUtil.US_ASCII);
70          out.writeByte((byte) ' ');
71          out.writeCharSequence(String.valueOf(msg.sourcePort()), CharsetUtil.US_ASCII);
72          out.writeByte((byte) ' ');
73          out.writeCharSequence(String.valueOf(msg.destinationPort()), CharsetUtil.US_ASCII);
74          out.writeByte((byte) '\r');
75          out.writeByte((byte) '\n');
76      }
77  
78      private static void encodeV2(HAProxyMessage msg, ByteBuf out) {
79          out.writeBytes(BINARY_PREFIX);
80          out.writeByte(V2_VERSION_BITMASK | msg.command().byteValue());
81          out.writeByte(msg.proxiedProtocol().byteValue());
82  
83          switch (msg.proxiedProtocol().addressFamily()) {
84              case AF_IPv4:
85              case AF_IPv6:
86                  byte[] srcAddrBytes = NetUtil.createByteArrayFromIpAddressString(msg.sourceAddress());
87                  byte[] dstAddrBytes = NetUtil.createByteArrayFromIpAddressString(msg.destinationAddress());
88                  // srcAddrLen + dstAddrLen + 4 (srcPort + dstPort) + numTlvBytes
89                  out.writeShort(srcAddrBytes.length + dstAddrBytes.length + 4 + msg.tlvNumBytes());
90                  out.writeBytes(srcAddrBytes);
91                  out.writeBytes(dstAddrBytes);
92                  out.writeShort(msg.sourcePort());
93                  out.writeShort(msg.destinationPort());
94                  encodeTlvs(msg.tlvs(), out);
95                  break;
96              case AF_UNIX:
97                  out.writeShort(TOTAL_UNIX_ADDRESS_BYTES_LENGTH + msg.tlvNumBytes());
98                  int srcAddrBytesWritten = out.writeCharSequence(msg.sourceAddress(), CharsetUtil.US_ASCII);
99                  out.writeZero(UNIX_ADDRESS_BYTES_LENGTH - srcAddrBytesWritten);
100                 int dstAddrBytesWritten = out.writeCharSequence(msg.destinationAddress(), CharsetUtil.US_ASCII);
101                 out.writeZero(UNIX_ADDRESS_BYTES_LENGTH - dstAddrBytesWritten);
102                 encodeTlvs(msg.tlvs(), out);
103                 break;
104             case AF_UNSPEC:
105                 out.writeShort(0);
106                 break;
107             default:
108                 throw new HAProxyProtocolException("unexpected addrFamily");
109         }
110     }
111 
112     private static void encodeTlv(HAProxyTLV haProxyTLV, ByteBuf out) {
113         if (haProxyTLV instanceof HAProxySSLTLV) {
114             HAProxySSLTLV ssltlv = (HAProxySSLTLV) haProxyTLV;
115             out.writeByte(haProxyTLV.typeByteValue());
116             out.writeShort(ssltlv.contentNumBytes());
117             out.writeByte(ssltlv.client());
118             out.writeInt(ssltlv.verify());
119             encodeTlvs(ssltlv.encapsulatedTLVs(), out);
120         } else {
121             out.writeByte(haProxyTLV.typeByteValue());
122             ByteBuf value = haProxyTLV.content();
123             int readableBytes = value.readableBytes();
124             out.writeShort(readableBytes);
125             out.writeBytes(value.readSlice(readableBytes));
126         }
127     }
128 
129     private static void encodeTlvs(List<HAProxyTLV> haProxyTLVs, ByteBuf out) {
130         for (int i = 0; i < haProxyTLVs.size(); i++) {
131             encodeTlv(haProxyTLVs.get(i), out);
132         }
133     }
134 }