#!/usr/bin/env python3
from __future__ import annotations

import argparse
import json
import math
import sys
from dataclasses import asdict, dataclass, replace
from datetime import datetime, timezone
from pathlib import Path
from typing import Any

import pandas as pd

ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT / "src"))

from dn_research.strategy_lab import (  # noqa: E402
    BacktestConfig,
    RuleBasedCarryStrategy,
    StrategyMetadata,
    StrategyParams,
    StrategySpec,
    build_mark_lookup,
    build_side_panel,
    build_surface,
    compute_metrics,
    load_panel,
    objective,
    purged_walk_forward_splits,
    run_strategy_on_signals,
    safe_name,
    score_candidates,
)


@dataclass(frozen=True)
class CostProfile:
    name: str
    execution_style: str
    entry_slippage_bps: float
    exit_slippage_bps: float
    fee_bps: float
    maker_rebate_bps: float


PRICE_BANDS = [
    ("broad_50_95", 0.50, 0.95),
    ("broad_65_95", 0.65, 0.95),
    ("core_80_95", 0.80, 0.95),
    ("favorite_85_92", 0.85, 0.92),
    ("tight_90_95", 0.90, 0.95),
    ("terminal_95_999", 0.95, 0.999),
]

DTE_LIMITS = [1.0, 3.0, 7.0, 14.0, 30.0]
EDGE_LIMITS = [0.03, 0.05, 0.08, 0.12]
SIDES = ["both", "YES", "NO"]
POSITION_LIMITS = [25_000.0]
EXIT_POLICIES = [
    ("hold", "hold_to_resolution", 0.03, 0.08, 0.04),
    ("dd3c", "drawdown_stop", 0.03, 0.08, 0.04),
]
COST_PROFILES = [
    CostProfile("maker_base", "maker", 30.0, 50.0, 25.0, 0.0),
    CostProfile("taker_stress", "taker", 100.0, 100.0, 40.0, 0.0),
]


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description=(
            "Run a one-year real-data DNS1 strategy sweep and write one CSV row "
            "for every tested parameter combination."
        )
    )
    parser.add_argument("--panel-path", default=str(ROOT / "artifacts" / "empirical_resolution_panel.parquet"))
    parser.add_argument("--output-dir", default=str(ROOT / "artifacts" / "final_real_backtest_1y"))
    parser.add_argument("--backtest-days", type=int, default=365)
    parser.add_argument("--train-lookback-days", type=int, default=365)
    parser.add_argument("--fold-days", type=int, default=30)
    parser.add_argument("--purge-days", type=int, default=30)
    parser.add_argument("--embargo-days", type=int, default=7)
    parser.add_argument("--max-strategies", type=int, default=0, help="Optional deterministic cap for smoke tests.")
    parser.add_argument(
        "--max-signals-per-day",
        type=int,
        default=2500,
        help="Keep this many top-ranked eligible candidates per day per strategy. Use 0 for no cap.",
    )
    parser.add_argument("--checkpoint-every", type=int, default=50)
    return parser.parse_args()


