#!/usr/bin/env python3
from argparse import ArgumentParser
import itertools
import subprocess
import sys
import os
from os.path import isdir

class InvalidIP(Exception):
    pass

def fatal(msg):
    print(msg, file=sys.stderr)
    sys.exit(1)

def parse_config(path):
    last_header = ''
    server_addr = ''
    server_port = ''
    taken_ips = set()

    with open(path, 'r') as fob:
        for line in fob:
            line = line.strip()
            if not line:
                continue
            elif line.startswith('['):
                last_header = line[1:-1]
            else:
                key, value = line.split('=', 1)
                key = key.strip()
                value = value.strip()
                if last_header == 'Interface':
                    if key == 'Address':
                        server_addr = value
                    elif key == 'ListenPort':
                        server_port = value
                elif last_header == 'Peer' and key == 'AllowedIPs':
                    ips = value.split(',')
                    for ip in map(lambda x: x.strip(), ips):
                        taken_ips.add(ip)
    if not server_addr:
        fatal(f'No interface section or no server address specified in {path}')
    return (server_addr, server_port, taken_ips)

def gen_key():
    ''' Generate a new key pair '''
    priv_key = subprocess.run(['wg', 'genkey'], capture_output=True, text=True).stdout.strip()
    pub_key = subprocess.run(['wg', 'pubkey'], capture_output=True, text=True,
            input=priv_key).stdout.strip()
    return (priv_key, pub_key)

def parse_cidr(addr):
    if not '/' in addr:
        raise InvalidIP(
            f'Expected IP address to be in CIDR format, but {addr!r} is not!')
    if addr.count('/') > 1:
        raise InvalidIP(
            f'Expected IP address to be in CIDR format, but {addr!r} is not!')
    ip, bit_count = addr.strip().split('/')
    if not bit_count:
        raise InvalidIP(
            f'Expected IP address to be in CIDR format, but bitcount is empty!')
    try:
        bit_count = int(bit_count.strip())
    except:
        raise InvalidIP(
            f'Expected IP address to be in CIDR format, but {addr!r}\'s' +\
                    'bit count is NOT an integer!')
    if bit_count < 0 or bit_count > 32:
        raise InvalidIP(
            f'CIDR bitcount ({bit_count!r}) is out of range 0-32!')

    split_ip = ip.split('.')
    if len(split_ip) != 4:
        raise InvalidIP(
            f'IP address should have 4 segments but {addr!r} has {len(split_ip)}!')
    ip_out = []
    for part in split_ip:
        try:
            ip_out.append(int(part))
        except:
            raise InvalidIP(
                f'Each segment of IP address should be numeric but {part!r} '+\
                f'of {addr!r} is not!')

    return (bit_count, ip_out)

def strip_cidr(addr):
    if '/' in addr:
        return addr.split('/')[0]
    else:
        return addr
    
def bits_to_ip(ip_bits):
    return '{}.{}.{}.{}'.format(
        int(ip_bits[:8], 2),
        int(ip_bits[8:16], 2),
        int(ip_bits[16:24], 2),
        int(ip_bits[24:], 2),
    )

def get_next_ip(server_addr, taken_ips):
    ''' Find the next available IP address in a given ip range,
    the ip range is pulled from the server_addr since it's in cidr'''
    # the server addr can't be allocated to a new client either 
    taken_ips.add(server_addr)
    # and just incase, remove the CIDRs
    taken_ips = set(map(strip_cidr, taken_ips))
    bit_count, ip = parse_cidr(server_addr)
    ip_bits = bin(ip[0] << 24 | ip[1] << 16 | ip[2] << 8 | ip[3])[2:].zfill(32)
    left_part = ip_bits[:bit_count]
    right_part_size = 32-bit_count
    # conveniently an itertools.product of ('01', repeat=N) produces
    # all N-bit integers in order
    for right_part in itertools.product('01', repeat=right_part_size):
        test_ip = bits_to_ip(left_part + ''.join(right_part))
        if not test_ip in taken_ips:
            return test_ip
    fatal('unable to find free IP in virtual subnet')

def get_domain():
    with open('/etc/wireguard/domain.txt', 'r') as fob:
        domain = fob.read().strip()
    return domain

def add_client(client_name, allowed_ips):
    subprocess.run(['wg-quick', 'save', 'wg0'])
    server_addr, server_port, taken_ips = parse_config('/etc/wireguard/wg0.conf')
    priv_key, pub_key = gen_key()
    client_virt_ip = get_next_ip(server_addr, taken_ips)
    proc = subprocess.run([
        'python3', '/usr/local/bin/wireguard-client-list.py', 'add',
        client_name, pub_key
    ], capture_output=True, text=True)
    if proc.returncode:
        fatal(proc.stderr)

    subprocess.run([
        'wg', 'set', 'wg0',
        'peer', pub_key,
        'allowed-ips', client_virt_ip + '/32',
    ])
    subprocess.run(['wg-quick', 'save', 'wg0'], capture_output=True)

    with open('/etc/wireguard/server.pub', 'r') as fob:
        server_pub_key = fob.read().strip()
    if not isdir('/etc/wireguard/clients'):
        os.mkdir('/etc/wireguard/clients')
    client_conf = f'/etc/wireguard/clients/{client_name}.conf'
    with open(client_conf, 'w') as fob:
        fob.write(f'''
[Interface]
PrivateKey = {priv_key}
Address = {client_virt_ip}

[Peer]
PublicKey = {server_pub_key}
Endpoint = {get_domain()}:{server_port}
AllowedIPs = {allowed_ips}
''')
    print(f'Generated {client_conf}')

if __name__ == '__main__':
    parser = ArgumentParser(
        description='Generate keys and configuration for a new client')
    parser.add_argument('client_name', help='Unique name for client')
    parser.add_argument('allowed_ips',
        nargs='?',
        default='0.0.0.0/0',
        help="IP range(s) of client's traffic to route through this vpn (default is 0/0)")
    args = parser.parse_args()

    try:
        add_client(args.client_name, args.allowed_ips)
    except InvalidIP as e:
        fatal(e.args[0])

