import math


def dot(left: list[float], right: list[float]) -> float:
    return sum(a * b for a, b in zip(left, right))


def softmax(row: list[float]) -> list[float]:
    shifted = [item - max(row) for item in row]
    values = [math.exp(item) for item in shifted]
    total = sum(values)
    return [item / total for item in values]


def attention(query: list[list[float]], key: list[list[float]], value: list[list[float]]):
    dim = len(query[0])
    scores = [[dot(q, k) / math.sqrt(dim) for k in key] for q in query]
    weights = [softmax(row) for row in scores]
    outputs = []
    for row in weights:
        outputs.append([
            sum(weight * value[index][column] for index, weight in enumerate(row))
            for column in range(len(value[0]))
        ])
    return outputs, weights


if __name__ == "__main__":
    query = [[1.0, 0.0], [0.8, 0.2], [0.0, 1.0]]
    key = [[1.0, 0.0], [0.7, 0.3], [0.0, 1.0]]
    value = [[10.0, 0.0], [6.0, 4.0], [0.0, 10.0]]
    outputs, weights = attention(query, key, value)
    print([[round(item, 3) for item in row] for row in weights])
    print([[round(item, 3) for item in row] for row in outputs])