def main() -> int:
    args = parse_args()
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    config = BacktestConfig(
        panel_path=Path(args.panel_path),
        output_dir=output_dir,
        backtest_days=args.backtest_days,
        train_lookback_days=args.train_lookback_days,
        fold_days=args.fold_days,
        purge_days=args.purge_days,
        embargo_days=args.embargo_days,
    )

    specs = build_sweep_specs()
    if args.max_strategies > 0:
        specs = specs[: args.max_strategies]
    write_json(output_dir / "sweep_grid.json", {"strategy_count": len(specs), "grid": grid_manifest()})

    print(f"[real-sweep] output_dir={output_dir}", flush=True)
    print(f"[real-sweep] loading panel={config.panel_path}", flush=True)
    panel = load_panel(
        config.panel_path,
        backtest_days=config.backtest_days,
        train_lookback_days=config.train_lookback_days,
        purge_days=config.purge_days,
    )
    panel_manifest = panel_summary(panel)
    write_json(output_dir / "panel_manifest.json", panel_manifest)
    print(
        "[real-sweep] panel rows={rows:,} quote_date={start}..{end} markets={markets:,}".format(**panel_manifest),
        flush=True,
    )

    side_panel = build_side_panel(panel)
    mark_lookup = build_mark_lookup(side_panel)
    scored, fold_manifest = build_scored_walk_forward(side_panel, specs, config)
    write_json(output_dir / "fold_manifest.json", fold_manifest)
    if not fold_manifest["folds"]:
        raise RuntimeError("No walk-forward folds were available for the requested configuration.")
    start_date = pd.Timestamp(fold_manifest["period_start"])
    end_date = pd.Timestamp(fold_manifest["period_end"])
    print(
        f"[real-sweep] scored rows={len(scored):,} folds={len(fold_manifest['folds'])} period={start_date.date()}..{end_date.date()}",
        flush=True,
    )

    rows: list[dict[str, Any]] = []
    best: tuple[float, StrategySpec, pd.DataFrame, pd.DataFrame, dict[str, float], pd.DataFrame] | None = None
    checkpoint_every = max(args.checkpoint_every, 1)
    for index, spec in enumerate(specs, start=1):
        signals = filter_signals(scored, spec.params, max_signals_per_day=args.max_signals_per_day)
        daily, trades = run_strategy_on_signals(
            RuleBasedCarryStrategy(spec),
            signals,
            mark_lookup,
            start_date=start_date,
            end_date=end_date,
        )
        metrics = compute_metrics(daily, trades, target_gross_usdc=spec.params.target_gross_usdc)
        score = objective(metrics, target_daily_notional_usdc=config.target_daily_notional_usdc)
        row = result_row(
            spec=spec,
            metrics=metrics,
            objective_score=score,
            signal_count=len(signals),
            config=config,
            period_start=start_date,
            period_end=end_date,
            max_signals_per_day=args.max_signals_per_day,
        )
        rows.append(row)
        if best is None or score > best[0]:
            best = (score, spec, daily, trades, metrics, signals)
        if index % checkpoint_every == 0 or index == len(specs):
            ranking = write_results(output_dir / "all_strategy_results.csv", rows)
            write_json(output_dir / "all_strategy_results.json", ranking.to_dict(orient="records"))
            leader = ranking.iloc[0]
            print(
                "[real-sweep] {done:,}/{total:,} leader={leader} objective={objective:.4f} pnl={pnl:,.2f}".format(
                    done=index,
                    total=len(specs),
                    leader=leader["strategy"],
                    objective=float(leader["objective_score"]),
                    pnl=float(leader["total_pnl_usdc"]),
                ),
                flush=True,
            )

    if best is None:
        raise RuntimeError("No strategies were evaluated.")
    _, best_spec, best_daily, best_trades, best_metrics, best_signals = best
    best_dir = output_dir / "best_strategy"
    best_dir.mkdir(parents=True, exist_ok=True)
    best_daily.to_csv(best_dir / "daily.csv", index=False)
    best_trades.to_csv(best_dir / "trades.csv", index=False)
    best_signals.to_csv(best_dir / "signals.csv", index=False)
    write_json(best_dir / "params.json", asdict(best_spec.params))
    write_json(best_dir / "metrics.json", best_metrics)
    write_json(output_dir / "best_strategy_manifest.json", best_manifest(best_spec, best_metrics, config, start_date, end_date))
    write_summary(output_dir / "summary.md", output_dir / "all_strategy_results.csv", panel_manifest, fold_manifest, best_spec, best_metrics)
    print(output_dir)
    print(output_dir / "all_strategy_results.csv")
    print(best_dir)
    return 0


