#! /usr/bin/env python3


# eurostat-fetcher -- Fetch series from Eurostat database
# By: Emmanuel Raviart <emmanuel.raviart@cepremap.org>
#
# Copyright (C) 2017 Cepremap
# https://git.nomics.world/dbnomics-fetchers/eurostat-fetcher
#
# eurostat-fetcher is free software; you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# eurostat-fetcher is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.


"""Fetch series from Eurostat, the statistical office of the European Union, using bulk download and SDMX formats.

http://ec.europa.eu/eurostat/data/database

EUROSTAT bulk download:
- http://ec.europa.eu/eurostat/fr/data/bulkdownload
- http://ec.europa.eu/eurostat/estat-navtree-portlet-prod/BulkDownloadListing

EUROSTAT SDMX documentation:
- http://ec.europa.eu/eurostat/web/sdmx-infospace/welcome
- http://ec.europa.eu/eurostat/web/sdmx-web-services/rest-sdmx-2.1
"""


import argparse
import io
import os
import re
import shutil
import subprocess
import sys
import zipfile

from dulwich.repo import Repo
from lxml import etree
import requests


nsmap = dict(
    nt='urn:eu.europa.ec.eurostat.navtree',
)
prepared_element_re = re.compile('<Prepared>.+</Prepared>')


