Package dns :: Module query
[hide private]
[frames] | no frames]

Source Code for Module dns.query

  1  # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. 
  2  # 
  3  # Permission to use, copy, modify, and distribute this software and its 
  4  # documentation for any purpose with or without fee is hereby granted, 
  5  # provided that the above copyright notice and this permission notice 
  6  # appear in all copies. 
  7  # 
  8  # THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES 
  9  # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 
 10  # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR 
 11  # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 
 12  # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 
 13  # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT 
 14  # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 
 15   
 16  """Talk to a DNS server.""" 
 17   
 18  from __future__ import generators 
 19   
 20  import errno 
 21  import select 
 22  import socket 
 23  import struct 
 24  import sys 
 25  import time 
 26   
 27  import dns.exception 
 28  import dns.inet 
 29  import dns.name 
 30  import dns.message 
 31  import dns.rdataclass 
 32  import dns.rdatatype 
 33  from ._compat import long, string_types 
 34   
 35  if sys.version_info > (3,): 
 36      select_error = OSError 
 37  else: 
 38      select_error = select.error 
 39   
 40   
41 -class UnexpectedSource(dns.exception.DNSException):
42 43 """A DNS query response came from an unexpected address or port."""
44 45
46 -class BadResponse(dns.exception.FormError):
47 48 """A DNS query response does not respond to the question asked."""
49 50
51 -def _compute_expiration(timeout):
52 if timeout is None: 53 return None 54 else: 55 return time.time() + timeout
56 57
58 -def _poll_for(fd, readable, writable, error, timeout):
59 """Poll polling backend. 60 @param fd: File descriptor 61 @type fd: int 62 @param readable: Whether to wait for readability 63 @type readable: bool 64 @param writable: Whether to wait for writability 65 @type writable: bool 66 @param timeout: Deadline timeout (expiration time, in seconds) 67 @type timeout: float 68 @return True on success, False on timeout 69 """ 70 event_mask = 0 71 if readable: 72 event_mask |= select.POLLIN 73 if writable: 74 event_mask |= select.POLLOUT 75 if error: 76 event_mask |= select.POLLERR 77 78 pollable = select.poll() 79 pollable.register(fd, event_mask) 80 81 if timeout: 82 event_list = pollable.poll(long(timeout * 1000)) 83 else: 84 event_list = pollable.poll() 85 86 return bool(event_list)
87 88
89 -def _select_for(fd, readable, writable, error, timeout):
90 """Select polling backend. 91 @param fd: File descriptor 92 @type fd: int 93 @param readable: Whether to wait for readability 94 @type readable: bool 95 @param writable: Whether to wait for writability 96 @type writable: bool 97 @param timeout: Deadline timeout (expiration time, in seconds) 98 @type timeout: float 99 @return True on success, False on timeout 100 """ 101 rset, wset, xset = [], [], [] 102 103 if readable: 104 rset = [fd] 105 if writable: 106 wset = [fd] 107 if error: 108 xset = [fd] 109 110 if timeout is None: 111 (rcount, wcount, xcount) = select.select(rset, wset, xset) 112 else: 113 (rcount, wcount, xcount) = select.select(rset, wset, xset, timeout) 114 115 return bool((rcount or wcount or xcount))
116 117
118 -def _wait_for(fd, readable, writable, error, expiration):
119 done = False 120 while not done: 121 if expiration is None: 122 timeout = None 123 else: 124 timeout = expiration - time.time() 125 if timeout <= 0.0: 126 raise dns.exception.Timeout 127 try: 128 if not _polling_backend(fd, readable, writable, error, timeout): 129 raise dns.exception.Timeout 130 except select_error as e: 131 if e.args[0] != errno.EINTR: 132 raise e 133 done = True
134 135
136 -def _set_polling_backend(fn):
137 """ 138 Internal API. Do not use. 139 """ 140 global _polling_backend 141 142 _polling_backend = fn
143 144 if hasattr(select, 'poll'): 145 # Prefer poll() on platforms that support it because it has no 146 # limits on the maximum value of a file descriptor (plus it will 147 # be more efficient for high values). 148 _polling_backend = _poll_for 149 else: 150 _polling_backend = _select_for 151 152
153 -def _wait_for_readable(s, expiration):
154 _wait_for(s, True, False, True, expiration)
155 156
157 -def _wait_for_writable(s, expiration):
158 _wait_for(s, False, True, True, expiration)
159 160
161 -def _addresses_equal(af, a1, a2):
162 # Convert the first value of the tuple, which is a textual format 163 # address into binary form, so that we are not confused by different 164 # textual representations of the same address 165 n1 = dns.inet.inet_pton(af, a1[0]) 166 n2 = dns.inet.inet_pton(af, a2[0]) 167 return n1 == n2 and a1[1:] == a2[1:]
168 169
170 -def _destination_and_source(af, where, port, source, source_port):
171 # Apply defaults and compute destination and source tuples 172 # suitable for use in connect(), sendto(), or bind(). 173 if af is None: 174 try: 175 af = dns.inet.af_for_address(where) 176 except: 177 af = dns.inet.AF_INET 178 if af == dns.inet.AF_INET: 179 destination = (where, port) 180 if source is not None or source_port != 0: 181 if source is None: 182 source = '0.0.0.0' 183 source = (source, source_port) 184 elif af == dns.inet.AF_INET6: 185 destination = (where, port, 0, 0) 186 if source is not None or source_port != 0: 187 if source is None: 188 source = '::' 189 source = (source, source_port, 0, 0) 190 return (af, destination, source)
191 192
193 -def udp(q, where, timeout=None, port=53, af=None, source=None, source_port=0, 194 ignore_unexpected=False, one_rr_per_rrset=False):
195 """Return the response obtained after sending a query via UDP. 196 197 @param q: the query 198 @type q: dns.message.Message 199 @param where: where to send the message 200 @type where: string containing an IPv4 or IPv6 address 201 @param timeout: The number of seconds to wait before the query times out. 202 If None, the default, wait forever. 203 @type timeout: float 204 @param port: The port to which to send the message. The default is 53. 205 @type port: int 206 @param af: the address family to use. The default is None, which 207 causes the address family to use to be inferred from the form of where. 208 If the inference attempt fails, AF_INET is used. 209 @type af: int 210 @rtype: dns.message.Message object 211 @param source: source address. The default is the wildcard address. 212 @type source: string 213 @param source_port: The port from which to send the message. 214 The default is 0. 215 @type source_port: int 216 @param ignore_unexpected: If True, ignore responses from unexpected 217 sources. The default is False. 218 @type ignore_unexpected: bool 219 @param one_rr_per_rrset: Put each RR into its own RRset 220 @type one_rr_per_rrset: bool 221 """ 222 223 wire = q.to_wire() 224 (af, destination, source) = _destination_and_source(af, where, port, 225 source, source_port) 226 s = socket.socket(af, socket.SOCK_DGRAM, 0) 227 begin_time = None 228 try: 229 expiration = _compute_expiration(timeout) 230 s.setblocking(0) 231 if source is not None: 232 s.bind(source) 233 _wait_for_writable(s, expiration) 234 begin_time = time.time() 235 s.sendto(wire, destination) 236 while 1: 237 _wait_for_readable(s, expiration) 238 (wire, from_address) = s.recvfrom(65535) 239 if _addresses_equal(af, from_address, destination) or \ 240 (dns.inet.is_multicast(where) and 241 from_address[1:] == destination[1:]): 242 break 243 if not ignore_unexpected: 244 raise UnexpectedSource('got a response from ' 245 '%s instead of %s' % (from_address, 246 destination)) 247 finally: 248 if begin_time is None: 249 response_time = 0 250 else: 251 response_time = time.time() - begin_time 252 s.close() 253 r = dns.message.from_wire(wire, keyring=q.keyring, request_mac=q.mac, 254 one_rr_per_rrset=one_rr_per_rrset) 255 r.time = response_time 256 if not q.is_response(r): 257 raise BadResponse 258 return r
259 260
261 -def _net_read(sock, count, expiration):
262 """Read the specified number of bytes from sock. Keep trying until we 263 either get the desired amount, or we hit EOF. 264 A Timeout exception will be raised if the operation is not completed 265 by the expiration time. 266 """ 267 s = b'' 268 while count > 0: 269 _wait_for_readable(sock, expiration) 270 n = sock.recv(count) 271 if n == b'': 272 raise EOFError 273 count = count - len(n) 274 s = s + n 275 return s
276 277
278 -def _net_write(sock, data, expiration):
279 """Write the specified data to the socket. 280 A Timeout exception will be raised if the operation is not completed 281 by the expiration time. 282 """ 283 current = 0 284 l = len(data) 285 while current < l: 286 _wait_for_writable(sock, expiration) 287 current += sock.send(data[current:])
288 289
290 -def _connect(s, address):
291 try: 292 s.connect(address) 293 except socket.error: 294 (ty, v) = sys.exc_info()[:2] 295 296 if hasattr(v, 'errno'): 297 v_err = v.errno 298 else: 299 v_err = v[0] 300 if v_err not in [errno.EINPROGRESS, errno.EWOULDBLOCK, errno.EALREADY]: 301 raise v
302 303
304 -def tcp(q, where, timeout=None, port=53, af=None, source=None, source_port=0, 305 one_rr_per_rrset=False):
306 """Return the response obtained after sending a query via TCP. 307 308 @param q: the query 309 @type q: dns.message.Message object 310 @param where: where to send the message 311 @type where: string containing an IPv4 or IPv6 address 312 @param timeout: The number of seconds to wait before the query times out. 313 If None, the default, wait forever. 314 @type timeout: float 315 @param port: The port to which to send the message. The default is 53. 316 @type port: int 317 @param af: the address family to use. The default is None, which 318 causes the address family to use to be inferred from the form of where. 319 If the inference attempt fails, AF_INET is used. 320 @type af: int 321 @rtype: dns.message.Message object 322 @param source: source address. The default is the wildcard address. 323 @type source: string 324 @param source_port: The port from which to send the message. 325 The default is 0. 326 @type source_port: int 327 @param one_rr_per_rrset: Put each RR into its own RRset 328 @type one_rr_per_rrset: bool 329 """ 330 331 wire = q.to_wire() 332 (af, destination, source) = _destination_and_source(af, where, port, 333 source, source_port) 334 s = socket.socket(af, socket.SOCK_STREAM, 0) 335 begin_time = None 336 try: 337 expiration = _compute_expiration(timeout) 338 s.setblocking(0) 339 begin_time = time.time() 340 if source is not None: 341 s.bind(source) 342 _connect(s, destination) 343 344 l = len(wire) 345 346 # copying the wire into tcpmsg is inefficient, but lets us 347 # avoid writev() or doing a short write that would get pushed 348 # onto the net 349 tcpmsg = struct.pack("!H", l) + wire 350 _net_write(s, tcpmsg, expiration) 351 ldata = _net_read(s, 2, expiration) 352 (l,) = struct.unpack("!H", ldata) 353 wire = _net_read(s, l, expiration) 354 finally: 355 if begin_time is None: 356 response_time = 0 357 else: 358 response_time = time.time() - begin_time 359 s.close() 360 r = dns.message.from_wire(wire, keyring=q.keyring, request_mac=q.mac, 361 one_rr_per_rrset=one_rr_per_rrset) 362 r.time = response_time 363 if not q.is_response(r): 364 raise BadResponse 365 return r
366 367
368 -def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN, 369 timeout=None, port=53, keyring=None, keyname=None, relativize=True, 370 af=None, lifetime=None, source=None, source_port=0, serial=0, 371 use_udp=False, keyalgorithm=dns.tsig.default_algorithm):
372 """Return a generator for the responses to a zone transfer. 373 374 @param where: where to send the message 375 @type where: string containing an IPv4 or IPv6 address 376 @param zone: The name of the zone to transfer 377 @type zone: dns.name.Name object or string 378 @param rdtype: The type of zone transfer. The default is 379 dns.rdatatype.AXFR. 380 @type rdtype: int or string 381 @param rdclass: The class of the zone transfer. The default is 382 dns.rdataclass.IN. 383 @type rdclass: int or string 384 @param timeout: The number of seconds to wait for each response message. 385 If None, the default, wait forever. 386 @type timeout: float 387 @param port: The port to which to send the message. The default is 53. 388 @type port: int 389 @param keyring: The TSIG keyring to use 390 @type keyring: dict 391 @param keyname: The name of the TSIG key to use 392 @type keyname: dns.name.Name object or string 393 @param relativize: If True, all names in the zone will be relativized to 394 the zone origin. It is essential that the relativize setting matches 395 the one specified to dns.zone.from_xfr(). 396 @type relativize: bool 397 @param af: the address family to use. The default is None, which 398 causes the address family to use to be inferred from the form of where. 399 If the inference attempt fails, AF_INET is used. 400 @type af: int 401 @param lifetime: The total number of seconds to spend doing the transfer. 402 If None, the default, then there is no limit on the time the transfer may 403 take. 404 @type lifetime: float 405 @rtype: generator of dns.message.Message objects. 406 @param source: source address. The default is the wildcard address. 407 @type source: string 408 @param source_port: The port from which to send the message. 409 The default is 0. 410 @type source_port: int 411 @param serial: The SOA serial number to use as the base for an IXFR diff 412 sequence (only meaningful if rdtype == dns.rdatatype.IXFR). 413 @type serial: int 414 @param use_udp: Use UDP (only meaningful for IXFR) 415 @type use_udp: bool 416 @param keyalgorithm: The TSIG algorithm to use; defaults to 417 dns.tsig.default_algorithm 418 @type keyalgorithm: string 419 """ 420 421 if isinstance(zone, string_types): 422 zone = dns.name.from_text(zone) 423 if isinstance(rdtype, string_types): 424 rdtype = dns.rdatatype.from_text(rdtype) 425 q = dns.message.make_query(zone, rdtype, rdclass) 426 if rdtype == dns.rdatatype.IXFR: 427 rrset = dns.rrset.from_text(zone, 0, 'IN', 'SOA', 428 '. . %u 0 0 0 0' % serial) 429 q.authority.append(rrset) 430 if keyring is not None: 431 q.use_tsig(keyring, keyname, algorithm=keyalgorithm) 432 wire = q.to_wire() 433 (af, destination, source) = _destination_and_source(af, where, port, 434 source, source_port) 435 if use_udp: 436 if rdtype != dns.rdatatype.IXFR: 437 raise ValueError('cannot do a UDP AXFR') 438 s = socket.socket(af, socket.SOCK_DGRAM, 0) 439 else: 440 s = socket.socket(af, socket.SOCK_STREAM, 0) 441 s.setblocking(0) 442 if source is not None: 443 s.bind(source) 444 expiration = _compute_expiration(lifetime) 445 _connect(s, destination) 446 l = len(wire) 447 if use_udp: 448 _wait_for_writable(s, expiration) 449 s.send(wire) 450 else: 451 tcpmsg = struct.pack("!H", l) + wire 452 _net_write(s, tcpmsg, expiration) 453 done = False 454 delete_mode = True 455 expecting_SOA = False 456 soa_rrset = None 457 if relativize: 458 origin = zone 459 oname = dns.name.empty 460 else: 461 origin = None 462 oname = zone 463 tsig_ctx = None 464 first = True 465 while not done: 466 mexpiration = _compute_expiration(timeout) 467 if mexpiration is None or mexpiration > expiration: 468 mexpiration = expiration 469 if use_udp: 470 _wait_for_readable(s, expiration) 471 (wire, from_address) = s.recvfrom(65535) 472 else: 473 ldata = _net_read(s, 2, mexpiration) 474 (l,) = struct.unpack("!H", ldata) 475 wire = _net_read(s, l, mexpiration) 476 is_ixfr = (rdtype == dns.rdatatype.IXFR) 477 r = dns.message.from_wire(wire, keyring=q.keyring, request_mac=q.mac, 478 xfr=True, origin=origin, tsig_ctx=tsig_ctx, 479 multi=True, first=first, 480 one_rr_per_rrset=is_ixfr) 481 tsig_ctx = r.tsig_ctx 482 first = False 483 answer_index = 0 484 if soa_rrset is None: 485 if not r.answer or r.answer[0].name != oname: 486 raise dns.exception.FormError( 487 "No answer or RRset not for qname") 488 rrset = r.answer[0] 489 if rrset.rdtype != dns.rdatatype.SOA: 490 raise dns.exception.FormError("first RRset is not an SOA") 491 answer_index = 1 492 soa_rrset = rrset.copy() 493 if rdtype == dns.rdatatype.IXFR: 494 if soa_rrset[0].serial <= serial: 495 # 496 # We're already up-to-date. 497 # 498 done = True 499 else: 500 expecting_SOA = True 501 # 502 # Process SOAs in the answer section (other than the initial 503 # SOA in the first message). 504 # 505 for rrset in r.answer[answer_index:]: 506 if done: 507 raise dns.exception.FormError("answers after final SOA") 508 if rrset.rdtype == dns.rdatatype.SOA and rrset.name == oname: 509 if expecting_SOA: 510 if rrset[0].serial != serial: 511 raise dns.exception.FormError( 512 "IXFR base serial mismatch") 513 expecting_SOA = False 514 elif rdtype == dns.rdatatype.IXFR: 515 delete_mode = not delete_mode 516 # 517 # If this SOA RRset is equal to the first we saw then we're 518 # finished. If this is an IXFR we also check that we're seeing 519 # the record in the expected part of the response. 520 # 521 if rrset == soa_rrset and \ 522 (rdtype == dns.rdatatype.AXFR or 523 (rdtype == dns.rdatatype.IXFR and delete_mode)): 524 done = True 525 elif expecting_SOA: 526 # 527 # We made an IXFR request and are expecting another 528 # SOA RR, but saw something else, so this must be an 529 # AXFR response. 530 # 531 rdtype = dns.rdatatype.AXFR 532 expecting_SOA = False 533 if done and q.keyring and not r.had_tsig: 534 raise dns.exception.FormError("missing TSIG") 535 yield r 536 s.close()
537