"""
DiploAI Knowledge Graph - Neo4j Loader (Phase 2b)

Dual label model:
  - Post/Topic nodes: batch CREATE with multiple labels (e.g. :Document:Blog)
  - Tag/Date nodes: batch MERGE (single label)
  - Relationships: batch MERGE grouped by type
"""

import math

import pandas as pd
from neo4j import GraphDatabase
from tqdm import tqdm

from config import NEO4J_URI, NEO4J_USER, NEO4J_PASS

BATCH_SIZE = 500


def get_driver():
    return GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASS))


# ------------------------------------------------------------------
# Index management
# ------------------------------------------------------------------

def ensure_indexes(driver, database: str):
    """Create indexes on :Document and other labels for fast lookups."""
    specs = [
        ("Document", "node_id"),
        ("Document", "document_hash"),
        ("Document", "wp_id"),
        ("Tag", "node_id"),
        ("Tag", "name"),
        ("Date", "node_id"),
    ]
    with driver.session(database=database) as session:
        for label, prop in specs:
            try:
                session.run(f"CREATE INDEX IF NOT EXISTS FOR (n:`{label}`) ON (n.{prop})")
            except Exception:
                pass
    print(f"[{database}] Indexes ensured ({len(specs)} specs).")


# ------------------------------------------------------------------
# Property cleaning
# ------------------------------------------------------------------

def _clean_props(d: dict) -> dict:
    """Remove NaN/None/labels from properties dict for Neo4j."""
    skip = {'labels'}
    return {k: (v if v is not None and v == v else '')
            for k, v in d.items() if k not in skip}


# ------------------------------------------------------------------
# Node loading with dual labels
# ------------------------------------------------------------------

def load_all_nodes(driver, database: str, nodes_dict: dict[str, pd.DataFrame]):
    """
    Load all node types:
      - posts/topics: CREATE with dual labels from 'labels' column
      - tags/dates: MERGE with single label
    """
    total = sum(len(df) for df in nodes_dict.values() if not df.empty)
    print(f"Loading {total} nodes into '{database}'...")

    with driver.session(database=database) as session:
        pbar = tqdm(total=total, desc="Nodes")

        # 1. Topics (CREATE with dual/triple labels: Document:Topic or Document:Topic:TopicBasket)
        topics = nodes_dict.get('topics', pd.DataFrame())
        if not topics.empty:
            _load_labeled_nodes(session, topics, pbar)

        # 2. Tags (MERGE, single label)
        tags = nodes_dict.get('tags', pd.DataFrame())
        if not tags.empty:
            props_list = [_clean_props(row.to_dict()) for _, row in tags.iterrows()]
            _batch_merge(session, 'Tag', 'node_id', props_list, pbar)

        # 3. Posts (CREATE with dual labels: Document:Blog, Document:Actor, etc.)
        posts = nodes_dict.get('posts', pd.DataFrame())
        if not posts.empty:
            _load_labeled_nodes(session, posts, pbar)

        # 4. Dates (MERGE, single label)
        dates = nodes_dict.get('dates', pd.DataFrame())
        if not dates.empty:
            props_list = [_clean_props(row.to_dict()) for _, row in dates.iterrows()]
            _batch_merge(session, 'Date', 'node_id', props_list, pbar)

        pbar.close()

    print(f"Done loading {total} nodes.")


def _load_labeled_nodes(session, df: pd.DataFrame, pbar):
    """MERGE nodes grouped by their label combination (supports dual/triple labels)."""
    label_groups: dict[str, list[dict]] = {}
    for _, row in df.iterrows():
        labels = row['labels']
        label_key = ':'.join(labels)
        label_groups.setdefault(label_key, []).append(_clean_props(row.to_dict()))

    for label_str, props_list in label_groups.items():
        label_cypher = '`:`'.join(label_str.split(':'))
        cypher = f"UNWIND $batch AS props MERGE (n:`{label_cypher}` {{node_id: props.node_id}}) SET n += props"
        for i in range(0, len(props_list), BATCH_SIZE):
            batch = props_list[i:i + BATCH_SIZE]
            session.run(cypher, batch=batch)
            pbar.update(len(batch))


def _batch_merge(session, label: str, merge_key: str, props_list: list[dict], pbar=None):
    """Batch MERGE nodes on a single label + merge_key."""
    cypher = (f"UNWIND $batch AS props "
              f"MERGE (n:`{label}` {{{merge_key}: props.{merge_key}}}) "
              f"SET n += props")
    for i in range(0, len(props_list), BATCH_SIZE):
        batch = props_list[i:i + BATCH_SIZE]
        session.run(cypher, batch=batch)
        if pbar:
            pbar.update(len(batch))


# ------------------------------------------------------------------
# Relationship loading
# ------------------------------------------------------------------

def load_relationships(driver, database: str, df_links: pd.DataFrame):
    """Batch MERGE relationships grouped by type."""
    df_clean = df_links.dropna(subset=['source_id', 'target_id', 'link'])

    rel_groups: dict[str, list[dict]] = {}
    for _, row in df_clean.iterrows():
        rel_groups.setdefault(row['link'], []).append({
            'source_id': row['source_id'],
            'target_id': row['target_id'],
        })

    total = len(df_clean)
    print(f"Loading {total} relationships into '{database}' "
          f"({len(rel_groups)} types, batch={BATCH_SIZE})...")

    loaded = 0
    failed_types = []
    with driver.session(database=database) as session:
        pbar = tqdm(total=total, desc="Rels")
        for rel_type, pairs in rel_groups.items():
            cypher = f"""
                UNWIND $batch AS row
                MATCH (s {{node_id: row.source_id}}), (t {{node_id: row.target_id}})
                MERGE (s)-[r:`{rel_type}`]->(t)
            """
            for i in range(0, len(pairs), BATCH_SIZE):
                batch = pairs[i:i + BATCH_SIZE]
                try:
                    session.run(cypher, batch=batch)
                except Exception as e:
                    if rel_type not in [ft[0] for ft in failed_types]:
                        failed_types.append((rel_type, str(e)[:80]))
                loaded += len(batch)
                pbar.update(len(batch))
        pbar.close()

    print(f"Done loading {loaded} relationships.")
    if failed_types:
        print(f"  Failed types ({len(failed_types)}):")
        for rt, err in failed_types:
            print(f"    {rt}: {err}")


# ------------------------------------------------------------------
# Utility
# ------------------------------------------------------------------

def clear_database(driver, database: str):
    """Delete all nodes and relationships."""
    with driver.session(database=database) as session:
        session.run("MATCH (n) DETACH DELETE n")
    print(f"Cleared all data from '{database}'.")


def run_query(driver, database: str, cypher: str):
    """Run an arbitrary Cypher query and return results."""
    with driver.session(database=database) as session:
        result = session.run(cypher)
        return [record for record in result]
