Skip to content

Commit

Permalink
adds scorer to AggregateRequest (#3409)
Browse files Browse the repository at this point in the history
* adds scorer to AggregateRequest

* fix linting

* update tests for BM25

* enum for aggregation scorer

* update signature

* revert back to string input

---------

Co-authored-by: Vladyslav Vildanov <[email protected]>
  • Loading branch information
rbs333 and vladvildanov authored Oct 22, 2024
1 parent 4c4d4af commit 00f5be4
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 0 deletions.
16 changes: 16 additions & 0 deletions redis/commands/search/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def __init__(self, query: str = "*") -> None:
self._cursor = []
self._dialect = None
self._add_scores = False
self._scorer = "TFIDF"

def load(self, *fields: List[str]) -> "AggregateRequest":
"""
Expand Down Expand Up @@ -300,6 +301,17 @@ def add_scores(self) -> "AggregateRequest":
self._add_scores = True
return self

def scorer(self, scorer: str) -> "AggregateRequest":
"""
Use a different scoring function to evaluate document relevance.
Default is `TFIDF`.
:param scorer: The scoring function to use
(e.g. `TFIDF.DOCNORM` or `BM25`)
"""
self._scorer = scorer
return self

def verbatim(self) -> "AggregateRequest":
self._verbatim = True
return self
Expand All @@ -323,6 +335,9 @@ def build_args(self) -> List[str]:
if self._verbatim:
ret.append("VERBATIM")

if self._scorer:
ret.extend(["SCORER", self._scorer])

if self._add_scores:
ret.append("ADDSCORES")

Expand All @@ -332,6 +347,7 @@ def build_args(self) -> List[str]:
if self._loadall:
ret.append("LOAD")
ret.append("*")

elif self._loadfields:
ret.append("LOAD")
ret.append(str(len(self._loadfields)))
Expand Down
55 changes: 55 additions & 0 deletions tests/test_asyncio/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -1556,6 +1556,61 @@ async def test_aggregations_add_scores(decoded_r: redis.Redis):
assert res.rows[1] == ["__score", "0.2"]


@pytest.mark.redismod
@skip_ifmodversion_lt("2.10.05", "search")
async def test_aggregations_hybrid_scoring(decoded_r: redis.Redis):
assert await decoded_r.ft().create_index(
(
TextField("name", sortable=True, weight=5.0),
TextField("description", sortable=True, weight=5.0),
VectorField(
"vector",
"HNSW",
{"TYPE": "FLOAT32", "DIM": 2, "DISTANCE_METRIC": "COSINE"},
),
)
)

assert await decoded_r.hset(
"doc1",
mapping={
"name": "cat book",
"description": "an animal book about cats",
"vector": np.array([0.1, 0.2]).astype(np.float32).tobytes(),
},
)
assert await decoded_r.hset(
"doc2",
mapping={
"name": "dog book",
"description": "an animal book about dogs",
"vector": np.array([0.2, 0.1]).astype(np.float32).tobytes(),
},
)

query_string = "(@description:animal)=>[KNN 3 @vector $vec_param AS dist]"
req = (
aggregations.AggregateRequest(query_string)
.scorer("BM25")
.add_scores()
.apply(hybrid_score="@__score + @dist")
.load("*")
.dialect(4)
)

res = await decoded_r.ft().aggregate(
req,
query_params={"vec_param": np.array([0.11, 0.22]).astype(np.float32).tobytes()},
)

if isinstance(res, dict):
assert len(res["results"]) == 2
else:
assert len(res.rows) == 2
for row in res.rows:
len(row) == 6


@pytest.mark.redismod
@skip_if_redis_enterprise()
async def test_search_commands_in_pipeline(decoded_r: redis.Redis):
Expand Down
55 changes: 55 additions & 0 deletions tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -1466,6 +1466,61 @@ def test_aggregations_add_scores(client):
assert res.rows[1] == ["__score", "0.2"]


@pytest.mark.redismod
@skip_ifmodversion_lt("2.10.05", "search")
async def test_aggregations_hybrid_scoring(client):
client.ft().create_index(
(
TextField("name", sortable=True, weight=5.0),
TextField("description", sortable=True, weight=5.0),
VectorField(
"vector",
"HNSW",
{"TYPE": "FLOAT32", "DIM": 2, "DISTANCE_METRIC": "COSINE"},
),
)
)

client.hset(
"doc1",
mapping={
"name": "cat book",
"description": "an animal book about cats",
"vector": np.array([0.1, 0.2]).astype(np.float32).tobytes(),
},
)
client.hset(
"doc2",
mapping={
"name": "dog book",
"description": "an animal book about dogs",
"vector": np.array([0.2, 0.1]).astype(np.float32).tobytes(),
},
)

query_string = "(@description:animal)=>[KNN 3 @vector $vec_param AS dist]"
req = (
aggregations.AggregateRequest(query_string)
.scorer("BM25")
.add_scores()
.apply(hybrid_score="@__score + @dist")
.load("*")
.dialect(4)
)

res = client.ft().aggregate(
req,
query_params={"vec_param": np.array([0.11, 0.21]).astype(np.float32).tobytes()},
)

if isinstance(res, dict):
assert len(res["results"]) == 2
else:
assert len(res.rows) == 2
for row in res.rows:
len(row) == 6


@pytest.mark.redismod
@skip_ifmodversion_lt("2.0.0", "search")
def test_index_definition(client):
Expand Down

0 comments on commit 00f5be4

Please sign in to comment.