#!/usr/bin/python3

# Check that signatures on regulatory.db are valid and made by keys
# trusted by the current kernel in the target suite

import glob
import io
import os
import os.path
import re
import subprocess
import sys


tmp_dir = os.getenv('AUTOPKGTEST_TMP')


# Find current major/minor kernel version
def get_linux_source_name():
    proc = subprocess.Popen(['dpkg', '-s', 'linux-source'],
                            stdout=subprocess.PIPE)
    with io.open(proc.stdout.fileno(), encoding='utf-8') as pipe:
        for line in pipe:
            match = re.match(r'^Depends:.*\b(linux-source-[0-9.]*)\b', line)
            if match:
                return match.group(1)
    return None


# Extract wireless regdb certificate sources
def extract_certs_source(source_name):
    certs_subdir = f'{source_name}/net/wireless/certs'
    subprocess.check_call(['tar', '-C', tmp_dir, '-xa',
                           '-f', f'/usr/src/{source_name}.tar.xz',
                           certs_subdir])
    return os.path.join(tmp_dir, certs_subdir)


c_comment_re = re.compile(r'/\*.*?\*/')
c_hex_byte_re = re.compile(r'\s*0x([0-9A-Fa-f]{2})\s*')


# Convert from a C hex literal array to bytes
def convert_c_hex_to_bytes(hex):
    decoded = []
    hex = c_comment_re.sub(' ', hex)
    for word in hex.split(','):
        match = c_hex_byte_re.fullmatch(word)
        if match:
            decoded.append(int(match.group(1), 16))
        elif word.isspace():
            pass
        else:
            raise ValueError()
    return bytes(decoded)


# Convert an X.509 certificate from DER to PEM format
def convert_cert_der_to_pem(der):
    proc = subprocess.Popen(['openssl', 'x509', '-inform', 'der',
                             '-outform', 'pem'],
                            stdin=subprocess.PIPE,
                            stdout=subprocess.PIPE)
    pem, _ = proc.communicate(der)
    proc.wait()
    return pem


# Return an array of (cert, is_upstream] pairs.  Each cert is in PEM
# format.  The is_upstream flag is false for certificates found in
# "debian.hex" and otherwise true.
def convert_certs(certs_source_dir):
    certs = []

    for hex_name in glob.glob(f'{certs_source_dir}/*.hex'):
        is_upstream = os.path.basename(hex_name) != 'debian.hex'
        decoded = convert_c_hex_to_bytes(open(hex_name).read())

        # Split into certificates and convert one at a time, as the
        # openssl command line doesn't seem to support this
        start = 0
        while start < len(decoded):
            end = decoded.find(b'\x30\x82\x02', start + 3)
            if end < 0:
                end = len(decoded)
            certs.append((convert_cert_der_to_pem(decoded[start:end]),
                          is_upstream))
            start = end

    return certs


# Verify a PKCS#7 signature where the signer must exactly match
# one of the given PEM-format X.509 certificates
def verify_signature(payload_name, signature_name, certs):
    cert_name = os.path.join(tmp_dir, 'certificate.pem')
    with open(cert_name, 'wb') as cert_file:
        for cert in certs:
            cert_file.write(cert)

    # The certificates used here are self-signed and not CAs.  This
    # means we need to turn off vertification of the certificates
    # (-noverify).  We then to ignore the certificate in the signature
    # (-nointern), so that the signature is required to have been made
    # by one of the certificates we pass in (I hope).
    # The payload is printed to stdout by default, so disable that.
    # Successful verification is reported to stderr, which autopkgtest
    # counts as a failure, so redirect that to our stdout.
    subprocess.check_call(['openssl', 'smime', '-verify', '-noverify',
                           '-certfile', cert_name, '-nointern',
                           '-no_check_time',
                           '-content', payload_name,
                           '-in', signature_name, '-inform', 'der'],
                          stdout=subprocess.DEVNULL,
                          stderr=1)


def main():
    source_name = get_linux_source_name()
    certs_source_dir = extract_certs_source(source_name)
    certs = convert_certs(certs_source_dir)
    certs_upstream = [cert
                      for cert, is_upstream in certs
                      if is_upstream]
    certs_debian = [cert for cert, is_upstream in certs]

    verify_signature('/lib/firmware/regulatory.db-upstream',
                     '/lib/firmware/regulatory.db.p7s-upstream',
                     certs_upstream)
    verify_signature('/lib/firmware/regulatory.db-debian',
                     '/lib/firmware/regulatory.db.p7s-debian',
                     certs_debian)


if __name__ == '__main__':
    main()
