"""MCP server exposing file + quota + share tools to remote agents.

Transports
----------
Default: streamable-http on 127.0.0.1:8790, mounted at /mcp.
Set SHARE_MCP_TRANSPORT=stdio to run as a CLI subprocess (pipe-style).

Tools
-----
- list_dir(path="")                    directory listing
- read_file(path, max_bytes=...)       fetch file content (truncated)
- file_info(path)                      metadata
- mkdir(path)                          create folder under share root
- upload_text(path, content)           write a text file via MCP
- delete(path)                         remove file or empty dir
- share_link(path)                     build a public https://files.spannerjun.top/raw?path=... URL
- search(pattern)                      glob filename match
- miniMax_usage()                      proxy to miniquota (raw API)
- opencode_quota()                     proxy to miniquota (raw API)
- litellm_models()                     list models on lapi.spannerjun.top
- litellm_chat(model, prompt, ...)     run a quick chat through LiteLLM proxy
- smail_list_accounts()                 list mail accounts on newyork-server (dms mailserver)
- smail_create_account(email, password) create a mail account
- smail_delete_account(email)           delete a mail account
- smail_read_email(folder="INBOX", limit=10)  fetch latest emails via IMAP
- smail_send_email(to, subject, body, html="")  send an email via SMTP
- smail_postqueue()                     show outbound mail queue
- smail_logs(n=50)                      tail recent mail log lines
- smail_run(cmd)                        raw docker exec into mailserver (escape hatch)

Run
---
/home/hermes/venv/bin/python mcp_server.py --root /home/hermes/files
"""
from __future__ import annotations

import argparse
import json
import os
import shlex
import subprocess
import sys
from pathlib import Path
from typing import Any

from mcp.server.fastmcp import FastMCP
import httpx
from mcp.server.transport_security import TransportSecuritySettings

# --- Config ---------------------------------------------------------------

SHARE_ROOT = Path(os.environ.get("SHARE_ROOT", "/root/files")).resolve()
MINIQUOTA_URL = os.environ.get("MINIQUOTA_URL", "http://127.0.0.1:8788")
LITELLM_URL = os.environ.get("LITELLM_URL", "http://127.0.0.1:4000")
LITELLM_MASTER_KEY = os.environ.get("LITELLM_MASTER_KEY", "")
PUBLIC_BASE = os.environ.get("SHAREKIT_PUBLIC_BASE", "https://files.spannerjun.top")
MCP_HOST = os.environ.get("SHARE_MCP_HOST", "127.0.0.1")
MCP_PORT = int(os.environ.get("SHARE_MCP_PORT", "8790"))
MCP_TRANSPORT = os.environ.get("SHARE_MCP_TRANSPORT", "streamable-http")
MCP_PATH = os.environ.get("SHARE_MCP_PATH", "/mcp")

# --- smail.icu mail system (newyork-server, docker-mailserver) ---------------
SMAIL_HOST = os.environ.get("SMAIL_HOST", "mail.smail.icu")
SMAIL_SSH_ALIAS = os.environ.get("SMAIL_SSH_ALIAS", "newyork-server")
SMAIL_CONTAINER = os.environ.get("SMAIL_CONTAINER", "mailserver")
# Account that the agent uses to send/receive (hermes@smail.icu from skill).
SMAIL_AGENT_EMAIL = os.environ.get("SMAIL_AGENT_EMAIL", "hermes@smail.icu")
SMAIL_AGENT_PASSWORD = os.environ.get("SMAIL_AGENT_PASSWORD", "")
SMAIL_ADMIN_EMAIL = os.environ.get("SMAIL_ADMIN_EMAIL", "admin@smail.icu")
SMAIL_ADMIN_PASSWORD = os.environ.get("SMAIL_ADMIN_PASSWORD", "")


# --- FastMCP instance ------------------------------------------------------

