from __future__ import annotations

import argparse
import os
import shutil
import statistics
import time
import warnings
from pathlib import Path

import duckdb
import ladybug as lb
import lancedb
import metaxy as mx
import narwhals as nw
import polars as pl
from lance_graph import CypherQuery, GraphConfig
from metaxy.ext.duckdb.metadata_store import DuckDBMetadataStore
from metaxy.ext.lancedb.metadata_store import LanceDBMetadataStore


ROOT = Path(__file__).parent / "_demo_data"
METAXY_CONFIG = Path(__file__).parent / "metaxy.toml"
LANCE_DIR = ROOT / "lance"
LANCE_META_DIR = ROOT / "metaxy_lance"
DUCKDB_META = ROOT / "metaxy.duckdb"
LADYBUG_DB = ROOT / "ladybug.db"
QUERY_PHASE = "conference"
BENCHMARK_RUNS = int(os.environ.get("METAXY_GRAPH_BENCH_ROUNDS", "3"))
WARMUP_RUNS = int(os.environ.get("METAXY_GRAPH_BENCH_WARMUPS", "1"))
BENCHMARK_SCALE_FACTORS = tuple(
    int(part) for part in os.environ.get("METAXY_GRAPH_BENCH_SCALES", "100,1000,10000,100000,1000000").split(",")
)
BENCHMARK_SCENARIOS = (
    "filtered_one_hop",
    "filtered_two_hop",
    "topological_two_hop",
    "topological_three_hop",
    "topological_four_hop",
)
HASH_ALGORITHM = mx.HashAlgorithm.XXH3_64


class DuckPGQUnavailable(RuntimeError):
    pass


class DetectedFaces(
    mx.BaseFeature,
    spec=mx.FeatureSpec(
        key="vision/detected_faces",
        id_columns=["face_id"],
        fields=["image_id", "person", "phase", "confidence"],
    ),
):
    face_id: str
    image_id: str
    person: str
    phase: str
    confidence: float


class FaceCoOccurrences(
    mx.BaseFeature,
    spec=mx.FeatureSpec(
        key="vision/face_cooccurrences",
        id_columns=["edge_id"],
        fields=["src_face_id", "dst_face_id", "image_id", "phase"],
    ),
):
    edge_id: str
    src_face_id: str
    dst_face_id: str
    image_id: str
    phase: str


def detected_faces(scale: int = 1) -> pl.DataFrame:
    if scale == 1:
        return pl.DataFrame(
            {
                "face_id": ["ada", "grace", "katherine", "alan"],
                "image_id": ["img-001", "img-001", "img-002", "img-003"],
                "person": ["Ada", "Grace", "Katherine", "Alan"],
                "phase": ["conference", "conference", "archive", "office"],
                "confidence": [0.98, 0.96, 0.93, 0.91],
                "embedding": [
                    [0.10, 0.20, 0.30],
                    [0.12, 0.18, 0.28],
                    [0.77, 0.13, 0.08],
                    [0.40, 0.80, 0.20],
                ],
            }
        )

    return pl.DataFrame(
        {
            "face_id": [f"{name}-{i:04d}" for i in range(scale) for name in ("ada", "grace", "katherine", "alan")],
            "image_id": [
                f"img-{i:04d}-001" if name in {"ada", "grace"} else f"img-{i:04d}-{offset:03d}"
                for i in range(scale)
                for name, offset in (("ada", 1), ("grace", 1), ("katherine", 2), ("alan", 3))
            ],
            "person": [person for _ in range(scale) for person in ("Ada", "Grace", "Katherine", "Alan")],
            "phase": [phase for _ in range(scale) for phase in ("conference", "conference", "archive", "office")],
            "confidence": [score for _ in range(scale) for score in (0.98, 0.96, 0.93, 0.91)],
            "embedding": [embedding for _ in range(scale) for embedding in ([0.10, 0.20, 0.30], [0.12, 0.18, 0.28], [0.77, 0.13, 0.08], [0.40, 0.80, 0.20])],
        }
    )


