# Copyright (C) 2001-2003 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

# $Id: message.py,v 1.14 2003/06/20 07:12:38 halley Exp $

"""DNS Messages"""

import cStringIO
import random
import string
import struct
import time

import DNS.exception
import DNS.flags
import DNS.name
import DNS.node
import DNS.opcode
import DNS.rcode
import DNS.rdata
import DNS.rdataclass
import DNS.rdataset
import DNS.rdatatype
import DNS.tsig

class ShortHeader(DNS.exception.FormError):
    """Raised if the DNS packet passed to from_wire() is too short."""
    pass

class TrailingJunk(DNS.exception.FormError):
    """Raised if the DNS packet passed to from_wire() has extra junk
    at the end of it."""
    pass

class UnknownHeaderField(DNS.exception.DNSException):
    """Raised if a header field name is not recognized when converting from
    text into a message."""
    pass

class BadEDNS(DNS.exception.FormError):
    """Raised if an OPT record occurs somewhere other than the start of
    the additional data section."""
    pass

class BadTSIG(DNS.exception.FormError):
    """Raised if a TSIG record occurs somewhere other than the end of
    the additional data section."""
    pass

class Message(object):
    """A DNS message.

    @ivar id: The query id; the default is a randomly chosen id.
    @type id: int
    @ivar flags: The DNS flags of the message.  @see: RFC 1035 for an
    explanation of these flags.
    @type flags: int
    @ivar question: The question section.
    @type question: list of DNS.node.Node objects
    @ivar answer: The answer section.
    @type answer: list of DNS.node.Node objects
    @ivar authority: The authority section.
    @type authority: list of DNS.node.Node objects
    @ivar additional: The additional data section.
    @type additional: list of DNS.node.Node objects
    @ivar wire: When building a message object from wire format, this variable
    contains the wire-format message.
    @type wire: string
    @ivar current: When building a message object from wire format, this
    variable contains the offset from the beginning of wire of the next octet
    to be read.
    @type current: int
    @ivar last_name: The most recently read name when building a message object
    from text format.
    @type last_name: DNS.name.Name object
    @ivar edns: The EDNS level to use.  The default is -1, no EDNS.
    @type edns: int
    @ivar ednsflags: The EDNS flags
    @type ednsflags: int
    @ivar payload: The EDNS payload size.  The default is 0.
    @type payload: int
    @ivar keyring: The TSIG keyring to use.  The default is None.
    @type keyring: dict
    @ivar keyname: The TSIG keyname to use.  The default is None.
    @type keyname: DNS.name.Name object
    @ivar request_mac: The TSIG MAC of the request message associated with
    this message; used when validating TSIG signatures.   @see: RFC 2845 for
    more information on TSIG fields.
    @type request_mac: string
    @ivar other_data: TSIG other data.
    @type other_data: string
    @ivar tsig_error: TSIG error code; default is 0.
    @type tsig_error: int
    @ivar fudge: TSIG time fudge; default is 300 seconds.
    @type fudge: int
    @ivar mac: The TSIG MAC for this message.
    @type mac: string
    @ivar zone_rdclass: The class of the zone in messages which are used for
    zone transfers or for DNS dynamic updates.
    @type zone_rdclass: int
    @ivar xfr: Is the message being used to contain the results of a DNS
    zone transfer?  The default is False.
    @type xfr: bool
    @ivar origin: The origin of the zone in messages which are used for
    zone transfers or for DNS dynamic updates.  The default is None.
    @type origin: DNS.name.Name object
    @ivar tsig_ctx: The TSIG signature context associated with this
    message.  The default is None.
    @type tsig_ctx: hmac.HMAC object
    @ivar multi: Is this message part of a multi-message sequence?  The
    default is false.  This variable is used when validating TSIG signatures
    on messages which are part of a zone transfer.
    @type multi: bool
    @ivar had_tsig: Did the message decoded from wire format have a TSIG
    signature?
    @type had_tsig: bool"""

    def __init__(self):
        self.id = random.randint(0, 65535)
        self.flags = 0
        self.question = []
        self.answer = []
        self.authority = []
        self.additional = []
        self.wire = ''
        self.current = 0
        self.last_name = None
        self.edns = -1
        self.ednsflags = 0
        self.payload = 0
        self.keyring = None
        self.keyname = None
        self.request_mac = ''
        self.other_data = ''
        self.tsig_error = 0
        self.fudge = 300
        self.mac = ''
        self.zone_rdclass = DNS.rdataclass.IN
        self.xfr = False
        self.origin = None
        self.tsig_ctx = None
        self.multi = False
        self.had_tsig = False

    def __repr__(self):
        return '<DNS message, ID ' + `self.id` + '>'
    
    def __str__(self):
        s = cStringIO.StringIO()
        print >> s, 'id %d' % self.id
        print >> s, 'opcode %s' % \
              DNS.opcode.to_text(DNS.opcode.from_flags(self.flags))
        rc = DNS.rcode.from_flags(self.flags, self.ednsflags)
        print >> s, 'rcode %s' % DNS.rcode.to_text(rc)
        print >> s, 'flags %s' % DNS.flags.to_text(self.flags)
        if self.edns >= 0:
            print >> s, 'edns %s' % self.edns
            if self.ednsflags != 0:
                print >> s, 'eflags %s' % \
                      DNS.flags.edns_to_text(self.ednsflags)
            print >> s, 'payload', self.payload
        is_update = DNS.opcode.is_update(self.flags)
        if is_update:
            print >> s, ';ZONE'
        else:
            print >> s, ';QUESTION'
        for n in self.question:
            print >> s, n
        if is_update:
            print >> s, ';PREREQ'
        else:
            print >> s, ';ANSWER'
        for n in self.answer:
            print >> s, n
        if is_update:
            print >> s, ';UPDATE'
        else:
            print >> s, ';AUTHORITY'
        for n in self.authority:
            print >> s, n
        print >> s, ';ADDITIONAL'
        for n in self.additional:
            print >> s, n
        #
        # We strip off the final \n so the caller can print the result without
        # doing weird things to get around eccentricities in Python print
        # formatting
        #
        return s.getvalue()[:-1]

    def __eq__(self, other):
        """Two messages are equal if they have the same content in the
        header, question, answer, and authority sections.
        @rtype: bool"""
        if self.id != other.id:
            return False
        if self.flags != other.flags:
            return False
        for n in self.question:
            if n not in other.question:
                return False
        for n in other.question:
            if n not in self.question:
                return False
        for n in self.answer:
            if n not in other.answer:
                return False
        for n in other.answer:
            if n not in self.answer:
                return False
        for n in self.authority:
            if n not in other.authority:
                return False
        for n in other.authority:
            if n not in self.authority:
                return False
        return True

    def __ne__(self, other):
        """Are two messages not equal?
        @rtype: bool"""
        return not self.__eq__(other)

    def is_response(self, other):
        """Is other a response to self?
        @rtype: bool"""
        if other.flags & DNS.flags.QR == 0 or \
           self.id != other.id or \
           DNS.opcode.from_flags(self.flags) != \
           DNS.opcode.from_flags(other.flags):
            return False
        if DNS.rcode.from_flags(other.flags, other.ednsflags) != \
               DNS.rcode.NOERROR:
            return True
        if DNS.opcode.is_update(self.flags):
            return True
        for n in self.question:
            if n not in other.question:
                return False
        for n in other.question:
            if n not in self.question:
                return False
        return True
        
    def find_node(self, section, name, create = False, force_unique = False):
        """Find the node named I{name} in the specified section.
        
        @param section: the section of the message to look in, e.g.
        self.answer.
        @type section: list of DNS.node.Node objects
        @param name: the name of the node to find
        @type name: DNS.name.Name object
        @param create: If True, create the node if it is not found, unless
        I{force_unique} is also True, in which case always create a new node.
        The created node is appended to I{section}.
        @type create: bool
        @param force_unique: If I{create} is True, always create the node,
        even if it already exists.
        @type force_unique: bool
        @rtype: DNS.node.Node object"""
        
        if create and force_unique:
            n = DNS.node.Node(name)
            section.append(n)
            return n
        for n in section:
            if n.name == name:
                return n
        if not create:
            raise KeyError
        n = DNS.node.Node(name)
        section.append(n)
        return n
        
    def get_question(self, qcount):
        """Read the next I{qcount} records from the wire data and add them to
        the question section.
        @param qcount: the number of questions in the message
        @type qcount: int"""

        for i in xrange(0, qcount):
            (qname, used) = DNS.name.from_wire(self.wire, self.current)
            if not self.origin is None:
                qname = qname.relativize(self.origin)
            self.current = self.current + used
            (rdtype, rdclass) = \
                     struct.unpack('!HH',
                                   self.wire[self.current:self.current + 4])
            self.current = self.current + 4
            n = self.find_node(self.question, qname, True, True)
            rds = n.find_rdataset(rdclass, rdtype, DNS.rdatatype.NONE, True,
                                  True)
        
    def get_section(self, section, count):
        """Read the next I{count} records from the wire data and add them to
        the specified section.
        @param section: the section of the message to which to add records
        @type section: list of DNS.node.Node objects
        @param count: the number of records to read
        @type count: int"""
        
        if DNS.opcode.is_update(self.flags):
            updating = True
            force_unique = True
        else:
            updating = False
            force_unique = False
        seen_opt = False
        for i in xrange(0, count):
            rr_start = self.current
            (nm, used) = DNS.name.from_wire(self.wire, self.current)
            if not self.origin is None:
                nm = nm.relativize(self.origin)
            self.current = self.current + used
            (rdtype, rdclass, ttl, rdlen) = \
                     struct.unpack('!HHIH',
                                   self.wire[self.current:self.current + 10])
            self.current = self.current + 10
            if rdtype == DNS.rdatatype.OPT:
                if not section is self.additional or seen_opt:
                    raise BadEDNS
                self.payload = rdclass
                self.ednsflags = ttl
                self.edns = (ttl & 0xff0000) >> 16
                seen_opt = True
            elif rdtype == DNS.rdatatype.TSIG:
                if not (section is self.additional and i == (count - 1)):
                    raise BadTSIG
                if not self.keyring is None:
                    self.tsig_ctx = DNS.tsig.validate(self.wire,
                                                      nm,
                                                      self.keyring[nm],
                                                      int(time.time()),
                                                      self.request_mac,
                                                      rr_start,
                                                      self.current,
                                                      rdlen,
                                                      self.tsig_ctx,
                                                      self.multi)
                self.had_tsig = True
            else:
                if ttl < 0:
                    ttl = 0
                if updating and \
                   (rdclass == DNS.rdataclass.ANY or
                    rdclass == DNS.rdataclass.NONE):
                    deleting = rdclass
                    rdclass = self.zone_rdclass
                else:
                    deleting = False
                if deleting == DNS.rdataclass.ANY:
                    covers = DNS.rdatatype.NONE
                    rd = None
                else:
                    rd = DNS.rdata.from_wire(rdclass, rdtype, self.wire,
                                             self.current, rdlen, self.origin)
                    covers = rd.covers()
                if self.xfr and rdtype == DNS.rdatatype.SOA:
                    force_unique = True
                n = self.find_node(section, nm, True, force_unique)
                rds = n.find_rdataset(rdclass, rdtype, covers, True,
                                      force_unique, deleting)
                if not rd is None:
                    rds.add(rd, ttl)
            self.current = self.current + rdlen

    def to_wire(self, origin=None):
        """Return a string containing the message in DNS compressed wire
        format.

        @param origin: The origin to be appended to any relative names.
        @type origin: DNS.name.Name object
        @rtype: string"""
        
        f = cStringIO.StringIO()
        compress = {}
        header = '\x00' * 12
        f.write(header)
        qcount = 0
        ancount = 0
        aucount = 0
        adcount = 0
        for n in self.question:
            qcount += n.to_wire(f, compress, origin, True)
        for n in self.answer:
            ancount += n.to_wire(f, compress, origin, False)
        for n in self.authority:
            aucount += n.to_wire(f, compress, origin, False)
        if self.edns >= 0:
            stuff = struct.pack('!BHHIH', 0, DNS.rdatatype.OPT, self.playload,
                                self.ednsflags, 0)
            file.write(stuff)
            adcount += 1
        for n in self.additional:
            adcount += n.to_wire(f, compress, origin, False)
        header = struct.pack('!HHHHHH', self.id, self.flags, qcount, ancount,
                             aucount, adcount)
        f.seek(0)
        f.write(header)
        if not self.keyname is None:
            secret = self.keyring[self.keyname]
            s = f.getvalue()
            (tsig_rdata, self.mac, ctx) = DNS.tsig.hmac_md5(s, self.keyname,
                                                            secret,
                                                            int(time.time()),
                                                            self.fudge,
                                                            self.id,
                                                            self.tsig_error,
                                                            self.other_data,
                                                            self.request_mac)
            f.seek(0, 2)
            self.keyname.to_wire(f, compress)
            stuff = struct.pack('!HHIH', DNS.rdatatype.TSIG,
                                DNS.rdataclass.ANY, 0, 0)
            f.write(stuff)
            start = f.tell()
            f.write(tsig_rdata)
            end = f.tell()
            assert end - start < 65536
            f.seek(start - 2)
            f.write(struct.pack('!H', end - start))
            adcount += 1
            stuff = struct.pack('!H', adcount)
            f.seek(10)
            f.write(stuff)
        s = f.getvalue()
        return s

    def use_tsig(self, keyring, keyname=None):
        """When sending, a TSIG signature using the specified keyring
        and keyname should be added.
        
        @param keyring: The TSIG keyring to use; defaults to None.
        @type keyring: dict
        @param keyname: The name of the TSIG key to use; defaults to None.
        The key must be defined in the keyring.  If a keyring is specified
        but a keyname is not, then the key used will be the first key in the
        keyring.  Note that the order of keys in a dictionary is not defined,
        so applications should supply a keyname when a keyring is used, unless
        they know the keyring contains only one key.
        @type keyname: DNS.name.Name or string"""

        self.keyring = keyring
        if keyname is None:
            self.keyname = self.keyring.keys()[0]
        else:
            if isinstance(keyname, str):
                keyname = DNS.name.from_text(keyname)
            self.keyname = keyname

    def use_edns(self, edns, ednsflags, payload):
        """Configure EDNS behavior.
        @param edns: The EDNS level to use.  Specifying None or -1 means
        'do not use EDNS'.
        @type edns: int or None
        @param ednsflags: EDNS flag values.
        @type ednsflags: int
        @param payload: The EDNS sender's payload field, which is the maximum
        size of UDP datagram the sender can handle.
        @type payload: int
        @see: RFC 2671
        """
        if edns is None:
            edns = -1
        self.edns = edns
        self.ednsflags = ednsflags
        self.payload = payload

    def rcode(self):
        return DNS.rcode.from_flags(self.flags, self.ednsflags)

