# Copyright (C) 2003 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

# $Id: zone.py,v 1.24 2003/07/31 06:36:54 halley Exp $

"""DNS Zones."""

import dns.exception
import dns.name
import dns.node
import dns.rdataclass
import dns.rdatatype
import dns.rdata
import dns.tokenizer
        
class Zone(object):
    """A DNS zone.

    A Zone is a mapping from names to nodes.  The zone object may be
    treated like a Python dictionary, e.g. zone[name] will retrieve
    the node associated with that name.  The I{name} may be a
    dns.name.Name object, or it may be a string.  In the either case,
    if the name is relative it is treated as relative to the origin of
    the zone.
    
    @ivar rdclass: The zone's rdata class; the default is class IN.
    @type rdclass: int
    @ivar origin: The origin of the zone.
    @type origin: dns.name.Name object
    @ivar nodes: A dictionary mapping the names of nodes in the zone to the
    nodes themselves.
    @type nodes: dict
    @cvar node_factory: the factory used to create a new node
    @type node_factory: class or callable
    """

    node_factory = dns.node.Node

    __slots__ = ['rdclass', 'origin', 'nodes']
    
    def __init__(self, origin, rdclass=dns.rdataclass.IN):
        """Initialize a zone object.

        @param origin: The origin of the zone.
        @type origin: dns.name.Name object
        @param rdclass: The zone's rdata class; the default is class IN.
        @type rdclass: int"""

        self.rdclass = rdclass
        self.origin = origin
        self.nodes = {}

    def _validate_key(self, key):
        if isinstance(key, str):
            key = dns.name.from_text(key, None)
        if not isinstance(key, dns.name.Name):
            raise KeyError, "zone key must be convertable to a DNS name"
        return key
    
    def __getitem__(self, key):
        key = self._validate_key(key)
        return self.nodes[key]

    def __setitem__(self, key, value):
        key = self._validate_key(key)
        self.nodes[key] = value

    def __delitem__(self, key):
        key = self._validate_key(key)
        del self.nodes[key]

    def __iter__(self):
        return self.nodes.iterkeys()

    def iterkeys(self):
        return self.nodes.iterkeys()

    def keys(self):
        return self.nodes.keys()

    def itervalues(self):
        return self.nodes.itervalues()

    def values(self):
        return self.nodes.values()

    def iteritems(self):
        return self.nodes.iteritems()

    def items(self):
        return self.nodes.items()

    def get(self, key):
        key = self._validate_key(key)
        return self.nodes.get(key)

    def __contains__(self, other):
        return other in self.nodes

    def find_rdataset(self, name, rdtype, covers=dns.rdatatype.NONE):
        """Look for rdata with the specified name and type in the zone,
        and return an rdataset encapsulating it.

        The I{name}, I{rdtype}, and I{covers} parameters may be
        strings, in which case they will be converted to their proper
        type.

        This routine is more efficient than the similar L{find_rrset}
        because it does not copy data, but may be less convenient for
        some uses since rdatasets are not bound to owner names.

        @param name: the owner name to look for
        @type name: DNS.name.Name object or string
        @param rdtype: the rdata type desired
        @type rdtype: int or string
        @param covers: the covered type (defaults to None)
        @type covers: int or string
        @raises KeyError: the node or rdata could not be found
        @rtype: dns.rrset.RRset object
        """

        name = self._validate_key(name)
        if isinstance(rdtype, str):
            rdtype = dns.rdatatype.from_text(rdtype)
        if isinstance(covers, str):
            covers = dns.rdatatype.from_text(covers)
        return self.nodes[name].find_rdataset(self.rdclass, rdtype, covers)

    def find_rrset(self, name, rdtype, covers=dns.rdatatype.NONE):
        """Look for rdata with the specified name and type in the zone,
        and return an RRset encapsulating it.

        The I{name}, I{rdtype}, and I{covers} parameters may be
        strings, in which case they will be converted to their proper
        type.
        
        This routine is less efficient than the similar
        L{find_rdataset} because it copies data, but may be more
        convenient for some uses since it returns an object which
        binds the owner name to the rdata.

        @param name: the owner name to look for
        @type name: DNS.name.Name object or string
        @param rdtype: the rdata type desired
        @type rdtype: int or string
        @param covers: the covered type (defaults to None)
        @type covers: int or string
        @raises KeyError: the node or rdata could not be found
        @rtype: dns.rrset.RRset object
        """
        
        name = self._validate_key(name)
        if isinstance(rdtype, str):
            rdtype = dns.rdatatype.from_text(rdtype)
        if isinstance(covers, str):
            covers = dns.rdatatype.from_text(covers)
        rdataset = self.nodes[name].find_rdataset(self.rdclass, rdtype, covers)
        rrset = dns.rrset.RRset(name, self.rdclass, rdtype, covers)
        rrset.update(rdataset)
        return rrset

    def __eq__(self, other):
        """Two zones are equal if they have the same origin, class, and
        nodes.
        @rtype: bool
        """
        
        if not isinstance(other, Zone):
            return False
        if self.rdclass != other.rdclass or \
           self.origin != other.origin or \
           self.nodes != other.nodes:
            return False
        return True

    def __ne__(self, other):
        """Are two zones not equal?
        @rtype: bool
        """
        
        return not self.__eq__(other)