# DNS rebinding protection: OFF by default since this MCP server is exposed
# via a public reverse proxy. Pass allow-list of public hostnames just in case.
TS_ALLOWED_HOSTS = [h.strip() for h in os.environ.get(
    "SHARE_MCP_ALLOWED_HOSTS",
    "mcp.spannerjun.top,127.0.0.1,localhost",
).split(",") if h.strip()]

mcp = FastMCP(
    name="sharekit",
    instructions=(
        "Tools for sharing files between agents and inspecting quota/policy on this host. "
        "All file ops are scoped to the configured share root."
    ),
    host=MCP_HOST,
    port=MCP_PORT,
    transport_security=TransportSecuritySettings(
        enable_dns_rebinding_protection=True,
        allowed_hosts=TS_ALLOWED_HOSTS,
        allowed_origins=[f"https://{h}" for h in TS_ALLOWED_HOSTS],
    ),
)


# --- Helpers --------------------------------------------------------------

def _safe_join(rel: str) -> Path:
    p = (SHARE_ROOT / rel).resolve()
    if SHARE_ROOT not in p.parents and p != SHARE_ROOT:
        raise ValueError(f"path escapes share root: {rel!r}")
    return p


def _fmt_size(n: int) -> str:
    for unit in ("B", "KB", "MB", "GB"):
        if n < 1024:
            return f"{n:.1f} {unit}" if unit != "B" else f"{n} {unit}"
        n /= 1024
    return f"{n:.1f} TB"


# --- Tools -----------------------------------------------------------------

@mcp.tool()
def list_dir(path: str = "") -> dict:
    """List a directory inside the share root (relative path). Returns entries with size + mtime."""
    p = _safe_join(path)
    if not p.exists():
        return {"error": "no such path", "path": path}
    if not p.is_dir():
        return {"error": "not a directory", "path": path}
    out = []
    for e in sorted(p.iterdir(), key=lambda x: (not x.is_dir(), x.name.lower())):
        if e.name.startswith("."):
            continue
        try:
            st = e.stat()
            out.append({
                "name": e.name,
                "path": str(e.relative_to(SHARE_ROOT)),
                "type": "dir" if e.is_dir() else "file",
                "size": st.st_size,
                "size_h": _fmt_size(st.st_size),
                "mtime": st.st_mtime,
                "public_url": f"{PUBLIC_BASE}/raw?path={e.relative_to(SHARE_ROOT)}" if e.is_file() else None,
            })
        except OSError as exc:
            out.append({"name": e.name, "error": str(exc)})
    return {"path": path or "/", "entries": out, "count": len(out)}


@mcp.tool()
def read_file(path: str, max_bytes: int = 100_000) -> dict:
    """Read a text file (or raw bytes summary for binaries). max_bytes caps at 5MB."""
    if not isinstance(path, str) or not path:
        return {"error": "path required"}
    try:
        p = _safe_join(path)
    except ValueError as e:
        return {"error": str(e)}
    if not p.is_file():
        return {"error": "not a file", "path": path}
    max_bytes = min(max_bytes, 5 * 1024 * 1024)
    try:
        with p.open("rb") as f:
            data = f.read(max_bytes + 1)
        truncated = len(data) > max_bytes
        if truncated:
            data = data[:max_bytes]
        try:
            content = data.decode("utf-8")
            encoding = "utf-8"
        except UnicodeDecodeError:
            content = data.hex()
            encoding = "binary-hex"
        return {
            "path": path,
            "size": p.stat().st_size,
            "truncated": truncated,
            "encoding": encoding,
            "content": content,
            "public_url": f"{PUBLIC_BASE}/raw?path={path}",
        }
    except OSError as e:
        return {"error": str(e)}