def from_wire(wire, keyring=None, request_mac='', xfr=False, origin=None,
              tsig_ctx = None, multi = False):
    """Convert a DNS wire format message into a message
    object.

    @param keyring: The keyring to use if the message is signed.
    @type keyring: dict
    @param request_mac: If the message is a response to a TSIG-signed request,
    I{request_mac} should be set to the MAC of that request.
    @type request_mac: string
    @param xfr: Is this message part of a zone transfer?
    @type xfr: bool
    @param origin: If the message is part of a zone transfer, I{origin}
    should be the origin name of the zone.
    @type origin: DNS.name.Name object
    @param tsig_ctx: The ongoing TSIG context, used when validating zone
    transfers.
    @type tsig_ctx: hmac.HMAC object
    @param multi: Is this message part of a multiple message sequence?
    @type multi: bool
    @raises ShortHeader: The message is less than 12 octets long.
    @raises TrailingJunk: There were octets in the message past the end
    of the proper DNS message.
    @raises BadEDNS: An OPT record was in the wrong section, or occurred more
    than once.
    @raises BadTSIG: A TSIG record was not the last record of the additional
    data section.
    @rtype: DNS.message.Message object"""

    m = Message()
    m.keyring = keyring
    m.request_mac = request_mac
    m.xfr = xfr
    m.origin = origin
    m.tsig_ctx = tsig_ctx
    m.multi = multi
    m.wire = wire
    l = len(m.wire)
    if l < 12:
        raise ShortHeader
    (m.id, m.flags, qcount, ancount, aucount, adcount) = \
         struct.unpack('!HHHHHH', m.wire[:12])
    m.current = 12
    m.get_question(qcount)
    m.get_section(m.answer, ancount)
    m.get_section(m.authority, aucount)
    m.get_section(m.additional, adcount)
    if m.current != l:
        raise TrailingJunk
    if m.multi and m.tsig_ctx:
        if not m.had_tsig:
            m.tsig_ctx.update(m.wire)
    return m