def face_cooccurrences(scale: int = 1) -> pl.DataFrame:
    if scale == 1:
        return pl.DataFrame(
            {
                "edge_id": [
                    "img-001:ada-grace",
                    "img-002:grace-katherine",
                    "img-003:katherine-alan",
                    "img-004:alan-ada",
                ],
                "src_face_id": ["ada", "grace", "katherine", "alan"],
                "dst_face_id": ["grace", "katherine", "alan", "ada"],
                "image_id": ["img-001", "img-002", "img-003", "img-004"],
                "phase": ["conference", "archive", "office", "conference"],
            }
        )

    return pl.DataFrame(
        {
            "edge_id": [
                f"{i:04d}:{src}-{dst}"
                for i in range(scale)
                for src, dst in (
                    ("ada", "grace"),
                    ("grace", "katherine"),
                    ("katherine", "alan"),
                    ("alan", "ada"),
                )
            ],
            "src_face_id": [
                f"{src}-{i:04d}"
                for i in range(scale)
                for src in ("ada", "grace", "katherine", "alan")
            ],
            "dst_face_id": [
                f"{dst}-{i:04d}"
                for i in range(scale)
                for dst in ("grace", "katherine", "alan", "ada")
            ],
            "image_id": [
                f"img-{i:04d}-{offset:03d}"
                for i in range(scale)
                for offset in (1, 2, 3, 4)
            ],
            "phase": [
                phase
                for _ in range(scale)
                for phase in ("conference", "archive", "office", "conference")
            ],
        }
    )


def expected_graph_result(scenario: str, scale: int = 1) -> pl.DataFrame:
    match scenario:
        case "filtered_one_hop":
            return pl.DataFrame(
                {
                    "source": ["Ada"] * scale,
                    "target": ["Grace"] * scale,
                    "phase": [QUERY_PHASE] * scale,
                }
            )
        case "filtered_two_hop":
            return pl.DataFrame(
                {
                    "source": ["Ada"] * scale,
                    "target": ["Katherine"] * scale,
                    "phase": ["archive"] * scale,
                }
            )
        case "topological_two_hop" | "topological_three_hop" | "topological_four_hop":
            return pl.DataFrame({"paths": [4 * scale]})
        case _:
            raise ValueError(f"Unknown benchmark scenario: {scenario}")


def normalize_graph_result(df: pl.DataFrame) -> pl.DataFrame:
    if "paths" in df.columns:
        return df.select("paths")
    return df.select("source", "target", "phase").sort("source", "target", "phase")


def reset_demo_data() -> None:
    if ROOT.exists():
        shutil.rmtree(ROOT)
    ROOT.mkdir(parents=True)


def markdown_table(df: pl.DataFrame) -> str:
    headers = df.columns
    rows = df.rows()
    lines = [
        "| " + " | ".join(headers) + " |",
        "| " + " | ".join(["---"] * len(headers)) + " |",
    ]
    for row in rows:
        lines.append("| " + " | ".join(format_cell(value) for value in row) + " |")
    return "\n".join(lines)


def format_cell(value: object) -> str:
    if isinstance(value, float):
        return f"{value:.4g}"
    if isinstance(value, list):
        return "[" + ", ".join(format_cell(item) for item in value) + "]"
    return str(value)


def report(title: str, command: str, sections: list[tuple[str, pl.DataFrame]]) -> str:
    lines = [f"## {title}", "", f"Command: `{command}`", ""]
    for heading, df in sections:
        lines.extend([f"### {heading}", "", markdown_table(df), ""])
    return "\n".join(lines).rstrip() + "\n"


def init_metaxy() -> None:
    warnings.filterwarnings("ignore", message="AUTO_CREATE_TABLES is enabled.*")
    mx.MetaxyConfig.set(mx.MetaxyConfig.load(METAXY_CONFIG, search_parents=False))


def lance_metaxy_store() -> LanceDBMetadataStore:
    return LanceDBMetadataStore(LANCE_META_DIR, hash_algorithm=HASH_ALGORITHM)


def duck_metaxy_store() -> DuckDBMetadataStore:
    return DuckDBMetadataStore(DUCKDB_META, auto_create_tables=True, hash_algorithm=HASH_ALGORITHM)


