# Copyright (C) 2003 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

# $Id: resolver.py,v 1.11 2003/06/20 07:12:38 halley Exp $

"""DNS stub resolver.

@var default_resolver: The default resolver object
@type default_resolver: DNS.resolver.Resolver object"""

import socket
import sys
import time
import types

import DNS.exception
import DNS.message
import DNS.name
import DNS.query
import DNS.rcode
import DNS.rdataclass
import DNS.rdatatype

if sys.platform == 'win32':
    import _winreg

class NXDOMAIN(DNS.exception.DNSException):
    """The query name does not exist."""
    pass

class Timeout(DNS.exception.DNSException):
    """The query timed out."""
    pass

class NoAnswer(DNS.exception.DNSException):
    """The response did not contain an answer to the question."""
    pass

class NoNameservers(DNS.exception.DNSException):
    """No non-broken nameservers are available to answer the query."""
    pass

class Answer(object):
    """DNS stub resolver answer

    Instances of this class bundle up the result of a successful DNS
    resolution.

    For convenience, the answer is iterable.  "for a in answer" is
    equivalent to "for a in answer.rdataset".

    Note that CNAMEs or DNAMEs in the response may mean that answer
    node's name might not be the query name.

    @ivar qname: The query name
    @type qname: DNS.name.Name object
    @ivar rdtype: The query type
    @type rdtype: int
    @ivar rdclass: The query class
    @type rdclass: int
    @ivar response: The response message
    @type response: DNS.message.Message object
    @ivar node: The node of the answer
    @type node: DNS.node.Node object
    @ivar rdataset: The answer
    @type rdataset: DNS.rdataset.Rdataset object
    """
    def __init__(self, qname, rdtype, rdclass, response):
        self.qname = qname
        self.rdtype = rdtype
        self.rdclass = rdclass
        self.response = response
        for count in range(0, 15):
            node = None
            rds = None
            try:
                node = response.find_node(response.answer, qname)
                try:
                    rds = node.find_rdataset(rdclass, rdtype)
                    break
                except KeyError:
                    if rdtype != DNS.rdatatype.CNAME:
                        try:
                            rds = node.find_rdataset(rdclass,
                                                     DNS.rdatatype.CNAME)
                            for rd in rds:
                                qname = rd.target
                                break
                            continue
                        except KeyError:
                            raise NoAnswer
            except KeyError:
                raise NoAnswer
        if rds is None:
            raise NoAnswer
        self.node = node
        self.rdataset = rds

    def __getattr__(self, attr):
        if attr == 'name':
            return self.node.name
        elif attr == 'ttl':
            return self.rdataset.ttl
        elif attr == 'covers':
            return self.rdataset.covers
        else:
            raise AttributeError, attr

    def __len__(self):
        return len(self.rdataset)

    def __iter__(self):
        return iter(self.rdataset)

