Skip to content

Commit

Permalink
feat: allow users to filter by confidence in /questions
Browse files Browse the repository at this point in the history
  • Loading branch information
raphael0202 committed Jan 16, 2023
1 parent 6c3ea9a commit 8af5719
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 21 deletions.
61 changes: 53 additions & 8 deletions doc/references/api.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ paths:
tags:
- Questions
summary: Get questions for a given product
operationId: getQuestionsBarcode
parameters:
- name: count
in: query
Expand Down Expand Up @@ -50,13 +49,64 @@ paths:
type: array
items:
type: object
/questions:
get:
tags:
- Questions
summary: Fetch questions
parameters:
- $ref: "#/components/parameters/lang"
- $ref: "#/components/parameters/count"
- $ref: "#/components/parameters/server_domain"
- $ref: "#/components/parameters/insight_types"
- $ref: "#/components/parameters/country"
- $ref: "#/components/parameters/brands"
- $ref: "#/components/parameters/value_tag"
- $ref: "#/components/parameters/page"
- $ref: "#/components/parameters/reserved_barcode"
- $ref: "#/components/parameters/campaign"
- $ref: "#/components/parameters/predictor"
- name: order_by
in: query
description: |
The field to use for ordering results:
- confidence: order by (descending) model confidence, null confidence insights come last
- popularity: order by (descending) popularity (=scan count)
- random: use a random order
schema:
type: string
default: popularity
enum:
- confidence
- random
- popularity
responses:
"200":
description: The questions matching the filters
content:
application/json:
schema:
type: object
properties:
status:
type: string
enum:
- "no_questions"
- "found"
questions:
type: array
items:
type: object
count:
type: integer
description: The total number of results with the provided filters

/questions/random:
get:
tags:
- Questions
summary: Get random questions
operationId: getQuestionsRandom
deprecated: true
parameters:
- $ref: "#/components/parameters/lang"
- $ref: "#/components/parameters/count"
Expand Down Expand Up @@ -95,9 +145,9 @@ paths:
tags:
- Questions
summary: Get questions about popular products
deprecated: true
description: |
Questions are ranked by the product popularity (based on scan count).
operationId: GetQuestionsPopular
parameters:
- $ref: "#/components/parameters/lang"
- $ref: "#/components/parameters/count"
Expand All @@ -122,7 +172,6 @@ paths:
description: |
Get number of unanswered questions grouped by `value_tag`.
The list is ordered from highest count to lowest.
operationId: GetQuestionsUnanswered
parameters:
- name: count
in: query
Expand Down Expand Up @@ -173,7 +222,6 @@ paths:
tags:
- Insights
summary: Get a random insight
operationId: GetInsightsRandom
parameters:
- $ref: "#/components/parameters/insight_type"
- $ref: "#/components/parameters/country"
Expand All @@ -199,7 +247,6 @@ paths:
tags:
- Insights
summary: Get all insights for a specific product
operationId: Getallinsightsforaspecificproduct
parameters:
- $ref: "#/components/parameters/barcode_path"
responses:
Expand All @@ -211,7 +258,6 @@ paths:
tags:
- Insights
summary: Get a specific insight
operationId: GetInsightsDetail
parameters:
- name: id
in: path
Expand Down Expand Up @@ -244,7 +290,6 @@ paths:
(so it won't be applied), `1` means it is correct (so it will be applied) and `-1` means the insight
won't be returned to the user (_skip_). We use the voting mecanism system to remember which insight
to skip for a user (authenticated or not).
operationId: Submitanannotation
requestBody:
required: true
content:
Expand Down
18 changes: 16 additions & 2 deletions robotoff/app/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import io
import tempfile
import uuid
from typing import Optional
from typing import Literal, Optional

import falcon
import orjson
Expand Down Expand Up @@ -1096,8 +1096,21 @@ def on_get(self, req: falcon.Request, resp: falcon.Response):
get_questions_resource_on_get(req, resp, "popularity")


class QuestionsCollectionResource:
def on_get(self, req: falcon.Request, resp: falcon.Response):
order_by = req.get_param("order_by", default="popularity")

if order_by not in ("random", "popularity", "confidence"):
raise falcon.HTTPBadRequest(
description=f"invalid '{order_by}' value for `order_by` parameter"
)
get_questions_resource_on_get(req, resp, order_by)


def get_questions_resource_on_get(
req: falcon.Request, resp: falcon.Response, order_by: str
req: falcon.Request,
resp: falcon.Response,
order_by: Literal["random", "popularity", "confidence"],
):
response: JSONType = {}
page: int = req.get_param_as_int("page", min_value=1, default=1)
Expand Down Expand Up @@ -1500,6 +1513,7 @@ def on_get(self, req: falcon.Request, resp: falcon.Response):
api.add_route("/api/v1/ann/search/{logo_id:int}", ANNResource())
api.add_route("/api/v1/ann/search", ANNResource())
api.add_route("/api/v1/questions/{barcode}", ProductQuestionsResource())
api.add_route("/api/v1/questions", QuestionsCollectionResource())
api.add_route("/api/v1/questions/random", RandomQuestionsResource())
api.add_route("/api/v1/questions/popular", PopularQuestionsResource())
api.add_route("/api/v1/questions/unanswered", UnansweredQuestionCollection())
Expand Down
9 changes: 6 additions & 3 deletions robotoff/app/core.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import datetime
import functools
from enum import Enum
from typing import Iterable, NamedTuple, Optional, Union
from typing import Iterable, Literal, NamedTuple, Optional, Union

import peewee
from peewee import JOIN, fn
from peewee import JOIN, SQL, fn

from robotoff import settings
from robotoff.app import events
Expand Down Expand Up @@ -72,7 +72,7 @@ def get_insights(
brands: Optional[list[str]] = None,
annotated: Optional[bool] = False,
annotation: Optional[int] = None,
order_by: Optional[str] = None,
order_by: Optional[Literal["random", "popularity", "n_votes", "confidence"]] = None,
value_tag: Optional[str] = None,
server_domain: Optional[str] = None,
reserved_barcode: Optional[bool] = None,
Expand Down Expand Up @@ -158,6 +158,9 @@ def get_insights(
elif order_by == "popularity":
query = query.order_by(ProductInsight.unique_scans_n.desc())

elif order_by == "confidence":
query = query.order_by(SQL("confidence DESC NULLS LAST"))

elif order_by == "n_votes":
query = query.order_by(ProductInsight.n_votes.desc())

Expand Down
1 change: 1 addition & 0 deletions tests/integration/models_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class Meta:
unique_scans_n = 10
annotation = None
automatic_processing = False
confidence: Optional[float] = None


class PredictionFactory(PeeweeModelFactory):
Expand Down
43 changes: 35 additions & 8 deletions tests/integration/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
)

insight_id = "94371643-c2bc-4291-a585-af2cb1a5270a"
DEFAULT_BARCODE = "1"


@pytest.fixture(autouse=True)
Expand All @@ -30,7 +31,7 @@ def _set_up_and_tear_down(peewee_db):
# clean db
clean_db()
# Set up.
ProductInsightFactory(id=insight_id, barcode=1)
ProductInsightFactory(id=insight_id, barcode=DEFAULT_BARCODE)
# Run the test case.
yield
with peewee_db:
Expand All @@ -56,7 +57,7 @@ def test_random_question(client, mocker):
}
}
mocker.patch("robotoff.insights.question.get_product", return_value=product)
result = client.simulate_get("/api/v1/questions/random")
result = client.simulate_get("/api/v1/questions?order_by=random")

