import random


class BanditEnv:
    def __init__(self, rewards: list[float]) -> None:
        self.rewards = rewards

    def pull(self, arm: int) -> float:
        return self.rewards[arm] + random.uniform(-0.2, 0.2)


class EpsilonGreedyAgent:
    def __init__(self, arm_count: int, epsilon: float = 0.2) -> None:
        self.epsilon = epsilon
        self.counts = [0 for _ in range(arm_count)]
        self.values = [0.0 for _ in range(arm_count)]

    def choose(self) -> int:
        if random.random() < self.epsilon:
            return random.randrange(len(self.values))
        return max(range(len(self.values)), key=lambda arm: self.values[arm])

    def learn(self, arm: int, reward: float) -> None:
        self.counts[arm] += 1
        n = self.counts[arm]
        self.values[arm] += (reward - self.values[arm]) / n


if __name__ == "__main__":
    random.seed(7)
    env = BanditEnv([0.2, 0.5, 0.8])
    agent = EpsilonGreedyAgent(arm_count=3, epsilon=0.25)
    for _ in range(80):
        arm = agent.choose()
        agent.learn(arm, env.pull(arm))
    print("选择次数：", agent.counts)
    print("价值估计：", [round(value, 3) for value in agent.values])