@mcp.tool()
def file_info(path: str) -> dict:
    """Stat a single file/folder."""
    try:
        p = _safe_join(path)
    except ValueError as e:
        return {"error": str(e)}
    if not p.exists():
        return {"error": "no such path"}
    st = p.stat()
    return {
        "path": path,
        "type": "dir" if p.is_dir() else "file",
        "size": st.st_size,
        "size_h": _fmt_size(st.st_size),
        "mtime": st.st_mtime,
        "public_url": f"{PUBLIC_BASE}/raw?path={path}" if p.is_file() else None,
    }


@mcp.tool()
def mkdir(path: str) -> dict:
    """Create a directory (recursively) under the share root."""
    try:
        p = _safe_join(path)
        p.mkdir(parents=True, exist_ok=True)
    except (ValueError, OSError) as e:
        return {"error": str(e)}
    return {"ok": True, "path": path}


@mcp.tool()
def upload_text(path: str, content: str) -> dict:
    """Write a UTF-8 text file under the share root. Refuses to overwrite; pass a new filename."""
    if not isinstance(path, str) or not path:
        return {"error": "path required"}
    try:
        p = _safe_join(path)
        if p.exists():
            return {"error": f"file exists: {path}"}
        p.parent.mkdir(parents=True, exist_ok=True)
        p.write_text(content, encoding="utf-8")
    except (ValueError, OSError) as e:
        return {"error": str(e)}
    return {"ok": True, "path": path, "bytes": len(content.encode("utf-8")),
            "public_url": f"{PUBLIC_BASE}/raw?path={path}"}


@mcp.tool()
def delete(path: str) -> dict:
    """Delete a file or empty directory. Refuses non-empty dirs."""
    try:
        p = _safe_join(path)
        if p == SHARE_ROOT:
            return {"error": "refusing to delete share root"}
        if not p.exists():
            return {"error": "no such path"}
        if p.is_dir() and any(p.iterdir()):
            return {"error": "directory not empty"}
        if p.is_dir():
            p.rmdir()
        else:
            p.unlink()
    except (ValueError, OSError) as e:
        return {"error": str(e)}
    return {"ok": True, "path": path}


@mcp.tool()
def search(pattern: str) -> dict:
    """Glob-style search under share root. pattern like '*.pdf' or 'docs/*'."""
    try:
        paths = list(SHARE_ROOT.rglob(pattern))
    except (OSError, ValueError) as e:
        return {"error": str(e)}
    return {
        "pattern": pattern,
        "matches": [
            {"path": str(p.relative_to(SHARE_ROOT)), "size": p.stat().st_size,
             "public_url": (f"{PUBLIC_BASE}/raw?path={p.relative_to(SHARE_ROOT)}"
                            if p.is_file() else None)}
            for p in paths[:200]
        ],
        "count": min(len(paths), 200),
        "total": len(paths),
    }


@mcp.tool()
def share_link(path: str) -> dict:
    """Return the public HTTPS URL for downloading this file."""
    try:
        p = _safe_join(path)
    except ValueError as e:
        return {"error": str(e)}
    if not p.is_file():
        return {"error": "not a file"}
    return {
        "path": path,
        "url": f"{PUBLIC_BASE}/raw?path={path}",
        "note": "file_server requires the share token in Authorization header.",
    }


@mcp.tool()
def miniMax_usage() -> dict:
    """Proxy to miniquota API: full MiniMax /api/MiniMax/quota response."""
    try:
        r = httpx.get(f"{MINIQUOTA_URL}/api/MiniMax/quota", timeout=15.0)
        r.raise_for_status()
        return {"ok": True, "data": r.json()}
    except Exception as e:
        return {"ok": False, "error": str(e)}


@mcp.tool()
def opencode_quota() -> dict:
    """Proxy to miniquota API: full /api/opencode_go/quota response."""
    try:
        r = httpx.get(f"{MINIQUOTA_URL}/api/opencode_go/quota", timeout=30.0)
        r.raise_for_status()
        return {"ok": True, "data": r.json()}
    except Exception as e:
        return {"ok": False, "error": str(e)}


