#! /usr/bin/env python3


# eurostat-fetcher -- Fetch series from Eurostat database
# By: Christophe Benz <christophe.benz@cepremap.org>
#
# Copyright (C) 2017-2018 Cepremap
# https://git.nomics.world/dbnomics-fetchers/eurostat-fetcher
#
# eurostat-fetcher is free software; you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# eurostat-fetcher is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.


"""Convert Eurostat provider, categories, datasets and time series to DBnomics JSON and TSV files."""


import argparse
import logging
import os
import re
import sys
from collections import defaultdict
from collections.abc import Iterator
from datetime import datetime
from pathlib import Path
from typing import Any, cast

import sdmx
import ujson as json
from contexttimer import Timer
from humanfriendly import format_timespan
from lxml import etree
from more_itertools import one
from sdmx.message import StructureMessage
from sdmx.model.common import Dimension
from sdmx.model.v21 import DataflowDefinition, DataStructureDefinition

from eurostat_fetcher.convert_utils import iter_series_json

provider_code = "Eurostat"
provider_json = {
    "code": provider_code,
    "name": provider_code,
    "region": "EU",
    "terms_of_use": "https://ec.europa.eu/eurostat/about/policies/copyright",
    "website": "https://ec.europa.eu/eurostat",
}

datasets_dir_name = "data"
log = logging.getLogger(__name__)
namespace_url_by_name = {"xml": "http://www.w3.org/XML/1998/namespace"}

DAILY_PERIOD_RE = re.compile(r"(?P<year>\d{4})(?P<month>\d{2})(?P<day>\d{2})")
DATASETS_ENV_VAR = "DATASETS"
FULL_ENV_VAR = "FULL"


def convert_datasets(
    datasets_to_convert: list[tuple[str, Path]], dataset_json_stubs: dict[str, dict], target_dir: Path
) -> None:
    log.info("Converting %d datasets...", len(datasets_to_convert))

    converted_datasets_codes = set()
    for index, (dataset_code, source_dataset_dir) in enumerate(sorted(datasets_to_convert), start=1):
        if dataset_code in converted_datasets_codes:
            log.debug("Skipping dataset %r because it was already converted", dataset_code)
            continue

        log.info("Converting dataset %d/%d %r", index, len(datasets_to_convert), dataset_code)

        dataset_dir = target_dir / dataset_code
        dataset_dir.mkdir(exist_ok=True)

        dataset_json_stub = dataset_json_stubs[dataset_code]
        convert_dataset(dataset_code, dataset_json_stub, source_dataset_dir, dataset_dir)

        converted_datasets_codes.add(dataset_code)


