# Copyright (C) 2003 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

# $Id: query.py,v 1.9 2003/06/20 08:37:51 halley Exp $

"""Talk to a DNS server."""

from __future__ import generators

import socket
import struct
import sys

import DNS.exception
import DNS.name
import DNS.message
import DNS.rdataclass
import DNS.rdatatype

class UnexpectedSource(DNS.exception.DNSException):
    """Raised if a query response comes from an unexpected address or port."""
    pass

class BadResponse(DNS.exception.FormError):
    """Raised if a query response does not respond to the question asked."""
    pass

def udp(q, where, timeout=None, port=53):
    """Return the response obtained after sending a query via UDP.

    @param q: the query
    @type q: DNS.message.Message
    @param timeout: The number of seconds to wait before the query times out.
    If None, the default, wait forever.  This option requires Python 2.3a1
    or later.
    @type timeout: float
    @param port: The port to which to send the message.  The default is 53.
    @type port: int
    @rtype: DNS.message.Message object"""
    
    wire = q.to_wire()
    s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
    try:
        if hasattr(s, 'settimeout'):
            s.settimeout(timeout)
        s.sendto(wire, (where, port))
        (wire, from_address) = s.recvfrom(65536)
    finally:
        s.close()
    if from_address != (where, port):
        raise UnexpectedSource
    r = DNS.message.from_wire(wire, keyring=q.keyring, request_mac=q.mac)
    if not q.is_response(r):
        raise BadResponse
    return r

def _net_read(sock, count):
    """Read the specified number of bytes from sock.  Keep trying until we
    either get the desired amount, or we hit EOF."""
    
    s = ''
    while count > 0:
        n = sock.recv(count)
        if n == '':
            raise EOFError
        count = count - len(n)
        s = s + n
    return s

def tcp(q, where, timeout=None, port=53):
    """Return the response obtained after sending a query via TCP.

    @param q: the query
    @type q: DNS.message.Message object
    @param timeout: The number of seconds to wait before the query times out.
    If None, the default, wait forever.  This option requires Python 2.3a1
    or later.
    @type timeout: float
    @param port: The port to which to send the message.  The default is 53.
    @type port: int
    @rtype: DNS.message.Message object"""
    
    wire = q.to_wire()
    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
    try:
        if hasattr(s, 'settimeout'):
            s.settimeout(timeout)
        s.connect((where, port))
        l = len(wire)

        # copying the wire into tcpmsg is inefficient, but lets us
        # avoid writev() or doing a short write that would get pushed
        # onto the net
        tcpmsg = struct.pack("!H", l) + wire
        s.sendall(tcpmsg)
        ldata = _net_read(s, 2)
        (l,) = struct.unpack("!H", ldata)
        wire = _net_read(s, l)
    finally:
        s.close()
    r = DNS.message.from_wire(wire, keyring=q.keyring, request_mac=q.mac)
    if not q.is_response(r):
        raise BadResponse
    return r

def xfr(where, zone, rdtype=DNS.rdatatype.AXFR, rdclass=DNS.rdataclass.IN,
        timeout=None, port=53, keyring=None, keyname=None, relativize=True):
    """Return a generator for the responses to a zone transfer.

    @param zone: The name of the zone to transfer
    @type zone: DNS.name.Name object or string
    @param rdtype: The type of zone transfer.  The default is
    DNS.rdatatype.AXFR.
    @type rdtype: int or string
    @param rdclass: The class of the zone transfer.  The default is
    DNS.rdatatype.IN.
    @type rdclass: int or string
    @param timeout: The number of seconds to wait before the query times out.
    If None, the default, wait forever.  This option requires Python 2.3a1
    or later.
    @type timeout: float
    @param port: The port to which to send the message.  The default is 53.
    @type port: int
    @param keyring: The TSIG keyring to use
    @type keyring: dict
    @param keyname: The name of the TSIG key to use
    @type keyname: DNS.name.Name object or string
    @param relativize: If True, all names in the zone will be relativized to
    the zone origin.
    @type relativize: bool
    @rtype: generator of DNS.message.Message objects."""

    if isinstance(zone, str):
        zone = DNS.name.from_text(zone)
    q = DNS.message.make_query(zone, rdtype, rdclass)
    if not keyring is None:
        q.use_tsig(keyring, keyname)
    wire = q.to_wire()
    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
    if hasattr(s, 'settimeout'):
        s.settimeout(timeout)
    s.connect((where, port))
    l = len(wire)
    tcpmsg = struct.pack("!H", l) + wire
    s.sendall(tcpmsg)
    done = False
    seen_soa = False
    if relativize:
        origin = zone
        oname = DNS.name.empty
    else:
        origin = None
        oname = zone
    tsig_ctx = None
    multi = False
    while not done:
        ldata = _net_read(s, 2)
        (l,) = struct.unpack("!H", ldata)
        wire = _net_read(s, l)
        r = DNS.message.from_wire(wire, keyring=q.keyring, request_mac=q.mac,
                                  xfr=True, origin=origin, tsig_ctx=tsig_ctx,
                                  multi=multi)
        tsig_ctx = r.tsig_ctx
        multi = True
        if not seen_soa:
            if not r.answer or r.answer[0].name != oname:
                raise DNS.exception.FormError
            rds = r.answer[0].rdatasets[0]
            if rds.rdtype != DNS.rdatatype.SOA:
                raise DNS.exception.FormError
            seen_soa = True
            if len(r.answer) > 1 and r.answer[-1].name == oname:
                rds = r.answer[-1].rdatasets[-1]
                if rds.rdtype == DNS.rdatatype.SOA:
                    if q.keyring and not r.had_tsig:
                        raise DNS.exception.FormError, "missing TSIG"
                    done = True
        elif r.answer and r.answer[-1].name == oname:
            rds = r.answer[-1].rdatasets[-1]
            if rds.rdtype == DNS.rdatatype.SOA:
                done = True
        yield r
    s.close()
