import json
import re
from dataclasses import dataclass
from typing import Callable


@dataclass
class Tool:
    name: str
    description: str
    handler: Callable[[dict], str]


def calculator(args: dict) -> str:
    expression = str(args.get("expression", ""))
    if not re.fullmatch(r"[0-9+\-*/(). ]+", expression):
        return "表达式包含不允许的字符"
    return str(eval(expression, {"__builtins__": {}}, {}))


def write_note(args: dict) -> str:
    title = str(args.get("title", "untitled"))
    content = str(args.get("content", ""))
    return f"已写入笔记：{title}，正文 {len(content)} 字"


TOOLS = {
    "calculator": Tool("calculator", "执行基础数学计算", calculator),
    "write_note": Tool("write_note", "保存一条文本笔记", write_note),
}


class FirstAgent:
    def __init__(self, tools: dict[str, Tool], max_steps: int = 5) -> None:
        self.tools = tools
        self.max_steps = max_steps
        self.observations: list[str] = []

    def build_prompt(self, task: str) -> str:
        tool_text = "\n".join(f"{tool.name}: {tool.description}" for tool in self.tools.values())
        history = "\n".join(self.observations[-3:]) or "暂无观察"
        return (
            f"任务：{task}\n"
            f"可用工具：\n{tool_text}\n"
            f"历史观察：\n{history}\n"
            '请输出 JSON：{"action":"tool","tool":"工具名","args":{}} '
            '或 {"action":"finish","answer":"最终答案"}'
        )

    def call_model(self, prompt: str) -> str:
        if not any("calculator" in item for item in self.observations):
            return json.dumps(
                {"action": "tool", "tool": "calculator", "args": {"expression": "23 * 17 + 8"}},
                ensure_ascii=False,
            )
        if not any("write_note" in item for item in self.observations):
            return json.dumps(
                {
                    "action": "tool",
                    "tool": "write_note",
                    "args": {"title": "第一次 Agent 运行", "content": "先计算，再保存观察结果。"},
                },
                ensure_ascii=False,
            )
        return json.dumps({"action": "finish", "answer": "计算结果为 399，笔记已保存。"}, ensure_ascii=False)

    def parse_action(self, raw: str) -> dict:
        action = json.loads(raw)
        if action.get("action") not in {"tool", "finish"}:
            raise ValueError("action 必须是 tool 或 finish")
        return action

    def run(self, task: str) -> str:
        for step in range(1, self.max_steps + 1):
            action = self.parse_action(self.call_model(self.build_prompt(task)))
            if action["action"] == "finish":
                return action["answer"]

            tool_name = action.get("tool")
            if tool_name not in self.tools:
                return f"停止：未知工具 {tool_name}"

            result = self.tools[tool_name].handler(action.get("args", {}))
            self.observations.append(f"第 {step} 步 {tool_name} 返回：{result}")

        return "停止：超过最大步数"


if __name__ == "__main__":
    agent = FirstAgent(TOOLS)
    print(agent.run("计算 23 * 17 + 8，并保存运行记录"))
    print("\n执行轨迹：")
    for observation in agent.observations:
        print("-", observation)
