#!/usr/bin/env python3

from __future__ import annotations

import logging

from datetime import date, datetime
from importlib.metadata import version
from pathlib import PurePosixPath
from typing import Any
from urllib.parse import urljoin, urlparse

import requests

from urllib3.util import Retry
from requests.adapters import HTTPAdapter


def enable_full_debug() -> None:
    import http.client as http_client
    http_client.HTTPConnection.debuglevel = 1
    logging.basicConfig()
    logging.getLogger().setLevel(logging.DEBUG)
    requests_log = logging.getLogger("requests.packages.urllib3")
    requests_log.setLevel(logging.DEBUG)
    requests_log.propagate = True


class PyVulnerabilityLookup():

    def __init__(self, root_url: str='https://vulnerability.circl.lu', useragent: str | None=None, token: str | None=None,
                 *, proxies: dict[str, str] | None=None) -> None:
        '''Query a specific instance.

        :param root_url: URL of the instance to query.
        :param useragent: The User Agent used by requests to run the HTTP requests against the vulnerability lookup instance
        :param proxies: The proxies to use to connect to the vulnerability lookup instance - More details: https://requests.readthedocs.io/en/latest/user/advanced/#proxies
        '''
        self.root_url = root_url

        if not urlparse(self.root_url).scheme:
            self.root_url = 'http://' + self.root_url
        if not self.root_url.endswith('/'):
            self.root_url += '/'
        self.session = requests.session()
        self.session.headers['user-agent'] = useragent if useragent else f'PyVulnerabilityLookup / {version("pyvulnerabilitylookup")}'
        self.session.headers['X-API-KEY'] = token if token else ''
        self.session.headers['Accept'] = 'application/json'
        self.session.headers['Content-Type'] = 'application/json'
        if proxies:
            self.session.proxies.update(proxies)
        retries = Retry(total=5, backoff_factor=1, status_forcelist=[500, 502, 503, 504])
        self.session.mount('https://', HTTPAdapter(max_retries=retries))

    def set_apikey(self, apikey: str) -> None:
        '''Set the API key to use for the requests'''
        self.session.headers['X-API-KEY'] = apikey

    @property
    def is_up(self) -> bool:
        '''Test if the given instance is accessible'''
        try:
            r = self.session.head(self.root_url)
        except requests.exceptions.ConnectionError:
            return False
        return r.status_code == 200

    def redis_up(self) -> bool:
        '''Check if redis is up and running'''
        r = self.session.get(urljoin(self.root_url, str(PurePosixPath('api', 'system', 'redis_up'))))
        return r.json()

    # #### DB status ####

    def get_info(self) -> dict[str, Any]:
        '''Get more information about the current databases in use and when it was updated'''
        r = self.session.get(urljoin(self.root_url, str(PurePosixPath('api', 'system', 'dbInfo'))))
        return r.json()

    def get_config_info(self) -> dict[str, Any]:
        '''Get more information about the current databases in use and when it was updated'''
        r = self.session.get(urljoin(self.root_url, str(PurePosixPath('api', 'system', 'configInfo'))))
        return r.json()

    # #### Vulnerabilities ####

    def get_vulnerability(self, vulnerability_id: str) -> dict[str, Any]:
        '''Get a vulnerability

        :param vulnerability_id: The ID of the vulnerability to get (can be from any source, as long as it is a valid ID)
        '''
        r = self.session.get(urljoin(self.root_url, str(PurePosixPath('api', 'vulnerability', vulnerability_id))))
        return r.json()

    def create_vulnerability(self, vulnerability: dict[str, Any]) -> dict[str, Any]:
        '''Create a vulnerability.

        :param vulnerability: The vulnerability
        '''
        r = self.session.post(urljoin(self.root_url, str(PurePosixPath('api', 'vulnerability'))),
                              json=vulnerability)
        return r.json()

    def delete_vulnerability(self, vulnerability_id: str) -> int:
        '''Delete a vulnerability.

        :param vulnerability_id: The vulnerability ID
        '''
        r = self.session.delete(urljoin(self.root_url, str(PurePosixPath('api', 'vulnerability', vulnerability_id))))
        return r.status_code

    def get_last(self, number: int | None=None, source: str | None = None) -> list[dict[str, Any]]:
        '''Get the last vulnerabilities

        :param number: The number of vulnerabilities to get
        :param source: The source of the vulnerabilities
        '''
        path = PurePosixPath('last')
        if source:
            path /= source
        if number is not None:
            path /= str(number)
        r = self.session.get(urljoin(self.root_url, str(PurePosixPath('api', 'vulnerability', path))))
        return r.json()

    def get_vendors(self) -> list[str]:
        '''Get the  known vendors'''
        r = self.session.get(urljoin(self.root_url, str(PurePosixPath('api', 'vulnerability', 'browse'))))
        return r.json()

    def get_vendor_products(self, vendor: str) -> list[str]:
        '''Get the known products for a vendor

        :params vendor: A vendor owning products (must be in the known vendor list)
        '''
        r = self.session.get(urljoin(self.root_url, str(PurePosixPath('api', 'vulnerability', 'browse', vendor))))
        return r.json()

    def get_vendor_product_vulnerabilities(self, vendor: str, product: str) -> list[str]:
        '''Get the the vulnerabilities per vendor and a specific product

        :param vendor: A vendor owning products (must be in the known vendor list)
        :param product: A product owned by that vendor
        '''
        r = self.session.get(urljoin(self.root_url, str(PurePosixPath('api', 'vulnerability', 'search', vendor, product))))
        return r.json()

    # #### Comments ####

    def create_comment(self, /, *, comment: dict[str, Any] | None=None, description: str | None=None,
                       description_format: str | None = None, meta: dict[str, str] | None = None,
                       related_vulnerabilities: list[str] | None=None, title: str | None=None,
                       uuid: str | None=None, vulnerability: str | None = None) -> dict[str, Any]:
        '''Create a comment.

        :param comment: The comment
        :param description: The description of the comment
        :param description_format: Description format (markdown or text).
        :param meta: Zero or more meta-fields.
        :param related_vulnerabilities: Zero or more related vulnerabilities.
        :param title: The title of the comment
        :param uuid: The UUID of the comment
        :param vulnerability: The vulnerability ID of the comment
        '''

        if not comment:
            comment = {}
            if description:
                comment['description'] = description
            if description_format:
                comment['description_format'] = description_format
            if meta:
                comment['meta'] = meta
            if related_vulnerabilities:
                comment['related_vulnerabilities'] = related_vulnerabilities
            if title:
                comment['title'] = title
            if uuid:
                comment['uuid'] = uuid
            if vulnerability:
                comment['vulnerability'] = vulnerability

        r = self.session.post(urljoin(self.root_url, str(PurePosixPath('api', 'comment'))), json=comment)
        return r.json()

    def get_comments(self, uuid: str | None = None, vuln_id: str | None = None,
                     author: str | None = None) -> dict[str, Any]:
        '''Get comment(s)

        :param uuid: The UUID of a specific comment
        :param vuln_id: The vulnerability ID to get comments of
        :param author: The author of the comment(s)
        '''
        params = {}
        if uuid:
            params['uuid'] = uuid
        if vuln_id:
            params['vuln_id'] = vuln_id
        if author:
            params['author'] = author
        r = self.session.get(urljoin(self.root_url, str(PurePosixPath('api', 'comment'))),
                             params=params)
        return r.json()

    def get_comment(self, comment_uuid: str) -> dict[str, Any]:
        '''Get a comment

        :param comment_uuid: The UUID of the comment
        '''
        r = self.session.get(urljoin(self.root_url, str(PurePosixPath('api', 'comment', comment_uuid)))
                             )
        return r.json()

    def delete_comment(self, comment_uuid: str) -> int:
        '''Delete a comment.

        :param comment_uuid: The comment UUID
        '''
        r = self.session.delete(urljoin(self.root_url, str(PurePosixPath('api', 'comment', comment_uuid))))
        return r.status_code

    # #### Bundles ####

    def create_bundle(self, /, *, bundle: dict[str, Any] | None=None, description: str | None=None,
                      meta: dict[str, str] | None=None, name: str | None=None, related_vulnerabilities: list[str] | None=None,
                      uuid: str | None=None) -> dict[str, Any]:
        '''Create a bundle.

        :param bundle: The bundle
        '''

        if not bundle:
            bundle = {}
            if description:
                bundle['description'] = description
            if meta:
                bundle['meta'] = meta
            if name:
                bundle['name'] = name
            if related_vulnerabilities:
                bundle['related_vulnerabilities'] = related_vulnerabilities
            if uuid:
                bundle['uuid'] = uuid

        r = self.session.post(urljoin(self.root_url, str(PurePosixPath('api', 'bundle'))),
                              json=bundle)
        return r.json()

    def get_bundles(self, uuid: str | None = None, vuln_id: str | None = None,
                    author: str | None = None, per_page: int | None=None,
                    meta: list[dict[str, str]] | None=None) -> dict[str, Any]:
        '''Get bundle(s)

        :param uuid: The UUID a specific bundle
        :param vuln_id: The vulnerability ID to get bundles of
        :param author: The author of the bundle(s)
        :param per_page: The number of bundles to get per page
        :param meta: Query for the meta JSON field. Example: meta=[{‘tags’: [‘tcp’]}]
        '''
        params: dict[str, Any] = {}
        if uuid:
            params['uuid'] = uuid
        if vuln_id:
            params['vuln_id'] = vuln_id
        if author:
            params['author'] = author
        if per_page is not None:
            params['per_page'] = per_page
        if meta:
            params['meta'] = meta

        r = self.session.get(urljoin(self.root_url, str(PurePosixPath('api', 'bundle'))), params=params)
        return r.json()

    def get_bundle(self, bundle_uuid: str) -> dict[str, Any]:
        '''Get a bundle

        :param bundle_uuid: The UUID of the bundle
        '''
        r = self.session.get(urljoin(self.root_url, str(PurePosixPath('api', 'bundle', bundle_uuid))))
        return r.json()

    def delete_bundle(self, bundle_uuid: str) -> int:
        '''Delete a bundle.

        :param bundle_uuid: The bundle UUID
        '''
        r = self.session.delete(urljoin(self.root_url, str(PurePosixPath('api', 'bundle', bundle_uuid))))
        return r.status_code

    # #### Users ####

    def create_user(self, /, *, user: dict[str, Any] | None=None,
                    login: str | None=None, name: str | None=None,
                    organisation: str | None=None, email: str | None=None) -> dict[str, Any]:
        '''Create a user.

        :param login: The login of the user
        :param name: The name of the user
        :param organisation: The organisation of the user
        :param email: The email of the user
        '''

        if not user:
            user = {}
            if login:
                user['login'] = login
            if name:
                user['name'] = name
            if organisation:
                user['organisation'] = organisation
            if email:
                user['email'] = email

        r = self.session.post(urljoin(self.root_url, str(PurePosixPath('api', 'user'))), json=user)
        return r.json()

    def list_users(self) -> dict[str, Any]:
        # Alias this one to get_users for consistency
        return self.get_users()

    def get_users(self) -> dict[str, Any]:
        '''List users'''
        r = self.session.get(urljoin(self.root_url, str(PurePosixPath('api', 'user'))))
        return r.json()

    def get_user_information(self) -> dict[str, Any]:
        '''Get user information'''
        r = self.session.get(urljoin(self.root_url, str(PurePosixPath('api', 'user', 'me'))))
        return r.json()

    def reset_api_key(self) -> dict[str, Any]:
        '''Reset the API key'''
        r = self.session.get(urljoin(self.root_url, str(PurePosixPath('api', 'user', 'api_key'))))
        return r.json()

    def delete_user(self, user_id: str) -> int:
        '''Delete a user.

        :param user_id: The user ID
        '''
        r = self.session.delete(urljoin(self.root_url, str(PurePosixPath('api', 'user', user_id))))
        return r.status_code

    # #### Sightings ####

    def get_sighting(self, sighting_uuid: str) -> dict[str, Any]:
        '''Get a sighting

        :param sighting_uuid: The UUID of the sighting
        '''
        r = self.session.get(urljoin(self.root_url, str(PurePosixPath('api', 'sighting', sighting_uuid))))
        return r.json()

    def get_sightings(self, /, *, sighting_uuid: str | None=None,
                      sighting_type: str | None=None, vuln_id: str | None = None,
                      author: str | None = None,
                      date_from: date | datetime | None=None,
                      date_to: date | datetime | None=None) -> dict[str, Any]:
        '''Get sightings

        :param sighting_uuid: The UUID of a specific sighting
        :param sighting_type: The type of sighting, can be one of: 'seen', 'exploided', 'not-exploited', 'confirmed', 'not-confirmed', 'patched', 'not-patched'.
        :param vuln_id: The vulnerability ID to get sightings of
        :param author: The author of the sighting(s)
        :param date_from: The date from which to get sightings
        :param date_to: The date to which to get sightings
        '''

        params = {}
        if sighting_uuid:
            params['uuid'] = sighting_uuid
        if sighting_type:
            params['type'] = sighting_type
        if vuln_id:
            params['vuln_id'] = vuln_id
        if author:
            params['author'] = author
        if date_from:
            if isinstance(date_from, datetime):
                date_from = date_from.date()
            params['date_from'] = date_from.isoformat()
        if date_to:
            if isinstance(date_to, datetime):
                date_to = date_to.date()
            params['date_to'] = date_to.isoformat()

        r = self.session.get(urljoin(self.root_url, str(PurePosixPath('api', 'sighting'))), params=params)
        return r.json()

    def create_sighting(self, /, *, sighting: dict[str, Any] | None=None,
                        creation_timestamp: datetime | None=None,
                        source: str | None = None,
                        sighting_type: str | None=None,
                        vulnerability: str | None=None) -> dict[str, Any]:
        '''Create a sighting.

        :param sighting: The sighting, as an object.
        :param creation_timestamp: The timestamp of the sighting - set to now if not provided
        :param source: The source of the sighting
        :param sighting_type: The type of sighting, can be one of: 'seen', 'exploided', 'not-exploited', 'confirmed', 'not-confirmed', 'patched', 'not-patched'.
        :param vulnerability: The vulnerability ID of the sighting
        '''
        if not sighting:
            sighting = {}
            if creation_timestamp:
                # This calue may or may not have a TZ at this point
                sighting['creation_timestamp'] = creation_timestamp
            if source:
                sighting['source'] = source
            if sighting_type:
                sighting['type'] = sighting_type
            if vulnerability:
                sighting['vulnerability'] = vulnerability

        if 'creation_timestamp' in sighting:
            # check if the datetime object has a TZ, if it doesn't, set it to localtime, make it a string
            if sighting['creation_timestamp'].tzinfo is None:
                sighting['creation_timestamp'] = sighting['creation_timestamp'].astimezone()
            sighting['creation_timestamp'] = sighting['creation_timestamp'].isoformat()

        r = self.session.post(urljoin(self.root_url, str(PurePosixPath('api', 'sighting'))),
                              json=sighting)
        return r.json()

    # #### EPSS ####

    def get_epss(self, vulnerability: str) -> dict[str, Any]:
        '''Get the EPSS for a vulnerability

        :param vulnerability: The vulnerability ID
        '''
        r = self.session.get(urljoin(self.root_url, str(PurePosixPath('api', 'epss', vulnerability))))
        return r.json()