from dataclasses import dataclass, field

from llm_client import LLMMessage, OfflineLLMClient
from tools import ToolRegistry, default_registry


@dataclass
class PlanState:
    question: str
    plan: list[str] = field(default_factory=list)
    results: list[str] = field(default_factory=list)


class PlanAndSolveAgent:
    def __init__(self, llm: OfflineLLMClient, tools: ToolRegistry) -> None:
        self.llm = llm
        self.tools = tools

    def plan(self, question: str) -> list[str]:
        raw = self.llm.complete([LLMMessage("user", f"Plan-and-Solve 规划：{question}")])
        return [line.split(".", 1)[1].strip() for line in raw.splitlines() if "." in line]

    def execute_step(self, step: str) -> str:
        if "计算" in step:
            return self.tools.get("calculator").run("18 * 7 + 6")
        if "事实" in step:
            return self.tools.get("local_search").run("Plan-and-Solve")
        return "根据已有结果组织最终回答"

    def run(self, question: str) -> PlanState:
        state = PlanState(question=question, plan=self.plan(question))
        for step in state.plan:
            state.results.append(f"{step} => {self.execute_step(step)}")
        return state


if __name__ == "__main__":
    agent = PlanAndSolveAgent(OfflineLLMClient(), default_registry())
    result = agent.run("解释 Plan-and-Solve，并计算 18 * 7 + 6")
    print(result.plan)
    print(result.results)
