# Copyright (C) 2003 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software for any
# purpose with or without fee is hereby granted, provided that the above
# copyright notice and this permission notice appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

# $Id: zone.py,v 1.5 2003/06/06 05:17:53 halley Exp $

"""DNS Zones."""

import DNS.name
import DNS.node
import DNS.rdataclass
import DNS.rdatatype
import DNS.rdata
import DNS.tokenizer

class Zone(object):
    """A DNS zone.

    Zones are a container for nodes.  The zone object may be treated like
    a dictionary, e.g. zone[name] will retrieve the node associated with
    that name.  The I{name} may be a DNS.name.Name object, or it may be
    a string.  In the either case, if the name is relative it is treated
    as relative to the origin of the zone.
    
    @ivar rdclass: The zone's rdata class; the default is class IN.
    @type rdclass: int
    @ivar origin: The origin of the zone.
    @type origin: DNS.name.Name object
    @ivar ttl: The default TTL (used when reading from master file format).
    @type ttl: int
    @ivar nodes: A dictionary mapping the names of nodes in the zone to the
    nodes themselves.
    @type nodes: dict
    @ivar last_name: The last name read (used when reading from master file
    format).
    @type last_name: DNS.name.Name object
    @ivar current_origin: The current origin (used when reading from master
    file format).
    @type current_origin: DNS.name.Name object"""
    
    def __init__(self, origin, rdclass=DNS.rdataclass.IN):
        """Initialize a zone object.

        @param origin: The origin of the zone.
        @type origin: DNS.name.Name object or string
        @param rdclass: The zone's rdata class; the default is class IN.
        @type rdclass: int"""

        self.rdclass = rdclass
        if isinstance(origin, str):
            origin = DNS.name.from_text(origin)
        self.origin = origin
        self.ttl = 0
        self.nodes = {}
        self.last_name = None
        self.current_origin = origin

    def __getitem__(self, key):
        if isinstance(key, str):
            key = DNS.name.from_text(key, None)
        return self.nodes[key]

    def __setitem__(self, key, value):
        assert key == value.name
        self.nodes[value.name] = value

    def __iter__(self):
        return self.nodes.itervalues()

def _eat_line(tok):
    while 1:
        (ttype, t) = tok.get()
        if ttype == DNS.tokenizer.EOL or ttype == DNS.tokenizer.EOF:
            break
        
def _rr_line(z, tok):
    """Process one line from a DNS master file."""
    # Name
    token = tok.get(want_leading = True)
    if token[0] != DNS.tokenizer.WHITESPACE:
        z.last_name = DNS.name.from_text(token[1], z.current_origin)
    name = z.last_name
    if not name.is_subdomain(z.origin):
        _eat_line(tok)
        return
    name = name.relativize(z.origin)
    token = tok.get()
    if token[0] != DNS.tokenizer.IDENTIFIER:
        raise SyntaxError
    # TTL
    try:
        ttl = int(token[1], 0)
        token = tok.get()
        if token[0] != DNS.tokenizer.IDENTIFIER:
            raise SyntaxError
    except SyntaxError:
        raise SyntaxError
    except:
        ttl = z.ttl
    # Class
    try:
        rdclass = DNS.rdataclass.from_text(token[1])
        token = tok.get()
        if token[0] != DNS.tokenizer.IDENTIFIER:
            raise SyntaxError
        rdclass = z.rdclass
    except SyntaxError:
        raise SyntaxError
    except:
        rdclass = DNS.rdataclass.IN
    if rdclass != z.rdclass:
        raise SyntaxError, "RR class is not zone's class"
    # Type
    rdtype = DNS.rdatatype.from_text(token[1])
    n = z.nodes.get(name)
    if n is None:
        n = DNS.node.Node(name)
        z.nodes[n.name] = n
    rd = DNS.rdata.from_text(rdclass, rdtype, tok, z.origin)
    covers = rd.covers()
    rds = n.find_rdataset(rdclass, rdtype, covers, True)
    rds.add(rd, ttl)

def from_text(text, origin, rdclass = DNS.rdataclass.IN):
    """Build a zone object from a master file format string.

    @param text: the master file format input
    @type text: string.
    @param origin: The origin of the zone.
    @type origin: DNS.name.Name object or string
    @param rdclass: The zone's rdata class; the default is class IN.
    @type rdclass: int
    @rtype: DNS.zone.Zone object"""

    # 'text' can also be a file, but we don't publish that fact
    # since it's an implementation detail.  The official file
    # interface is from_file().

    z = Zone(origin, rdclass)
    tok = DNS.tokenizer.Tokenizer(text)
    while 1:
        token = tok.get(True, True)
        if token[0] == DNS.tokenizer.EOF:
            break
        elif token[0] == DNS.tokenizer.EOL:
            continue
        elif token[0] == DNS.tokenizer.COMMENT:
            tok.get_eol()
            continue
        elif token[1][0] == '$':
            u = token[1].upper()
            if u == '$TTL':
                z.ttl = tok.get_int()
                tok.get_eol()
            if u == '$ORIGIN':
                z.current_origin = tok.get_name()
                tok.get_eol()
            else:
                _eat_line(tok)
            continue
        tok.unget(token)
        _rr_line(z, tok)
    return z

def from_file(f, origin, rdclass = DNS.rdataclass.IN):
    """Read a master file and build a zone object.

    @param f: file or string.  If I{f} is a string, it is treated
    as the name of a file to open.
    @param origin: The origin of the zone.
    @type origin: DNS.name.Name object or string
    @param rdclass: The zone's rdata class; the default is class IN.
    @type rdclass: int
    @rtype: DNS.zone.Zone object"""

    if isinstance(f, str):
        f = file(f)
        want_close = True
    else:
        want_close = False
    try:
        z = from_text(f)
    finally:
        if want_close:
            f.close()
    return z

def from_xfr(xfr):
    """Convert the output of a zone transfer generator into a zone object.
    
    @param xfr: The xfr generator
    @type xfr: generator of DNS.message.Message objects
    @rtype: DNS.zone.Zone object"""
    
    z = None
    for r in xfr:
        if z is None:
            origin = r.answer[0].name
            rdclass = r.answer[0].rdatasets[0].rdclass
            z = Zone(origin, rdclass)
        for node in r.answer:
            znode = z.nodes.get(node.name)
            if not znode:
                znode = DNS.node.Node(node.name)
                z.nodes[znode.name] = znode
            for rds in node.rdatasets:
                zrds = znode.find_rdataset(rds.rdclass, rds.rdtype, rds.covers,
                                           True)
                zrds.update(rds)
    return z