class Resolver(object):
    """DNS stub resolver

    @ivar domain: The domain of this host
    @type domain: DNS.name.Name object
    @ivar nameservers: A list of nameservers to query.  Each nameserver is
    a string which contains the IP address of a nameserver.
    @type nameservers: list of strings
    @ivar search: The search list.  If the query name is a relative name,
    the resolver will construct an absolute query name by appending the search
    names one by one to the query name.
    @type search: list of DNS.name.Name objects
    @ivar port: The port to which to send queries.  The default is 53.
    @type port: int
    @ivar timeout: The number of seconds to wait for a response from a
    server, before timing out.
    @type timeout: float
    @ivar lifetime: The total number of seconds to spend trying to get an
    answer to the question.  If the lifetime expires, a Timeout exception
    will occur.
    @type lifetime: float
    @ivar keyring: The TSIG keyring to use.  The default is None.
    @type keyring: dict
    @ivar keyname: The TSIG keyname to use.  The default is None.
    @type keyname: DNS.name.Name object
    @ivar edns: The EDNS level to use.  The default is -1, no EDNS.
    @type edns: int
    @ivar ednsflags: The EDNS flags
    @type ednsflags: int
    @ivar payload: The EDNS payload size.  The default is 0.
    @type payload: int
    """
    def __init__(self, filename='/etc/resolv.conf', configure=True):
        """Initialize a resolver instance.

        @param filename: The filename of a configuration file in
        standard /etc/resolv.conf format.  This parameter is meaningful
        only when I{configure} is true and the platform is POSIX.
        @type filename: string or file object
        @param configure: If True (the default), the resolver instance
        is configured in the normal fashion for the operating system
        the resolver is running on.  (I.e. a /etc/resolv.conf file on
        POSIX systems and from the registry on Windows systems.)
        @type configure: bool"""

        self.reset()
        if configure:
            if sys.platform == 'win32':
                self.read_registry()
            elif filename:
                self.read_resolv_conf(filename)

    def reset(self):
        """Reset all resolver configuration to the defaults."""
        self.domain = \
            DNS.name.Name(DNS.name.from_text(socket.gethostname())[1:])
        if len(self.domain) == 0:
            self.domain = DNS.name.root
        self.nameservers = []
        self.search = []
        self.port = 53
        self.timeout = 2.0
        self.lifetime = 30.0
        self.keyring = None
        self.keyname = None
        self.edns = -1
        self.ednsflags = 0
        self.payload = 0

    def read_resolv_conf(self, f):
        """Process f as a file in the /etc/resolv.conf format.  If f is
        a string, it is used as the name of the file to open; otherwise it
        is treated as the file itself."""
        if isinstance(f, str) or isinstance(f, unicode):
            f = open(f, 'r')
            want_close = True
        else:
            want_close = False
        try:
            for l in f:
                tokens = l.split()
                if tokens[0] == 'nameserver':
                    self.nameservers.append(tokens[1])
                elif tokens[0] == 'domain':
                    self.domain = DNS.name.from_text(tokens[1])
                elif tokens[0] == 'search':
                    for suffix in tokens[1:]:
                        self.search.append(DNS.name.from_text(suffix))
        finally:
            if want_close:
                f.close()

    def _config_win32_nameservers(self, nameservers, split_char=','):
        """Configure a NameServer registry entry."""
        # we call str() on nameservers to convert it from unicode to ascii
        ns_list = str(nameservers).split(split_char)
        for ns in ns_list:
            if not ns in self.nameservers:
                self.nameservers.append(ns)

    def _config_win32_domain(self, domain):
        """Configure a Domain registry entry."""
        # we call str() on domain to convert it from unicode to ascii
        self.domain = DNS.name.from_text(str(domain))

    def _config_win32_search(self, search):
        """Configure a Search registry entry."""
        # we call str() on search to convert it from unicode to ascii
        search_list = str(search).split(',')
        for s in search_list:
            if not s in self.search:
                self.search.append(DNS.name.from_text(s))

    def _config_win32_fromkey(self, key):
        """Extract DNS info from a registry key."""
        servers, rtype = _winreg.QueryValueEx(key, 'NameServer')
        if servers:
            self._config_win32_nameservers(servers)
            dom, rtype = _winreg.QueryValueEx(key, 'Domain')
            if dom:
                self._config_win32_domain(servers)
        else:
            servers, rtype = _winreg.QueryValueEx(key, 'DhcpNameServer')
            if servers:
                # Annoyingly, the DhcpNameServer list is apparently space
                # separated instead of comma separated like NameServer.
                self._config_win32_nameservers(servers, ' ')
                dom, rtype = _winreg.QueryValueEx(key, 'DhcpDomain')
                if dom:
                    self._config_win32_domain(servers)
        search, rtype = _winreg.QueryValueEx(key, 'SearchList')
        if search:
            self._config_win32_search(servers)

    def read_registry(self):
        """Extract resolver configuration from the Windows registry."""
        lm = _winreg.ConnectRegistry(None, _winreg.HKEY_LOCAL_MACHINE)
        want_scan = False
        try:
            try:
                # XP, 2000
                tcp_params = _winreg.OpenKey(lm,
                                             r'SYSTEM\CurrentControlSet'
                                             r'\Services\Tcpip\Parameters')
                want_scan = True
            except EnvironmentError:
                # ME
                tcp_params = _winreg.OpenKey(lm,
                                             r'SYSTEM\CurrentControlSet'
                                             r'\Services\VxD\MSTCP')
            try:
                self._config_win32_fromkey(tcp_params)
            finally:
                tcp_params.Close()
            if want_scan:
                interfaces = _winreg.OpenKey(lm,
                                             r'SYSTEM\CurrentControlSet'
                                             r'\Services\Tcpip\Parameters'
                                             r'\Interfaces')
                try:
                    i = 0
                    while True:
                        try:
                            guid = _winreg.EnumKey(interfaces, i)
                            i += 1
                            key = _winreg.OpenKey(interfaces, guid)
                            try:
                                # enabled interfaces seem to have a non-empty
                                # NTEContextList
                                (nte, ttype) = _winreg.QueryValueEx(key,
                                                             'NTEContextList')
                                if nte:
                                    self._config_win32_fromkey(key)
                            finally:
                                key.Close()
                        except EnvironmentError:
                            break
                finally:
                    interfaces.Close()
        finally:
            lm.Close()

    def query(self, qname, rdtype=DNS.rdatatype.A, rdclass=DNS.rdataclass.IN,
              tcp=False):
        """Query nameservers to find the answer to the question.

        The I{qname}, I{rdtype}, and I{rdclass} parameters may be objects
        of the appropriate type, or strings that can be converted into objects
        of the appropriate type.  E.g. For I{rdtype} the integer 2 and the
        the string 'NS' both mean to query for records with DNS rdata type NS.
        
        @param qname: the query name
        @type qname: DNS.name.Name object or string
        @param rdtype: the query type
        @type rdtype: int or string
        @param rdclass: the query class
        @type rdclass: int or string
        @param tcp: use TCP to make the query (default is False).
        @type tcp: bool
        @rtype: DNS.resolver.Answer instance
        @raises Timeout: no answers could be found in the specified lifetime
        @raises NXDOMAIN: the query name does not exist
        @raises NoAnswer: the response did not contain an answer
        @raises NoNameservers: no non-broken nameservers are available to
        answer the question."""
        
        if isinstance(qname, str):
            qname = DNS.name.from_text(qname, None)
        if isinstance(rdtype, str):
            rdtype = DNS.rdatatype.from_text(rdtype)
        if isinstance(rdclass, str):
            rdclass = DNS.rdataclass.from_text(rdclass)
        qnames_to_try = []
        if qname.is_absolute():
            qnames_to_try.append(qname)
        else:
            if len(qname) > 1:
                qnames_to_try.append(qname.concatenate(DNS.name.root))
            if self.search:
                for suffix in self.search:
                    qnames_to_try.append(qname.concatenate(suffix))
            else:
                qnames_to_try.append(qname.concatenate(self.domain))
        all_nxdomain = True
        start = time.time()
        for qname in qnames_to_try:
            request = DNS.message.make_query(qname, rdtype, rdclass)
            if not self.keyname is None:
                request.use_tsig(self.keyring, self.keyname)
            request.use_edns(self.edns, self.ednsflags, self.payload)
            response = None
            #
            # make a copy of the servers list so we can alter it later.
            #
            nameservers = self.nameservers[:]
            while response is None:
                if len(nameservers) == 0:
                    raise NoNameservers
                for nameserver in nameservers:
                    now = time.time()
                    if now < start:
                        # Time going backwards is bad.  Just give up.
                        raise Timeout
                    duration = now - start
                    if duration >= self.lifetime:
                        raise Timeout
                    timeout = min(self.lifetime - duration, self.timeout)
                    try:
                        if tcp:
                            response = DNS.query.tcp(request, nameserver,
                                                     timeout, self.port)
                        else:
                            response = DNS.query.udp(request, nameserver,
                                                     timeout, self.port)
                    except socket.error:
                        #
                        # Communication failure or timeout.  Go to the
                        # next server
                        #
                        response = None
                        continue
                    except DNS.query.UnexpectedSource:
                        #
                        # Who knows?  Keep going.
                        #
                        response = None
                        continue
                    except DNS.exception.FormError:
                        #
                        # We don't understand what this server is
                        # saying.  Take it out of the mix and
                        # continue.
                        #
                        nameservers.remove(nameserver)
                        response = None
                        continue
                    rcode = response.rcode()
                    if rcode == DNS.rcode.NOERROR or \
                           rcode == DNS.rcode.NXDOMAIN:
                        break
                    response = None
            if response.rcode() == DNS.rcode.NXDOMAIN:
                continue
            all_nxdomain = False
            break
        if all_nxdomain:
            raise NXDOMAIN
        return Answer(qname, rdtype, rdclass, response)

    def use_tsig(self, keyring, keyname=None):
        """Add a TSIG signature to the query.

        @param keyring: The TSIG keyring to use; defaults to None.
        @type keyring: dict
        @param keyname: The name of the TSIG key to use; defaults to None.
        The key must be defined in the keyring.  If a keyring is specified
        but a keyname is not, then the key used will be the first key in the
        keyring.  Note that the order of keys in a dictionary is not defined,
        so applications should supply a keyname when a keyring is used, unless
        they know the keyring contains only one key."""
        self.keyring = keyring
        if keyname is None:
            self.keyname = self.keyring.keys()[0]
        else:
            self.keyname = keyname

    def use_edns(self, edns, ednsflags, payload):
        """Configure EDNS.

        @param edns: The EDNS level to use.  The default is -1, no EDNS.
        @type edns: int
        @param ednsflags: The EDNS flags
        @type ednsflags: int
        @param payload: The EDNS payload size.  The default is 0.
        @type payload: int"""

        if edns is None:
            edns = -1
        self.edns = edns
        self.ednsflags = ednsflags
        self.payload = payload

default_resolver = None

def query(qname, rdtype=DNS.rdatatype.A, rdclass=DNS.rdataclass.IN,
          tcp=False):
    """Query nameservers to find the answer to the question.

    This is a convenience function that uses the default resolver
    object to make the query.
    @see: L{DNS.resolver.Resolver.query} for more information on the
    parameters."""
    global default_resolver
    if default_resolver is None:
        default_resolver = Resolver()
    return default_resolver.query(qname, rdtype, rdclass, tcp)
