"""
DiploAI Knowledge Graph - Node Builder (Phase 2b)

Dual label model:
  - ALL WP posts -> :Document + specific label, with document_hash = md5(link)
  - Topics from taxonomy -> :Document:Topic (+ :TopicBasket if parent)
  - Tags -> :Tag only (no Document label, no hash)
  - Dates -> :Date only (no Document label, no hash)
"""

import hashlib

import pandas as pd
from bs4 import BeautifulSoup

from config import POST_TYPE_LABELS
from wp_extractor import WPExtractor


def _md5(text: str) -> str:
    return hashlib.md5(text.encode()).hexdigest()


def _clean_text(html: str) -> str:
    if not html:
        return ''
    return BeautifulSoup(html, "lxml").text


# ------------------------------------------------------------------
# Post nodes (ALL WP posts -> :Document + specific label)
# ------------------------------------------------------------------

def build_post_nodes(extractor: WPExtractor) -> pd.DataFrame:
    """
    Build nodes for ALL WP posts. Each gets:
      - labels: ['Document', specific_label] (e.g. ['Document', 'Blog'])
      - document_hash: md5(link_url)
    """
    posts = extractor.all_posts
    prefix = extractor.site_prefix
    site_url = f"https://www.{prefix}" if 'diplomacy' in prefix else f"https://{prefix}"

    records = []
    for _, row in posts.iterrows():
        pt = row['post_type']
        specific_label = POST_TYPE_LABELS.get(pt, pt.capitalize())
        link = f"{site_url}/{pt}/{row['post_name']}/"

        records.append({
            'node_id': f"{prefix}_{pt}_{row['ID']}",
            'labels': ['Document', specific_label],
            'document_hash': _md5(link),
            'name': row['post_title'],
            'post_type': pt,
            'site': prefix,
            'wp_id': str(row['ID']),
            'slug': row['post_name'],
            'url': link,
            'text': _clean_text(row['post_content']),
            'date': str(row.get('post_date_gmt', '')),
        })

    df = pd.DataFrame(records)
    if not df.empty:
        label_counts = df['labels'].apply(lambda x: x[1]).value_counts()
        print(f"[{prefix}] Post nodes: {len(df)}  ({label_counts.to_dict()})")
    return df


# ------------------------------------------------------------------
# Topic nodes (from taxonomy, :Document:Topic, parent gets :TopicBasket)
# ------------------------------------------------------------------

def build_topic_nodes(extractor: WPExtractor) -> tuple[pd.DataFrame, pd.DataFrame]:
    """
    Build :Document:Topic nodes from taxonomy.
    Parent topics also get :TopicBasket label.
    Returns (df_topic_nodes, df_topics_raw).
    """
    prefix = extractor.site_prefix
    site_url = f"https://www.{prefix}" if 'diplomacy' in prefix else f"https://{prefix}"
    df_topics = extractor.get_topics()

    parent_ids = set(df_topics[df_topics['parent'] != 0]['parent'].astype(int).tolist())

    records = []
    for _, row in df_topics.iterrows():
        tid = int(row['term_id'])
        is_basket = tid in parent_ids
        link = f"{site_url}/topics/{row['slug']}/"

        labels = ['Document', 'Topic']
        if is_basket:
            labels.append('TopicBasket')

        records.append({
            'node_id': f"{prefix}_topic_{tid}",
            'labels': labels,
            'document_hash': _md5(link),
            'name': row['name'],
            'post_type': 'topic',
            'site': prefix,
            'wp_id': str(tid),
            'slug': row['slug'],
            'url': link,
            'text': _clean_text(row.get('text', '')),
        })

    df = pd.DataFrame(records)
    baskets = sum(1 for r in records if 'TopicBasket' in r['labels'])
    print(f"[{prefix}] Topic nodes: {len(df)}  (TopicBasket={baskets}, Topic={len(df)-baskets})")
    return df, df_topics


# ------------------------------------------------------------------
# Tag nodes (:Tag only, no Document label, no hash)
# ------------------------------------------------------------------

def build_tag_nodes(extractor: WPExtractor) -> pd.DataFrame:
    prefix = extractor.site_prefix
    tags = extractor.get_tags()
    if tags.empty:
        return pd.DataFrame()

    records = []
    for _, row in tags.iterrows():
        records.append({
            'node_id': f"{prefix}_tag_{row['term_id']}",
            'labels': ['Tag'],
            'name': row['name'],
            'wp_id': str(row['term_id']),
            'slug': row['slug'],
        })

    df = pd.DataFrame(records)
    print(f"[{prefix}] Tag nodes: {len(df)}")
    return df


# ------------------------------------------------------------------
# Date nodes (:Date only, no Document label, no hash)
# ------------------------------------------------------------------

def build_date_nodes(df_posts: pd.DataFrame, site_prefix: str) -> pd.DataFrame:
    if df_posts.empty or 'date' not in df_posts.columns:
        return pd.DataFrame()

    dates = pd.to_datetime(df_posts['date'], errors='coerce').dropna()
    if dates.empty:
        return pd.DataFrame()

    year_months = dates.dt.to_period('M').unique()

    records = []
    for ym in sorted(year_months):
        dt = ym.to_timestamp()
        records.append({
            'node_id': f"{site_prefix}_date_{ym}",
            'labels': ['Date'],
            'name': dt.strftime('%B %Y'),
            'wp_id': str(ym),
            'year': dt.year,
            'month': dt.month,
        })

    df = pd.DataFrame(records)
    print(f"[{site_prefix}] Date nodes: {len(df)}")
    return df


# ------------------------------------------------------------------
# Assemble all nodes
# ------------------------------------------------------------------

def build_all_nodes(extractor: WPExtractor) -> tuple[dict[str, pd.DataFrame], pd.DataFrame]:
    """
    Returns:
      - nodes_dict: {'posts': df, 'topics': df, 'tags': df, 'dates': df}
      - df_topics_raw
    """
    post_nodes = build_post_nodes(extractor)
    topic_nodes, df_topics_raw = build_topic_nodes(extractor)
    tag_nodes = build_tag_nodes(extractor)
    date_nodes = build_date_nodes(post_nodes, extractor.site_prefix)

    nodes_dict = {
        'posts': post_nodes,
        'topics': topic_nodes,
        'tags': tag_nodes,
        'dates': date_nodes,
    }

    total = sum(len(df) for df in nodes_dict.values())
    print(f"[{extractor.site_prefix}] TOTAL nodes: {total}")
    return nodes_dict, df_topics_raw
