#! /usr/bin/env python3
#
# Copyright (C) 2019 Garmin Ltd.
#
# SPDX-License-Identifier: GPL-2.0-only
#

import argparse
import hashlib
import logging
import os
import pprint
import sys
import threading
import time
import warnings
import netrc
import json
warnings.simplefilter("default")

try:
    import tqdm
    ProgressBar = tqdm.tqdm
except ImportError:
    class ProgressBar(object):
        def __init__(self, *args, **kwargs):
            pass

        def __enter__(self):
            return self

        def __exit__(self, *args, **kwargs):
            pass

        def update(self):
            pass

sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(__file__)), 'lib'))

import hashserv
import bb.asyncrpc

DEFAULT_ADDRESS = 'unix://./hashserve.sock'
METHOD = 'stress.test.method'

def print_user(u):
    print(f"Username:    {u['username']}")
    if "permissions" in u:
        print("Permissions: " + " ".join(u["permissions"]))
    if "token" in u:
        print(f"Token:       {u['token']}")


def main():
    def handle_get(args, client):
        result = client.get_taskhash(args.method, args.taskhash, all_properties=True)
        if not result:
            return 0

        print(json.dumps(result, sort_keys=True, indent=4))
        return 0

    def handle_get_outhash(args, client):
        result = client.get_outhash(args.method, args.outhash, args.taskhash)
        if not result:
            return 0

        print(json.dumps(result, sort_keys=True, indent=4))
        return 0

    def handle_stats(args, client):
        if args.reset:
            s = client.reset_stats()
        else:
            s = client.get_stats()
        print(json.dumps(s, sort_keys=True, indent=4))
        return 0

    def handle_stress(args, client):
        def thread_main(pbar, lock):
            nonlocal found_hashes
            nonlocal missed_hashes
            nonlocal max_time

            with hashserv.create_client(args.address) as client:
                for i in range(args.requests):
                    taskhash = hashlib.sha256()
                    taskhash.update(args.taskhash_seed.encode('utf-8'))
                    taskhash.update(str(i).encode('utf-8'))

                    start_time = time.perf_counter()
                    l = client.get_unihash(METHOD, taskhash.hexdigest())
                    elapsed = time.perf_counter() - start_time

                    with lock:
                        if l:
                            found_hashes += 1
                        else:
                            missed_hashes += 1

                        max_time = max(elapsed, max_time)
                        pbar.update()

        max_time = 0
        found_hashes = 0
        missed_hashes = 0
        lock = threading.Lock()
        total_requests = args.clients * args.requests
        start_time = time.perf_counter()
        with ProgressBar(total=total_requests) as pbar:
            threads = [threading.Thread(target=thread_main, args=(pbar, lock), daemon=False) for _ in range(args.clients)]
            for t in threads:
                t.start()

            for t in threads:
                t.join()

        elapsed = time.perf_counter() - start_time
        with lock:
            print("%d requests in %.1fs. %.1f requests per second" % (total_requests, elapsed, total_requests / elapsed))
            print("Average request time %.8fs" % (elapsed / total_requests))
            print("Max request time was %.8fs" % max_time)
            print("Found %d hashes, missed %d" % (found_hashes, missed_hashes))

        if args.report:
            with ProgressBar(total=args.requests) as pbar:
                for i in range(args.requests):
                    taskhash = hashlib.sha256()
                    taskhash.update(args.taskhash_seed.encode('utf-8'))
                    taskhash.update(str(i).encode('utf-8'))

                    outhash = hashlib.sha256()
                    outhash.update(args.outhash_seed.encode('utf-8'))
                    outhash.update(str(i).encode('utf-8'))

                    client.report_unihash(taskhash.hexdigest(), METHOD, outhash.hexdigest(), taskhash.hexdigest())

                    with lock:
                        pbar.update()

    def handle_remove(args, client):
        where = {k: v for k, v in args.where}
        if where:
            result = client.remove(where)
            print("Removed %d row(s)" % (result["count"]))
        else:
            print("No query specified")

    def handle_clean_unused(args, client):
        result = client.clean_unused(args.max_age)
        print("Removed %d rows" % (result["count"]))
        return 0

    def handle_refresh_token(args, client):
        r = client.refresh_token(args.username)
        print_user(r)

    def handle_set_user_permissions(args, client):
        r = client.set_user_perms(args.username, args.permissions)
        print_user(r)

    def handle_get_user(args, client):
        r = client.get_user(args.username)
        print_user(r)

    def handle_get_all_users(args, client):
        users = client.get_all_users()
        print("{username:20}| {permissions}".format(username="Username", permissions="Permissions"))
        print(("-" * 20) + "+" + ("-" * 20))
        for u in users:
            print("{username:20}| {permissions}".format(username=u["username"], permissions=" ".join(u["permissions"])))

    def handle_new_user(args, client):
        r = client.new_user(args.username, args.permissions)
        print_user(r)

    def handle_delete_user(args, client):
        r = client.delete_user(args.username)
        print_user(r)

    def handle_get_db_usage(args, client):
        usage = client.get_db_usage()
        print(usage)
        tables = sorted(usage.keys())
        print("{name:20}| {rows:20}".format(name="Table name", rows="Rows"))
        print(("-" * 20) + "+" + ("-" * 20))
        for t in tables:
            print("{name:20}| {rows:<20}".format(name=t, rows=usage[t]["rows"]))
        print()

        total_rows = sum(t["rows"] for t in usage.values())
        print(f"Total rows: {total_rows}")

    def handle_get_db_query_columns(args, client):
        columns = client.get_db_query_columns()
        print("\n".join(sorted(columns)))

    def handle_gc_status(args, client):
        result = client.gc_status()
        if not result["mark"]:
            print("No Garbage collection in progress")
            return 0

        print("Current Mark: %s" % result["mark"])
        print("Total hashes to keep: %d" % result["keep"])
        print("Total hashes to remove: %s" % result["remove"])
        return 0

    def handle_gc_mark(args, client):
        where = {k: v for k, v in args.where}
        result = client.gc_mark(args.mark, where)
        print("New hashes marked: %d" % result["count"])
        return 0

    def handle_gc_sweep(args, client):
        result = client.gc_sweep(args.mark)
        print("Removed %d rows" % result["count"])
        return 0

    def handle_unihash_exists(args, client):
        result = client.unihash_exists(args.unihash)
        if args.quiet:
            return 0 if result else 1

        print("true" if result else "false")
        return 0

    parser = argparse.ArgumentParser(description='Hash Equivalence Client')
    parser.add_argument('--address', default=DEFAULT_ADDRESS, help='Server address (default "%(default)s")')
    parser.add_argument('--log', default='WARNING', help='Set logging level')
    parser.add_argument('--login', '-l', metavar="USERNAME", help="Authenticate as USERNAME")
    parser.add_argument('--password', '-p', metavar="TOKEN", help="Authenticate using token TOKEN")
    parser.add_argument('--become', '-b', metavar="USERNAME", help="Impersonate user USERNAME (if allowed) when performing actions")
    parser.add_argument('--no-netrc', '-n', action="store_false", dest="netrc", help="Do not use .netrc")

    subparsers = parser.add_subparsers()

    get_parser = subparsers.add_parser('get', help="Get the unihash for a taskhash")
    get_parser.add_argument("method", help="Method to query")
    get_parser.add_argument("taskhash", help="Task hash to query")
    get_parser.set_defaults(func=handle_get)

    get_outhash_parser = subparsers.add_parser('get-outhash', help="Get output hash information")
    get_outhash_parser.add_argument("method", help="Method to query")
    get_outhash_parser.add_argument("outhash", help="Output hash to query")
    get_outhash_parser.add_argument("taskhash", help="Task hash to query")
    get_outhash_parser.set_defaults(func=handle_get_outhash)

    stats_parser = subparsers.add_parser('stats', help='Show server stats')
    stats_parser.add_argument('--reset', action='store_true',
                              help='Reset server stats')
    stats_parser.set_defaults(func=handle_stats)

    stress_parser = subparsers.add_parser('stress', help='Run stress test')
    stress_parser.add_argument('--clients', type=int, default=10,
                               help='Number of simultaneous clients')
    stress_parser.add_argument('--requests', type=int, default=1000,
                               help='Number of requests each client will perform')
    stress_parser.add_argument('--report', action='store_true',
                               help='Report new hashes')
    stress_parser.add_argument('--taskhash-seed', default='',
                               help='Include string in taskhash')
    stress_parser.add_argument('--outhash-seed', default='',
                               help='Include string in outhash')
    stress_parser.set_defaults(func=handle_stress)

    remove_parser = subparsers.add_parser('remove', help="Remove hash entries")
    remove_parser.add_argument("--where", "-w", metavar="KEY VALUE", nargs=2, action="append", default=[],
                               help="Remove entries from table where KEY == VALUE")
    remove_parser.set_defaults(func=handle_remove)

    clean_unused_parser = subparsers.add_parser('clean-unused', help="Remove unused database entries")
    clean_unused_parser.add_argument("max_age", metavar="SECONDS", type=int, help="Remove unused entries older than SECONDS old")
    clean_unused_parser.set_defaults(func=handle_clean_unused)

    refresh_token_parser = subparsers.add_parser('refresh-token', help="Refresh auth token")
    refresh_token_parser.add_argument("--username", "-u", help="Refresh the token for another user (if authorized)")
    refresh_token_parser.set_defaults(func=handle_refresh_token)

    set_user_perms_parser = subparsers.add_parser('set-user-perms', help="Set new permissions for user")
    set_user_perms_parser.add_argument("--username", "-u", help="Username", required=True)
    set_user_perms_parser.add_argument("permissions", metavar="PERM", nargs="*", default=[], help="New permissions")
    set_user_perms_parser.set_defaults(func=handle_set_user_permissions)

    get_user_parser = subparsers.add_parser('get-user', help="Get user")
    get_user_parser.add_argument("--username", "-u", help="Username")
    get_user_parser.set_defaults(func=handle_get_user)

    get_all_users_parser = subparsers.add_parser('get-all-users', help="List all users")
    get_all_users_parser.set_defaults(func=handle_get_all_users)

    new_user_parser = subparsers.add_parser('new-user', help="Create new user")
    new_user_parser.add_argument("--username", "-u", help="Username", required=True)
    new_user_parser.add_argument("permissions", metavar="PERM", nargs="*", default=[], help="New permissions")
    new_user_parser.set_defaults(func=handle_new_user)

    delete_user_parser = subparsers.add_parser('delete-user', help="Delete user")
    delete_user_parser.add_argument("--username", "-u", help="Username", required=True)
    delete_user_parser.set_defaults(func=handle_delete_user)

    db_usage_parser = subparsers.add_parser('get-db-usage', help="Database Usage")
    db_usage_parser.set_defaults(func=handle_get_db_usage)

    db_query_columns_parser = subparsers.add_parser('get-db-query-columns', help="Show columns that can be used in database queries")
    db_query_columns_parser.set_defaults(func=handle_get_db_query_columns)

    gc_status_parser = subparsers.add_parser("gc-status", help="Show garbage collection status")
    gc_status_parser.set_defaults(func=handle_gc_status)

    gc_mark_parser = subparsers.add_parser('gc-mark', help="Mark hashes to be kept for garbage collection")
    gc_mark_parser.add_argument("mark", help="Mark for this garbage collection operation")
    gc_mark_parser.add_argument("--where", "-w", metavar="KEY VALUE", nargs=2, action="append", default=[],
                             help="Keep entries in table where KEY == VALUE")
    gc_mark_parser.set_defaults(func=handle_gc_mark)

    gc_sweep_parser = subparsers.add_parser('gc-sweep', help="Perform garbage collection and delete any entries that are not marked")
    gc_sweep_parser.add_argument("mark", help="Mark for this garbage collection operation")
    gc_sweep_parser.set_defaults(func=handle_gc_sweep)

    unihash_exists_parser = subparsers.add_parser('unihash-exists', help="Check if a unihash is known to the server")
    unihash_exists_parser.add_argument("--quiet", action="store_true", help="Don't print status. Instead, exit with 0 if unihash exists and 1 if it does not")
    unihash_exists_parser.add_argument("unihash", help="Unihash to check")
    unihash_exists_parser.set_defaults(func=handle_unihash_exists)

    args = parser.parse_args()

    logger = logging.getLogger('hashserv')

    level = getattr(logging, args.log.upper(), None)
    if not isinstance(level, int):
        raise ValueError('Invalid log level: %s' % args.log)

    logger.setLevel(level)
    console = logging.StreamHandler()
    console.setLevel(level)
    logger.addHandler(console)

    login = args.login
    password = args.password

    if login is None and args.netrc:
        try:
            n = netrc.netrc()
            auth = n.authenticators(args.address)
            if auth is not None:
                login, _, password = auth
        except FileNotFoundError:
            pass
        except netrc.NetrcParseError as e:
            sys.stderr.write(f"Error parsing {e.filename}:{e.lineno}: {e.msg}\n")

    func = getattr(args, 'func', None)
    if func:
        try:
            with hashserv.create_client(args.address, login, password) as client:
                if args.become:
                    client.become_user(args.become)
                return func(args, client)
        except bb.asyncrpc.InvokeError as e:
            print(f"ERROR: {e}")
            return 1

    return 0


if __name__ == '__main__':
    try:
        ret = main()
    except Exception:
        ret = 1
        import traceback
        traceback.print_exc()
    sys.exit(ret)