def metaxy_store_config() -> pl.DataFrame:
    return pl.DataFrame({"setting": ["hash_algorithm"], "value": [HASH_ALGORITHM.value]})


def write_lancedb_tables(faces: pl.DataFrame, edges: pl.DataFrame) -> lancedb.DBConnection:
    db = lancedb.connect(LANCE_DIR)
    write_lance_table(db, "faces", faces)
    write_lance_table(db, "co_occurs", edges)

    conference_faces = db.open_table("faces").search().where("phase = 'conference'").to_polars()
    assert set(conference_faces.select("person").to_series()) == {"Ada", "Grace"}

    nearest = db.open_table("faces").search([0.09, 0.21, 0.31]).limit(1).to_polars()
    assert nearest.item(0, "person") == "Ada"
    return db


def write_lance_table(db: lancedb.DBConnection, name: str, data: pl.DataFrame) -> None:
    mode = "overwrite" if name in db.list_tables() else "create"
    db.create_table(name, data=data, mode=mode)


def query_lancedb_tables(db: lancedb.DBConnection) -> tuple[pl.DataFrame, pl.DataFrame]:
    conference_faces = (
        db.open_table("faces")
        .search()
        .where("phase = 'conference'")
        .to_polars()
        .select("face_id", "person", "phase", "confidence")
    )
    nearest = (
        db.open_table("faces")
        .search([0.09, 0.21, 0.31])
        .limit(1)
        .to_polars()
        .select("face_id", "person", "_distance")
    )
    return conference_faces, nearest


def write_metaxy(store: mx.MetadataStore, faces: pl.DataFrame, edges: pl.DataFrame) -> None:
    with store.open("w"):
        for feature, data in ((DetectedFaces, faces.drop("embedding")), (FaceCoOccurrences, edges)):
            with_provenance = store.compute_provenance(feature, nw.from_native(data)).to_native()
            store.write(feature, with_provenance, materialization_id="local-demo")


def read_metaxy(store: mx.MetadataStore) -> pl.DataFrame:
    with store.open("r"):
        native = store.read(DetectedFaces).collect().to_native()
    return pl.from_arrow(native) if hasattr(native, "to_batches") else native


def display_metaxy_faces(df: pl.DataFrame) -> pl.DataFrame:
    return df.select("face_id", "person", "phase", "metaxy_materialization_id").sort("face_id")


def graph_config() -> GraphConfig:
    return (
        GraphConfig.builder()
        .with_node_label("Face", "face_id")
        .with_relationship("CO_OCCURS_WITH", "src_face_id", "dst_face_id")
        .build()
    )


def cypher_for(scenario: str) -> str:
    match scenario:
        case "filtered_one_hop":
            return (
                "MATCH (a:Face)-[r:CO_OCCURS_WITH]->(b:Face) "
                f"WHERE a.person = 'Ada' AND r.phase = '{QUERY_PHASE}' "
                "RETURN a.person AS source, b.person AS target, r.phase AS phase"
            )
        case "filtered_two_hop":
            return (
                "MATCH (a:Face)-[r1:CO_OCCURS_WITH]->(m:Face)-[r2:CO_OCCURS_WITH]->(b:Face) "
                "WHERE a.person = 'Ada' AND m.person = 'Grace' "
                "RETURN a.person AS source, b.person AS target, r2.phase AS phase"
            )
        case "topological_two_hop":
            return (
                "MATCH (a:Face)-[r1:CO_OCCURS_WITH]->(m:Face)-[r2:CO_OCCURS_WITH]->(b:Face) "
                "RETURN count(*) AS paths"
            )
        case "topological_three_hop":
            return (
                "MATCH (a:Face)-[r1:CO_OCCURS_WITH]->(m1:Face)-[r2:CO_OCCURS_WITH]->(m2:Face)-[r3:CO_OCCURS_WITH]->(b:Face) "
                "RETURN count(*) AS paths"
            )
        case "topological_four_hop":
            return (
                "MATCH (a:Face)-[r1:CO_OCCURS_WITH]->(m1:Face)-[r2:CO_OCCURS_WITH]->(m2:Face)-[r3:CO_OCCURS_WITH]->(m3:Face)-[r4:CO_OCCURS_WITH]->(b:Face) "
                "RETURN count(*) AS paths"
            )
        case _:
            raise ValueError(f"Unknown graph scenario: {scenario}")


