"""
Load relationships from CSV into Neo4j using batch MERGE.
Safe to re-run -- MERGE won't create duplicates.

Usage:
    python load_links.py diplo                    # Load diplomacy links
    python load_links.py dw                       # Load dig.watch links
    python load_links.py diplo --csv diplomacy_links_16032026.csv  # Specific CSV
    python load_links.py dw --csv dw_links_16032026.csv
"""

import sys
import glob
import os

import pandas as pd

from config import DIPLO_CONFIG, DW_CONFIG
from neo4j_loader import get_driver, load_relationships, ensure_indexes, run_query


def find_latest_csv(prefix: str) -> str:
    pattern = f"{prefix}_links_*.csv"
    files = sorted(glob.glob(pattern), key=os.path.getmtime, reverse=True)
    if not files:
        print(f"No CSV found matching '{pattern}'. Use --csv to specify.")
        sys.exit(1)
    return files[0]


def main():
    if len(sys.argv) < 2 or sys.argv[1] not in ('diplo', 'dw'):
        print("Usage: python load_links.py <diplo|dw> [--csv filename.csv]")
        sys.exit(1)

    site = sys.argv[1]
    cfg = DIPLO_CONFIG if site == 'diplo' else DW_CONFIG
    csv_prefix = 'diplomacy' if site == 'diplo' else 'dw'

    csv_path = None
    if '--csv' in sys.argv:
        csv_path = sys.argv[sys.argv.index('--csv') + 1]
    else:
        csv_path = find_latest_csv(csv_prefix)

    print(f"Site: {site}")
    print(f"Database: {cfg['neo4j_database']}")
    print(f"CSV: {csv_path}")

    df_links = pd.read_csv(csv_path)
    print(f"Loaded {len(df_links)} links ({df_links['link'].nunique()} types)")
    print(df_links['link'].value_counts().head(10))

    driver = get_driver()
    ensure_indexes(driver, cfg['neo4j_database'])

    print(f"\n--- Loading relationships into {cfg['neo4j_database']} ---")
    load_relationships(driver, cfg['neo4j_database'], df_links)

    rels = run_query(driver, cfg['neo4j_database'],
                     "MATCH ()-[r]->() RETURN count(r) AS cnt")
    print(f"\nTotal relationships in {cfg['neo4j_database']}: {rels[0]['cnt']}")

    driver.close()
    print("Done.")


if __name__ == '__main__':
    main()
