#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""Model Gateway — the single switch for every model call in the system.

Nothing else in the codebase ever names a provider or an endpoint. Every model
call — the mapping BRAIN and the region/crop VISION reader — goes through
`Gateway.chat(role=...)`. Which model actually runs, and whether it's OpenRouter
or a local server, is decided ONLY by `model_config.json` (active_provider).
Switch local<->OpenRouter by editing that one file. No code change.

Transport: if the `litellm` package is installed it is used (gives you the full
LiteLLM provider/router ecosystem). If it is NOT installed, the gateway falls
back to a dependency-free OpenAI-compatible HTTP call to the configured
`api_base`. Either way the SAME model_config.json drives it, so switching always
works — litellm is an optional upgrade, never a requirement to run.

A `chat()` call walks the role's model_chain for the active provider: each model
gets `retries_per_model` attempts with exponential backoff; on exhaustion it
falls through to the next model in the chain. Returns (text, meta) or raises
GatewayError if the whole chain fails.
"""

from __future__ import annotations

import json
import os
import time
import urllib.request
import urllib.error

HERE = os.path.dirname(os.path.abspath(__file__))
DEFAULT_CONFIG_PATH = os.path.join(HERE, "model_config.json")
OPENROUTER_ENV_FILE = os.path.expanduser("~/.config/openrouter.env")


class GatewayError(RuntimeError):
    pass


def _resolve_key(api_key_env):
    """Return the API key for env var `api_key_env` (or None if not required).

    Looks in os.environ first, then ~/.config/openrouter.env (KEY=VALUE lines).
    """
    if not api_key_env:
        return None
    key = os.environ.get(api_key_env)
    if key:
        return key.strip()
    if os.path.exists(OPENROUTER_ENV_FILE):
        try:
            with open(OPENROUTER_ENV_FILE, "r", encoding="utf-8") as f:
                for line in f:
                    line = line.strip()
                    if not line or line.startswith("#") or "=" not in line:
                        continue
                    k, _, v = line.partition("=")
                    if k.strip() == api_key_env:
                        return v.strip().strip('"').strip("'")
        except Exception:  # noqa: BLE001
            pass
    return None


class Gateway:
    def __init__(self, config_path=None, provider_override=None):
        self.config_path = config_path or os.environ.get("RPA_MODEL_CONFIG") \
            or DEFAULT_CONFIG_PATH
        with open(self.config_path, "r", encoding="utf-8") as f:
            self.cfg = json.load(f)
        self.provider = provider_override or self.cfg.get("active_provider", "openrouter")
        if self.provider not in self.cfg.get("providers", {}):
            raise GatewayError("active_provider %r not in providers" % self.provider)
        self.pcfg = self.cfg["providers"][self.provider]
        self.api_base = self.pcfg["api_base"].rstrip("/")
        self.api_key = _resolve_key(self.pcfg.get("api_key_env"))
        if self.pcfg.get("api_key_env") and not self.api_key:
            raise GatewayError(
                "provider %r needs key env %r but it is not set (checked env and %s)"
                % (self.provider, self.pcfg["api_key_env"], OPENROUTER_ENV_FILE))
        self.timeout = int(self.cfg.get("timeout_seconds", 120))
        self.retries = int(self.cfg.get("retries_per_model", 2))
        self.backoff = float(self.cfg.get("backoff_seconds", 1.5))
        # optional litellm
        try:
            import litellm  # type: ignore  # noqa: F401
            self._litellm = litellm
        except Exception:  # noqa: BLE001
            self._litellm = None

    # -- public --------------------------------------------------------------
    def model_chain(self, role):
        r = self.cfg.get("roles", {}).get(role)
        if not r:
            raise GatewayError("unknown role %r" % role)
        chain = r.get(self.provider)
        if not chain:
            raise GatewayError("role %r has no model_chain for provider %r"
                               % (role, self.provider))
        return list(chain), r

    def describe(self):
        out = {"provider": self.provider, "api_base": self.api_base,
               "transport": "litellm" if self._litellm else "http",
               "roles": {}}
        for role in self.cfg.get("roles", {}):
            try:
                chain, _ = self.model_chain(role)
                out["roles"][role] = chain
            except GatewayError:
                pass
        return out

    def chat(self, role, messages, **overrides):
        """Run a chat completion for `role`. messages = OpenAI-style list
        (text and/or image_url content). Returns (text, meta)."""
        chain, rcfg = self.model_chain(role)
        temperature = overrides.get("temperature", rcfg.get("temperature", 0.0))
        max_tokens = overrides.get("max_tokens", rcfg.get("max_tokens", 512))
        attempts = []
        for model in chain:
            backoff = self.backoff
            for attempt in range(self.retries + 1):
                try:
                    text = self._one_call(model, messages, temperature, max_tokens)
                    return text, {"provider": self.provider, "model": model,
                                  "attempts": attempts}
                except Exception as e:  # noqa: BLE001
                    attempts.append({"model": model, "attempt": attempt + 1,
                                     "error": str(e)[:300]})
                    if attempt < self.retries:
                        time.sleep(backoff)
                        backoff *= 2
        raise GatewayError("all models failed for role %r: %s"
                           % (role, json.dumps(attempts, ensure_ascii=False)))

    # -- transports ----------------------------------------------------------
    def _one_call(self, model, messages, temperature, max_tokens):
        if self._litellm is not None:
            return self._call_litellm(model, messages, temperature, max_tokens)
        return self._call_http(model, messages, temperature, max_tokens)

    def _call_litellm(self, model, messages, temperature, max_tokens):
        lm = self._litellm
        kwargs = {
            "model": "%s/%s" % (self._litellm_provider(), model),
            "messages": messages,
            "temperature": temperature,
            "max_tokens": max_tokens,
            "api_base": self.api_base,
            "timeout": self.timeout,
        }
        if self.api_key:
            kwargs["api_key"] = self.api_key
        eh = self.pcfg.get("extra_headers")
        if eh:
            kwargs["extra_headers"] = eh
        resp = lm.completion(**kwargs)
        return resp["choices"][0]["message"]["content"]

    def _litellm_provider(self):
        # litellm routes by a provider prefix on the model string. For an
        # OpenAI-compatible endpoint (OpenRouter or a local llama.cpp/vLLM
        # server) the 'openai' provider + api_base is the portable choice.
        kind = self.pcfg.get("kind", "openai")
        return "openrouter" if self.provider == "openrouter" and kind == "openrouter" \
            else "openai"

    def _call_http(self, model, messages, temperature, max_tokens):
        url = "%s/chat/completions" % self.api_base
        payload = {"model": model, "messages": messages,
                   "temperature": temperature, "max_tokens": max_tokens}
        headers = {"Content-Type": "application/json"}
        if self.api_key:
            headers["Authorization"] = "Bearer %s" % self.api_key
        for k, v in (self.pcfg.get("extra_headers") or {}).items():
            headers[k] = v
        req = urllib.request.Request(
            url, data=json.dumps(payload).encode("utf-8"), headers=headers)
        try:
            with urllib.request.urlopen(req, timeout=self.timeout) as resp:
                body = json.loads(resp.read().decode("utf-8"))
        except urllib.error.HTTPError as e:
            detail = ""
            try:
                detail = e.read().decode("utf-8", "ignore")[:300]
            except Exception:  # noqa: BLE001
                pass
            raise GatewayError("HTTP %s from %s: %s" % (e.code, url, detail))
        return body["choices"][0]["message"]["content"]


# convenience: build an image_url content part from a base64 PNG
def image_part(b64_png):
    return {"type": "image_url",
            "image_url": {"url": "data:image/png;base64,%s" % b64_png}}


def text_part(text):
    return {"type": "text", "text": text}


if __name__ == "__main__":
    # tiny self-describe (no model call) so you can eyeball the active wiring
    import sys
    try:
        g = Gateway(provider_override=(sys.argv[1] if len(sys.argv) > 1 else None))
        print(json.dumps(g.describe(), ensure_ascii=False, indent=2))
    except Exception as e:  # noqa: BLE001
        print("gateway error: %s" % e)
        sys.exit(1)