def duckpgq_query_for(scenario: str) -> str:
    match scenario:
        case "filtered_one_hop":
            match_clause = "MATCH (a:Face)-[r:CO_OCCURS_WITH]->(b:Face)"
            where_clause = f"WHERE a.person = 'Ada' AND r.phase = '{QUERY_PHASE}'"
            columns = "a.person AS source, b.person AS target, r.phase AS phase"
        case "filtered_two_hop":
            match_clause = "MATCH (a:Face)-[r1:CO_OCCURS_WITH]->(m:Face)-[r2:CO_OCCURS_WITH]->(b:Face)"
            where_clause = "WHERE a.person = 'Ada' AND m.person = 'Grace'"
            columns = "a.person AS source, b.person AS target, r2.phase AS phase"
        case "topological_two_hop":
            match_clause = "MATCH (a:Face)-[r1:CO_OCCURS_WITH]->(m:Face)-[r2:CO_OCCURS_WITH]->(b:Face)"
            where_clause = ""
            columns = "a.person AS source"
        case "topological_three_hop":
            match_clause = "MATCH (a:Face)-[r1:CO_OCCURS_WITH]->(m1:Face)-[r2:CO_OCCURS_WITH]->(m2:Face)-[r3:CO_OCCURS_WITH]->(b:Face)"
            where_clause = ""
            columns = "a.person AS source"
        case "topological_four_hop":
            match_clause = "MATCH (a:Face)-[r1:CO_OCCURS_WITH]->(m1:Face)-[r2:CO_OCCURS_WITH]->(m2:Face)-[r3:CO_OCCURS_WITH]->(m3:Face)-[r4:CO_OCCURS_WITH]->(b:Face)"
            where_clause = ""
            columns = "a.person AS source"
        case _:
            raise ValueError(f"Unknown graph scenario: {scenario}")

    graph_table_query = f"""
        FROM GRAPH_TABLE (
          face_graph
          {match_clause}
          {where_clause}
          COLUMNS ({columns})
        )
        """
    if scenario.startswith("topological_"):
        return f"SELECT count(*) AS paths {graph_table_query}"
    return f"{graph_table_query} ORDER BY source, target, phase"


def lance_graph_cypher(db: lancedb.DBConnection) -> pl.DataFrame:
    result = (
        CypherQuery(cypher_for("filtered_one_hop"))
        .with_config(graph_config())
        .execute(
            {
                "Face": db.open_table("faces").to_arrow(),
                "CO_OCCURS_WITH": db.open_table("co_occurs").to_arrow(),
            }
        )
    )
    return pl.from_arrow(result)


def lance_graph_query(scenario: str) -> CypherQuery:
    return CypherQuery(cypher_for(scenario)).with_config(graph_config())


def duckpgq_graph(database: Path, scenario: str = "filtered_one_hop") -> tuple[duckdb.DuckDBPyConnection, str]:
    con = duckdb.connect(str(database))
    try:
        con.install_extension("duckpgq", repository="community")
        con.load_extension("duckpgq")
    except duckdb.Error as exc:
        con.close()
        raise DuckPGQUnavailable(
            f"DuckPGQ is not available for DuckDB {duckdb.__version__}; DuckDB 1.5.x docs recommend DuckDB 1.4.4 for DuckPGQ."
        ) from exc
    con.execute(
        """
        CREATE TEMP TABLE Face AS
        SELECT face_id, person, phase
        FROM vision__detected_faces
        """
    )
    con.execute(
        """
        CREATE TEMP TABLE CoOccurs AS
        SELECT edge_id, src_face_id, dst_face_id, phase
        FROM vision__face_cooccurrences
        """
    )
    con.execute(
        """
        CREATE PROPERTY GRAPH face_graph
        VERTEX TABLES (Face)
        EDGE TABLES (
          CoOccurs SOURCE KEY (src_face_id) REFERENCES Face (face_id)
                   DESTINATION KEY (dst_face_id) REFERENCES Face (face_id)
                   LABEL CO_OCCURS_WITH
        )
        """
    )
    return con, duckpgq_query_for(scenario)