def iter_datasets(xml_element, old_xml_element=None):
    """Yield datasets. If old_xml_element is provided, yield only updated datasets."""
    old_last_update_by_dataset_code = {}
    if old_xml_element is not None:
        # Index lastUpdate attributes in old table_of_contents.xml.
        for element in old_xml_element.iterfind('.//nt:leaf[@type="dataset"]', namespaces=nsmap):
            dataset_code = element.findtext("nt:code", namespaces=nsmap)
            old_last_update = element.findtext("nt:lastUpdate", namespaces=nsmap)
            old_last_update_by_dataset_code[dataset_code] = old_last_update

    for leaf_element in xml_element.iterfind('.//nt:leaf[@type="dataset"]', namespaces=nsmap):
        if old_xml_element is None:
            yield leaf_element
        else:
            dataset_code = leaf_element.findtext('nt:code', namespaces=nsmap)
            old_last_update = old_last_update_by_dataset_code.get(dataset_code)
            if old_last_update is None:
                # This leaf_element is new in this version of table_of_contents.xml
                yield leaf_element
            else:
                last_update = leaf_element.findtext("nt:lastUpdate", namespaces=nsmap)
                if last_update != old_last_update:
                    yield leaf_element


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('target_dir', help='path of target directory containing Eurostat series')
    parser.add_argument('--incremental', action='store_true',
                        help='download only datasets that changed since the last commit')
    parser.add_argument('--keep-files', action='store_true', help='keep existing files in target directory')
    args = parser.parse_args()

    old_xml_element = None

    if args.incremental:
        repo = Repo(args.target_dir)
        assert b'HEAD' in repo.get_refs()
        old_xml_element = etree.fromstring(repo[repo[repo[repo.head()].tree][b"table_of_contents.xml"][1]].data)

    # Fetch list of datasets.
    xml_url = 'http://ec.europa.eu/eurostat/estat-navtree-portlet-prod/BulkDownloadListing?file=table_of_contents.xml'
    print('Fetching table of content {}'.format(xml_url))
    response = requests.get(xml_url)
    xml_element = etree.fromstring(response.content, parser=etree.XMLParser(remove_blank_text=True))
    xml_file_path = os.path.join(args.target_dir, 'table_of_contents.xml')
    with open(xml_file_path, 'wb') as xml_file:
        etree.ElementTree(xml_element).write(xml_file, encoding='utf-8', pretty_print=True, xml_declaration=True)

    # Fetch datasets.

    data_dir = os.path.join(args.target_dir, 'data')
    if os.path.exists(data_dir):
        if not args.keep_files and not args.incremental:
            for node_name in os.listdir(data_dir):
                node_path = os.path.join(data_dir, node_name)
                if os.path.isdir(node_path):
                    shutil.rmtree(node_path)
                else:
                    os.remove(node_path)
    else:
        os.mkdir(data_dir)

    data_urls = set(
        data_url
        for data_url in (
            leaf_element.findtext('./nt:downloadLink[@format="sdmx"]', namespaces=nsmap)
            for leaf_element in iter_datasets(xml_element, old_xml_element)
        )
        if data_url
    )

    for data_url in data_urls:
        dataset_dir = os.path.join(data_dir, data_url.rsplit('/', 1)[-1].split('.', 1)[0])
        if os.path.exists(dataset_dir):
            print('Skipping existing dataset {}'.format(data_url))
        else:
            print('Fetching dataset {}'.format(data_url))
            response = requests.get(data_url)
            data_zip_file = zipfile.ZipFile(io.BytesIO(response.content))
            if os.path.exists(dataset_dir):
                for node_name in os.listdir(dataset_dir):
                    node_path = os.path.join(dataset_dir, node_name)
                    if os.path.isdir(node_path):
                        shutil.rmtree(node_path)
                    else:
                        os.remove(node_path)
            else:
                os.mkdir(dataset_dir)
            for data_zip_info in data_zip_file.infolist():
                if data_zip_info.filename.endswith('.xml'):
                    with data_zip_file.open(data_zip_info) as data_file:
                        xml_file_path = os.path.join(dataset_dir, data_zip_info.filename)
                        write_normalized_xml_file(xml_file_path, data_file)
                else:
                    data_zip_file.extract(data_zip_info, dataset_dir)

    # Fetch datasets definitions.

    data_structures_dir = os.path.join(args.target_dir, 'datastructure')
    if os.path.exists(data_structures_dir):
        if not args.keep_files and not args.incremental:
            for node_name in os.listdir(data_structures_dir):
                node_path = os.path.join(data_structures_dir, node_name)
                if os.path.isdir(node_path):
                    shutil.rmtree(node_path)
                else:
                    os.remove(node_path)
    else:
        os.mkdir(data_structures_dir)

    metadata_urls = set(
        metadata_url
        for metadata_url in (
            leaf_element.findtext('./nt:metadata[@format="sdmx"]', namespaces=nsmap)
            for leaf_element in iter_datasets(xml_element, old_xml_element)
        )
        if metadata_url
    )

    for metadata_url in metadata_urls:
        metadata_dir = os.path.join(data_structures_dir, metadata_url.rsplit('/', 1)[-1].split('.', 1)[0])
        if os.path.exists(metadata_dir):
            print('Skipping existing data structure {}'.format(metadata_url))
        else:
            print('Fetching data structure {}'.format(metadata_url))
            response = requests.get(metadata_url)
            metadata_zip_file = zipfile.ZipFile(io.BytesIO(response.content))
            if os.path.exists(metadata_dir):
                for node_name in os.listdir(metadata_dir):
                    node_path = os.path.join(metadata_dir, node_name)
                    if os.path.isdir(node_path):
                        shutil.rmtree(node_path)
                    else:
                        os.remove(node_path)
            else:
                os.mkdir(metadata_dir)
            for metadata_zip_info in metadata_zip_file.infolist():
                if metadata_zip_info.filename.endswith('.xml'):
                    with metadata_zip_file.open(metadata_zip_info) as metadata_file:
                        xml_file_path = os.path.join(metadata_dir, metadata_zip_info.filename)
                        write_normalized_xml_file(xml_file_path, metadata_file)
                else:
                    metadata_zip_file.extract(metadata_zip_info, metadata_dir)

    return 0


def write_normalized_xml_file(xml_file_path, source_file):
    """Normalize data that changes at each download, like today date,
    in order to avoid triggering a false commit in source data.

    Use regexes because lxml raises SerialisationError with too large files.
    """
    global prepared_element_re
    xml_str = source_file.read().decode('utf-8')
    with open(xml_file_path, mode="w") as xml_file:
        xml_file.write(prepared_element_re.sub("<Prepared>1111-11-11T11:11:11</Prepared>", xml_str, 1))


if __name__ == '__main__':
    sys.exit(main())