#!/usr/bin/env python3
#
#
# Copyright 2025 The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tool for registering or unregistering a device to receive new KeyMint roots."""

import argparse
import base64
from collections.abc import Callable
import shutil
import subprocess
import sys
import urllib.error
import urllib.parse
import urllib.request

# Magic request IDs for registering and unregistering a device to receive new
# Google-rooted certificates.
REGISTER_REQUEST_ID = 'keymint_register_for_new_root'
UNREGISTER_REQUEST_ID = 'keymint_unregister'
CANDIDATE_HALS = ('strongbox', 'default')
RKP_URL = 'https://remoteprovisioning.googleapis.com/v1:signCertificates'


class GoogleRootRegistrationError(Exception):
  """Base class for google_root_register errors."""


class AdbError(GoogleRootRegistrationError):
  """Error issuing adb commands."""


class CsrError(GoogleRootRegistrationError):
  """Error getting CSR from device."""


class RequestError(GoogleRootRegistrationError):
  """Error sending request to RKP server."""


class PostError(IOError):
  """Error POSTing to the RKP server."""

  def __init__(self, *args, response=None, **kwargs):
    self.response = response
    super().__init__(*args, **kwargs)


class PostResponse:
  """Response to an HTTP POST request."""

  def __init__(self, status_code: int, reason: str, content: bytes):
    self.status_code = status_code
    self.reason = reason
    self.content = content
    self.text = content.decode(errors='ignore')

  def raise_for_status(self):
    if 400 <= self.status_code < 600:
      http_error_msg = f'{self.status_code} Client Error: {self.reason}'
      raise PostError(http_error_msg, response=self)


def http_post(
    url: str,
    *,
    params=None,
    data=None,
    headers=None,
) -> PostResponse:
  """Sends POST request using urllib."""
  params = params or {}
  headers = headers or {}
  query_string = urllib.parse.urlencode(params)
  full_url = f'{url}?{query_string}'
  request = urllib.request.Request(full_url, data=data, headers=headers)
  try:
    with urllib.request.urlopen(request) as resp:
      return PostResponse(
          status_code=resp.status, reason=resp.reason, content=resp.read()
      )
  except urllib.error.HTTPError as e:
    if e.code == 400:
      raise PostError(
          'Bad request! This is expected when the device is already registered.'
      ) from e
    return PostResponse(status_code=e.code, reason=e.reason, content=e.read())
  except urllib.error.URLError as e:
    raise PostError(f'Connection error: {e.reason}') from e


class GoogleRootRegister:
  """Registers or unregisters a device to receive new KeyMint roots."""

  def __init__(
      self,
      *,
      which: Callable[[str], any] = shutil.which,
      run: Callable[
          ...,
          any,
      ] = subprocess.run,
      b64decode: Callable[[str], bytes] = base64.b64decode,
      post: Callable[..., PostResponse] = http_post,
  ):
    self._which = which
    self._run = run
    self._b64decode = b64decode
    self._post = post

  def check_adb(self) -> None:
    """Checks if adb is available in PATH and a device is connected."""
    if self._which('adb') is None:
      raise AdbError(
          'Error: adb command not found. Download and install from'
          " https://developer.android.com/tools/adb and ensure it's in your"
          ' PATH.'
      )
    try:
      print('Waiting for device connection up to 10 seconds...')
      self._run(
          ['adb', 'wait-for-device'],
          check=True,
          capture_output=True,
          timeout=10,
      )
    except subprocess.TimeoutExpired as e:
      raise AdbError(
          'Error: No device response after 10 seconds. Is the device connected?'
      ) from e
    except subprocess.CalledProcessError as e:
      raise AdbError(
          f'Error running adb command: {e!r}\nStderr: {e.stderr}'
      ) from e
    except OSError as e:
      raise AdbError(f'Failed to wait for device: {e!r}') from e

  def is_hal_supported(self, *, hal: str) -> bool:
    """Checks if HAL is supported on device."""
    try:
      result = self._run(
          ['adb', 'shell', 'cmd', 'remote_provisioning', 'list'],
          capture_output=True,
          text=True,
          check=True,
          timeout=10,
      )
      return any(hal in line.split() for line in result.stdout.splitlines())
    except subprocess.CalledProcessError as e:
      raise AdbError(
          f'Error running adb command: {e!r}\nStderr: {e.stderr}'
      ) from e
    except OSError as e:
      raise AdbError(f'Failed to get HAL list: {e!r}') from e

  def get_csr(self, *, hal: str) -> bytes:
    """Runs adb command to get CSR from the device and returns it as bytes."""
    try:
      result = self._run(
          ['adb', 'shell', 'cmd', 'remote_provisioning', 'csr', hal],
          capture_output=True,
          text=True,
          check=True,
          timeout=10,
      )
    except subprocess.TimeoutExpired as e:
      raise CsrError(
          f'Error: No device response after 10 seconds.\nStderr: {e.stderr}'
      ) from e
    except subprocess.CalledProcessError as e:
      raise CsrError(
          f'Error running adb command: {e}\nStderr: {e.stderr}'
      ) from e
    except OSError as e:
      raise CsrError(f'Failed to get CSR: {e}') from e

    # The CSR is expected to be a base64 string in stdout.
    csr_b64 = result.stdout.strip()
    if not csr_b64:
      raise CsrError('Error: adb command returned empty CSR.')

    try:
      return self._b64decode(csr_b64)
    except base64.binascii.Error as e:
      raise CsrError(
          f'Error decoding base64 CSR: {e}\n'
          'The output from adb may not be valid base64.'
      ) from e

  def send_request(self, *, csr_bytes: bytes, request_id: str) -> None:
    """Sends HTTP request to Remote Provisioning service."""
    params = {'request_id': request_id}
    headers = {'Content-Type': 'application/cbor'}
    try:
      response = self._post(
          RKP_URL, params=params, data=csr_bytes, headers=headers
      )
      response.raise_for_status()
    except PostError as e:
      if e.response is not None:
        print(f'Server response: {e.response.text}', file=sys.stderr)
      raise RequestError(
          f'Error during HTTP request to RKP server:\n{e}'
      ) from e

    print('Operation completed successfully.')
    # Server is not expected to return any content, but if it does, print it.
    if response.content:
      print(f'Response: {response.content!r}')


def main() -> None:
  parser = argparse.ArgumentParser(
      description=(
          'Register or unregister a device for getting new Google-rooted'
          ' certificates.'
      )
  )
  parser.add_argument(
      'action',
      nargs='?',
      choices=['register', 'unregister'],
      default='register',
      help="Action to perform: 'register' (default) or 'unregister'.",
  )
  args = parser.parse_args()

  request_id = (
      REGISTER_REQUEST_ID
      if args.action == 'register'
      else UNREGISTER_REQUEST_ID
  )

  google_root_register = GoogleRootRegister()
  try:
    google_root_register.check_adb()
  except GoogleRootRegistrationError as e:
    sys.exit(e)

  for hal in CANDIDATE_HALS:
    try:
      if not google_root_register.is_hal_supported(hal=hal):
        print(f'HAL `{hal}` is not supported on this device. Skipping...')
        continue

      print(f'\nPerforming `{args.action}` for HAL {hal}...')
      csr_bytes = google_root_register.get_csr(hal=hal)
      google_root_register.send_request(
          csr_bytes=csr_bytes, request_id=request_id
      )
    except GoogleRootRegistrationError as e:
      print(
          f'Error performing `{args.action}` for HAL {hal}: {e}',
          file=sys.stderr,
      )
      continue


if __name__ == '__main__':
  main()