def _header_line(m, tok, section):
    """Process one line from the text format header section."""
    (ttype, what) = tok.get()
    if what == 'id':
        m.id = tok.get_int()
    elif what == 'flags':
        while True:
            token = tok.get()
            if token[0] != DNS.tokenizer.IDENTIFIER:
                tok.unget(token)
                break
            m.flags = m.flags | DNS.flags.from_text(string.join(token[1]))
    elif what == 'edns':
        m.edns = tok.get_int()
        m.ednsflags = m.ednsflags | (m.edns << 16)
    elif what == 'eflags':
        if m.edns < 0:
            m.edns = 0
        while True:
            token = tok.get()
            if token[0] != DNS.tokenizer.IDENTIFIER:
                tok.unget(token)
                break
            m.ednsflags = m.ednsflags | \
                          DNS.flags.edns_from_text(string.join(token[1]))
    elif what == 'payload':
        m.payload = tok.get_int()
        if m.edns < 0:
            m.edns = 0
    elif what == 'opcode':
        text = tok.get_string()
        m.flags = m.flags | \
                  DNS.opcode.to_flags(DNS.opcode.from_text(text))
    elif what == 'rcode':
        text = tok.get_string()
        (value, evalue) = DNS.rcode.to_flags(rcode.from_text(text))
        m.flags = m.flags | value
        m.ednsflags = m.ednsflags | evalue
        if m.ednsflags != 0 and m.edns < 0:
            m.edns = 0
    else:
        raise UnknownHeaderField
    tok.get_eol()