def duckpgq_same_question(database: Path, scenario: str = "filtered_one_hop") -> pl.DataFrame:
    con, query = duckpgq_graph(database, scenario)
    return con.sql(query).pl()


def validate_graph_result(state: tuple[object, ...], result: pl.DataFrame) -> None:
    scale = int(state[0])
    scenario = str(state[1])
    if not normalize_graph_result(result).equals(normalize_graph_result(expected_graph_result(scenario, scale))):
        raise AssertionError(f"{scenario} returned unexpected rows at scale {scale}")


def close_graph_state(state: tuple[object, ...]) -> None:
    for item in reversed(state):
        close = getattr(item, "close", None)
        if close is not None:
            close()


def benchmark_rounds(setup, run) -> list[float]:
    for _ in range(WARMUP_RUNS):
        state = setup()
        try:
            validate_graph_result(state, run(state))
        finally:
            close_graph_state(state)

    timings = []
    for _ in range(BENCHMARK_RUNS):
        state = setup()
        try:
            started_at = time.perf_counter()
            result = run(state)
            timings.append((time.perf_counter() - started_at) * 1000)
            validate_graph_result(state, result)
        finally:
            close_graph_state(state)
    return timings


def percentile(values: list[float], q: float) -> float:
    ordered = sorted(values)
    if len(ordered) == 1:
        return ordered[0]
    index = (len(ordered) - 1) * q
    lower = int(index)
    upper = min(lower + 1, len(ordered) - 1)
    weight = index - lower
    return ordered[lower] * (1 - weight) + ordered[upper] * weight


def benchmark_row(engine: str, scenario: str, scale: int, timings: list[float]) -> dict[str, object]:
    return {
        "engine": engine,
        "scenario": scenario,
        "scale": scale,
        "nodes": 4 * scale,
        "edges": 4 * scale,
        "runs": BENCHMARK_RUNS,
        "warmups": WARMUP_RUNS,
        "median_ms": statistics.median(timings),
        "q1_ms": percentile(timings, 0.25),
        "q3_ms": percentile(timings, 0.75),
        "min_ms": min(timings),
        "max_ms": max(timings),
    }


def setup_lance_graph_round(scale: int, scenario: str) -> tuple[int, str, CypherQuery, dict[str, object]]:
    reset_demo_data()
    faces = detected_faces(scale)
    edges = face_cooccurrences(scale)
    lance_db = write_lancedb_tables(faces, edges)
    return scale, scenario, lance_graph_query(scenario), {
        "Face": lance_db.open_table("faces").to_arrow(),
        "CO_OCCURS_WITH": lance_db.open_table("co_occurs").to_arrow(),
    }


def run_lance_graph_round(state: tuple[int, str, CypherQuery, dict[str, object]]) -> pl.DataFrame:
    _, _, cypher_query, inputs = state
    return pl.from_arrow(cypher_query.execute(inputs))


def setup_duckpgq_round(scale: int, scenario: str) -> tuple[int, str, duckdb.DuckDBPyConnection, str]:
    init_metaxy()
    reset_demo_data()
    faces = detected_faces(scale)
    edges = face_cooccurrences(scale)
    write_metaxy(duck_metaxy_store(), faces, edges)
    con, query = duckpgq_graph(DUCKDB_META, scenario)
    return scale, scenario, con, query


def run_duckpgq_round(state: tuple[int, str, duckdb.DuckDBPyConnection, str]) -> pl.DataFrame:
    _, _, con, query = state
    return con.sql(query).pl()


