#!/usr/bin/python

from ldaptor.protocols.ldap import distinguishedname, ldapconnector, ldapsyntax, ldapclient
from ldaptor.protocols import pureber, pureldap
from ldaptor import usage, ldapfilter, config, dns
import sys, os, socket
from twisted.internet import protocol, defer, reactor

def formatIPAddress(name, ip):
    return '%s\tIN A\t%s\n' % (name, ip)

def formatPTR(name, ip):
    octets = ip.split('.')
    octets.reverse()
    octets.append('in-addr.arpa.')
    return '%s\tIN PTR\t%s.\n' % ('.'.join(octets), name)

class HostIPAddress:
    def __init__(self, host, ipAddress):
        self.host=host
        self.ipAddress=ipAddress

    def getForward(self, domain):
        return ((';  %s\n' % self.host.dn)
                + formatIPAddress(self.host.name+'.'+domain,
                                  self.ipAddress))

    def getReverse(self, domain):
        return ((';  %s\n' % self.host.dn)
                + formatPTR(self.host.name+'.'+domain, self.ipAddress))

    def __repr__(self):
        return (self.__class__.__name__
                +'('
                +'host=%r, ' % self.host.name
                +'ipAddress=%s' % repr(self.ipAddress)
                +')')

class Host:
    def __init__(self, dn, name, ipAddresses):
        self.dn=dn
        self.name=name
        self.ipAddresses=[HostIPAddress(self, ip) for ip in ipAddresses]

    def __repr__(self):
        return (self.__class__.__name__
                +'('
                +'dn=%s, ' % repr(self.dn)
                +'name=%s, ' % repr(self.name)
                +'ipAddresses=%s' % repr(self.ipAddresses)
                +')')

class Net:
    reverseZone = None

    def __init__(self, dn, name, address, mask):
        self.dn=dn
        self.name=name
        self.address=address
        self.mask=mask

    def isInNet(self, ipAddress):
        try:
            net = dns.aton(self.address)
            mask = dns.aton(self.mask)
            ip = dns.aton(ipAddress)
        except socket.error:
            # no need to log here, higher levels will log a warning
            # when they see the address is in no net
            return False
        if ip&mask == net:
            return True
        return False

    def getForward(self):
        ip = dns.aton(self.address)
        mask = dns.aton(self.mask)
        ipmask = dns.ntoa(mask)
        broadcast = dns.ntoa(ip|~mask)

        return (('; %s\n' % self.dn)
                + formatIPAddress(self.name, self.address)
                + formatIPAddress('netmask.'+self.name, ipmask)
                + formatIPAddress('broadcast.'+self.name, broadcast))

    def getReverse(self, domain):
        ip = dns.aton(self.address)
        mask = dns.aton(self.mask)
        broadcast = dns.ntoa(ip|~mask)

        return (('; %s\n' % self.dn)
                + formatPTR(self.name+'.'+domain, self.address)
                + formatPTR('broadcast.'+self.name+'.'+domain, broadcast))

    def __repr__(self):
        return (self.__class__.__name__
                +'('
                +'dn=%s, ' % repr(self.dn)
                +'name=%s, ' % repr(self.name)
                +'address=%s, ' % repr(self.address)
                +'mask=%s' % repr(self.mask)
                +')')



exitStatus=0

def error(fail):
    print >>sys.stderr, 'fail:', str(fail) #.getErrorMessage()
    global exitStatus
    exitStatus=1

def only(e, attrName):
    assert len(e[attrName])==1, \
           "object %s attribute %r has multiple values: %s" \
           % (e.dn, attrName, e[attrName])
    for val in e[attrName]:
        return val

def getNets(e, domain, forward, reverse, filter):
    filt=pureldap.LDAPFilter_and(value=(
        pureldap.LDAPFilter_present('cn'),
        pureldap.LDAPFilter_present('ipNetworkNumber'),
        pureldap.LDAPFilter_present('ipNetmaskNumber'),
        ))
    if filter:
        filt = pureldap.LDAPFilter_and(value=(filter, filt))
    d = e.search(filterObject=filt,
                 attributes=['cn',
                             'ipNetworkNumber',
                             'ipNetmaskNumber',
                             ])
    def _cbGotNets(nets, forward, reverse):
        r = []
        for e in nets:
            net = Net(str(e.dn),
                      str(only(e, 'cn')),
                      str(only(e, 'ipNetworkNumber')),
                      str(only(e, 'ipNetmaskNumber')))
            print >>forward, net.getForward()

            for data in reverse:
                ip = dns.aton(net.address)
                if ip & data['netmask'] == data['address']:
                    if 'file' not in data:
                        data['tempname'] = '%s.%d.tmp' % (data['filename'], os.getpid())
                        data['file'] = open(data['tempname'], 'w')
                    print >>data['file'], net.getReverse(domain)
                    net.reverseZone = data
            r.append(net)
        return r
    d.addCallback(_cbGotNets, forward, reverse)
    return d