def build_sweep_specs() -> list[StrategySpec]:
    base = StrategyParams(
        min_count=50,
        use_category_surface=True,
        target_gross_usdc=1_000_000.0,
        max_gross_usdc=1_250_000.0,
        max_category_fraction=0.35,
        dynamic_sizing=True,
        edge_scale=0.20,
        volume_cap_fraction=0.0,
        zero_volume_cap_usdc=5_000.0,
        min_position_usdc=1_000.0,
    )
    specs: list[StrategySpec] = []
    for band_name, min_price, max_price in PRICE_BANDS:
        for max_dte in DTE_LIMITS:
            for min_edge in EDGE_LIMITS:
                for side in SIDES:
                    for cost in COST_PROFILES:
                        for max_position in POSITION_LIMITS:
                            for exit_name, exit_policy, drawdown_floor, profit_take, trailing_floor in EXIT_POLICIES:
                                params = replace(
                                    base,
                                    min_edge_after_cost=min_edge,
                                    min_side_price=min_price,
                                    max_side_price=max_price,
                                    max_dte_days=max_dte,
                                    side=side,
                                    execution_style=cost.execution_style,
                                    entry_slippage_bps=cost.entry_slippage_bps,
                                    exit_slippage_bps=cost.exit_slippage_bps,
                                    fee_bps=cost.fee_bps,
                                    maker_rebate_bps=cost.maker_rebate_bps,
                                    max_position_usdc=max_position,
                                    exit_policy=exit_policy,
                                    drawdown_floor=drawdown_floor,
                                    profit_take=profit_take,
                                    trailing_floor=trailing_floor,
                                )
                                name = (
                                    f"{band_name}__dte{int(max_dte)}__edge{int(min_edge * 1000):03d}"
                                    f"__{side.lower()}__{cost.name}__pos{int(max_position)}__{exit_name}"
                                )
                                description = (
                                    f"{band_name}, max DTE {max_dte:g}, min edge {min_edge:.3f}, side {side}, "
                                    f"{cost.name}, max position ${max_position:,.0f}, exit {exit_name}."
                                )
                                specs.append(
                                    StrategySpec(
                                        metadata=StrategyMetadata(
                                            name=name,
                                            description=description,
                                            version="real-sweep-v1",
                                            family="real_panel_grid",
                                        ),
                                        params=params,
                                    )
                                )
    return specs


def build_scored_walk_forward(
    side_panel: pd.DataFrame,
    specs: list[StrategySpec],
    config: BacktestConfig,
) -> tuple[pd.DataFrame, dict[str, Any]]:
    min_side_price = min(spec.params.min_side_price for spec in specs)
    max_side_price = max(spec.params.max_side_price for spec in specs)
    max_dte_days = max(spec.params.max_dte_days for spec in specs)
    signal_panel = side_panel[
        side_panel["side_price"].between(min_side_price, max_side_price, inclusive="both")
        & side_panel["days_to_expiry"].between(0.0, max_dte_days, inclusive="both")
    ].copy()
    scoring_params = replace(StrategyParams(), entry_slippage_bps=0.0, exit_slippage_bps=0.0, fee_bps=0.0, maker_rebate_bps=0.0)
    parts: list[pd.DataFrame] = []
    fold_rows: list[dict[str, Any]] = []
    folds = purged_walk_forward_splits(
        signal_panel,
        backtest_days=config.backtest_days,
        train_lookback_days=config.train_lookback_days,
        fold_days=config.fold_days,
        purge_days=config.purge_days,
        embargo_days=config.embargo_days,
    )
    for fold_index, (fold_start, fold_end, train, test) in enumerate(folds, start=1):
        surface = build_surface(train, scoring_params)
        scored = score_candidates(test, surface, scoring_params)
        if not scored.empty:
            scored["fold_index"] = fold_index
            scored["fold_start"] = fold_start
            scored["fold_end"] = fold_end
            scored["zero_cost_predicted_edge"] = scored["predicted_edge_after_cost"]
            parts.append(scored)
        fold_rows.append(
            {
                "fold_index": fold_index,
                "fold_start": fold_start.isoformat(),
                "fold_end": fold_end.isoformat(),
                "train_rows": int(len(train)),
                "test_rows": int(len(test)),
                "scored_rows": int(len(scored)),
            }
        )
        print(
            f"[real-sweep] scored fold {fold_index}/{len(folds)} train={len(train):,} test={len(test):,} scored={len(scored):,}",
            flush=True,
        )
    scored_all = pd.concat(parts, ignore_index=True) if parts else pd.DataFrame()
    if not scored_all.empty:
        scored_all = scored_all.sort_values(["quote_date", "market_id", "side"]).reset_index(drop=True)
    return scored_all, {
        "period_start": min(row["fold_start"] for row in fold_rows) if fold_rows else "",
        "period_end": max(row["fold_end"] for row in fold_rows) if fold_rows else "",
        "folds": fold_rows,
        "global_signal_filter": {
            "min_side_price": min_side_price,
            "max_side_price": max_side_price,
            "max_dte_days": max_dte_days,
        },
    }


