from dataclasses import dataclass

from llm_client import LLMMessage, OfflineLLMClient


@dataclass
class ReflectionRound:
    draft: str
    critique: str
    improved: str


class ReflectionAgent:
    def __init__(self, llm: OfflineLLMClient, max_rounds: int = 2) -> None:
        self.llm = llm
        self.max_rounds = max_rounds
        self.rounds: list[ReflectionRound] = []

    def draft(self, task: str) -> str:
        return f"初稿：{task} 可以用 Agent 完成。"

    def critique(self, task: str, draft: str) -> str:
        return self.llm.complete([LLMMessage("user", f"Reflection 审查任务：{task}\n草稿：{draft}")])

    def revise(self, draft: str, critique: str) -> str:
        return f"{draft} 已补充：需要列出工具、验证步骤和人工接管条件。审查意见：{critique}"

    def run(self, task: str) -> str:
        answer = self.draft(task)
        for _ in range(self.max_rounds):
            critique = self.critique(task, answer)
            improved = self.revise(answer, critique)
            self.rounds.append(ReflectionRound(answer, critique, improved))
            answer = improved
        return answer


if __name__ == "__main__":
    agent = ReflectionAgent(OfflineLLMClient())
    print(agent.run("设计一个资料调研 Agent"))
    print(agent.rounds[-1])
