Skip to content

Commit

Permalink
fix: Reranker tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ishaansehgal99 committed Jan 7, 2025
1 parent 7ebe956 commit fbf2a38
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 3 deletions.
96 changes: 95 additions & 1 deletion presets/ragengine/tests/api/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_query_index_success(mock_post):
"result": "This is the completion from the API"
}
mock_post.return_value.json.return_value = mock_response
# Index
# Index
request_data = {
"index_name": "test_index",
"documents": [
Expand Down Expand Up @@ -74,6 +74,100 @@ def test_query_index_success(mock_post):
assert response.json()["source_nodes"][0]["metadata"] == {}
assert mock_post.call_count == 1


@patch('requests.post')
def test_reranker_and_query_with_index(mock_post):
"""
Test reranker and query functionality with indexed documents.
This test ensures the following:
1. The custom reranker returns a relevance-sorted list of documents.
2. The query response matches the expected format and contains the correct top results.
Template for reranker input:
A list of documents is shown below. Each document has a number next to it along with a summary of the document.
A question is also provided. Respond with the numbers of the documents you should consult to answer the question,
in order of relevance, as well as the relevance score. The relevance score is a number from 1-10 based on how
relevant you think the document is to the question. Do not include any documents that are not relevant.
Example format:
Document 1: <summary of document 1>
Document 2: <summary of document 2>
...
Document 10: <summary of document 10>
Question: <question>
Answer:
Doc: 9, Relevance: 7
Doc: 3, Relevance: 4
Doc: 7, Relevance: 3
"""
# Mock responses for the reranker and query API calls
reranker_mock_response = "Doc: 4, Relevance: 10\nDoc: 5, Relevance: 10"
query_mock_response = {"result": "This is the completion from the API"}
mock_http_responses = [reranker_mock_response, query_mock_response]

mock_post.return_value.json.side_effect = mock_http_responses

# Define input documents for indexing
documents = [
"The capital of France is great.",
"The capital of France is huge.",
"The capital of France is beautiful.",
"""Have you ever visited Paris? It is a beautiful city where you can eat delicious food and see the Eiffel Tower.
I really enjoyed all the cities in France, but its capital with the Eiffel Tower is my favorite city.""",
"I really enjoyed my trip to Paris, France. The city is beautiful and the food is delicious. I would love to visit again. "
"Such a great capital city."
]

# Indexing request payload
index_request_payload = {
"index_name": "test_index",
"documents": [{"text": doc} for doc in documents]
}

# Perform indexing
response = client.post("/index", json=index_request_payload)
assert response.status_code == 200

# Query request payload with reranking
top_n = len(reranker_mock_response.split("\n")) # Extract top_n from mock reranker response
query_request_payload = {
"index_name": "test_index",
"query": "what is the capital of france?",
"top_k": 5,
"llm_params": {"temperature": 0.7},
"rerank_params": {"top_n": top_n}
}

# Perform query
response = client.post("/query", json=query_request_payload)
assert response.status_code == 200
query_response = response.json()

# Validate query response
assert query_response["response"] == query_response["result"]
assert len(query_response["source_nodes"]) == top_n

# Validate each source node in the query response
expected_source_nodes = [
{"text": "Have you ever visited Paris? It is a beautiful city where you can eat "
"delicious food and see the Eiffel Tower. I really enjoyed all the cities in "
"France, but its capital with the Eiffel Tower is my favorite city.",
"score": 10.0, "metadata": {}},
{"text": "I really enjoyed my trip to Paris, France. The city is beautiful and the "
"food is delicious. I would love to visit again. Such a great capital city.",
"score": 10.0, "metadata": {}},
]
for i, expected_node in enumerate(expected_source_nodes):
actual_node = query_response["source_nodes"][i]
assert actual_node["text"] == expected_node["text"]
assert actual_node["score"] == expected_node["score"]
assert actual_node["metadata"] == expected_node["metadata"]

# Verify the number of mock API calls
assert mock_post.call_count == len(mock_http_responses)
g
def test_query_index_failure():
# Prepare request data for querying.
request_data = {
Expand Down
5 changes: 3 additions & 2 deletions presets/ragengine/vector_store/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ def query(self,
top_k (int): Number of initial top results to retrieve
llm_params (dict): Optional parameters for the language model
rerank_params (dict): Optional configuration for reranking
- 'top_n' (int): Number of documents to process in each batch
- 'batch_size' (int): Number of top documents to return after reranking
- 'top_n' (int): Number of top documents to return after reranking
- 'batch_size' (int): Number of documents to process in each batch
Returns:
dict: A dictionary containing the response and source nodes.
Expand All @@ -125,6 +125,7 @@ def query(self,
# Add LLMRerank to postprocessors
node_postprocessors.append(
LLMRerank(
llm=self.llm,
choice_batch_size=rerank_params['choice_batch_size'],
top_n=rerank_params['top_n']
)
Expand Down

0 comments on commit fbf2a38

Please sign in to comment.