#!/usr/bin/python3
# -*- coding: UTF-8 -*-
"""
Handle OpenVPN client connection event.
"""
# © 2012, Multapplied Networks, Inc.
import json
import logging
import os
import re
import socket
import sys
from systemd.journal import JournalHandler


CONFIG_FILE_UMASK = 0o77  # Mask all group, other bits
# This uses a very short timeout since OpenVPN blocks while this script runs
TIMEOUT = 1
LOG_LEVEL = logging.INFO


CONFIG_TEMPLATE = """ifconfig-push %(local_ipv4)s %(remote_ipv4)s
push "route %(network_ipv4)s %(netmask_ipv4)s"
ifconfig-ipv6-push %(local_ipv6)s
push "route-ipv6 %(network_ipv6)s"
"""


logger = logging.getLogger("openvpn-client-connect")


class EventError(Exception):
    pass


def get_client_config(
    node_id,
    client_addr,
    path="/run/bondingadmin/mgmtvpn-event.sock"
) -> str:
    """
    Send the connect event and get the client config.
    """
    request = {
        "event": "connect",
        "id": node_id,
        "client_addr": client_addr,
    }

    sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
    sock.settimeout(TIMEOUT)
    try:
        sock.connect(path)
    except OSError as e:
        raise EventError(f"Could not connect to event handler socket: {e}")
    sock.send(json.dumps(request).encode())
    data = sock.recv(1024)
    try:
        response = json.loads(data)
    except ValueError:
        raise EventError("Invalid response data")

    if "error" in response:
        raise EventError(f"Got error from server: {response['error']}")

    return CONFIG_TEMPLATE % response


def save_client_config(filename, client_config):
    """Save the client configuration."""
    original_umask = os.umask(CONFIG_FILE_UMASK)
    with open(filename, "w") as config:
        config.write(client_config)
    os.umask(original_umask)


if __name__ == "__main__":
    logger = logging.getLogger()
    logger.setLevel(LOG_LEVEL)
    handler = JournalHandler()
    handler.setLevel(LOG_LEVEL)
    logger.addHandler(handler)

    node_id = None
    try:
        node_id_match = re.match("node-([0-9]+)-", os.environ["common_name"])
        node_id = int(node_id_match.group(1))
        logger.info(f"Node {node_id}: Starting connect")

        if "trusted_ip" in os.environ:
            client_addr = "%s:%s" % (
                os.environ["trusted_ip"],
                os.environ["trusted_port"],
            )
        elif "trusted_ip6" in os.environ:
            client_addr = "%s:%s" % (
                os.environ["trusted_ip6"],
                os.environ["trusted_port"],
            )
        else:
            logger.error(f"Node {node_id}: No trusted IP in environment")
            sys.exit(1)

        try:
            config = get_client_config(node_id, client_addr)
        except EventError as e:
            logger.error(f"Node {node_id}: Got error handling event: {e}")
            sys.exit(1)

        save_client_config(sys.argv[1], config)
    except Exception:
        logger.exception(f"Node {node_id}: Got exception handling event")
        sys.exit(1)  # This disconnects the client.

    logger.info(f"Node {node_id}: Completed connect event")