@mcp.tool()
def litellm_models() -> dict:
    """List model aliases on the LiteLLM proxy (https://lapi.spannerjun.top)."""
    if not LITELLM_MASTER_KEY:
        return {"error": "LITELLM_MASTER_KEY not set in mcp server env"}
    try:
        r = httpx.get(
            f"{LITELLM_URL}/v1/models",
            headers={"Authorization": f"Bearer {LITELLM_MASTER_KEY}"},
            timeout=10.0,
        )
        r.raise_for_status()
        return {"ok": True, "models": [m["id"] for m in r.json().get("data", [])]}
    except Exception as e:
        return {"ok": False, "error": str(e)}


@mcp.tool()
def litellm_chat(model: str, prompt: str, system: str = "", max_tokens: int = 256) -> dict:
    """Quick chat via the LiteLLM proxy. Returns the assistant reply text."""
    if not LITELLM_MASTER_KEY:
        return {"error": "LITELLM_MASTER_KEY not set in mcp server env"}
    messages = []
    if system:
        messages.append({"role": "system", "content": system})
    messages.append({"role": "user", "content": prompt})
    try:
        r = httpx.post(
            f"{LITELLM_URL}/v1/chat/completions",
            headers={
                "Authorization": f"Bearer {LITELLM_MASTER_KEY}",
                "Content-Type": "application/json",
            },
            json={"model": model, "messages": messages, "max_tokens": max_tokens},
            timeout=60.0,
        )
        r.raise_for_status()
        j = r.json()
        return {
            "ok": True,
            "model": model,
            "reply": j.get("choices", [{}])[0].get("message", {}).get("content", ""),
            "usage": j.get("usage", {}),
        }
    except Exception as e:
        return {"ok": False, "error": str(e)}


# --- smail.icu tools (proxy to newyork-server docker-mailserver) ------------

# We import imaplib/smtplib lazily so the rest of the server can start even if
# Python is built without TLS (rare on Linux).
def _ssh_run(remote_cmd: str, timeout: int = 20) -> dict:
    """Execute a shell command on SMAIL_SSH_ALIAS over ssh, return dict.

    BatchMode is left at the default (off) so SSH_ASKPASS is consulted when the
    server asks for a password — required because dms lives behind a password
    gate with no publickey. Pass-through env so SSH_ASKPASS / DISPLAY reach ssh.
    """
    import os as _os
    env = _os.environ.copy()
    # Force ASKPASS to the system-installed location so it works even if
    # /home/hermes is not readable by the running user (defensive).
    if not env.get("SSH_ASKPASS") and _os.path.exists("/usr/local/bin/smail-askpass.sh"):
        env["SSH_ASKPASS"] = "/usr/local/bin/smail-askpass.sh"
        env["SSH_ASKPASS_REQUIRE"] = "force"
        env.setdefault("DISPLAY", ":0")
    try:
        proc = subprocess.run(
            ["ssh", "-o", "StrictHostKeyChecking=accept-new",
             "-o", f"ConnectTimeout={timeout}",
             "-o", "NumberOfPasswordPrompts=1",
             "-F", "/home/hermes/.ssh/config",
             SMAIL_SSH_ALIAS, remote_cmd],
            capture_output=True, text=True, timeout=timeout, env=env,
        )
        return {
            "ok": proc.returncode == 0,
            "rc": proc.returncode,
            "stdout": proc.stdout,
            "stderr": proc.stderr,
        }
    except Exception as e:
        return {"ok": False, "error": str(e)}


def _smail_check() -> dict:
    """Return a small status dict that callers can attach for debugging."""
    return {
        "host": SMAIL_HOST,
        "ssh_alias": SMAIL_SSH_ALIAS,
        "container": SMAIL_CONTAINER,
        "agent_email": SMAIL_AGENT_EMAIL,
    }