def _question_line(m, tok, section):
    """Process one line from the text format question section."""
    token = tok.get(want_leading = True)
    if token[0] != DNS.tokenizer.WHITESPACE:
        m.last_name = DNS.name.from_text(token[1])
    name = m.last_name
    token = tok.get()
    if token[0] != DNS.tokenizer.IDENTIFIER:
        raise DNS.exception.SyntaxError
    # Class
    try:
        rdclass = DNS.rdataclass.from_text(token[1])
        token = tok.get()
        if token[0] != DNS.tokenizer.IDENTIFIER:
            raise DNS.exception.SyntaxError
    except DNS.exception.SyntaxError:
        raise DNS.exception.SyntaxError
    except:
        rdclass = DNS.rdataclass.IN
    # Type
    rdtype = DNS.rdatatype.from_text(token[1])
    n = m.find_node(section, name, True, True)
    rds = n.find_rdataset(rdclass, rdtype, DNS.rdatatype.NONE, True)
    if DNS.opcode.is_update(m.flags):
        m.zone_rdclass = rdclass
    tok.get_eol()
        
def _rr_line(m, tok, section):
    """Process one line from the text format answer, authority, or additional
    data sections."""
    if DNS.opcode.is_update(m.flags):
        updating = True
    else:
        updating = False
    deleting = False
    # Name
    token = tok.get(want_leading = True)
    if token[0] != DNS.tokenizer.WHITESPACE:
        m.last_name = DNS.name.from_text(token[1])
    name = m.last_name
    token = tok.get()
    if token[0] != DNS.tokenizer.IDENTIFIER:
        raise DNS.exception.SyntaxError
    # TTL
    try:
        ttl = int(token[1], 0)
        token = tok.get()
        if token[0] != DNS.tokenizer.IDENTIFIER:
            raise DNS.exception.SyntaxError
    except DNS.exception.SyntaxError:
        raise DNS.exception.SyntaxError
    except:
        ttl = 0
    # Class
    try:
        rdclass = DNS.rdataclass.from_text(token[1])
        token = tok.get()
        if token[0] != DNS.tokenizer.IDENTIFIER:
            raise DNS.exception.SyntaxError
        if rdclass == DNS.rdataclass.ANY or rdclass == DNS.rdataclass.NONE:
            deleting = rdclass
            rdclass = m.zone_rdclass
    except DNS.exception.SyntaxError:
        raise DNS.exception.SyntaxError
    except:
        rdclass = DNS.rdataclass.IN
    # Type
    rdtype = DNS.rdatatype.from_text(token[1])
    n = m.find_node(section, name, True, updating)
    token = tok.get()
    if token[0] != DNS.tokenizer.EOL and token[0] != DNS.tokenizer.EOF:
        tok.unget(token)
        rd = DNS.rdata.from_text(rdclass, rdtype, tok, DNS.name.root)
        covers = rd.covers()
    else:
        rd = None
        covers = DNS.rdatatype.NONE
    rds = n.find_rdataset(rdclass, rdtype, covers, True, updating, deleting)
    if not rd is None:
        rds.add(rd, ttl)

