from collections import Counter, defaultdict


class NGramLanguageModel:
    def __init__(self, n: int = 2) -> None:
        self.n = n
        self.table: dict[tuple[str, ...], Counter[str]] = defaultdict(Counter)

    def fit(self, corpus: list[str]) -> None:
        for sentence in corpus:
            tokens = ["<s>"] * (self.n - 1) + sentence.split() + ["</s>"]
            for index in range(len(tokens) - self.n + 1):
                context = tuple(tokens[index:index + self.n - 1])
                target = tokens[index + self.n - 1]
                self.table[context][target] += 1

    def next_token_distribution(self, context_words: list[str]) -> list[tuple[str, float]]:
        context = tuple(context_words[-(self.n - 1):])
        counts = self.table.get(context, Counter())
        total = sum(counts.values())
        if total == 0:
            return []
        return [(token, count / total) for token, count in counts.most_common()]


if __name__ == "__main__":
    lm = NGramLanguageModel(n=2)
    lm.fit(["agent uses tools", "agent reads memory", "agent uses memory"])
    print(lm.next_token_distribution(["agent"]))