def convert_dataset(
    dataset_code: str, dataset_json_stub: dict[str, Any], source_dataset_dir: Path, dataset_dir: Path
) -> None:
    dsd_file_path = source_dataset_dir / f"{dataset_code}.dsd.xml"
    with Timer() as timer:
        dataflow_message = cast(StructureMessage, sdmx.read_sdmx(dsd_file_path))
    log.debug("%s file was read in %s", str(dsd_file_path), format_timespan(timer.elapsed))

    dataflow_definition = cast(DataflowDefinition, one(dataflow_message.dataflow.values()))
    data_structure_definition = cast(DataStructureDefinition, one(dataflow_message.structure.values()))

    sdmx_file_path = source_dataset_dir / f"{dataset_code}.sdmx.xml"
    series_jsonl_file = dataset_dir / "series.jsonl"

    dimensions_codes_order = [
        dimension.id for dimension in data_structure_definition.dimensions if isinstance(dimension, Dimension)
    ]

    attribute_codes = sorted(attribute.id for attribute in data_structure_definition.attributes)
    used_attribute_value_codes: defaultdict[str, set[str]] = defaultdict(set)
    used_dimension_value_codes: defaultdict[str, set[str]] = defaultdict(set)

    with Timer() as timer, series_jsonl_file.open("w") as series_jsonl_fp:
        iterparse_context = etree.iterparse(sdmx_file_path, huge_tree=True, tag="{*}Series", events=["end"])
        for series_json in sorted(
            iter_series_json(
                iterparse_context,
                attribute_codes=attribute_codes,
                dimensions_codes_order=dimensions_codes_order,
                used_attribute_value_codes=used_attribute_value_codes,
                used_dimension_value_codes=used_dimension_value_codes,
            ),
            key=lambda series_json: series_json["code"],
        ):
            json.dump(series_json, series_jsonl_fp, ensure_ascii=False, sort_keys=True)
            series_jsonl_fp.write("\n")

    log.debug("Series were written to %s in %s", str(series_jsonl_file), format_timespan(timer.elapsed))

    with Timer() as timer:
        updated_at = extract_updated_at(dataflow_definition)
        dataset_json = {
            **dataset_json_stub,
            "attributes_labels": {
                attribute.id: attribute.concept_identity.name["en"]
                for attribute in data_structure_definition.attributes
            },
            "attributes_values_labels": {
                attribute.id: {
                    k: v.name["en"]
                    for k, v in attribute.local_representation.enumerated.items.items()
                    if k in used_attribute_value_codes[attribute.id]
                }
                for attribute in data_structure_definition.attributes
            },
            "dimensions_codes_order": dimensions_codes_order,
            "dimensions_labels": {
                dimension.id: dimension.concept_identity.name["en"]
                for dimension in data_structure_definition.dimensions
                if isinstance(dimension, Dimension)
            },
            "dimensions_values_labels": {
                dimension.id: {
                    k: v.name["en"]
                    for k, v in dimension.local_representation.enumerated.items.items()
                    if k in used_dimension_value_codes[dimension.id]
                }
                for dimension in data_structure_definition.dimensions
                if isinstance(dimension, Dimension)
            },
            "updated_at": updated_at.isoformat() if updated_at is not None else None,
        }
        dataset_json_file = dataset_dir / "dataset.json"
        write_json_file(dataset_json_file, without_falsy_values(dataset_json))

    log.debug("Dataset metadata was written to %s in %s", str(dataset_json_file), format_timespan(timer.elapsed))


def extract_updated_at(dataflow_definition: DataflowDefinition) -> datetime | None:
    return max(
        (
            datetime.fromisoformat(cast(str, annotation.title))
            for annotation in dataflow_definition.annotations
            if annotation.type in {"UPDATE_DATA", "UPDATE_STRUCTURE"}
        ),
        default=None,
    )


def iter_child_directories(directory: Path) -> Iterator[Path]:
    """Iterate over child directories of a directory."""
    for child in directory.iterdir():
        if child.is_dir():
            yield child


def iter_datasets_to_convert(
    source_datasets_dir: Path, target_dir: Path, *, datasets, resume
) -> Iterator[tuple[str, Path]]:
    for source_dataset_dir in sorted(iter_child_directories(source_datasets_dir)):
        dataset_code = source_dataset_dir.name

        if datasets and dataset_code not in datasets:
            log.debug(
                "Skipping dataset %r because it is not mentioned by --datasets option",
                dataset_code,
            )
            continue

        sdmx_file = source_dataset_dir / f"{dataset_code}.sdmx.xml"

        if not sdmx_file.is_file():
            log.error(
                "Skipping dataset %s because SDMX file %s is missing",
                dataset_code,
                str(sdmx_file),
            )
            continue

        dataset_dir = target_dir / dataset_code

        if resume and dataset_dir.is_dir():
            log.debug(
                "Skipping dataset %r because it already exists (due to --resume option)",
                dataset_code,
            )
            continue

        yield dataset_code, source_dataset_dir


def toc_to_category_tree(source_dir: Path):
    """Walk recursively table_of_contents.xml and return category_tree_json and dataset.json stubs."""
    # Parse "table_of_contents", abbreviated "toc".
    toc_element = etree.parse(str(source_dir / "table_of_contents.xml")).getroot()

    dataset_json_stubs = {}
    category_tree_json = list(iter_category_tree_nodes(toc_element, dataset_json_stubs))

    return category_tree_json, dataset_json_stubs