def setup_ladybug_round(scale: int, scenario: str) -> tuple[int, str, lb.Database, lb.Connection, str]:
    init_metaxy()
    reset_demo_data()
    faces = detected_faces(scale)
    edges = face_cooccurrences(scale)
    write_metaxy(duck_metaxy_store(), faces, edges)

    source = duckdb.connect(str(DUCKDB_META))
    try:
        exported_faces = source.sql(
            """
            SELECT face_id, person, phase
            FROM vision__detected_faces
            """
        ).pl()
        exported_edges = source.sql(
            """
            SELECT src_face_id, dst_face_id, edge_id, image_id, phase
            FROM vision__face_cooccurrences
            """
        ).pl()
    finally:
        source.close()

    database = lb.Database(LADYBUG_DB)
    conn = lb.Connection(database)
    conn.execute("CREATE NODE TABLE Face(face_id STRING PRIMARY KEY, person STRING, phase STRING)")
    conn.execute(
        """
        CREATE REL TABLE CO_OCCURS_WITH(
            FROM Face TO Face,
            edge_id STRING,
            image_id STRING,
            phase STRING
        )
        """
    )
    # Demo setup uses an explicit export/import hop from the Metaxy DuckDB tables.
    conn.execute("COPY Face FROM exported_faces")
    conn.execute("COPY CO_OCCURS_WITH FROM exported_edges")

    conn.close()
    database.close()

    database = lb.Database(LADYBUG_DB, read_only=True)
    conn = lb.Connection(database)
    return scale, scenario, database, conn, cypher_for(scenario)


def run_ladybug_round(state: tuple[int, str, lb.Database, lb.Connection, str]) -> pl.DataFrame:
    _, _, _, conn, query = state
    return conn.execute(query).get_as_pl()


def duckpgq_unavailable_status(exc: DuckPGQUnavailable) -> pl.DataFrame:
    return pl.DataFrame(
        {
            "engine": ["DuckPGQ"],
            "duckdb_version": [duckdb.__version__],
            "status": ["skipped"],
            "reason": [str(exc)],
        }
    )


def benchmark_graph_traversals() -> tuple[pl.DataFrame, pl.DataFrame | None]:
    rows = []
    duckpgq_status = None
    for scenario in BENCHMARK_SCENARIOS:
        for scale in BENCHMARK_SCALE_FACTORS:
            lance_timings = benchmark_rounds(
                lambda scale=scale, scenario=scenario: setup_lance_graph_round(scale, scenario),
                run_lance_graph_round,
            )
            pgq_timings = None
            if duckpgq_status is None:
                try:
                    pgq_timings = benchmark_rounds(
                        lambda scale=scale, scenario=scenario: setup_duckpgq_round(scale, scenario),
                        run_duckpgq_round,
                    )
                except DuckPGQUnavailable as exc:
                    duckpgq_status = duckpgq_unavailable_status(exc)
            ladybug_timings = benchmark_rounds(
                lambda scale=scale, scenario=scenario: setup_ladybug_round(scale, scenario),
                run_ladybug_round,
            )
            rows.extend(
                [
                    benchmark_row("Lance Graph", scenario, scale, lance_timings),
                    benchmark_row("LadybugDB", scenario, scale, ladybug_timings),
                ]
            )
            if pgq_timings is not None:
                rows.append(benchmark_row("DuckPGQ", scenario, scale, pgq_timings))
    return pl.DataFrame(rows), duckpgq_status


def run_lancedb() -> lancedb.DBConnection:
    reset_demo_data()
    faces = detected_faces()
    edges = face_cooccurrences()
    db = write_lancedb_tables(faces, edges)
    conference_faces, nearest = query_lancedb_tables(db)
    print(
        report(
            "1. LanceDB from local Polars data",
            "just run-1-lancedb",
            [
                ("Scalar filter: faces in the conference phase", conference_faces),
                ("Vector search: nearest face embedding", nearest),
            ],
        )
    )
    return db


def run_metaxy_lance() -> None:
    init_metaxy()
    reset_demo_data()
    faces = detected_faces()
    edges = face_cooccurrences()
    write_metaxy(lance_metaxy_store(), faces, edges)
    metaxy_lance = read_metaxy(lance_metaxy_store())
    assert set(metaxy_lance["metaxy_materialization_id"]) == {"local-demo"}
    print(
        report(
            "2. Metaxy metadata backed by LanceDB",
            "just run-2-metaxy-lance",
            [
                ("Metaxy store config", metaxy_store_config()),
                (
                    "Read back Metaxy-managed face metadata",
                    display_metaxy_faces(metaxy_lance),
                )
            ],
        )
    )


