#!/usr/bin/env python3
# -*- coding: UTF-8 -*-
# © 2020, Multapplied Networks, Inc.
import argparse
import os
import re
import sys
import tempfile

RED = '\u001b[31m'
RESET = '\u001b[0m'
SALT_CONFIG_DIR = "/etc/bondingadmin/salt-config/states/node"
KNOWN_IPS_FILE = f"{SALT_CONFIG_DIR}/known_ips"
TRUSTED_NETWORKS_FILE = f"{SALT_CONFIG_DIR}/filter-input-99-trusted-networks.nft"
SOURCE_MATCH = re.compile('^(iptables|ip6tables|iptables_6) -A \$CHAIN -s (?P<network>[0-9a-f:./]+)(.*)-j ACCEPT(\s+\#\s*(?P<comment>.*))?$')
OSPF_MATCH = re.compile('^(iptables|ip6tables|iptables_6) -A \$CHAIN(.*)-p ospf(.*)-j ACCEPT(.*)(\#\s*(?P<comment>.*))?$')


class UnhandledRuleError(Exception):
    """ The migration encountered an iptables rule that it doesn't understand
    how to migrate.
    """
    pass


def translate_known_ips(args, known_ips):
    failures = 0

    ipv4_source_matches = []
    ipv6_source_matches = []
    nftables_lines = []

    for line_no, ip_line in enumerate(known_ips, 1):
        ip_line = ip_line.strip()
        if not ip_line:
            continue

        if ip_line.startswith('#'):
            nftables_lines.append(ip_line)
            continue

        try:
            match = re.match(SOURCE_MATCH, ip_line).groupdict()
            if 'ip6tables' in ip_line or 'iptables_6' in ip_line:
                ipv6_source_matches.append((match['network'], match['comment']))
            else:
                ipv4_source_matches.append((match['network'], match['comment']))
        except AttributeError:
            try:
                match = re.match(OSPF_MATCH, ip_line).groupdict()
                if 'ip6tables' in ip_line or 'iptables_6' in ip_line:
                    nftables_lines.append(f'ip6 nexthdr ospf accept')
                else:
                    nftables_lines.append(f'ip protocol ospf accept')
            except AttributeError:
                sys.stderr.write(f'{RED}Unable to parse line {line_no}{RESET}: {ip_line}\n')
                failures += 1

    if failures > 0:
        raise UnhandledRuleError(
            f'Failed to parse entire file. See above for lines which failed and manually migrate.'
        )

    if len(ipv4_source_matches) > 0:
        nftables_lines.append('ip saddr {')
        for network, comment in ipv4_source_matches:
            line = f'\t {network},'
            if comment:
                line += f" # {comment}"
            nftables_lines.append(line)
        nftables_lines.append('} accept')

    if len(ipv6_source_matches) > 0:
        nftables_lines.append('ip6 saddr {')
        for network, comment in ipv6_source_matches:
            line = f'\t {network},'
            if comment:
                line += f" # {comment}"
            nftables_lines.append(line)
        nftables_lines.append('} accept')

    return nftables_lines


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Migrate salt known_ips iptables rules to nftables')
    parser.add_argument('--dry-run', action='store_true')
    parser.add_argument('--force', action='store_true')
    args = parser.parse_args()

    tmp = tempfile.NamedTemporaryFile(delete=False)

    if os.path.isfile(TRUSTED_NETWORKS_FILE) and not args.force and not args.dry_run:
        print(f'{TRUSTED_NETWORKS_FILE} exists, already migrated. Use --force to migrate again.')
        sys.exit(0)

    try:
        with open(KNOWN_IPS_FILE, 'r') as f:
            known_ips = f.readlines()
    except FileNotFoundError:
        sys.stderr.write(f'File {KNOWN_IPS_FILE} does not exist, nothing to migrate\n')
        sys.exit(0)
    except OSError as e:
        sys.stderr.write(f'Unable to read {KNOWN_IPS_FILE}: {e}')
        sys.exit(1)

    try:
        nftables_rules = translate_known_ips(args, known_ips)
    except UnhandledRuleError as e:
        sys.stderr.write(
            f'{RED}{e}{RESET}\n'
        )
        sys.exit(1)

    if not args.dry_run:
        with open(TRUSTED_NETWORKS_FILE, 'w') as f:
            for line in nftables_rules:
                f.write(f'{line}\n')
    else:
        print('\n'.join(nftables_rules))