def filter_signals(scored: pd.DataFrame, params: StrategyParams, *, max_signals_per_day: int = 0) -> pd.DataFrame:
    if scored.empty:
        return scored
    predicted_edge = scored["gross_edge"] - params.cost_edge()
    mask = (
        predicted_edge.ge(params.min_edge_after_cost)
        & scored["side_price"].between(params.min_side_price, params.max_side_price, inclusive="both")
        & scored["days_to_expiry"].between(params.min_dte_days, params.max_dte_days, inclusive="both")
        & scored["empirical_count"].ge(params.min_count)
    )
    if params.side != "both":
        mask &= scored["side"].eq(params.side)
    if params.min_volume > 0 and "volume" in scored:
        mask &= pd.to_numeric(scored["volume"], errors="coerce").fillna(0.0).ge(params.min_volume)
    if params.edge_drag_limit > 0:
        drag = params.cost_edge() / scored["gross_edge"].clip(lower=1e-9)
        mask &= drag.le(params.edge_drag_limit)
    local = scored.loc[mask].copy()
    if local.empty:
        return local
    local["predicted_edge_after_cost"] = local["gross_edge"] - params.cost_edge()
    local["score"] = local["predicted_edge_after_cost"] * pd.to_numeric(local["volume"], errors="coerce").fillna(0.0).map(math.log1p)
    local.loc[local["score"].le(0), "score"] = local["predicted_edge_after_cost"]
    local = local.sort_values(["quote_date", "score", "predicted_edge_after_cost"], ascending=[True, False, False])
    if max_signals_per_day > 0:
        local = local.groupby("quote_date", sort=False, group_keys=False).head(max_signals_per_day)
    return local


def result_row(
    *,
    spec: StrategySpec,
    metrics: dict[str, float],
    objective_score: float,
    signal_count: int,
    config: BacktestConfig,
    period_start: pd.Timestamp,
    period_end: pd.Timestamp,
    max_signals_per_day: int,
) -> dict[str, Any]:
    params = asdict(spec.params)
    row: dict[str, Any] = {
        "strategy": spec.metadata.name,
        "family": spec.metadata.family,
        "version": spec.metadata.version,
        "description": spec.metadata.description,
        "objective_score": objective_score,
        "signal_count": signal_count,
        "period_start": period_start.isoformat(),
        "period_end": period_end.isoformat(),
        "backtest_days": config.backtest_days,
        "train_lookback_days": config.train_lookback_days,
        "fold_days": config.fold_days,
        "purge_days": config.purge_days,
        "embargo_days": config.embargo_days,
        "max_signals_per_day": max_signals_per_day,
        **metrics,
    }
    for key, value in params.items():
        row[f"param_{key}"] = value
    return row


def write_results(path: Path, rows: list[dict[str, Any]]) -> pd.DataFrame:
    ranking = pd.DataFrame(rows).sort_values("objective_score", ascending=False).reset_index(drop=True)
    ranking.insert(0, "rank", range(1, len(ranking) + 1))
    ranking.to_csv(path, index=False)
    return ranking


