from llm_client import LLMMessage, OfflineLLMClient
from tools import ToolRegistry, default_registry


class ReActAgent:
    def __init__(self, llm: OfflineLLMClient, tools: ToolRegistry, max_steps: int = 4) -> None:
        self.llm = llm
        self.tools = tools
        self.max_steps = max_steps
        self.trace: list[str] = []

    def build_prompt(self, question: str) -> str:
        history = "\n".join(self.trace) or "暂无"
        return (
            "你是 ReAct 智能体。按 Thought/Action/Action Input 或 Final Answer 输出。\n"
            f"工具：\n{self.tools.describe()}\n"
            f"问题：{question}\n"
            f"历史：\n{history}\n"
        )

    def parse_action(self, text: str) -> tuple[str, str] | None:
        lines = [line.strip() for line in text.splitlines()]
        action = next((line.split(":", 1)[1].strip() for line in lines if line.startswith("Action:")), None)
        action_input = next((line.split(":", 1)[1].strip() for line in lines if line.startswith("Action Input:")), "")
        if action:
            return action, action_input
        return None

    def run(self, question: str) -> str:
        for _ in range(self.max_steps):
            response = self.llm.complete([LLMMessage("user", self.build_prompt(question))])
            self.trace.append(response)
            if "Final Answer:" in response:
                return response.split("Final Answer:", 1)[1].strip()
            parsed = self.parse_action(response)
            if not parsed:
                return "停止：模型没有给出可执行动作"
            tool_name, tool_input = parsed
            observation = self.tools.get(tool_name).run(tool_input)
            self.trace.append(f"Observation: {observation}")
        return "停止：超过最大步数"


if __name__ == "__main__":
    agent = ReActAgent(OfflineLLMClient(), default_registry())
    print(agent.run("18 * 7 + 6 等于多少？"))
    print("\n".join(agent.trace))