def _smail_account_op(subcommand: str, email: str, password: str | None = None,
                     extra_args: list[str] | None = None) -> dict:
    """Generic dms `setup email ...` runner.

    subcommand in {"list", "add", "del", "update"}.
    """
    remote = ["docker", "exec", SMAIL_CONTAINER, "setup", "email", subcommand]
    remote.append(email)
    if password is not None and subcommand in ("add", "update"):
        remote.append(password)
    if extra_args:
        remote.extend(extra_args)
    remote_cmd = " ".join(shlex.quote(t) for t in remote)
    res = _ssh_run(remote_cmd, timeout=30)
    # post-fix: account add sometimes needs dovecot restart (DMS bug).
    if subcommand == "add" and res.get("ok"):
        fix = _ssh_run(f"docker exec {shlex.quote(SMAIL_CONTAINER)} supervisorctl restart dovecot",
                        timeout=20)
        res["dovecot_restart"] = fix
    return res


@mcp.tool()
def smail_status() -> dict:
    """smail.icu system status: SSH reachability + docker-mailserver container + queue + recent logs."""
    info = _smail_check()
    info["ssh_ping"] = _ssh_run("echo OK; hostname", timeout=15)
    info["container"] = _ssh_run(
        f"docker ps --format '{{{{.Names}}}}\\t{{{{.Status}}}}\\t{{{{.Image}}}}' "
        f"--filter name={shlex.quote(SMAIL_CONTAINER)}",
        timeout=20,
    )
    info["postqueue"] = _smail_run_postqueue()
    info["logs_tail"] = _smail_logs(n=10)
    return info


def _smail_run_postqueue() -> dict:
    return _ssh_run(
        f"docker exec {shlex.quote(SMAIL_CONTAINER)} postqueue -p",
        timeout=15,
    )


def _smail_logs(n: int = 50) -> dict:
    return _ssh_run(
        f"docker exec {shlex.quote(SMAIL_CONTAINER)} tail -n {int(n)} /var/log/mail.log",
        timeout=15,
    )


@mcp.tool()
def smail_list_accounts() -> dict:
    """List all smail.icu mail accounts on newyork-server."""
    remote = f"docker exec {shlex.quote(SMAIL_CONTAINER)} setup email list"
    return _ssh_run(remote, timeout=20)


@mcp.tool()
def smail_create_account(email: str, password: str) -> dict:
    """Create a smail.icu mailbox. Password should be quoted if it contains shell-special chars."""
    if not email.endswith("@smail.icu"):
        return {"ok": False, "error": "email must end with @smail.icu"}
    # postfix-safe password escape: shell-quote inline in the _smail_account_op helper.
    return _smail_account_op("add", email, password)


@mcp.tool()
def smail_delete_account(email: str) -> dict:
    """Delete a smail.icu mailbox (preserves the maildir on disk)."""
    if not email.endswith("@smail.icu"):
        return {"ok": False, "error": "email must end with @smail.icu"}
    return _smail_account_op("del", email)


@mcp.tool()
def smail_postqueue() -> dict:
    """Show deferred/active outbound mail queue (postfix)."""
    return _smail_run_postqueue()


@mcp.tool()
def smail_logs(n: int = 50) -> dict:
    """Tail the last N lines of /var/log/mail.log inside the mailserver container."""
    return _smail_logs(n=int(n))


@mcp.tool()
def smail_run(cmd: str) -> dict:
    """Arbitrary `docker exec mailserver <cmd>` (escape hatch). Refuses destructive ops."""
    forbidden = ["rm -rf /", "docker rm -f", "shutdown", "reboot", "mkfs", "dd if="]
    if any(f in cmd for f in forbidden):
        return {"ok": False, "error": f"refused: '{f}' not allowed"}
    remote_cmd = "docker exec " + shlex.quote(SMAIL_CONTAINER) + " " + cmd
    return _ssh_run(remote_cmd, timeout=60)