class _MasterReader(object):
    """Read a DNS master file

    @ivar tok: The tokenizer
    @type tok: dns.tokenizer.Tokenizer object
    @ivar ttl: The default TTL
    @type ttl: int
    @ivar last_name: The last name read
    @type last_name: dns.name.Name object
    @ivar current_origin: The current origin
    @type current_origin: dns.name.Name object
    @ivar relativize: should names in the zone be relativized?
    @type relativize: bool
    @ivar zone: the zone
    @type zone: dns.zone.Zone object
    """

    def __init__(self, tok, origin, rdclass, relativize, zone_factory=Zone):
        if isinstance(origin, str):
            origin = dns.name.from_text(origin)
        self.tok = tok
        self.current_origin = origin
        self.relativize = relativize
        self.ttl = 0
        self.last_name = None
        self.zone = zone_factory(origin, rdclass)

    def _eat_line(self):
        while 1:
            (ttype, t) = self.tok.get()
            if ttype == dns.tokenizer.EOL or ttype == dns.tokenizer.EOF:
                break
        
    def _rr_line(self):
        """Process one line from a DNS master file."""
        # Name
        token = self.tok.get(want_leading = True)
        if token[0] != dns.tokenizer.WHITESPACE:
            self.last_name = dns.name.from_text(token[1], self.current_origin)
        name = self.last_name
        if not name.is_subdomain(self.zone.origin):
            self._eat_line()
            return
        if self.relativize:
            name = name.relativize(self.zone.origin)
        token = self.tok.get()
        if token[0] != dns.tokenizer.IDENTIFIER:
            raise dns.exception.SyntaxError
        # TTL
        try:
            ttl = int(token[1], 0)
            token = self.tok.get()
            if token[0] != dns.tokenizer.IDENTIFIER:
                raise dns.exception.SyntaxError
        except dns.exception.SyntaxError:
            raise dns.exception.SyntaxError
        except:
            ttl = self.ttl
        # Class
        try:
            rdclass = dns.rdataclass.from_text(token[1])
            token = self.tok.get()
            if token[0] != dns.tokenizer.IDENTIFIER:
                raise dns.exception.SyntaxError
        except dns.exception.SyntaxError:
            raise dns.exception.SyntaxError
        except:
            rdclass = self.zone.rdclass
        if rdclass != self.zone.rdclass:
            raise dns.exception.SyntaxError, "RR class is not zone's class"
        # Type
        rdtype = dns.rdatatype.from_text(token[1])
        n = self.zone.nodes.get(name)
        if n is None:
            n = self.zone.node_factory()
            self.zone.nodes[name] = n
        rd = dns.rdata.from_text(rdclass, rdtype, self.tok,
                                 self.current_origin, False)
        rd.choose_relativity(self.zone.origin, self.relativize)
        covers = rd.covers()
        rds = n.find_rdataset(rdclass, rdtype, covers, True)
        rds.add(rd, ttl)

    def read(self):
        """Read a DNS master file and build a zone object."""
        
        while 1:
            token = self.tok.get(True, True)
            if token[0] == dns.tokenizer.EOF:
                break
            elif token[0] == dns.tokenizer.EOL:
                continue
            elif token[0] == dns.tokenizer.COMMENT:
                self.tok.get_eol()
                continue
            elif token[1][0] == '$':
                u = token[1].upper()
                if u == '$TTL':
                    self.ttl = self.tok.get_int()
                    self.tok.get_eol()
                elif u == '$ORIGIN':
                    self.current_origin = self.tok.get_name()
                    self.tok.get_eol()
                else:
                    self.eat_line()
                continue
            self.tok.unget(token)
            self._rr_line()

def from_text(text, origin, rdclass = dns.rdataclass.IN, relativize = True,
              zone_factory=Zone):
    """Build a zone object from a master file format string.

    @param text: the master file format input
    @type text: string.
    @param origin: The origin of the zone.
    @type origin: dns.name.Name object or string
    @param rdclass: The zone's rdata class; the default is class IN.
    @type rdclass: int
    @param relativize: should names be relativized?  The default is True
    @type relativize: bool
    @rtype: dns.zone.Zone object"""

    # 'text' can also be a file, but we don't publish that fact
    # since it's an implementation detail.  The official file
    # interface is from_file().

    tok = dns.tokenizer.Tokenizer(text)
    reader = _MasterReader(tok, origin, rdclass, relativize, zone_factory)
    reader.read()
    return reader.zone

def from_file(f, origin, rdclass = dns.rdataclass.IN, relativize = True,
              zone_factory=Zone):
    """Read a master file and build a zone object.

    @param f: file or string.  If I{f} is a string, it is treated
    as the name of a file to open.
    @param origin: The origin of the zone.
    @type origin: dns.name.Name object or string
    @param rdclass: The zone's rdata class; the default is class IN.
    @type rdclass: int
    @param relativize: should names be relativized?  The default is True
    @type relativize: bool
    @rtype: dns.zone.Zone object"""

    if isinstance(f, str):
        f = file(f)
        want_close = True
    else:
        want_close = False
    try:
        z = from_text(f, origin, rdclass, relativize)
    finally:
        if want_close:
            f.close()
    return z

def from_xfr(xfr, zone_factory=Zone):
    """Convert the output of a zone transfer generator into a zone object.
    
    @param xfr: The xfr generator
    @type xfr: generator of dns.message.Message objects
    @rtype: dns.zone.Zone object"""
    
    z = None
    for r in xfr:
        if z is None:
            origin = r.answer[0].name
            rdclass = r.answer[0].rdclass
            z = zone_factory(origin, rdclass)
        for rrset in r.answer:
            znode = z.nodes.get(rrset.name)
            if not znode:
                znode = z.node_factory()
                z.nodes[rrset.name] = znode
            zrds = znode.find_rdataset(rrset.rdclass, rrset.rdtype,
                                       rrset.covers, True)
            zrds.update(rrset)
    return z
