1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.handler.codec.haproxy;
17
18 import io.netty.buffer.ByteBuf;
19 import io.netty.channel.ChannelHandlerContext;
20 import io.netty.handler.codec.ByteToMessageDecoder;
21 import io.netty.handler.codec.ProtocolDetectionResult;
22 import io.netty.util.CharsetUtil;
23
24 import java.util.List;
25
26 import static io.netty.handler.codec.haproxy.HAProxyConstants.*;
27
28
29
30
31
32
33 public class HAProxyMessageDecoder extends ByteToMessageDecoder {
34
35
36
37 private static final int V1_MAX_LENGTH = 108;
38
39
40
41
42 private static final int V2_MAX_LENGTH = 16 + 65535;
43
44
45
46
47 private static final int V2_MIN_LENGTH = 16 + 216;
48
49
50
51
52 private static final int V2_MAX_TLV = 65535 - 216;
53
54
55
56
57 private static final int BINARY_PREFIX_LENGTH = BINARY_PREFIX.length;
58
59
60
61
62 private static final ProtocolDetectionResult<HAProxyProtocolVersion> DETECTION_RESULT_V1 =
63 ProtocolDetectionResult.detected(HAProxyProtocolVersion.V1);
64
65
66
67
68 private static final ProtocolDetectionResult<HAProxyProtocolVersion> DETECTION_RESULT_V2 =
69 ProtocolDetectionResult.detected(HAProxyProtocolVersion.V2);
70
71
72
73
74 private HeaderExtractor headerExtractor;
75
76
77
78
79 private boolean discarding;
80
81
82
83
84 private int discardedBytes;
85
86
87
88
89 private final boolean failFast;
90
91
92
93
94 private boolean finished;
95
96
97
98
99 private int version = -1;
100
101
102
103
104
105 private final int v2MaxHeaderSize;
106
107
108
109
110
111 public HAProxyMessageDecoder() {
112 this(true);
113 }
114
115
116
117
118
119
120
121 public HAProxyMessageDecoder(boolean failFast) {
122 v2MaxHeaderSize = V2_MAX_LENGTH;
123 this.failFast = failFast;
124 }
125
126
127
128
129
130
131
132
133
134
135
136
137 public HAProxyMessageDecoder(int maxTlvSize) {
138 this(maxTlvSize, true);
139 }
140
141
142
143
144
145
146
147
148 public HAProxyMessageDecoder(int maxTlvSize, boolean failFast) {
149 if (maxTlvSize < 1) {
150 v2MaxHeaderSize = V2_MIN_LENGTH;
151 } else if (maxTlvSize > V2_MAX_TLV) {
152 v2MaxHeaderSize = V2_MAX_LENGTH;
153 } else {
154 int calcMax = maxTlvSize + V2_MIN_LENGTH;
155 if (calcMax > V2_MAX_LENGTH) {
156 v2MaxHeaderSize = V2_MAX_LENGTH;
157 } else {
158 v2MaxHeaderSize = calcMax;
159 }
160 }
161 this.failFast = failFast;
162 }
163
164
165
166
167
168 private static int findVersion(final ByteBuf buffer) {
169 final int n = buffer.readableBytes();
170
171 if (n < 13) {
172 return -1;
173 }
174
175 int idx = buffer.readerIndex();
176 return match(BINARY_PREFIX, buffer, idx) ? buffer.getByte(idx + BINARY_PREFIX_LENGTH) : 1;
177 }
178
179
180
181
182
183 private static int findEndOfHeader(final ByteBuf buffer) {
184 final int n = buffer.readableBytes();
185
186
187 if (n < 16) {
188 return -1;
189 }
190
191 int offset = buffer.readerIndex() + 14;
192
193
194 int totalHeaderBytes = 16 + buffer.getUnsignedShort(offset);
195
196
197 if (n >= totalHeaderBytes) {
198 return totalHeaderBytes;
199 } else {
200 return -1;
201 }
202 }
203
204
205
206
207
208 private static int findEndOfLine(final ByteBuf buffer) {
209 final int n = buffer.writerIndex();
210 for (int i = buffer.readerIndex(); i < n; i++) {
211 final byte b = buffer.getByte(i);
212 if (b == '\r' && i < n - 1 && buffer.getByte(i + 1) == '\n') {
213 return i;
214 }
215 }
216 return -1;
217 }
218
219 @Override
220 public boolean isSingleDecode() {
221
222
223 return true;
224 }
225
226 @Override
227 public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
228 super.channelRead(ctx, msg);
229 if (finished) {
230 ctx.pipeline().remove(this);
231 }
232 }
233
234 @Override
235 protected final void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
236
237 if (version == -1) {
238 if ((version = findVersion(in)) == -1) {
239 return;
240 }
241 }
242
243 ByteBuf decoded;
244
245 if (version == 1) {
246 decoded = decodeLine(ctx, in);
247 } else {
248 decoded = decodeStruct(ctx, in);
249 }
250
251 if (decoded != null) {
252 finished = true;
253 try {
254 if (version == 1) {
255 out.add(HAProxyMessage.decodeHeader(decoded.toString(CharsetUtil.US_ASCII)));
256 } else {
257 out.add(HAProxyMessage.decodeHeader(decoded));
258 }
259 } catch (HAProxyProtocolException e) {
260 fail(ctx, null, e);
261 }
262 }
263 }
264
265
266
267
268
269
270
271
272
273 private ByteBuf decodeStruct(ChannelHandlerContext ctx, ByteBuf buffer) throws Exception {
274 if (headerExtractor == null) {
275 headerExtractor = new StructHeaderExtractor(v2MaxHeaderSize);
276 }
277 return headerExtractor.extract(ctx, buffer);
278 }
279
280
281
282
283
284
285
286
287
288 private ByteBuf decodeLine(ChannelHandlerContext ctx, ByteBuf buffer) throws Exception {
289 if (headerExtractor == null) {
290 headerExtractor = new LineHeaderExtractor(V1_MAX_LENGTH);
291 }
292 return headerExtractor.extract(ctx, buffer);
293 }
294
295 private void failOverLimit(final ChannelHandlerContext ctx, int length) {
296 failOverLimit(ctx, String.valueOf(length));
297 }
298
299 private void failOverLimit(final ChannelHandlerContext ctx, String length) {
300 int maxLength = version == 1 ? V1_MAX_LENGTH : v2MaxHeaderSize;
301 fail(ctx, "header length (" + length + ") exceeds the allowed maximum (" + maxLength + ')', null);
302 }
303
304 private void fail(final ChannelHandlerContext ctx, String errMsg, Exception e) {
305 finished = true;
306 ctx.close();
307 HAProxyProtocolException ppex;
308 if (errMsg != null && e != null) {
309 ppex = new HAProxyProtocolException(errMsg, e);
310 } else if (errMsg != null) {
311 ppex = new HAProxyProtocolException(errMsg);
312 } else if (e != null) {
313 ppex = new HAProxyProtocolException(e);
314 } else {
315 ppex = new HAProxyProtocolException();
316 }
317 throw ppex;
318 }
319
320
321
322
323 public static ProtocolDetectionResult<HAProxyProtocolVersion> detectProtocol(ByteBuf buffer) {
324 if (buffer.readableBytes() < 12) {
325 return ProtocolDetectionResult.needsMoreData();
326 }
327
328 int idx = buffer.readerIndex();
329
330 if (match(BINARY_PREFIX, buffer, idx)) {
331 return DETECTION_RESULT_V2;
332 }
333 if (match(TEXT_PREFIX, buffer, idx)) {
334 return DETECTION_RESULT_V1;
335 }
336 return ProtocolDetectionResult.invalid();
337 }
338
339 private static boolean match(byte[] prefix, ByteBuf buffer, int idx) {
340 for (int i = 0; i < prefix.length; i++) {
341 final byte b = buffer.getByte(idx + i);
342 if (b != prefix[i]) {
343 return false;
344 }
345 }
346 return true;
347 }
348
349
350
351
352 private abstract class HeaderExtractor {
353
354 private final int maxHeaderSize;
355
356 protected HeaderExtractor(int maxHeaderSize) {
357 this.maxHeaderSize = maxHeaderSize;
358 }
359
360
361
362
363
364
365
366
367
368
369 public ByteBuf extract(ChannelHandlerContext ctx, ByteBuf buffer) throws Exception {
370 final int eoh = findEndOfHeader(buffer);
371 if (!discarding) {
372 if (eoh >= 0) {
373 final int length = eoh - buffer.readerIndex();
374 if (length > maxHeaderSize) {
375 buffer.readerIndex(eoh + delimiterLength(buffer, eoh));
376 failOverLimit(ctx, length);
377 return null;
378 }
379 ByteBuf frame = buffer.readSlice(length);
380 buffer.skipBytes(delimiterLength(buffer, eoh));
381 return frame;
382 } else {
383 final int length = buffer.readableBytes();
384 if (length > maxHeaderSize) {
385 discardedBytes = length;
386 buffer.skipBytes(length);
387 discarding = true;
388 if (failFast) {
389 failOverLimit(ctx, "over " + discardedBytes);
390 }
391 }
392 return null;
393 }
394 } else {
395 if (eoh >= 0) {
396 final int length = discardedBytes + eoh - buffer.readerIndex();
397 buffer.readerIndex(eoh + delimiterLength(buffer, eoh));
398 discardedBytes = 0;
399 discarding = false;
400 if (!failFast) {
401 failOverLimit(ctx, "over " + length);
402 }
403 } else {
404 discardedBytes += buffer.readableBytes();
405 buffer.skipBytes(buffer.readableBytes());
406 }
407 return null;
408 }
409 }
410
411
412
413
414
415
416
417
418 protected abstract int findEndOfHeader(ByteBuf buffer);
419
420
421
422
423
424
425
426
427 protected abstract int delimiterLength(ByteBuf buffer, int eoh);
428 }
429
430 private final class LineHeaderExtractor extends HeaderExtractor {
431
432 LineHeaderExtractor(int maxHeaderSize) {
433 super(maxHeaderSize);
434 }
435
436 @Override
437 protected int findEndOfHeader(ByteBuf buffer) {
438 return findEndOfLine(buffer);
439 }
440
441 @Override
442 protected int delimiterLength(ByteBuf buffer, int eoh) {
443 return buffer.getByte(eoh) == '\r' ? 2 : 1;
444 }
445 }
446
447 private final class StructHeaderExtractor extends HeaderExtractor {
448
449 StructHeaderExtractor(int maxHeaderSize) {
450 super(maxHeaderSize);
451 }
452
453 @Override
454 protected int findEndOfHeader(ByteBuf buffer) {
455 return HAProxyMessageDecoder.findEndOfHeader(buffer);
456 }
457
458 @Override
459 protected int delimiterLength(ByteBuf buffer, int eoh) {
460 return 0;
461 }
462 }
463 }