def from_text(text):
    """Convert the text format message into a message object.

    @param text: The text format message.
    @type text: string
    @raises UnknownHeaderField:
    @raises DNS.exception.SyntaxError:
    @rtype: DNS.message.Message object"""

    # 'text' can also be a file, but we don't publish that fact
    # since it's an implementation detail.  The official file
    # interface is from_file().
    
    m = Message()
    tok = DNS.tokenizer.Tokenizer(text)
    line_method = _header_line
    section = None
    while 1:
        token = tok.get(True, True)
        if token[0] == DNS.tokenizer.EOL or token[0] == DNS.tokenizer.EOF:
            break
        if token[0] == DNS.tokenizer.COMMENT:
            u = token[1].upper()
            if u == 'HEADER':
                line_method = _header_line
            elif u == 'QUESTION' or u == 'ZONE':
                line_method = _question_line
                section = m.question
            elif u == 'ANSWER' or u == 'PREREQ':
                line_method = _rr_line
                section = m.answer
            elif u == 'AUTHORITY' or u == 'UPDATE':
                line_method = _rr_line
                section = m.authority
            elif u == 'ADDITIONAL':
                line_method = _rr_line
                section = m.additional
            tok.get_eol()
            continue
        tok.unget(token)
        line_method(m, tok, section)
    return m
    