@mcp.tool()
def smail_send_email(to: str, subject: str, body: str, html: str = "") -> dict:
    """Send an email as hermes@smail.icu (STARTTLS 587). To may be a single address."""
    if not SMAIL_AGENT_PASSWORD:
        return {"ok": False, "error": "SMAIL_AGENT_PASSWORD not set in mcp server env"}
    import smtplib, ssl
    from email.mime.multipart import MIMEMultipart
    from email.mime.text import MIMEText
    from email.utils import formatdate, make_msgid
    msg = MIMEMultipart("alternative")
    msg["From"] = f"Hermes Agent <{SMAIL_AGENT_EMAIL}>"
    msg["To"] = to
    msg["Date"] = formatdate(localtime=True)
    msg["Message-ID"] = make_msgid(domain="smail.icu")
    msg["Subject"] = subject
    msg.attach(MIMEText(body, "plain", "utf-8"))
    if html:
        msg.attach(MIMEText(html, "html", "utf-8"))
    try:
        with smtplib.SMTP(SMAIL_HOST, 587, timeout=30) as s:
            s.starttls(context=ssl.create_default_context())
            s.login(SMAIL_AGENT_EMAIL, SMAIL_AGENT_PASSWORD)
            s.sendmail(SMAIL_AGENT_EMAIL, [to], msg.as_string())
        return {"ok": True, "to": to, "subject": subject, "message_id": msg["Message-ID"]}
    except Exception as e:
        return {"ok": False, "error": str(e)}


@mcp.tool()
def smail_read_email(folder: str = "INBOX", limit: int = 10) -> dict:
    """Fetch the latest N messages from the agent's IMAP inbox."""
    if not SMAIL_AGENT_PASSWORD:
        return {"ok": False, "error": "SMAIL_AGENT_PASSWORD not set in mcp server env"}
    import imaplib, email as em
    try:
        M = imaplib.IMAP4_SSL(SMAIL_HOST, 993)
        M.login(SMAIL_AGENT_EMAIL, SMAIL_AGENT_PASSWORD)
        M.select(folder)
        typ, data = M.search(None, "ALL")
        ids = data[0].split()[-int(limit):]
        msgs = []
        for msg_id in ids:
            typ, msg_data = M.fetch(msg_id, "(RFC822)")
            raw = msg_data[0][1]
            m = em.message_from_bytes(raw)
            body = ""
            for part in m.walk():
                if part.get_content_type() == "text/plain":
                    body = part.get_payload(decode=True).decode("utf-8", errors="ignore")
                    break
            msgs.append({
                "id": msg_id.decode(),
                "from": m.get("From", ""),
                "to": m.get("To", ""),
                "subject": m.get("Subject", ""),
                "date": m.get("Date", ""),
                "message_id": m.get("Message-ID", ""),
                "body": body[:5000],  # truncate
            })
        M.close()
        M.logout()
        return {"ok": True, "folder": folder, "count": len(msgs), "messages": msgs}
    except Exception as e:
        return {"ok": False, "error": str(e)}


# --- Entrypoint ------------------------------------------------------------

def main() -> None:
    global SHARE_ROOT, MCP_HOST, MCP_PORT
    parser = argparse.ArgumentParser()
    parser.add_argument("--root", default=str(SHARE_ROOT))
    parser.add_argument("--host", default=MCP_HOST)
    parser.add_argument("--port", type=int, default=MCP_PORT)
    parser.add_argument("--transport", choices=("stdio", "streamable-http"), default=MCP_TRANSPORT)
    args = parser.parse_args()

    SHARE_ROOT = Path(args.root).resolve()
    SHARE_ROOT.mkdir(parents=True, exist_ok=True)
    MCP_HOST = args.host
    MCP_PORT = args.port

    print(f"sharekit mcp: serving share root {SHARE_ROOT}", file=sys.stderr)
    print(f"sharekit mcp: transport={args.transport} host={args.host} port={args.port}", file=sys.stderr)
    mcp.run(transport=args.transport)


if __name__ == "__main__":
    main()
