RAG — rerank & évals

RAGmiddle

: @2
+ rerankΔ

pipeline.py

def retrieve(query, k=20):
    # этап 1: дешёвый широкий отбор (вектор / BM25) — быстро, грубо
    return vector_search(query, top_k=k)

def rerank(query, candidates):
    # этап 2: кросс-энкодер читает (query, passage) вместе — точно, дороже
    scores = cross_encoder.predict([(query, c.text) for c in candidates])
    return [c for _, c in sorted(zip(scores, candidates), reverse=True)]

def ndcg_at_k(ranking, gold, k):
    dcg  = sum(1/log2(i+2) for i, d in enumerate(ranking[:k]) if d in gold)
    idcg = sum(1/log2(i+2) for i in range(min(k, len(gold))))
    return dcg / idcg

def evaluate(dataset, k=10):
    # рег-гейт: гоняем на каждом изменении, падение метрик = блок мержа
    base = mean(ndcg_at_k(retrieve(q),         gold, k) for q, gold in dataset)
    rr   = mean(ndcg_at_k(rerank(q, retrieve(q)), gold, k) for q, gold in dataset)
    assert rr >= base, "rerank ухудшил nDCG — не мержим"
    return {"baseline": base, "rerank": rr}