#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""Tool layer for the mapping stage — the hands and eyes, on Windows.

Every method returns a ToolResult dict:
  {"ok": bool, "data": ..., "screenshot": "<abs png|null>",
   "changed": bool, "changed_regions": [[x,y,w,h],...], "error": null|"..."}

Eyes  : screenshot (ImageGrab), read_region (crop->VLM via Gateway),
        locate (full screen -> VLM -> bbox), screenshot diff.
Hands : click / type_text / press_key / scroll (real SendInput via win_input).
Probes: probe_accessibility (UIA), probe_windows (EnumWindows), probe_keyboard.
Map   : save_map_entry / save_template / emit_verdict.

ALL model calls go through the Gateway (model_config.json) — never a hardcoded
endpoint. Switch OpenRouter<->local in that one file.

Safety is enforced HERE, in code, not in a prompt:
  Before any click, the tool independently READS what is under the target
  (it does not trust a caller-supplied label) and classifies it. If the visible
  text is a COMMIT/DESTRUCTIVE control (save/send/delete/pay/approve/submit/
  print/close-case...), the click is BLOCKED and flagged for escalation — the
  agent physically cannot press it. Navigation/neutral controls proceed so the
  app can still be mapped.
"""

from __future__ import annotations

import base64
import io
import json
import os
import subprocess
import time

from gateway import Gateway, image_part, text_part
import win_input
import find_app

HERE = os.path.dirname(os.path.abspath(__file__))
CLIENTS_DIR = os.path.join(HERE, "clients")
VLM_DEFAULT_PROMPT = "קרא את הטקסט בתמונה. צטט במדויק. אם אין טקסט, החזר מחרוזת ריקה."

# ----------------------------------------------------------------------------
# Commit / destructive label denylist (the in-code safety core).
# Matched as normalized substrings against whatever text is actually under a
# click target. Hebrew + English. NEVER auto-clicked during mapping.
# ----------------------------------------------------------------------------
COMMIT_LABELS = [
    "מחק", "מחיקה", "מחק תחשיב", "שמור", "שמירה", "שלח", "שליחה",
    "שלח לחברת ביטוח", "אשר", "אישור", "בצע", "ביצוע", "עדכן", "עדכון",
    "סגירת תיק", "סגור תיק", "הדפס", "הדפסה", "אשר תשלום", "שלם", "תשלום",
    "submit", "save", "send", "delete", "remove", "approve", "confirm",
    "pay", "print", "post", "commit",
]
# Labels that are explicitly safe to dismiss a popup with (used by agent policy).
SAFE_DISMISS_LABELS = ["ביטול", "סגור", "הסתר", "לא", "אחר כך", "cancel", "close", "no"]


def _norm(s):
    return "".join((s or "").split()).lower()


def _is_commit(text):
    t = _norm(text)
    if not t:
        return False, None
    for lab in COMMIT_LABELS:
        if _norm(lab) in t:
            return True, lab
    return False, None


# ----------------------------------------------------------------------------
def _result(ok=True, data=None, screenshot=None, changed=False,
            changed_regions=None, error=None, **extra):
    r = {"ok": bool(ok), "data": data, "screenshot": screenshot,
         "changed": bool(changed), "changed_regions": changed_regions or [],
         "error": error}
    r.update(extra)
    return r


def _err(msg, **extra):
    return _result(ok=False, error=str(msg), **extra)


# ----------------------------------------------------------------------------
# pixel diff -> changed regions (PIL only, no numpy)
# ----------------------------------------------------------------------------
def _diff_regions(prev_path, cur_path, down=8, thresh=26, min_area_frac=0.0003):
    from PIL import Image, ImageChops
    try:
        a = Image.open(prev_path).convert("RGB")
        b = Image.open(cur_path).convert("RGB")
    except Exception as e:  # noqa: BLE001
        return False, []
    if a.size != b.size:
        b = b.resize(a.size)
    W, H = a.size
    sw, sh = max(1, W // down), max(1, H // down)
    da = ImageChops.difference(a.resize((sw, sh)), b.resize((sw, sh)))
    px = da.load()
    mask = [[max(px[x, y]) >= thresh for x in range(sw)] for y in range(sh)]
    seen = [[False] * sw for _ in range(sh)]
    boxes = []
    min_cells = max(1, int(min_area_frac * sw * sh))
    for y in range(sh):
        for x in range(sw):
            if mask[y][x] and not seen[y][x]:
                stack = [(x, y)]; seen[y][x] = True
                minx = maxx = x; miny = maxy = y; cells = 0
                while stack:
                    cx, cy = stack.pop(); cells += 1
                    minx, maxx = min(minx, cx), max(maxx, cx)
                    miny, maxy = min(miny, cy), max(maxy, cy)
                    for dx, dy in ((1, 0), (-1, 0), (0, 1), (0, -1)):
                        nx, ny = cx + dx, cy + dy
                        if 0 <= nx < sw and 0 <= ny < sh and mask[ny][nx] and not seen[ny][nx]:
                            seen[ny][nx] = True; stack.append((nx, ny))
                if cells < min_cells:
                    continue
                fx, fy = minx * down, miny * down
                fw = min((maxx - minx + 1) * down, W - fx)
                fh = min((maxy - miny + 1) * down, H - fy)
                boxes.append([fx, fy, fw, fh])
    boxes.sort(key=lambda b: -b[2] * b[3])
    return (len(boxes) > 0), boxes


# ----------------------------------------------------------------------------
class WinTools:
    def __init__(self, client="drumer", gateway=None, config=None,
                 target_software=None, window_title=None, exe_path=None):
        self.client = client
        self.gw = gateway or Gateway()
        self.config = dict(config or {})
        self.target_software = target_software
        self.window_title = window_title
        self.exe_path = exe_path
        self.client_dir = os.path.join(CLIENTS_DIR, client)
        self.session_dir = os.path.join(self.client_dir, "session")
        os.makedirs(self.session_dir, exist_ok=True)
        os.makedirs(os.path.join(self.client_dir, "popups"), exist_ok=True)
        self._seq = 0
        self._last_shot = None
        self._win_app = None
        # gate policy
        self.gated = self.config.get("gated_mode", True)
        self.block_unknown_in_modal = self.config.get("block_unknown_in_modal", False)

    # -- files --------------------------------------------------------------
    def _png(self, tag):
        self._seq += 1
        return os.path.join(self.session_dir, "%04d_%s.png" % (self._seq, tag))

    # -- screen grab (primary monitor, physical px) -------------------------
    def _grab(self, tag):
        from PIL import ImageGrab
        img = ImageGrab.grab()
        path = self._png(tag)
        img.save(path)
        return path, img.size

    @staticmethod
    def png_b64(path):
        from PIL import Image
        with open(path, "rb") as f:
            return base64.b64encode(f.read()).decode("ascii")

    # ======================================================================
    # OBSERVATION
    # ======================================================================
    def open_app(self):
        """Attach to (or launch) the target window so it is focused/foreground."""
        try:
            win_input.set_dpi_aware()
        except Exception:  # noqa: BLE001
            pass
        if not win_input.is_windows():
            return _err("open_app requires Windows")
        try:
            how = None
            if self.exe_path:
                # Launch directly — NO pywin32/pywinauto dependency. Snapshot the
                # open windows before/after so we can find + focus the new window.
                before = {h for h, _ in find_app.enumerate_windows()}
                try:
                    os.startfile(self.exe_path)  # type: ignore[attr-defined]
                except Exception:
                    subprocess.Popen([self.exe_path])
                # wait for a new top-level window to appear, then focus it
                appeared = None
                for _ in range(20):
                    time.sleep(0.8)
                    now = find_app.enumerate_windows()
                    new = [(h, t) for h, t in now if h not in before and t.strip()]
                    if new:
                        appeared = new[-1]
                        break
                if appeared:
                    try:
                        find_app.focus_window(appeared[0])
                    except Exception:  # noqa: BLE001
                        pass
                    how = "launched-exe: %s (window: %s)" % (self.exe_path, appeared[1])
                else:
                    how = "launched-exe: %s (no new window detected; using foreground)" \
                          % self.exe_path
            elif self.window_title:
                win, ranked = find_app.find_open_window(self.window_title)
                if not win:
                    return _err("no open window matched %r; top candidates: %s"
                                % (self.window_title, ranked))
                find_app.focus_window(win[0])
                how = "attached-title: %s" % win[1]
            elif self.target_software:
                res = find_app.open_target(self.target_software)
                if not res.get("ok"):
                    return _err("could not find/open %r: %s" %
                                (self.target_software, res.get("error")), data=res)
                how = "%s: %s" % (res.get("how"), res.get("title"))
            else:
                how = "current-foreground-window"
            time.sleep(1.2)
            path, size = self._grab("open_app")
            self._last_shot = path
            return _result(ok=True, data={"how": how, "size": list(size)},
                           screenshot=path)
        except Exception as e:  # noqa: BLE001
            return _err("open_app failed: %s" % e)

    def screenshot(self):
        try:
            prev = self._last_shot
            path, size = self._grab("screenshot")
            changed, regions = (False, [])
            if prev and os.path.exists(prev):
                changed, regions = _diff_regions(prev, path)
            self._last_shot = path
            return _result(ok=True, data={"size": list(size)}, screenshot=path,
                           changed=changed, changed_regions=regions)
        except Exception as e:  # noqa: BLE001
            return _err("screenshot failed: %s" % e)

    def _ensure_shot(self, tag="base"):
        if not self._last_shot or not os.path.exists(self._last_shot):
            path, _ = self._grab(tag)
            self._last_shot = path
        return self._last_shot

    # -- vision: read a region ---------------------------------------------
    def read_region(self, x, y, w, h, question=None):
        from PIL import Image
        try:
            base = self._ensure_shot("read_base")
            img = Image.open(base).convert("RGB")
            x, y, w, h = int(x), int(y), int(w), int(h)
            x2, y2 = min(x + w, img.width), min(y + h, img.height)
            x, y = max(0, x), max(0, y)
            if x2 <= x or y2 <= y:
                return _err("invalid region (%s,%s,%s,%s)" % (x, y, w, h))
            crop = img.crop((x, y, x2, y2))
            up = crop.resize((crop.width * 3, crop.height * 3), Image.LANCZOS)
            cpath = self._png("read_crop")
            up.save(cpath)
            buf = io.BytesIO(); up.save(buf, format="PNG")
            b64 = base64.b64encode(buf.getvalue()).decode("ascii")
        except Exception as e:  # noqa: BLE001
            return _err("crop failed: %s" % e)
        prompt = question or VLM_DEFAULT_PROMPT
        try:
            text, meta = self.gw.chat(
                "vision",
                [{"role": "user", "content": [text_part(prompt), image_part(b64)]}])
            return _result(ok=True, data={"text": text, "region": [x, y, w, h]},
                           screenshot=cpath, vision_model=meta.get("model"))
        except Exception as e:  # noqa: BLE001
            return _err("vision read failed: %s" % e, screenshot=cpath)

    # -- vision: locate a control ------------------------------------------
    def locate(self, goal):
        """Return the bbox [x,y,w,h] (full-screen pixels) of the control best
        matching `goal` (a Hebrew/English description or label), or null."""
        try:
            base = self._ensure_shot("locate_base")
            b64 = self.png_b64(base)
            from PIL import Image
            W, H = Image.open(base).size
        except Exception as e:  # noqa: BLE001
            return _err("locate prep failed: %s" % e)
        q = ("התמונה היא צילום מסך ברוחב %d וגובה %d פיקסלים. "
             "מצא את הפקד שמתאים ל: «%s». "
             "החזר JSON בלבד: {\"found\": true/false, \"bbox\": [x,y,w,h]} "
             "כאשר x,y,w,h בפיקסלים מוחלטים של התמונה (פינה שמאלית-עליונה + רוחב+גובה). "
             "אם לא נמצא, found=false." % (W, H, goal))
        try:
            text, meta = self.gw.chat(
                "vision",
                [{"role": "user", "content": [text_part(q), image_part(b64)]}])
            bbox = _extract_bbox(text, W, H)
            return _result(ok=True, data={"goal": goal, "found": bbox is not None,
                                          "bbox": bbox, "raw": text[:200]},
                           screenshot=base, vision_model=meta.get("model"))
        except Exception as e:  # noqa: BLE001
            return _err("locate failed: %s" % e)

    # -- probes -------------------------------------------------------------
    def probe_accessibility(self):
        if not win_input.is_windows():
            return _err("probe_accessibility requires Windows")
        # Optional: needs pywinauto+pywin32. If unavailable (common on a fresh
        # box), degrade gracefully — Dcwin is custom-drawn so vision is the path
        # anyway. Never crash the mapping over this.
        try:
            from pywinauto import Desktop  # type: ignore
        except Exception as e:  # noqa: BLE001
            return _result(ok=True, data={"accessibility_tree": "unavailable",
                                          "note": "pywinauto/pywin32 not usable "
                                          "(%s); treating as vision-only" % e},
                           screenshot=self._last_shot)
        try:
            ctrls = 0
            top = None
            for wdw in Desktop(backend="uia").windows():
                try:
                    if top is None:
                        top = wdw.window_text()
                    ctrls += len(wdw.descendants())
                except Exception:  # noqa: BLE001
                    pass
            tree = "full" if ctrls > 50 else ("partial" if ctrls > 0 else "none")
            return _result(ok=True, data={"accessibility_tree": tree,
                                          "control_count": ctrls,
                                          "top_window": top},
                           screenshot=self._last_shot)
        except Exception as e:  # noqa: BLE001
            return _result(ok=True, data={"accessibility_tree": "unavailable",
                                          "note": "UIA walk failed (%s)" % e},
                           screenshot=self._last_shot)

    def probe_windows(self):
        if not win_input.is_windows():
            return _err("probe_windows requires Windows")
        try:
            import ctypes
            from ctypes import wintypes
            u = ctypes.windll.user32
            titles = []
            EnumProc = ctypes.WINFUNCTYPE(wintypes.BOOL, wintypes.HWND, wintypes.LPARAM)

            def cb(hwnd, _l):
                if not u.IsWindowVisible(hwnd):
                    return True
                n = u.GetWindowTextLengthW(hwnd)
                if n:
                    buf = ctypes.create_unicode_buffer(n + 1)
                    u.GetWindowTextW(hwnd, buf, n + 1)
                    titles.append({"hwnd": int(hwnd), "title": buf.value})
                return True

            u.EnumWindows(EnumProc(cb), 0)
            return _result(ok=True, data={"window_enumeration": "full",
                                          "windows": titles},
                           screenshot=self._last_shot)
        except Exception as e:  # noqa: BLE001
            return _err("probe_windows failed: %s" % e)

    def probe_keyboard(self, keys):
        if not win_input.is_windows():
            return _err("probe_keyboard requires Windows")
        try:
            prev = self._ensure_shot("kbd_pre")
            seq = keys if isinstance(keys, (list, tuple)) else [keys]
            for k in seq:
                win_input.press_key(k)
                time.sleep(0.1)
            path, _ = self._grab("kbd_post")
            changed, regions = _diff_regions(prev, path)
            self._last_shot = path
            return _result(ok=True, data={"keyboard_nav": "partial" if changed else "none",
                                          "keys": list(seq), "screen_changed": changed},
                           screenshot=path, changed=changed, changed_regions=regions)
        except Exception as e:  # noqa: BLE001
            return _err("probe_keyboard failed: %s" % e)

    # ======================================================================
    # ACTUATION  (gated in code)
    # ======================================================================
    def _gate_click(self, x, y):
        """Independently read what is under (x,y) and block commits.

        Returns (allowed: bool, reason: str, seen_text: str).
        """
        if not self.gated:
            return True, "gated_mode=off", ""
        # read a label-sized window around the target (do NOT trust the caller)
        rd = self.read_region(x - 110, y - 22, 220, 44,
                              question="קרא את הטקסט על הכפתור/הפקד הזה. החזר טקסט בלבד.")
        seen = ""
        if rd.get("ok"):
            d = rd.get("data") or {}
            seen = d.get("text") if isinstance(d, dict) else (d or "")
            seen = seen or ""
        is_commit, lab = _is_commit(seen)
        if is_commit:
            return False, ("BLOCKED: target reads %r which matches commit/"
                           "destructive label %r — escalate, do not click" %
                           (seen[:60], lab)), seen
        if not _norm(seen) and self.block_unknown_in_modal:
            return False, "BLOCKED: empty/unreadable target with strict mode on", seen
        return True, "navigation/neutral (%r)" % (seen[:40],), seen

    def click(self, x, y, label=None):
        if not win_input.is_windows():
            return _err("click requires Windows")
        allowed, reason, seen = self._gate_click(int(x), int(y))
        if not allowed:
            return _err(reason, data={"gated": True, "blocked": True,
                                      "seen_text": seen, "x": x, "y": y,
                                      "escalate": True})
        try:
            prev = self._ensure_shot("click_pre")
            win_input.move_click(int(x), int(y))
            path, _ = self._grab("click")
            changed, regions = _diff_regions(prev, path)
            self._last_shot = path
            return _result(ok=True, data={"x": x, "y": y, "gate": reason,
                                          "seen_text": seen},
                           screenshot=path, changed=changed, changed_regions=regions)
        except Exception as e:  # noqa: BLE001
            return _err("click failed: %s" % e)

    def type_text(self, text):
        if not win_input.is_windows():
            return _err("type_text requires Windows")
        try:
            prev = self._ensure_shot("type_pre")
            win_input.type_unicode(text)
            path, _ = self._grab("type")
            changed, regions = _diff_regions(prev, path)
            self._last_shot = path
            return _result(ok=True, data={"text": text}, screenshot=path,
                           changed=changed, changed_regions=regions)
        except Exception as e:  # noqa: BLE001
            return _err("type_text failed: %s" % e)

    def press_key(self, key):
        if not win_input.is_windows():
            return _err("press_key requires Windows")
        try:
            prev = self._ensure_shot("press_pre")
            win_input.press_key(key)
            path, _ = self._grab("press")
            changed, regions = _diff_regions(prev, path)
            self._last_shot = path
            return _result(ok=True, data={"key": key}, screenshot=path,
                           changed=changed, changed_regions=regions)
        except Exception as e:  # noqa: BLE001
            return _err("press_key failed: %s" % e)

    def scroll(self, notches, x=None, y=None):
        if not win_input.is_windows():
            return _err("scroll requires Windows")
        try:
            prev = self._ensure_shot("scroll_pre")
            win_input.scroll(int(notches), x, y)
            path, _ = self._grab("scroll")
            changed, regions = _diff_regions(prev, path)
            self._last_shot = path
            return _result(ok=True, data={"notches": notches}, screenshot=path,
                           changed=changed, changed_regions=regions)
        except Exception as e:  # noqa: BLE001
            return _err("scroll failed: %s" % e)

    # ======================================================================
    # MAP PERSISTENCE
    # ======================================================================
    def save_map_entry(self, obj=None, json_path=None):
        try:
            path = json_path or os.path.join(self.client_dir, "software-map.json")
            existing = {}
            if os.path.exists(path):
                with open(path, "r", encoding="utf-8") as f:
                    existing = json.load(f)
            merged = _merge_map(existing, obj or {})
            with open(path, "w", encoding="utf-8") as f:
                json.dump(merged, f, ensure_ascii=False, indent=2)
            return _result(ok=True, data={"path": path, "keys": list(merged.keys())},
                           screenshot=self._last_shot)
        except Exception as e:  # noqa: BLE001
            return _err("save_map_entry failed: %s" % e)

    def save_template(self, name, region=None):
        try:
            from PIL import Image
            pdir = os.path.join(self.client_dir, "popups")
            os.makedirs(pdir, exist_ok=True)
            safe = "".join(c for c in str(name) if c.isalnum() or c in "-_.") or "tmpl"
            if not safe.endswith(".png"):
                safe += ".png"
            dest = os.path.join(pdir, safe)
            src = self._ensure_shot("tmpl_src")
            img = Image.open(src).convert("RGB")
            if region:
                rx, ry, rw, rh = [int(v) for v in region]
                img = img.crop((rx, ry, rx + rw, ry + rh))
            img.save(dest)
            return _result(ok=True, data={"path": dest,
                                          "signature_png": "popups/%s" % safe},
                           screenshot=src)
        except Exception as e:  # noqa: BLE001
            return _err("save_template failed: %s" % e)

    def emit_verdict(self, obj):
        return self.save_map_entry(obj={"automation_verdict": obj})


# ----------------------------------------------------------------------------
def _merge_map(base, incoming):
    out = dict(base)
    list_id_keys = {"screens", "popups"}
    for k, v in incoming.items():
        if k in list_id_keys and isinstance(v, list):
            cur = out.get(k, [])
            by_id = {e.get("id"): e for e in cur if isinstance(e, dict)}
            for e in v:
                if isinstance(e, dict) and e.get("id") in by_id:
                    by_id[e["id"]].update(e)
                else:
                    cur.append(e)
                    if isinstance(e, dict):
                        by_id[e.get("id")] = e
            out[k] = cur
        elif k == "transitions" and isinstance(v, list):
            cur = out.get(k, [])
            seen = {json.dumps(t, sort_keys=True, ensure_ascii=False) for t in cur}
            for t in v:
                kk = json.dumps(t, sort_keys=True, ensure_ascii=False)
                if kk not in seen:
                    cur.append(t); seen.add(kk)
            out[k] = cur
        else:
            out[k] = v
    return out


def _extract_bbox(text, W, H):
    """Parse a {found, bbox:[x,y,w,h]} JSON from model text. Tolerates 0..1000
    normalized coords (Qwen-VL grounding) by rescaling to pixels if values look
    normalized."""
    import re
    m = re.search(r'\{.*\}', text or "", re.S)
    if not m:
        return None
    try:
        j = json.loads(m.group(0))
    except Exception:  # noqa: BLE001
        return None
    if not j.get("found"):
        return None
    bb = j.get("bbox")
    if not (isinstance(bb, list) and len(bb) == 4):
        return None
    x, y, w, h = [float(v) for v in bb]
    # if all coords <= 1000 but the screen is bigger, assume 0..1000 normalized
    if max(x + w, y + h) <= 1000 and (W > 1000 or H > 1000):
        x = x * W / 1000.0; w = w * W / 1000.0
        y = y * H / 1000.0; h = h * H / 1000.0
    return [int(x), int(y), int(w), int(h)]
