1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.handler.codec.dns;
17
18 import io.netty.util.AbstractReferenceCounted;
19 import io.netty.util.ReferenceCountUtil;
20 import io.netty.util.ReferenceCounted;
21 import io.netty.util.ResourceLeakDetector;
22 import io.netty.util.ResourceLeakDetectorFactory;
23 import io.netty.util.ResourceLeakTracker;
24 import io.netty.util.internal.StringUtil;
25 import io.netty.util.internal.UnstableApi;
26
27 import java.util.ArrayList;
28 import java.util.List;
29
30 import static io.netty.util.internal.ObjectUtil.checkNotNull;
31
32
33
34
35 @UnstableApi
36 public abstract class AbstractDnsMessage extends AbstractReferenceCounted implements DnsMessage {
37
38 private static final ResourceLeakDetector<DnsMessage> leakDetector =
39 ResourceLeakDetectorFactory.instance().newResourceLeakDetector(DnsMessage.class);
40
41 private static final int SECTION_QUESTION = DnsSection.QUESTION.ordinal();
42 private static final int SECTION_COUNT = 4;
43
44 private final ResourceLeakTracker<DnsMessage> leak = leakDetector.track(this);
45 private short id;
46 private DnsOpCode opCode;
47 private boolean recursionDesired;
48 private byte z;
49
50
51
52 private Object questions;
53 private Object answers;
54 private Object authorities;
55 private Object additionals;
56
57
58
59
60 protected AbstractDnsMessage(int id) {
61 this(id, DnsOpCode.QUERY);
62 }
63
64
65
66
67 protected AbstractDnsMessage(int id, DnsOpCode opCode) {
68 setId(id);
69 setOpCode(opCode);
70 }
71
72 @Override
73 public int id() {
74 return id & 0xFFFF;
75 }
76
77 @Override
78 public DnsMessage setId(int id) {
79 this.id = (short) id;
80 return this;
81 }
82
83 @Override
84 public DnsOpCode opCode() {
85 return opCode;
86 }
87
88 @Override
89 public DnsMessage setOpCode(DnsOpCode opCode) {
90 this.opCode = checkNotNull(opCode, "opCode");
91 return this;
92 }
93
94 @Override
95 public boolean isRecursionDesired() {
96 return recursionDesired;
97 }
98
99 @Override
100 public DnsMessage setRecursionDesired(boolean recursionDesired) {
101 this.recursionDesired = recursionDesired;
102 return this;
103 }
104
105 @Override
106 public int z() {
107 return z;
108 }
109
110 @Override
111 public DnsMessage setZ(int z) {
112 this.z = (byte) (z & 7);
113 return this;
114 }
115
116 @Override
117 public int count(DnsSection section) {
118 return count(sectionOrdinal(section));
119 }
120
121 private int count(int section) {
122 final Object records = sectionAt(section);
123 if (records == null) {
124 return 0;
125 }
126 if (records instanceof DnsRecord) {
127 return 1;
128 }
129
130 @SuppressWarnings("unchecked")
131 final List<DnsRecord> recordList = (List<DnsRecord>) records;
132 return recordList.size();
133 }
134
135 @Override
136 public int count() {
137 int count = 0;
138 for (int i = 0; i < SECTION_COUNT; i ++) {
139 count += count(i);
140 }
141 return count;
142 }
143
144 @Override
145 public <T extends DnsRecord> T recordAt(DnsSection section) {
146 return recordAt(sectionOrdinal(section));
147 }
148
149 private <T extends DnsRecord> T recordAt(int section) {
150 final Object records = sectionAt(section);
151 if (records == null) {
152 return null;
153 }
154
155 if (records instanceof DnsRecord) {
156 return castRecord(records);
157 }
158
159 @SuppressWarnings("unchecked")
160 final List<DnsRecord> recordList = (List<DnsRecord>) records;
161 if (recordList.isEmpty()) {
162 return null;
163 }
164
165 return castRecord(recordList.get(0));
166 }
167
168 @Override
169 public <T extends DnsRecord> T recordAt(DnsSection section, int index) {
170 return recordAt(sectionOrdinal(section), index);
171 }
172
173 private <T extends DnsRecord> T recordAt(int section, int index) {
174 final Object records = sectionAt(section);
175 if (records == null) {
176 throw new IndexOutOfBoundsException("index: " + index + " (expected: none)");
177 }
178
179 if (records instanceof DnsRecord) {
180 if (index == 0) {
181 return castRecord(records);
182 } else {
183 throw new IndexOutOfBoundsException("index: " + index + "' (expected: 0)");
184 }
185 }
186
187 @SuppressWarnings("unchecked")
188 final List<DnsRecord> recordList = (List<DnsRecord>) records;
189 return castRecord(recordList.get(index));
190 }
191
192 @Override
193 public DnsMessage setRecord(DnsSection section, DnsRecord record) {
194 setRecord(sectionOrdinal(section), record);
195 return this;
196 }
197
198 private void setRecord(int section, DnsRecord record) {
199 clear(section);
200 setSection(section, checkQuestion(section, record));
201 }
202
203 @Override
204 public <T extends DnsRecord> T setRecord(DnsSection section, int index, DnsRecord record) {
205 return setRecord(sectionOrdinal(section), index, record);
206 }
207
208 private <T extends DnsRecord> T setRecord(int section, int index, DnsRecord record) {
209 checkQuestion(section, record);
210
211 final Object records = sectionAt(section);
212 if (records == null) {
213 throw new IndexOutOfBoundsException("index: " + index + " (expected: none)");
214 }
215
216 if (records instanceof DnsRecord) {
217 if (index == 0) {
218 setSection(section, record);
219 return castRecord(records);
220 } else {
221 throw new IndexOutOfBoundsException("index: " + index + " (expected: 0)");
222 }
223 }
224
225 @SuppressWarnings("unchecked")
226 final List<DnsRecord> recordList = (List<DnsRecord>) records;
227 return castRecord(recordList.set(index, record));
228 }
229
230 @Override
231 public DnsMessage addRecord(DnsSection section, DnsRecord record) {
232 addRecord(sectionOrdinal(section), record);
233 return this;
234 }
235
236 private void addRecord(int section, DnsRecord record) {
237 checkQuestion(section, record);
238
239 final Object records = sectionAt(section);
240 if (records == null) {
241 setSection(section, record);
242 return;
243 }
244
245 if (records instanceof DnsRecord) {
246 final List<DnsRecord> recordList = newRecordList();
247 recordList.add(castRecord(records));
248 recordList.add(record);
249 setSection(section, recordList);
250 return;
251 }
252
253 @SuppressWarnings("unchecked")
254 final List<DnsRecord> recordList = (List<DnsRecord>) records;
255 recordList.add(record);
256 }
257
258 @Override
259 public DnsMessage addRecord(DnsSection section, int index, DnsRecord record) {
260 addRecord(sectionOrdinal(section), index, record);
261 return this;
262 }
263
264 private void addRecord(int section, int index, DnsRecord record) {
265 checkQuestion(section, record);
266
267 final Object records = sectionAt(section);
268 if (records == null) {
269 if (index != 0) {
270 throw new IndexOutOfBoundsException("index: " + index + " (expected: 0)");
271 }
272
273 setSection(section, record);
274 return;
275 }
276
277 if (records instanceof DnsRecord) {
278 final List<DnsRecord> recordList;
279 if (index == 0) {
280 recordList = newRecordList();
281 recordList.add(record);
282 recordList.add(castRecord(records));
283 } else if (index == 1) {
284 recordList = newRecordList();
285 recordList.add(castRecord(records));
286 recordList.add(record);
287 } else {
288 throw new IndexOutOfBoundsException("index: " + index + " (expected: 0 or 1)");
289 }
290 setSection(section, recordList);
291 return;
292 }
293
294 @SuppressWarnings("unchecked")
295 final List<DnsRecord> recordList = (List<DnsRecord>) records;
296 recordList.add(index, record);
297 }
298
299 @Override
300 public <T extends DnsRecord> T removeRecord(DnsSection section, int index) {
301 return removeRecord(sectionOrdinal(section), index);
302 }
303
304 private <T extends DnsRecord> T removeRecord(int section, int index) {
305 final Object records = sectionAt(section);
306 if (records == null) {
307 throw new IndexOutOfBoundsException("index: " + index + " (expected: none)");
308 }
309
310 if (records instanceof DnsRecord) {
311 if (index != 0) {
312 throw new IndexOutOfBoundsException("index: " + index + " (expected: 0)");
313 }
314
315 T record = castRecord(records);
316 setSection(section, null);
317 return record;
318 }
319
320 @SuppressWarnings("unchecked")
321 final List<DnsRecord> recordList = (List<DnsRecord>) records;
322 return castRecord(recordList.remove(index));
323 }
324
325 @Override
326 public DnsMessage clear(DnsSection section) {
327 clear(sectionOrdinal(section));
328 return this;
329 }
330
331 @Override
332 public DnsMessage clear() {
333 for (int i = 0; i < SECTION_COUNT; i ++) {
334 clear(i);
335 }
336 return this;
337 }
338
339 private void clear(int section) {
340 final Object recordOrList = sectionAt(section);
341 setSection(section, null);
342 if (recordOrList instanceof ReferenceCounted) {
343 ((ReferenceCounted) recordOrList).release();
344 } else if (recordOrList instanceof List) {
345 @SuppressWarnings("unchecked")
346 List<DnsRecord> list = (List<DnsRecord>) recordOrList;
347 if (!list.isEmpty()) {
348 for (Object r : list) {
349 ReferenceCountUtil.release(r);
350 }
351 }
352 }
353 }
354
355 @Override
356 public DnsMessage touch() {
357 return (DnsMessage) super.touch();
358 }
359
360 @Override
361 public DnsMessage touch(Object hint) {
362 if (leak != null) {
363 leak.record(hint);
364 }
365 return this;
366 }
367
368 @Override
369 public DnsMessage retain() {
370 return (DnsMessage) super.retain();
371 }
372
373 @Override
374 public DnsMessage retain(int increment) {
375 return (DnsMessage) super.retain(increment);
376 }
377
378 @Override
379 protected void deallocate() {
380 clear();
381
382 final ResourceLeakTracker<DnsMessage> leak = this.leak;
383 if (leak != null) {
384 boolean closed = leak.close(this);
385 assert closed;
386 }
387 }
388
389 @Override
390 public boolean equals(Object obj) {
391 if (this == obj) {
392 return true;
393 }
394
395 if (!(obj instanceof DnsMessage)) {
396 return false;
397 }
398
399 final DnsMessage that = (DnsMessage) obj;
400 if (id() != that.id()) {
401 return false;
402 }
403
404 if (this instanceof DnsQuery) {
405 if (!(that instanceof DnsQuery)) {
406 return false;
407 }
408 } else if (that instanceof DnsQuery) {
409 return false;
410 }
411
412 return true;
413 }
414
415 @Override
416 public int hashCode() {
417 return id() * 31 + (this instanceof DnsQuery? 0 : 1);
418 }
419
420 private Object sectionAt(int section) {
421 switch (section) {
422 case 0:
423 return questions;
424 case 1:
425 return answers;
426 case 2:
427 return authorities;
428 case 3:
429 return additionals;
430 default:
431 break;
432 }
433
434 throw new Error();
435 }
436
437 private void setSection(int section, Object value) {
438 switch (section) {
439 case 0:
440 questions = value;
441 return;
442 case 1:
443 answers = value;
444 return;
445 case 2:
446 authorities = value;
447 return;
448 case 3:
449 additionals = value;
450 return;
451 default:
452 break;
453 }
454
455 throw new Error();
456 }
457
458 private static int sectionOrdinal(DnsSection section) {
459 return checkNotNull(section, "section").ordinal();
460 }
461
462 private static DnsRecord checkQuestion(int section, DnsRecord record) {
463 if (section == SECTION_QUESTION && !(checkNotNull(record, "record") instanceof DnsQuestion)) {
464 throw new IllegalArgumentException(
465 "record: " + record + " (expected: " + StringUtil.simpleClassName(DnsQuestion.class) + ')');
466 }
467 return record;
468 }
469
470 @SuppressWarnings("unchecked")
471 private static <T extends DnsRecord> T castRecord(Object record) {
472 return (T) record;
473 }
474
475 private static ArrayList<DnsRecord> newRecordList() {
476 return new ArrayList<DnsRecord>(2);
477 }
478 }