Semantic search, or embedding-based retrieval, has been a key component within many AI applications. Yet, a surprising number of applications I’ve seen still don’t do reranking, despite the relative ease of implementation.
If you’ve ever built a RAG pipeline and thought “the results are okay but not great”, the solution isn’t always to choose a better embedding model. Instead, you should consider including a reranking step, and cross-encoders are probably your best bet.
This article covers what cross-encoders are, why they’re so good at reranking, how to fine-tune them on your own data, and some ideas for pushing them even further.
All the code is available at https://github.com/ianhohoho/cross-encoder-and-reranking-demo.
The Retrieval Problem
Most semantic search systems use bi-encoders. They encode your query into a vector, encode your documents into vectors, and find the closest matches. It’s a fast operation that scales and gives you moderately decent results most of the time.
However, encoding the query and document independently throws away the possibility of interaction signals. And that’s because the embedding model has to compress all semantics into a single vector before it ever compares anything.
Here’s a concrete example. You search “cheap hotels in Tokyo” and get back:
- “Luxury hotels in Tokyo starting at $500/night”
- “Budget hostels in Tokyo at $30/night”
- “Cheap flights to Tokyo”
Result #1 scores high because it matches “hotels” and “Tokyo.” Result #3 matches “cheap” and “Tokyo.” But result #2 — the one you actually want — might rank below both because “cheap” and “budget” aren’t that close in embedding space.
A bi-encoder can’t reason about the relationship between “cheap” in your query and “$500/night” in the document. It just sees token overlap in the compressed vectors. A cross-encoder ‘reads’ the query and document together at one go, so it catches that $500/night contradicts “cheap” and ranks it lower. At least, that’s the layman way of explaining it.
The Two-Stage Pattern
In the real world, we can use a combination of bi-encoders and cross-encoders to achieve the most optimal retrieval and relevance performance.
- Stage 1: Fast, approximate retrieval. Cast a wide net to achieve high recall with a bi-encoder or BM25. Get your top k candidates.
- Stage 2: Precise reranking. Run a cross-encoder over those candidates in a pair-wise manner. Get a much better ranking that directly measures relevance.
It’s actually already quite a standard pattern in production, at least for teams at the frontier:
- Cohere offers Rerank as a standalone API — designed to sit on top of any first-stage retrieval. Their rerank-v4.0-pro is one such example.
- Pinecone has built-in reranking with hosted models, describing it as “a two-stage vector retrieval process to improve the quality of results”. One of the multilingual models they offer is bge-reranker-v2-m3 , for which the HuggingFace card can be found here.
- In fact, this practice has been around for a pretty long time already. Google announced back in 2019 that BERT is used to re-rank search results by reading queries & snippets together to judge relevance.
- LangChain and LlamaIndex both have built-in reranking steps for RAG pipelines.
Why Not Just Use Cross-Encoders for Everything?
Well, it’s a compute problem.
A bi-encoder encodes all your documents once at index time, and so the upfront complexity is O(n). At query time, you just encode the query and conduct an approximate nearest-neighbor lookup. With FAISS or any ANN index, that’s effectively O(1).
A cross-encoder can’t precompute anything. It needs to see the query and document together. So at query time, it runs a full transformer forward pass for every candidate of (query, document).
At the risk of failing my professors who used to teach about complexity, each pass costs O(L × (s_q + s_d)² × d), because that’s L layers, the combined sequence length squared, times the hidden dimension.
For a corpus of 1M documents, that’s 1M forward passes per query. Even with a small model like MiniLM (6 layers, 384 hidden dim), you’re looking at a silly amount of of GPU time per query so that’s obviously a non-starter.
But what if we narrowed it down to about 100+ candidates? On a single GPU, that would probably take just several hundred milliseconds.
That’s why two-stage retrieval works: retrieve cheaply and then rerank precisely.
How Bi-Encoders and Cross-Encoders Work
Bi-Encoder Architecture
A bi-encoder uses two transformer encoders, with both query and document producing a fixed-size embedding.
Query → [Transformer] → query_embedding (768-dim vector)
↓
cosine similarity
↑
Doc → [Transformer] → doc_embedding (768-dim vector)
The similarity score is just cosine similarity between the two vectors, and it’s fast because you can precompute all document embeddings and use approximate nearest-neighbor (ANN) search.
However, the key limitation is that the model compresses all meaning into one vector before any comparison happens. Query and document tokens never interact, and so it’s akin to summarising two essays separately and then comparing between them. You lose all sorts of nuances as a result.
Cross-Encoder Architecture
A cross-encoder takes a different approach. It concatenates the query and document into one input sequence before feeding it through a single transformer, something like that
Input: [CLS] query tokens [SEP] document tokens [SEP]
↓
[Transformer — full self-attention across ALL tokens]
↓
[CLS] → Linear Head → sigmoid → relevance score (0 to 1)
Every token in the query can attend to every token in the document. Consequently, the output isn’t an embedding, but a directly produced relevance score between the query and documents.
How Cross-Encoders Are Trained
Why not train a cross-encoder from scratch? Well, just like the LLMs themselves, training a transformer from scratch requires massive compute and data. BERT was trained on 3.3 billion words so… you probably don’t want to redo that.
Instead, you can use transfer learning. Take a pre-trained transformer that already understands language (grammar, semantics, word relationships), and teach it one new skill, which is “given a query and document together, is this document relevant?”
The setup looks something like that:
- Start with a pre-trained transformer (BERT, RoBERTa, MiniLM).
- Add a linear classification head on top of the [CLS] token, and this maps the hidden state to a single logit.
- Apply sigmoid to get a (relevance) score between 0 and 1. Or sometimes Softmax over pairs, for example for positive vs negative examples.
- Train on (query, document, relevance_label) triples.
The most well-known training dataset is MS MARCO, which contains about 500k queries from Bing with human-annotated relevant passages.
For the loss function, you have a few options:
- Binary cross-entropy (BCE): This treats the problem as classification, basically asking “is this document relevant or not?”.
- MSE loss: More commonly used for distillation (briefly mentioned later). Instead of hard labels, you match soft scores from a stronger teacher model.
- Pairwise margin loss: Given one relevant (positive) and one irrelevant (negative) document, ensure the relevant one scores higher by a margin.
The training loop is actually pretty straightforward too: sample a query, pair it with positive and negative documents, concatenate each pair as [CLS] query [SEP] document [SEP], do a forward pass, compute loss, backprop, rinse and repeat.
In practice, most fine-tuning use-cases would start from an already trained cross-encoder like cross-encoder/ms-marco-MiniLM-L-6-v2 and further fine-tune on their domain-specific data.
Why Cross-Attention Matters: The Technical Deep Dive
We’ve kept things pretty abstract for now, so this section gets into the core of why cross-encoders are better. Let’s get into the math.
In any transformer, self-attention computes:
Each token i produces a query vector:
A key vector:
and a value vector:
The attention score between tokens i and j is:
This score determines how much token i “pays attention to” token j.
In a bi-encoder, the query and document are separate sequences. The query has tokens [q1,q2,…,qm] and the document has [d1,d2,…,dn]. The attention matrix for the query is m×m and for the document, n×n.
Specifically, there are zero terms for:
No query token ever attends to any document token. The model independently compresses each into a single vector, then compares:
In a cross-encoder, the input is one concatenated sequence [q1,…,qm,d1,…,dn] and The attention matrix is (m+n)×(m+n).
Now attention terms exists. In a very approximate manner, the query token for “cheap” would attend to the document token for “$500”, and the model learns through training that this combination means “not relevant.” This cross-attention happens at every layer, building increasingly abstract relationships.
Multi-head attention makes this even more powerful. Each attention head has its own weight matrices, so different heads learn to detect different types of relationships simultaneously:
- One head might learn lexical matching same or similar words
- Another might learn semantic equivalence — “cheap” ↔ “budget”
- Another might learn contradiction detection — “without sugar” vs “contains sugar”
- Another might learn entity matching — the same person or place referenced differently
At the end of it, the outputs of all heads are concatenated and projected:
With multiple heads across multiple players, the model has many independent heads examining query-document interaction at every level of abstraction. Theoretically, that’s why cross-encoders are so much more expressive than bi-encoders.
But of course the tradeoff is then compute: attention costs more and nothing is precomputed.
Enough theory. Let’s look at actual code.
I’ve built a companion repo with eight example .py files that progress from basic implementation to distillation pipelines and full latency-profiled ColBERT implementations.
Each one runs end-to-end and you can follow along as you read through this section.
The first is pretty straightforward:
def predict_scores(self, query: str, documents: list[str]) -> list[float]:
pairs = [(query, doc) for doc in documents]
scores = self._model.predict(pairs)
return [float(s) for s in scores]
Under the hood, all my code does is pair the query with every document and score each pair through the cross-encoder:
def predict_scores(self, query: str, documents: list[str]) -> list[float]:
pairs = [(query, doc) for doc in documents]
scores = self._model.predict(pairs)
return [float(s) for s in scores]
We begin by feeding the query “How does photosynthesis work in plants?”, along with 10 documents.
- Five are about photosynthesis
- Five are noise about stock markets, electric vehicles, and ancient Rome.
Naturally the photosynthesis documents float to the top:
— Reranked Order (10 results) —
#1 (score: 8.0888) [was #0] Photosynthesis is the process by which green plants convert…
#2 (score: 3.7970) [was #4] During photosynthesis, carbon dioxide and water are converted…
#3 (score: 2.4054) [was #6] Chloroplasts are the organelles where photosynthesis takes…
#4 (score: 1.8762) [was #2] Plants use chlorophyll in their leaves to absorb light…
#5 (score: -9.7185) [was #8] The light-dependent reactions occur in the thylakoid…
…
#10 (score: -11.2886) [was #7] Machine learning algorithms can process vast amounts…
And there’s really nothing more to it. The model concatenates the query and document as [CLS] query [SEP] document [SEP], runs a forward pass, and produces a relevance score, order by descending.
Picking the Right Model
The natural follow-up question: which cross-encoder should I use?
We benchmark four MS MARCO models on the same query — from tiny to large.
I run all four models run in parallel via ThreadPoolExecutor, so you get results in the time of the slowest model rather than the sum. Here’s what the output looks like:
— Speed Comparison —
Model Time (s) Docs/sec
—————————————- ——— ———-
ms-marco-MiniLM-L-12-v2 0.560 14.3
ms-marco-electra-base 0.570 14.0
ms-marco-MiniLM-L6-v2 0.811 9.9
ms-marco-TinyBERT-L-2-v2 1.036 7.7
— Ranking Order (by document index) —
ms-marco-MiniLM-L6-v2: 0 → 2 → 4 → 6 → 7 → 1 → 3 → 5
ms-marco-TinyBERT-L-2-v2: 2 → 4 → 0 → 6 → 5 → 3 → 1 → 7
ms-marco-MiniLM-L-12-v2: 2 → 0 → 4 → 6 → 1 → 7 → 3 → 5
ms-marco-electra-base: 2 → 4 → 0 → 6 → 1 → 3 → 7 → 5
All four models agree on the top-4 documents (0, 2, 4, 6), just shuffled slightly.
- TinyBERT is the odd one out , which puts document 5 (irrelevant) in 5th place while the others push it to the bottom.
Generally speaking:
- TinyBERT-L2-v2: extremely fast but least accurate — use for low-latency or edge scenarios.
- MiniLM-L6-v2: best balance of speed and quality — use as the default for most reranking tasks.
- MiniLM-L12-v2: slightly more accurate but slower — use when maximizing ranking quality matters.
- electra-base: (older) and larger and slower with no clear advantage — generally not recommended over MiniLM.
Fine-Tuning: Making the Model Understand Your Domain
Many pre-trained cross-encoders are still generalists, because they are trained on datasets like MS MARCO, which itself is a massive dataset of Bing search queries paired with web passages.
If your domain is something like legal contracts, medical records, or cybersecurity incident reports, the generalist model might not rank your content correctly. For example, it does not know that “force majeure” is a contract term, not a military phrase.
Fine-tuning might just do the trick.
There are two approaches depending on what kind of training data you have, and the repo includes an example of each.
When you have soft scores, you can use MSE loss.
- A larger teacher model scores your query-document pairs, and the student learns to reproduce those continuous scores:
trainer = MSEDistillationTrainer(student_model_name=STUDENT_MODEL, config=config)
output_path = trainer.train(train_dataset)
When you have binary labels, you can use BCE loss.
- Each training pair is simply marked relevant or not relevant:
finetuner = BCEFineTuner(model_name=BASE_MODEL, config=config)
output_path = finetuner.train(train_dataset)
Both approaches are pretty straightforward to set up. Under the hood it’s as simple as:
class BCEFineTuner:
“””Fine-tune a cross-encoder with binary cross-entropy loss.
Suitable for binary relevance judgments (relevant/not-relevant).
Args:
model_name: HuggingFace model name to fine-tune.
config: Training configuration.
Example:
>>> finetuner = BCEFineTuner(“cross-encoder/ms-marco-MiniLM-L6-v2”)
>>> finetuner.train(train_dataset)
“””
def __init__(
self,
model_name: str = “cross-encoder/ms-marco-MiniLM-L6-v2”,
config: TrainingConfig | None = None,
) -> None:
self._config = config or TrainingConfig()
self._model = CrossEncoder(model_name, num_labels=1)
self._model_name = model_name
@property
def model(self) -> CrossEncoder:
“””Return the model being fine-tuned.”””
return self._model
def train(
self,
train_dataset: Dataset,
eval_dataset: Dataset | None = None,
) -> Path:
“””Run BCE fine-tuning.
The dataset should have columns: “sentence1”, “sentence2”, “label”
where “label” is 0 or 1.
Args:
train_dataset: Dataset with query-document-label triples.
eval_dataset: Optional evaluation dataset.
Returns:
Path to the saved model directory.
“””
from sentence_transformers.cross_encoder.losses import BinaryCrossEntropyLoss
loss = BinaryCrossEntropyLoss(self._model)
args = self._config.to_training_arguments(has_eval=eval_dataset is not None)
trainer = CrossEncoderTrainer(
model=self._model,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
loss=loss,
)
trainer.train()
output_path = Path(self._config.output_dir) / “final_model”
self._model.save(str(output_path))
return output_path
The interesting part is the evaluation, and specifically what happens when you throw adversarial distractors at the model.
After training, I test on cases where each query is paired with a relevant document and a hard distractor. In my definition, a hard distractor is a document that shares keywords but is actually about something different. For this evaluation, a “pass” just means the model scored the relevant doc higher:
b_scores = base_model.predict_scores(case.query, docs)
f_scores = fine_tuned.predict_scores(case.query, docs)
b_pass = b_scores[0] > b_scores[1]
f_pass = f_scores[0] > f_scores[1]
We split the eval into ‘SEEN’ topics (same topics as training, different examples) and ‘UNSEEN’ topics (entirely new). The ‘UNSEEN’ split is the one that matters because it proves the model learned the domain rather than memorising the training set. Just as we would for most ML evaluation workflows.
Here’s the MSE fine-tuning result:
Base Model Fine-Tuned
Overall accuracy: 15/20 ( 75%) 20/20 (100%)
Seen topics: 7/10 10/10
Unseen topics: 8/10 10/10
Fine-tuning fixed 5 case(s) the base model got wrong.
Average confidence: 316x improvement (gap: +0.0001 -> +0.0386)
From the above, we see that fine-tuning fixed the 5 cases where the base model got wrong, and there was a significant improvement in average confidence. The base model’s correct answers were barely correct (gap of +0.0001), but after fine-tuning, the gap widens to +0.0386. So, the model isn’t just getting the right answer more often, it’s getting it with quite a bit of conviction.
The BCE fine-tuning result on legal data (Example 4) is even clearer:
Base Model Fine-Tuned
Overall accuracy: 6/20 ( 30%) 19/20 ( 95%)
Seen topics: 2/10 9/10
Unseen topics: 4/10 10/10
Accuracy increasing from 30% to 95% means that the original base model was somehow worse than random on legal documents. After fine-tuning on just 72 training pairs , 12 legal topics with 6 pairs each, the model gets 19 out of 20 right. And notice that unseen topics went from 4/10 to 10/10. In a sense it learnt the domain of legal reasoning, not just the training examples.
The output in my repo marks each case where <– fine-tuning fixed this,essentially where the base model failed but the fine-tuned model got it right.
Here’s one illustrative example:
[SEEN ] What qualifies as wrongful termination?
Relevant: Terminating an employee in retaliation for reporting safety viola…
Distractor: The wrongful termination of the TV series qualified it for a fan …
Base: FAIL (gap: -8.3937) Fine-tuned: PASS (gap: +3.8407)
<– fine-tuning fixed this
The base model confidently chose the TV series distractor due to keyword matches. After fine-tuning, it correctly identifies the employment law document instead.
One thing I really want to call out, as I was figuring all of this out, is that your distractors can strongly influence what your model learns. Example 4 trains on legal data where the distractors come from related legal topics, for example, a contract dispute distractor for a tort case, a regulatory compliance distractor for a criminal law query. (No I am not a legal expert I got AI to generate these examples for me)
The challenge is that these examples share vocabulary like “plaintiff”, “jurisdiction”, “statute”. If you used cooking recipes as distractors for legal queries, the model would learn nothing because it can already tell those apart. So the hard negatives from the same domain are what force it to learn fine-grained distinctions.
In many ways, these shares similarities with how I’ve always viewed imbalanced datasets when doing supervised training. The way you select (downsample) your majority class is extremely important. Pick the observations that look really similar to the minority class, and you have yourself a dataset that will train a really powerful (precise) discriminator.
Semantic Query Caching
In production, users ask the same question a dozen different ways. “How do I reset my password?” and “I forgot my password, how do I change it?” should ideally return same cached results rather than triggering two separate and expensive search, reranking and generation operations.
The idea is simple: use a cross-encoder fine-tuned on something like the Quora duplicate question dataset to detect semantic duplicates at query time.
def find_duplicate(self, query: str) -> tuple[CacheEntry | None, float]:
if not self._cache:
return None, 0.0
…
cached_queries = [entry.query for entry in self._cache]
scores = self._reranker.predict_scores(query, cached_queries)
best_idx = max(range(len(scores)), key=lambda i: scores[i])
best_score = scores[best_idx]
if best_score >= self._threshold:
return self._cache[best_idx], best_score
return None, best_score
Every incoming query gets scored against everything already in the cache. If the best score exceeds a threshold, it’s a duplicate, so return the cached ranking. If not, run the full reranking pipeline and cache the new result.
To test this properly, we simulate 50 user queries across 12 topics. Each topic starts with a “seed” query that misses the cache, followed by paraphrase variants that should hit:
(“How do I reset my password?”, None), # MISS – first time
(“How can I reset my password?”, 1), # HIT → query #1
(“How to reset my password?”, 1), # HIT → query #1
(“I forgot my password, how do I change it?”, 1), # HIT → query #1
The output shows the cache building up over time. Early queries are all misses, but once the cache has 12 seed queries, everything that follows is a hit:
# Result Time Query Matched
1 ✗ MISS 0ms How do I reset my password? –
2 ✗ MISS 2395ms How do I export my data from the platform? –
…
4 ✓ HIT 844ms How can I reset my password? → #01 (0.99)
…
25 ✓ HIT 61ms I forgot my password, how do I change it? → #01 (0.99)
…
49 ✓ HIT 17ms I need to reset my password, how? → #01 (0.92)
50 ✓ HIT 25ms Can I add or remove people from my team? → #12 (0.93)
The ground-truth labels let us compute precision and recall:
Total queries: 50
Cache hits: 38 (expected 38)
Cache misses: 12 (expected 12)
HIT precision: 38 / 38 (100%)
MISS precision: 12 / 12 (100%)
Overall accuracy: 50 / 50 (100%)
Without caching: 50 rankings needed. With caching: 12 performed. 76% savings.
100% accuracy, and every single hit is correct, every single miss is genuinely new. As a result, we avoid 76% (38/50) of ranking operations in our test dataset.
Of course, the cache comparison itself has O(n) cost against the cache size. In a real system you’d probably want to limit the cache size or use a more efficient index. But the core idea of using a cross-encoder trained for paraphrase detection to gate expensive downstream operations is sound and production-tested.
The Multi-Stage Funnel
Bringing it all together in production, you can build a funnel where each stage trades speed for precision, and the candidate set shrinks at every step.
For example, 50 documents → 20 (bi-encoder) → 10 (cross-encoder) → 5 (LLM)
The implementation is pretty straightforward:
def run_pipeline(self, query, documents, stage1_k=20, stage2_k=10, stage3_k=5):
s1 = self.stage1_biencoder(query, documents, top_k=stage1_k)
s2 = self.stage2_crossencoder(query, documents, s1.doc_indices, top_k=stage2_k)
s3 = self.stage3_llm(query, documents, s2.doc_indices, top_k=stage3_k)
return [s1, s2, s3]
Stage 1 is a bi-encoder: encode query and documents independently, rank by cosine similarity. Cheap enough for thousands of documents. Take the top 20.
Stage 2 is the cross-encoder we’ve been discussing. Score the query-document pairs with full cross-attention. Take the top 10.
Stage 3 is an optional step where we can utilise an LLM to do list-wise reranking. Unlike the cross-encoder which scores each pair independently, the LLM sees all 10 candidates at once in a single prompt and produces a global ordering. This is the only stage that can reason about relative relevance: “Document A is better than Document B because…”
In my code, the LLM stage calls OpenRouter and uses structured output to guarantee parseable JSON back:
RANKING_SCHEMA = {
“name”: “ranking_response”,
“strict”: True,
“schema”: {
“type”: “object”,
“properties”: {
“ranking”: {
“type”: “array”,
“items”: {“type”: “integer”},
},
},
“required”: [“ranking”],
“additionalProperties”: False,
},
}
The test corpus has 50 documents with ground-truth relevance tiers: highly relevant, partially relevant, distractors, and irrelevant.
The output shows noise getting filtered at each stage:
Stage Relevant Partial Noise Precision
Bi-Encoder (all-MiniLM-L6-v2) 10/20 7/20 3/20 85%
Cross-Encoder (cross-encoder/ms-marco-MiniLM…) 10/10 0/10 0/10 100%
LLM (google/gemini-2.0-flash-001) 5/5 0/5 0/5 100%
Total pipeline time: 2243ms
The bi-encoder’s top-20 let through 3 noise documents and 7 partial matches. The cross-encoder eliminated all of them, 10 for 10 on relevant documents. The LLM preserved that precision while cutting to the final 5.
The timing breakdown is worth noting too: the bi-encoder took 176ms to score all 50 documents, the cross-encoder took 33ms for 20 pairs, the LLM took 2034ms for a single API call, by far the slowest stage, but it only ever sees 10 documents.
Knowledge Distillation: Teaching the Bi-Encoder to Think Like a Cross-Encoder
The multi-stage funnel works, but the generic bi-encoder was never trained on your domain data. It retrieves based on surface-level semantic similarity, which means it might still miss relevant documents or let through irrelevant ones.
What if you could teach the bi-encoder to rank like the cross-encoder?
That’s the essence of distillation. The cross-encoder (teacher) scores your training pairs. The bi-encoder (student) learns to reproduce those scores. At inference time, you throw away the teacher and just use the fast student.
distiller = CrossEncoderDistillation(
teacher_model_name=”cross-encoder/ms-marco-MiniLM-L6-v2″,
student_model_name=”all-MiniLM-L6-v2″,
)
output_path = distiller.train(
training_pairs=TRAINING_PAIRS,
epochs=4,
batch_size=16,
)
The train method that I’ve implemented basically looks like this:
train_dataset = Dataset.from_dict({
“sentence1”: [q for q, _, _ in training_pairs],
“sentence2”: [d for _, d, _ in training_pairs],
“score”: [s for _, _, s in training_pairs],
})
loss = losses.CosineSimilarityLoss(self._student)
args = SentenceTransformerTrainingArguments(
output_dir=output_dir,
num_train_epochs=epochs,
per_device_train_batch_size=batch_size,
learning_rate=2e-5,
warmup_steps=0.1,
logging_steps=5,
logging_strategy=”steps”,
save_strategy=”no”,
)
trainer = SentenceTransformerTrainer(
model=self._student,
args=args,
train_dataset=train_dataset,
loss=loss,
)
trainer.train()
To demonstrate that this actually works, we chose a deliberately difficult domain: cybersecurity. In cybersecurity, every document shares the same vocabulary. Attack, vulnerability, exploit, malicious, payload, compromise, breach, these words appear in documents about SQL injection, phishing, buffer overflows, and ransomware alike. A generic bi-encoder maps all of them to roughly the same region of embedding space and so it cannot tell them apart.
The AI-generated training dataset I have uses hard distractors from confusable subtopics:
- SQL injection ↔ command injection (both “injection attacks”)
- XSS ↔ CSRF (both client-side web attacks)
- phishing ↔ pretexting (both social engineering)
- buffer overflow ↔ use-after-free (both memory corruption)
After training, we run a three-way comparison on 30 test cases, 15 from attack types the model trained on, and 15 from attack types it’s never seen:
t_scores = teacher.generate_teacher_scores(case.query, docs) # cross-encoder
b_scores = teacher.generate_student_scores(case.query, docs) # base bi-encoder
d_scores = trained.generate_student_scores(case.query, docs) # distilled bi-encoder
Here’s what the output looks like for a typical case:
[SEEN ] What is a DDoS amplification attack?
Teacher: rel=+5.5097 dist=-6.5875
Base: PASS (rel=0.7630 dist=0.3295 gap=+0.4334)
Distilled: PASS (rel=0.8640 dist=0.2481 gap=+0.6160)
The teacher (cross-encoder) provides the ground truth scores. Both the base and distilled bi-encoders get this one right, but look at the gap: the distilled model is 42% more confident. In a way, it pushes the relevant document further from the distractor in embedding space.
The summary of all tests tells the full story of performance:
Base Student Distilled Student
Overall accuracy: 29/30 ( 96.7%) 29/30 ( 96.7%)
Seen topics: 15/15 15/15
Unseen topics: 14/15 14/15
Avg relevance gap: +0.2679 +0.4126
Same accuracy, but 1.5x wider confidence margins. Both models fail on one edge case : the “memory-safe languages” query, where even the cross-encoder teacher disagreed with the expected label. But across the board, the distilled student separates relevant from irrelevant documents more decisively.
This is one of the more innovative and potentially impactful technique that I’ve been experimenting in this project: you get cross-encoder quality at bi-encoder speed, at least for your specific domain… assuming you have enough data. So think hard about what kinds of data you would want to collect, label, and process if you think this kind of distillation would be useful to you down the road.
ColBERT-like Late Interaction
So now we have a spectrum. On one end, bi-encoders are fast, can precompute, but there is no interaction between query and document tokens. On the other end, cross-encoders have full interaction, are more accurate, but nothing is precomputable. Is there something in between?
ColBERT (COntextualized Late interaction over BERT) is one such middle ground. The name tells you the architecture. “Contextualised” means the token embeddings are context-dependent (unlike word2vec where “bank” always maps to the same vector, BERT’s representation of “bank” changes depending on whether it appears near “river” or “account”). “Late interaction” means query and document are encoded separately and only interact at the very end, via operationally inexpensive dot products rather than expensive transformer attention. And “BERT” is the backbone encoder.
That “late” part is the key distinction. A cross-encoder does early interaction in the sense that query and document tokens attend to each other inside the transformer. A bi-encoder does no interaction, just cosine similarity between two pooled vectors. ColBERT sits in between.
When a bi-encoder encodes a sentence, it produces one embedding per token, then pools them, typically by averaging into a single vector, for example:
“How do quantum computers achieve speedup?”
→ 9 token embeddings (each 384-dim)
→ mean pool
→ 1 vector (384-dim): [0.12, -0.34, 0.56, …]
That single vector is what gets compared via cosine similarity. It’s fast and it works, but the pooling step crushes the richness of information. The word “quantum” had its own embedding, and so did “speedup.” After mean pooling, their individual signals are averaged together with filler tokens like “do” and “how.” The resulting vector is a blurry summary of the whole sentence.
The ColBERT-like late interaction skips the pooling by keeping all 9 token embeddings:
“How do quantum computers achieve speedup?”
→
“how” → [0.05, -0.21, …] (384-dim)
“quantum” → [0.89, 0.42, …] (384-dim)
“computers” → [0.67, 0.31, …] (384-dim)
“speedup” → [0.44, 0.78, …] (384-dim)
… 9 tokens total → (9 × 384) matrix
Same for the documents we are comparing against. A 30-token document becomes a (30 × 384) matrix instead of a single vector.
Now you need a way to score the match between a (9 × 384) query matrix and a (30 × 384) document matrix. That’s MaxSim.
For each query token, find its best-matching document token (the one with the highest cosine similarity) and take that maximum. Then sum all the maxima across query tokens.
@staticmethod
def _maxsim(q_embs, d_embs):
sim_matrix = torch.matmul(q_embs, d_embs.T)
max_sims = sim_matrix.max(dim=1).values
return float(max_sims.sum())
Let’s trace through the math. The matrix multiply `(9 × 384) × (384 × 30)` produces a `9 × 30` similarity matrix. Each cell tells you how similar one query token is to one document token. Then `.max(dim=1)` takes the best document match for each query token , 9 values. Then `.sum()` adds them up into one score.
The query token “quantum” scans all 30 document tokens and finds its best match , probably something like “qubits” with similarity ~0.85. The query token “speedup” finds something like “faster” at ~0.7. Meanwhile, filler tokens like “how” and “do” match weakly against everything (~0.1). Sum those 9 maxima and you get a score like 9.93, just as an example.
Why does this work better than a single pooled vector? Because the token-level matching preserves fine-grained signal. The query token “quantum” can specifically latch onto the document token “qubit” via their embedding similarity, even though they’re different words.
With mean pooling, that precise match gets averaged away into a blurry centroid where “quantum” and “how” contribute equally.
The key advantage, and the reason you’d consider ColBERT-like late interaction in production, is pre-indexing. Because documents are encoded independently of the query, you can encode your entire corpus offline and cache the token embeddings:
def index(self, documents):
self._doc_embeddings = []
for doc in documents:
emb = self._model.encode(doc, output_value=”token_embeddings”)
tensor = torch.nn.functional.normalize(torch.tensor(emb), dim=-1)
self._doc_embeddings.append(tensor)
At search time, you only encode the query, one forward pass, and then run dot products against the cached embeddings. The cross-encoder would need to encode all 60 (query, document) pairs from scratch.
How close does it get to cross-encoder quality? Here’s the summary from running 10 queries across a 60-document corpus spanning quantum computing, vaccines, ocean chemistry, renewable energy, ML, astrophysics, genetics, blockchain, microbiology, and geography:
Ranking agreement (ColBERT vs cross-encoder ground truth):
Avg Kendall’s tau: +0.376
Avg top-3 overlap: 77%
Avg top-5 overlap: 92%
Latency breakdown:
ColBERT indexing: 358.7ms (one-time, 60 docs)
ColBERT queries: 226.4ms total (22.6ms avg per query)
Cross-encoder: 499.1ms total (49.9ms avg per query)
Query speedup: 2.2x faster
92% top-5 overlap, so most of the times it’s retrieving the same documents; it just occasionally shuffles the within-topic ordering. For most applications, that’s good enough, and at 2.2x faster per query.
And the real power comes when you observe what happens under load.
I collected 100 real processing time samples for each system, then simulated a single-server queue at increasing QPS levels. Requests arrive at fixed intervals, queue up if the server is busy, and we measure the total response time (queue wait + processing):
===========================================================================
LATENCY PROFILING
===========================================================================
Raw processing time (100 samples per system):
p50 p95 p99 p99.9 max
───────────────────────────────────────────────────────
ColBERT 20.4ms 30.8ms 54.2ms 64.3ms 64.3ms
Cross-encoder 45.2ms 56.7ms 69.0ms 72.1ms 72.1ms
===========================================================================
QPS SIMULATION (single-server queue, 1000 requests per level)
===========================================================================
Response time = queue wait + processing time.
When QPS exceeds throughput, requests queue and tail latencies explode.
QPS: 5 (ColBERT util: 10%, cross-encoder util: 23%)
p50 p95 p99 p99.9 max
───────────────────────────────────────────────────────
ColBERT 20.4ms 30.8ms 54.2ms 64.3ms 64.3ms
Cross-encoder 45.2ms 56.7ms 69.0ms 72.1ms 72.1ms
QPS: 10 (ColBERT util: 20%, cross-encoder util: 45%)
p50 p95 p99 p99.9 max
───────────────────────────────────────────────────────
ColBERT 20.4ms 30.8ms 54.2ms 64.3ms 64.3ms
Cross-encoder 45.2ms 56.7ms 69.0ms 72.1ms 72.1ms
QPS: 20 (ColBERT util: 41%, cross-encoder util: 90%)
p50 p95 p99 p99.9 max
───────────────────────────────────────────────────────
ColBERT 20.4ms 34.0ms 62.9ms 64.3ms 64.3ms
Cross-encoder 50.8ms 74.8ms 80.9ms 82.8ms 82.8ms
QPS: 30 (ColBERT util: 61%, cross-encoder util: 136%)
p50 p95 p99 p99.9 max
───────────────────────────────────────────────────────
ColBERT 20.7ms 49.1ms 67.3ms 79.6ms 79.6ms
Cross-encoder 6773.0ms 12953.5ms 13408.0ms 13512.6ms 13512.6ms
QPS: 40 (ColBERT util: 82%, cross-encoder util: 181%)
p50 p95 p99 p99.9 max
───────────────────────────────────────────────────────
ColBERT 23.0ms 67.8ms 84.0ms 87.9ms 87.9ms
Cross-encoder 10931.3ms 20861.8ms 21649.7ms 21837.6ms 21837.6ms
If you look at 30 QPS, you see that the cross-encoder’s utilization exceeds 100%, requests arrive every 33ms but each takes 45ms to process. Every request adds about 12ms of queue debt. After 500 requests, the queue has accumulated over 6 seconds of wait time. That’s your p50, so half your users are waiting nearly 7 seconds.
Meanwhile, ColBERT-like late interaction at 61% utilisation is barely sweating at 20.7ms p50, and every percentile roughly where it was at idle.
At 40 QPS, the cross-encoder’s p99.9 is over 21 seconds. ColBERT’s p50 is 23ms.
So this is something to think about as well in production, you might want to choose your reranking architecture based on your QPS budget, not just your accuracy requirements.
A caveat: this is a ColBERT-like implementation. It demonstrates the MaxSim mechanism using `all-MiniLM-L6-v2`, which is a general-purpose sentence transformer. Real ColBERT deployments use models specifically trained for token-level late interaction retrieval, like `colbert-ir/colbertv2.0`.
Where Does This Leave Us?
These examples illustrate options on retrieval and reranking:
- Cross-encoder (raw): Slow, highest quality. Use for small candidate sets under 100 docs.
- Fine-tuned cross-encoder: Slow, highest quality for your domain. Use when general models perform poorly on domain content.
- Semantic caching: Instant on cache hit, same quality as underlying ranker. Use for high-traffic systems with repeated queries.
- Multi-stage funnel: Slow per query, scales to large corpora, performance near cross-encoder
- Distilled bi-encoder: Fast, near cross-encoder quality. Use as first stage of a funnel or for domain-specific retrieval.
- ColBERT-like (late interaction): Medium speed, near cross-encoder quality. Use for high-QPS services where tail latency matters.
A mature search system might combine any of them: a distilled bi-encoder for first-pass retrieval, a cross-encoder for reranking the top candidates, semantic caching to skip redundant work, and ColBERT-like interaction as a fallback when the latency budget is tight.
All the code is available at https://github.com/ianhohoho/cross-encoder-and-reranking-demo. In fact, every example runs end-to-end without API keys required except Example 6, which calls an LLM through OpenRouter for the list-wise reranking stage.
If you’ve made it to the end, I’d love to hear how you’re handling retrieval and reranking in production, what’s your stack look like? Are you running a multi-stage funnel, or is a single bi-encoder doing the job?
I’m always happy to hear your thoughts on the approaches I’ve laid out above, and feel free to make suggestions to my implementation as well!
Let’s connect! 🤝🏻 with me on LinkedIn or check out my site

