#!/usr/bin/env python3
# Licensed under GPLv3
#
# Simple script to query the management interface of a running n2n edge node

import argparse
import socket
import json
import collections


class JsonUDP():
    """encapsulate communication with the edge"""

    def __init__(self, port):
        self.address = "127.0.0.1"
        self.port = port
        self.tag = 0
        self.key = None
        self.debug = False
        self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        self.sock.settimeout(1)

    def _next_tag(self):
        tagstr = str(self.tag)
        self.tag = (self.tag + 1) % 1000
        return tagstr

    def _cmdstr(self, msgtype, cmdline):
        """Create the full command string to send"""
        tagstr = self._next_tag()

        options = [tagstr]
        if self.key is not None:
            options += ['1']  # Flags set for auth key field
            options += [self.key]
        optionsstr = ':'.join(options)

        return tagstr, ' '.join((msgtype, optionsstr, cmdline))

    def _rx(self, tagstr):
        """Wait for rx packets"""

        # TODO: there are no timeouts with any of the recv calls
        data, _ = self.sock.recvfrom(1024)
        data = json.loads(data.decode('utf8'))

        # TODO: We assume the first packet we get will be tagged for us
        # and be either an "error" or a "begin"
        assert(data['_tag'] == tagstr)

        if data['_type'] == 'error':
            raise ValueError('Error: {}'.format(data['error']))

        assert(data['_type'] == 'begin')

        # Ideally, we would confirm that this is our "begin", but that
        # would need the cmd passed into this method, and that would
        # probably require parsing the cmdline passed to us :-(
        # assert(data['cmd'] == cmd)

        result = list()
        error = None

        while True:
            data, _ = self.sock.recvfrom(1024)
            data = json.loads(data.decode('utf8'))

            if data['_tag'] != tagstr:
                # this packet is not for us, ignore it
                continue

            if data['_type'] == 'error':
                # we still expect an end packet, so save the error
                error = ValueError('Error: {}'.format(data['error']))
                continue

            if data['_type'] == 'end':
                if error:
                    raise error
                return result

            if data['_type'] != 'row':
                raise ValueError('Unknown data type {} from '
                                 'edge'.format(data['_type']))

            # remove our boring metadata
            del data['_tag']
            del data['_type']

            if self.debug:
                print(data)

            result.append(data)

    def _call(self, msgtype, cmdline):
        """Perform a rpc call"""
        tagstr, msgstr = self._cmdstr(msgtype, cmdline)
        self.sock.sendto(msgstr.encode('utf8'), (self.address, self.port))
        return self._rx(tagstr)

    def read(self, cmdline):
        return self._call('r', cmdline)

    def write(self, cmdline):
        return self._call('w', cmdline)


def str_table(rows, columns):
    """Given an array of dicts, do a simple table print"""
    result = list()
    widths = collections.defaultdict(lambda: 0)

    if len(rows) == 0:
        # No data to show, be sure not to truncate the column headings
        for col in columns:
            widths[col] = len(col)
    else:
        for row in rows:
            for col in columns:
                if col in row:
                    widths[col] = max(widths[col], len(str(row[col])))

    for col in columns:
        if widths[col] == 0:
            widths[col] = 1
        result += "{:{}.{}} ".format(col, widths[col], widths[col])
    result += "\n"

    for row in rows:
        for col in columns:
            if col in row:
                data = row[col]
            else:
                data = ''
            result += "{:{}} ".format(data, widths[col])
        result += "\n"

    return ''.join(result)


def subcmd_show_supernodes(rpc, args):
    rows = rpc.read('supernodes')
    columns = [
        'version',
        'current',
        'macaddr',
        'sockaddr',
        'uptime',
    ]

    return str_table(rows, columns)


def subcmd_show_edges(rpc, args):
    rows = rpc.read('edges')
    columns = [
        'mode',
        'ip4addr',
        'macaddr',
        'sockaddr',
        'desc',
    ]

    return str_table(rows, columns)


def subcmd_show_help(rpc, args):
    result = 'Commands with pretty-printed output:\n\n'
    for name, cmd in subcmds.items():
        result += "{:12} {}\n".format(name, cmd['help'])

    result += "\n"
    result += "Possble remote commands:\n"
    result += "(those without a pretty-printer will pass-through)\n\n"
    rows = rpc.read('help')
    for row in rows:
        result += "{:12} {}\n".format(row['cmd'], row['help'])
    return result


subcmds = {
    'help': {
        'func': subcmd_show_help,
        'help': 'Show available commands',
    },
    'supernodes': {
        'func': subcmd_show_supernodes,
        'help': 'Show the list of supernodes',
    },
    'edges': {
        'func': subcmd_show_edges,
        'help': 'Show the list of edges/peers',
    },
}


def subcmd_default(rpc, args):
    """Just pass command through to edge"""
    cmdline = ' '.join([args.cmd] + args.args)
    if args.write:
        rows = rpc.write(cmdline)
    else:
        rows = rpc.read(cmdline)
    return json.dumps(rows, sort_keys=True, indent=4)


def main():
    ap = argparse.ArgumentParser(
            description='Query the running local n2n edge')
    ap.add_argument('-t', '--mgmtport', action='store', default=5644,
                    help='Management Port (default=5644)', type=int)
    ap.add_argument('-k', '--key', action='store',
                    help='Password for mgmt commands')
    ap.add_argument('-d', '--debug', action='store_true',
                    help='Also show raw internal data')
    ap.add_argument('--raw', action='store_true',
                    help='Force cmd to avoid any pretty printing')

    group = ap.add_mutually_exclusive_group()
    group.add_argument('--read', action='store_true',
                       help='Make a read request (default)')
    group.add_argument('--write', action='store_true',
                       help='Make a write request (only to non pretty'
                       'printed cmds)')

    ap.add_argument('cmd', action='store',
                    help='Command to run (try "help" for list)')
    ap.add_argument('args', action='store', nargs="*",
                    help='Optional args for the command')

    args = ap.parse_args()

    if args.raw or (args.cmd not in subcmds):
        func = subcmd_default
    else:
        func = subcmds[args.cmd]['func']

    rpc = JsonUDP(args.mgmtport)
    rpc.debug = args.debug
    rpc.key = args.key

    try:
        result = func(rpc, args)
    except socket.timeout as e:
        print(e)
        exit(1)

    print(result)


if __name__ == '__main__':
    main()
