"""
Lightweight Neo4j proxy API for the graph viewer.
Runs Cypher queries server-side and returns vis.js-ready JSON.
"""

import json
from flask import Flask, request, jsonify
from flask_cors import CORS
from neo4j import GraphDatabase

NEO4J_URI = "bolt://nimani.diplomacy.edu:7687"
NEO4J_USER = "neo4j"
NEO4J_PASS = "011diplo011"

app = Flask(__name__)
CORS(app)

driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASS))


def neo4j_value_to_json(val):
    """Convert neo4j types to JSON-serializable Python types."""
    from neo4j.graph import Node, Relationship, Path
    if isinstance(val, Node):
        return {
            "_type": "node",
            "id": val.element_id,
            "labels": list(val.labels),
            "properties": dict(val),
        }
    if isinstance(val, Relationship):
        return {
            "_type": "relationship",
            "id": val.element_id,
            "type": val.type,
            "startNodeId": val.start_node.element_id,
            "endNodeId": val.end_node.element_id,
            "properties": dict(val),
        }
    if isinstance(val, Path):
        nodes = [neo4j_value_to_json(n) for n in val.nodes]
        rels = [neo4j_value_to_json(r) for r in val.relationships]
        return {"_type": "path", "nodes": nodes, "relationships": rels}
    if isinstance(val, list):
        return [neo4j_value_to_json(v) for v in val]
    if isinstance(val, dict):
        return {k: neo4j_value_to_json(v) for k, v in val.items()}
    return val


@app.route("/query", methods=["POST"])
def run_query():
    data = request.get_json(force=True)
    cypher = data.get("cypher", "").strip()
    database = data.get("database", "weaviatediplo")

    if not cypher:
        return jsonify({"error": "No cypher query provided"}), 400

    WRITE_KEYWORDS = [
        "CREATE", "MERGE", "DELETE", "DETACH", "REMOVE",
        "SET ", "SET\n", "SET\t",
        "DROP", "CALL {", "FOREACH",
        "LOAD CSV",
    ]
    upper = cypher.upper()
    for kw in WRITE_KEYWORDS:
        if kw in upper:
            return jsonify({"error": f"Write queries are not allowed (found: {kw.strip()})"}), 403

    if not any(kw in upper for kw in ["RETURN", "YIELD"]):
        return jsonify({"error": "Query must contain RETURN or YIELD"}), 400

    try:
        with driver.session(database=database) as session:
            result = session.execute_read(lambda tx: list(tx.run(cypher)))
            records = []
            for record in result:
                row = {}
                for key in record.keys():
                    row[key] = neo4j_value_to_json(record[key])
                records.append(row)

        nodes = {}
        edges = {}
        table_rows = []

        for row in records:
            has_graph = False
            for val in row.values():
                _extract_graph(val, nodes, edges)
                if isinstance(val, dict) and val.get("_type") in ("node", "relationship", "path"):
                    has_graph = True
            if not has_graph:
                clean_row = {}
                for k, v in row.items():
                    clean_row[k] = v
                table_rows.append(clean_row)

        return jsonify({
            "nodes": list(nodes.values()),
            "edges": list(edges.values()),
            "table": table_rows,
        })

    except Exception as e:
        return jsonify({"error": str(e)}), 500


def _extract_graph(val, nodes, edges):
    if not isinstance(val, dict):
        return
    t = val.get("_type")
    if t == "node":
        nodes[val["id"]] = val
    elif t == "relationship":
        edges[val["id"]] = val
    elif t == "path":
        for n in val.get("nodes", []):
            nodes[n["id"]] = n
        for r in val.get("relationships", []):
            edges[r["id"]] = r


@app.route("/health")
def health():
    try:
        with driver.session(database="weaviatediplo") as s:
            cnt = s.run("RETURN 1 AS ok").single()["ok"]
        return jsonify({"status": "ok", "neo4j": "connected"})
    except Exception as e:
        return jsonify({"status": "error", "neo4j": str(e)}), 500


if __name__ == "__main__":
    print("Graph API starting on http://0.0.0.0:8091")
    print("Neo4j:", NEO4J_URI)
    app.run(host="0.0.0.0", port=8091, debug=False)