def write_summary(
    path: Path,
    results_path: Path,
    panel_manifest: dict[str, Any],
    fold_manifest: dict[str, Any],
    best_spec: StrategySpec,
    best_metrics: dict[str, float],
) -> None:
    ranking = pd.read_csv(results_path)
    top = ranking.head(20)
    lines = [
        "# DNS1 Final Real One-Year Strategy Sweep",
        "",
        f"- Generated at UTC: {datetime.now(timezone.utc).isoformat()}",
        f"- Real panel rows loaded: {panel_manifest['rows']:,}",
        f"- Real panel quote-date coverage: {panel_manifest['start']} to {panel_manifest['end']}",
        f"- Walk-forward period: {fold_manifest['period_start']} to {fold_manifest['period_end']}",
        f"- Strategies tested: {len(ranking):,}",
        "- Data rule: no generated prices, synthetic outcomes, or fabricated fills were used. The harness uses the existing empirical Polymarket panel, realized market resolutions, and explicit cost assumptions.",
        f"- Best strategy: `{best_spec.metadata.name}`",
        f"- Best total PnL: ${best_metrics['total_pnl_usdc']:,.2f}",
        f"- Best average daily deployed: ${best_metrics['avg_daily_deployed_usdc']:,.2f}",
        f"- Best annualized Sharpe: {best_metrics['annualized_sharpe']:.2f}",
        f"- Best max drawdown: ${best_metrics['max_drawdown_usdc']:,.2f}",
        "",
        "## Top 20",
        "",
        markdown_table(
            top[
                [
                    "rank",
                    "strategy",
                    "objective_score",
                    "total_pnl_usdc",
                    "avg_daily_deployed_usdc",
                    "annualized_sharpe",
                    "max_drawdown_usdc",
                    "trade_count",
                ]
            ]
        ),
        "",
        "## Interpretation",
        "",
        "This sweep is a strategy-ranking artifact, not a live deployment approval. It keeps the one-year evaluation window consistent across every parameter combination. It still needs a final execution-quality layer before capital promotion: point-in-time universe reconstruction, historical L2 depth, queue-position modeling, authenticated fill attribution, and the live exclusion/reversion controls from the post-loss investigation.",
    ]
    path.write_text("\n".join(lines) + "\n", encoding="utf-8")


def panel_summary(panel: pd.DataFrame) -> dict[str, Any]:
    return {
        "rows": int(len(panel)),
        "start": pd.Timestamp(panel["quote_date"].min()).date().isoformat(),
        "end": pd.Timestamp(panel["quote_date"].max()).date().isoformat(),
        "markets": int(panel["market_id"].nunique()),
        "token_ids": int(panel["token_id"].nunique()),
    }


def grid_manifest() -> dict[str, Any]:
    return {
        "price_bands": PRICE_BANDS,
        "dte_limits": DTE_LIMITS,
        "edge_limits": EDGE_LIMITS,
        "sides": SIDES,
        "position_limits": POSITION_LIMITS,
        "exit_policies": EXIT_POLICIES,
        "cost_profiles": [cost.__dict__ for cost in COST_PROFILES],
    }


def markdown_table(frame: pd.DataFrame) -> str:
    if frame.empty:
        return "_No rows._"
    headers = [str(column) for column in frame.columns]
    lines = ["| " + " | ".join(headers) + " |", "| " + " | ".join("---" for _ in headers) + " |"]
    for _, row in frame.iterrows():
        values = []
        for value in row:
            if isinstance(value, float):
                values.append(f"{value:.6g}")
            else:
                values.append(str(value))
        lines.append("| " + " | ".join(values) + " |")
    return "\n".join(lines)


def best_manifest(
    spec: StrategySpec,
    metrics: dict[str, float],
    config: BacktestConfig,
    period_start: pd.Timestamp,
    period_end: pd.Timestamp,
) -> dict[str, Any]:
    return {
        "generated_at_utc": datetime.now(timezone.utc).isoformat(),
        "promotion_state": "research_selected_not_live",
        "selection_basis": "highest objective score in the one-year real-panel strategy sweep",
        "strategy": {
            "name": spec.metadata.name,
            "description": spec.metadata.description,
            "version": spec.metadata.version,
            "family": spec.metadata.family,
            "params": asdict(spec.params),
            "metrics": metrics,
        },
        "backtest_config": {
            "panel_path": str(config.panel_path),
            "period_start": period_start.isoformat(),
            "period_end": period_end.isoformat(),
            "backtest_days": config.backtest_days,
            "train_lookback_days": config.train_lookback_days,
            "fold_days": config.fold_days,
            "purge_days": config.purge_days,
            "embargo_days": config.embargo_days,
        },
    }


def write_json(path: Path, value: Any) -> None:
    path.write_text(json.dumps(value, indent=2, default=str), encoding="utf-8")


if __name__ == "__main__":
    raise SystemExit(main())
