from collections import Counter


class SimpleBPE:
    def best_pair(self, words: list[list[str]]) -> tuple[str, str]:
        counts: Counter[tuple[str, str]] = Counter()
        for pieces in words:
            for index in range(len(pieces) - 1):
                counts[(pieces[index], pieces[index + 1])] += 1
        return counts.most_common(1)[0][0]

    def merge_pair(self, pieces: list[str], pair: tuple[str, str]) -> list[str]:
        merged = []
        index = 0
        while index < len(pieces):
            if index + 1 < len(pieces) and (pieces[index], pieces[index + 1]) == pair:
                merged.append(pieces[index] + pieces[index + 1])
                index += 2
            else:
                merged.append(pieces[index])
                index += 1
        return merged


if __name__ == "__main__":
    words = [list(word) + ["</w>"] for word in ["agent", "agents", "agenda"]]
    bpe = SimpleBPE()
    pair = bpe.best_pair(words)
    print(pair)
    print([bpe.merge_pair(word, pair) for word in words])