def iter_category_tree_nodes(xml_element, dataset_json_stubs) -> Iterator[dict[str, Any]]:
    """Walk recursively xml_element (table_of_contents.xml) and return category_tree_json.

    Side-effects: fill dataset_json_stubs.
    """
    xml_element_tag = xml_element.tag[len("urn:eu.europa.ec.eurostat.navtree") + 2 :]
    if xml_element_tag == "tree":
        for child_element in xml_element:
            yield from iter_category_tree_nodes(child_element, dataset_json_stubs)

    elif xml_element_tag == "branch":
        children = [
            child
            for child_element in xml_element.iterfind("{*}children/*")
            for child in iter_category_tree_nodes(child_element, dataset_json_stubs)
        ]
        if children:
            yield without_falsy_values(
                {
                    "code": xml_element.findtext("{*}code"),
                    "name": xml_element.findtext("{*}title[@language='en']"),
                    "children": children,
                }
            )

    elif xml_element_tag == "leaf" and xml_element.attrib["type"] in {"dataset", "table"}:
        dataset_code = xml_element.findtext("{*}code")
        dataset_name = xml_element.findtext("{*}title[@language='en']")

        # Datasets can appear multiple time in the category tree
        if dataset_code not in dataset_json_stubs:
            dataset_json_stubs[dataset_code] = {
                "code": dataset_code,
                "name": dataset_name,
                "description": xml_element.findtext("{*}shortDescription[@language='en']") or None,
                "doc_href": xml_element.findtext("{*}metadata[@format='html']") or None,
            }

        yield {
            "code": dataset_code,
            "name": dataset_name,
        }

        for child_element in xml_element.iterfind("{*}children/*"):
            yield from iter_category_tree_nodes(child_element, dataset_json_stubs)

    else:
        log.warning(
            "Unexpected node type: %r, type %r (code %r)",
            xml_element_tag,
            xml_element.attrib["type"],
            xml_element.findtext("{*}code"),
        )


def main() -> int:
    datasets_from_env = os.getenv(DATASETS_ENV_VAR)
    if datasets_from_env:
        datasets_from_env = datasets_from_env.split(",")

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "source_dir",
        type=Path,
        help="path of source directory containing Eurostat series in source format",
    )
    parser.add_argument(
        "target_dir",
        type=Path,
        help="path of target directory containing datasets & series in DBnomics JSON and TSV formats",
    )
    parser.add_argument(
        "--datasets",
        nargs="+",
        metavar="DATASET_CODE",
        default=datasets_from_env,
        help="convert only the given datasets (datasets codes, space separated)",
    )
    parser.add_argument("--log", default="INFO", help="level of logging messages")
    parser.add_argument("--resume", action="store_true", help="do not process already written datasets")
    args = parser.parse_args()

    if not args.source_dir.is_dir():
        parser.error(f"Could not find directory {str(args.source_dir)!r}")
    if not args.target_dir.is_dir():
        parser.error(f"Could not find directory {str(args.target_dir)!r}")

    numeric_level = getattr(logging, args.log.upper(), None)
    if not isinstance(numeric_level, int):
        msg = f"Invalid log level: {args.log}"
        raise ValueError(msg)
    logging.basicConfig()
    log.setLevel(numeric_level)

    log.info("Command-line arguments: %r", args)

    write_json_file(args.target_dir / "provider.json", provider_json)

    source_datasets_dir = args.source_dir / datasets_dir_name

    datasets_to_convert = list(
        iter_datasets_to_convert(
            source_datasets_dir, target_dir=args.target_dir, datasets=args.datasets, resume=args.resume
        )
    )

    category_tree_json, dataset_json_stubs = toc_to_category_tree(source_dir=args.source_dir)

    convert_datasets(
        datasets_to_convert=datasets_to_convert, dataset_json_stubs=dataset_json_stubs, target_dir=args.target_dir
    )

    log.info("Writing category tree...")
    write_json_file(args.target_dir / "category_tree.json", category_tree_json)

    return 0


def without_falsy_values(mapping):
    return {k: v for k, v in mapping.items() if v}


def write_json_file(path, data) -> None:
    with path.open("w") as f:
        json.dump(data, f, ensure_ascii=False, indent=2, sort_keys=True)


if __name__ == "__main__":
    sys.exit(main())