def run_lance_graph() -> pl.DataFrame:
    reset_demo_data()
    lance_db = write_lancedb_tables(detected_faces(), face_cooccurrences())
    cypher = normalize_graph_result(lance_graph_cypher(lance_db))
    assert cypher.equals(expected_graph_result("filtered_one_hop"))
    print(report("3. Lance Graph Cypher", "just run-3-lance-graph", [("Cypher result", cypher)]))
    return cypher


def run_duckpgq() -> pl.DataFrame:
    init_metaxy()
    reset_demo_data()
    faces = detected_faces()
    edges = face_cooccurrences()
    write_metaxy(duck_metaxy_store(), faces, edges)
    try:
        pgq = normalize_graph_result(duckpgq_same_question(DUCKDB_META))
        assert pgq.equals(expected_graph_result("filtered_one_hop"))
        sections = [("SQL/PGQ result", pgq)]
    except DuckPGQUnavailable as exc:
        pgq = pl.DataFrame()
        sections = [("DuckPGQ status", duckpgq_unavailable_status(exc))]
    print(report("4. Metaxy on DuckDB plus DuckPGQ", "just run-4-duckpgq", sections))
    return pgq


def run_benchmark() -> pl.DataFrame:
    benchmark, duckpgq_status = benchmark_graph_traversals()
    sections = [("Median query time after setup", benchmark)]
    if duckpgq_status is not None:
        sections.append(("Skipped engine", duckpgq_status))
    print(
        report(
            "5. Tiny traversal benchmark",
            "just run-5-benchmark",
            sections,
        )
    )
    return benchmark


def run_all() -> None:
    init_metaxy()
    reset_demo_data()
    faces = detected_faces()
    edges = face_cooccurrences()
    lance_db = write_lancedb_tables(faces, edges)
    write_metaxy(lance_metaxy_store(), faces, edges)

    metaxy_lance = read_metaxy(lance_metaxy_store())
    assert set(metaxy_lance["metaxy_materialization_id"]) == {"local-demo"}

    cypher = normalize_graph_result(lance_graph_cypher(lance_db))
    assert cypher.equals(expected_graph_result("filtered_one_hop"))

    write_metaxy(duck_metaxy_store(), faces, edges)
    pgq_status = None
    try:
        pgq = normalize_graph_result(duckpgq_same_question(DUCKDB_META))
        assert pgq.equals(expected_graph_result("filtered_one_hop"))
    except DuckPGQUnavailable as exc:
        pgq = None
        pgq_status = duckpgq_unavailable_status(exc)

    conference_faces, nearest = query_lancedb_tables(lance_db)
    benchmark, benchmark_duckpgq_status = benchmark_graph_traversals()
    sections = [
        ("LanceDB scalar filter", conference_faces),
        ("LanceDB vector search", nearest),
        ("Metaxy store config", metaxy_store_config()),
        (
            "Metaxy metadata from LanceDB",
            display_metaxy_faces(metaxy_lance),
        ),
        ("Lance Graph Cypher", cypher),
    ]
    sections.append(("DuckPGQ", pgq) if pgq is not None else ("DuckPGQ status", pgq_status))
    sections.append(("Tiny traversal benchmark", benchmark))
    if benchmark_duckpgq_status is not None:
        sections.append(("Benchmark skipped engine", benchmark_duckpgq_status))
    print(
        report(
            "All examples",
            "just all",
            sections,
        )
    )


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "example",
        choices=["lancedb", "metaxy-lance", "lance-graph", "duckpgq", "benchmark", "all"],
        nargs="?",
        default="all",
    )
    args = parser.parse_args()

    match args.example:
        case "lancedb":
            run_lancedb()
        case "metaxy-lance":
            run_metaxy_lance()
        case "lance-graph":
            run_lance_graph()
        case "duckpgq":
            run_duckpgq()
        case "benchmark":
            run_benchmark()
        case "all":
            run_all()


if __name__ == "__main__":
    main()