def getHosts(nets, e, domain, forward, reverse, filter):
    filt=pureldap.LDAPFilter_equalityMatch(attributeDesc=pureldap.LDAPAttributeDescription('objectClass'),
                                           assertionValue=pureber.BEROctetString('ipHost'))
    if filter:
        filt = pureldap.LDAPFilter_and(value=(filter, filt))
    def _cbGotHost(e):
        host = Host(str(e.dn),
                    str(only(e, 'cn')),
                    map(str, e['ipHostNumber']))
        for hostIP in host.ipAddresses:
            parent=None
            for net in nets:
                if net.isInNet(hostIP.ipAddress):
                    parent=net
                    break

            if parent:
                print >>forward, hostIP.getForward(parent.name)
                if parent.reverseZone:
                    print >>parent.reverseZone['file'], hostIP.getReverse(parent.name
                                                                          + '.'
                                                                          + domain)
                else:
                    print >>sys.stderr, "Not writing PTR for %s." % hostIP
            else:
                sys.stderr.write("IP address %s is in no net, discarding.\n" % hostIP)
    d = e.search(filterObject=filt,
                 attributes=['ipHostNumber',
                             'cn'],
                 callback=_cbGotHost)
    return d

def cbConnected(client, cfg, domain, forward, reverse, filter):
    e = ldapsyntax.LDAPEntryWithClient(client, cfg.getBaseDN())
    d = getNets(e, domain, forward, reverse, filter)
    d.addCallback(getHosts, e, domain, forward, reverse, filter)
    def unbind(r, e):
        e.client.unbind()
        return r
    d.addCallback(unbind, e)
    return d

def filesOk(result,
            forward, forwardTmp, forwardFile,
            reverse):
    forwardFile.close()
    os.rename(forwardTmp, forward)
    for data in reverse:
        if 'file' in data:
            data['file'].close()
            del data['file']
        if 'tempname' in data:
            os.rename(data['tempname'], data['filename'])
            del data['tempname']
    return result

def filesAbort(reason,
            forward, forwardTmp, forwardFile,
            reverse):
    forwardFile.close()
    os.unlink(forwardTmp)
    for data in reverse:
        if 'file' in data:
            data['file'].close()
        if 'tempname' in data:
            os.unlink(data['tempname'])
    return reason

def main(cfg, domain, forward, reverse, filter_text):
    from twisted.python import log
    log.startLogging(sys.stderr, setStdout=0)

    try:
        baseDN = cfg.getBaseDN()
    except config.MissingBaseDNError, e:
        print >>sys.stderr, "%s: %s." % (sys.argv[0], e)
        sys.exit(1)

    if filter_text is not None:
        filt = ldapfilter.parseFilter(filter_text)
    else:
        filt = None

    forwardTmp = '%s.%d.tmp' % (forward, os.getpid())
    forwardFile = file(forwardTmp, 'w')

    print >>forwardFile, '$ORIGIN\t%s.' % domain
    print >>forwardFile

    c=ldapconnector.LDAPClientCreator(reactor, ldapclient.LDAPClient)
    d = c.connectAnonymously(
        baseDN,
        overrides=cfg.getServiceLocationOverrides())
    d.addCallback(cbConnected, cfg, domain, forwardFile, reverse, filt)
    d.addCallbacks(callback=filesOk,
                   callbackArgs=(forward, forwardTmp, forwardFile,
                                 reverse),
                   errback=filesAbort,
                   errbackArgs=(forward, forwardTmp, forwardFile,
                                reverse))
    d.addErrback(error)
    d.addBoth(lambda x: reactor.stop())

    reactor.run()
    sys.exit(exitStatus)

class MyOptions(usage.Options, usage.Options_service_location, usage.Options_base_optional):
    """LDAPtor DNS zone file exporter"""
    synopsis = "Usage: %s [OPTION..] DOMAIN OUTPUTFILE [FILTER]" % sys.argv[0]

    def opt_reverse(self, net_file):
        """Write out reverse zone, in the form ADDRESS/NETMASK:FILE"""
        if ':' not in net_file:
            raise usage.UsageError('--reverse= value must contain semicolon')
        addr_nm, filename = net_file.split(':', 1)

        if '/' not in addr_nm:
            raise usage.UsageError('--reverse= value must have netmask')
        addressString, netmaskString = addr_nm.split('/', 1)

        try:
            address = dns.aton(addressString)
        except socket.error, e:
            raise usage.UsageError('--reverse= address is invalid: %s' % e)

        try:
            netmask = dns.aton(netmaskString)
        except socket.error, e:
            raise usage.UsageError('--reverse= netmask is invalid: %s' % e)

        self.opts.setdefault('reverse', []).append({
            'address': address,
            'netmask': netmask,
            'filename': filename,
            })

    def parseArgs(self, domain, forward, filter=None):
        self.opts['domain'] = domain
        self.opts['forward'] = forward
        self.opts['filter'] = filter

if __name__ == "__main__":
    import sys
    try:
        opts = MyOptions()
        opts.parseOptions()
    except usage.UsageError, ue:
        sys.stderr.write('%s: %s\n' % (sys.argv[0], ue))
        sys.exit(1)

    cfg = config.LDAPConfig(baseDN=opts['base'],
                            serviceLocationOverrides=opts['service-location'])
    main(cfg,
         opts['domain'],
         opts['forward'],
         opts['reverse'],
         opts['filter'])