def from_file(f):
    """Read the next text format message from the specified file.

    @param f: file or string.  If I{f} is a string, it is treated
    as the name of a file to open.
    @raises UnknownHeaderField:
    @raises DNS.exception.SyntaxError:
    @rtype: DNS.message.Message object"""

    if isinstance(f, str):
        f = file(f)
        want_close = True
    else:
        want_close = False
    try:
        m = from_text(f)
    finally:
        if want_close:
            f.close()
    return m

def make_query(qname, rdtype, rdclass = DNS.rdataclass.IN):
    """Make a query message.

    The query name, type, and class may all be specified either
    as objects of the appropriate type, or as strings.

    The query will have a randomly choosen query id, and its DNS flags
    will be set to DNS.flags.RD.
    
    @param qname: The query name.
    @type qname: DNS.name.Name object or string
    @param rdtype: The desired rdata type.
    @type rdtype: int
    @param rdclass: The desired rdata class; the default is class IN.
    @type rdclass: int
    @rtype: DNS.message.Message object"""
    
    if isinstance(qname, str):
        qname = DNS.name.from_text(qname)
    if isinstance(rdtype, str):
        rdtype = DNS.rdatatype.from_text(rdtype)
    if isinstance(rdclass, str):
        rdclass = DNS.rdataclass.from_text(rdclass)
    m = Message()
    m.flags |= DNS.flags.RD
    n = m.find_node(m.question, qname, True)
    rds = n.find_rdataset(rdclass, rdtype, DNS.rdatatype.NONE, True, True)
    return m