assert result.status_code == 200
assert result.json == {
Expand Down Expand Up @@ -85,15 +86,15 @@ def test_random_question_user_has_already_seen(client, mocker, peewee_db):
device_id="device1",
)

result = client.simulate_get("/api/v1/questions/random?device_id=device1")
result = client.simulate_get("/api/v1/questions?order_by=random&device_id=device1")

assert result.status_code == 200
assert result.json == {"count": 0, "questions": [], "status": "no_questions"}


def test_popular_question(client, mocker):
mocker.patch("robotoff.insights.question.get_product", return_value={})
result = client.simulate_get("/api/v1/questions/popular")
result = client.simulate_get("/api/v1/questions?order_by=popularity")

assert result.status_code == 200
assert result.json == {
Expand Down Expand Up @@ -121,32 +122,58 @@ def test_popular_question_pagination(client, mocker, peewee_db):
for i in range(0, 12):
ProductInsightFactory(barcode=i, unique_scans_n=100 - i)

result = client.simulate_get("/api/v1/questions/popular?count=5&page=1")
result = client.simulate_get("/api/v1/questions?order_by=popularity&count=5&page=1")
assert result.status_code == 200
data = result.json
assert data["count"] == 12
assert data["status"] == "found"
assert [q["barcode"] for q in data["questions"]] == ["0", "1", "2", "3", "4"]
result = client.simulate_get("/api/v1/questions/popular?count=5&page=2")
result = client.simulate_get("/api/v1/questions?order_by=popularity&count=5&page=2")
assert result.status_code == 200
data = result.json
assert data["count"] == 12
assert data["status"] == "found"
assert [q["barcode"] for q in data["questions"]] == ["5", "6", "7", "8", "9"]
result = client.simulate_get("/api/v1/questions/popular?count=5&page=3")
result = client.simulate_get("/api/v1/questions?order_by=popularity&count=5&page=3")
assert result.status_code == 200
data = result.json
assert data["count"] == 12
assert data["status"] == "found"
assert [q["barcode"] for q in data["questions"]] == ["10", "11"]
result = client.simulate_get("/api/v1/questions/popular?count=5&page=4")
result = client.simulate_get("/api/v1/questions?order_by=popularity&count=5&page=4")
assert result.status_code == 200
data = result.json
assert data["count"] == 12
assert data["status"] == "no_questions"
assert len(data["questions"]) == 0


def test_question_rank_by_confidence(client, mocker, peewee_db):
mocker.patch("robotoff.insights.question.get_source_image_url", return_value=None)

with peewee_db:
ProductInsight.delete().execute() # remove default sample
ProductInsightFactory(
barcode="1", type="category", value_tag="en:salmon", confidence=0.9
)
ProductInsightFactory(
barcode="3", type="category", value_tag="en:breads", confidence=0.4
)
ProductInsightFactory(
barcode="2", type="label", value_tag="en:eu-organic", confidence=0.7
)
ProductInsightFactory(
barcode="4", type="brand", value_tag="carrefour", confidence=None
)

result = client.simulate_get("/api/v1/questions?order_by=confidence")
assert result.status_code == 200
data = result.json
assert data["count"] == 4
assert data["status"] == "found"
assert [q["barcode"] for q in data["questions"]] == ["1", "2", "3", "4"]


def test_barcode_question_not_found(client):
result = client.simulate_get("/api/v1/questions/2")

Expand Down

0 comments on commit 8af5719

Please sign in to comment.