Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: update off sdk #1264

Merged
merged 3 commits into from
Oct 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,138 changes: 623 additions & 515 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ lark = "~1.1.4"
h5py = "~3.8.0"
opencv-contrib-python = "~4.7.0.72"
toml = "~0.10.2"
openfoodfacts = "0.1.5"
openfoodfacts = "0.1.10"

[tool.poetry.dependencies.sentry-sdk]
version = "~1.14.0"
Expand Down
19 changes: 8 additions & 11 deletions robotoff/app/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from falcon.media.validators import jsonschema
from falcon_cors import CORS
from falcon_multipart.middleware import MultipartMiddleware
from openfoodfacts import Country
from openfoodfacts.ocr import OCRParsingException, OCRResultGenerationException
from openfoodfacts.types import COUNTRY_CODE_TO_NAME, Country
from PIL import Image
from sentry_sdk.integrations.falcon import FalconIntegration

Expand Down Expand Up @@ -66,10 +67,6 @@
from robotoff.prediction import ingredient_list
from robotoff.prediction.category import predict_category
from robotoff.prediction.object_detection import ObjectDetectionModelRegistry
from robotoff.prediction.ocr.dataclass import (
OCRParsingException,
OCRResultGenerationException,
)
from robotoff.products import get_image_id, get_product, get_product_dataset_etag
from robotoff.taxonomy import is_prefixed_value, match_taxonomized_value
from robotoff.types import (
Expand Down Expand Up @@ -125,7 +122,7 @@ def get_server_type_from_req(
raise falcon.HTTPBadRequest(f"invalid `server_type`: {server_type_str}")


COUNTRY_NAME_TO_ENUM = {item.value: item for item in Country}
COUNTRY_NAME_TO_ENUM = {COUNTRY_CODE_TO_NAME[item]: item for item in Country}


def get_countries_from_req(req: falcon.Request) -> Optional[list[Country]]:
Expand All @@ -134,7 +131,9 @@ def get_countries_from_req(req: falcon.Request) -> Optional[list[Country]]:
countries: Optional[list[Country]] = None

if countries_str is None:
# `country` parameter is deprecated
# `country` parameter is deprecated, we check here if it's provided
# when `countries` is not provided.
# It's assumed to be a single country tag (ex: `en:spain`)
country: Optional[str] = req.get_param("country")

if country:
Expand All @@ -144,11 +143,9 @@ def get_countries_from_req(req: falcon.Request) -> Optional[list[Country]]:
else:
countries = None
else:
# countries_str is a list of 2-letter country codes
try:
countries = [
Country.get_from_2_letter_code(country_str)
for country_str in countries_str
]
countries = [Country[country_str] for country_str in countries_str]
except KeyError:
raise falcon.HTTPBadRequest(
description=f"invalid `countries` value: {countries_str}"
Expand Down
7 changes: 4 additions & 3 deletions robotoff/app/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Iterable, Literal, NamedTuple, Optional, Union

import peewee
from openfoodfacts import Country
from openfoodfacts.types import COUNTRY_CODE_TO_NAME, Country
from peewee import JOIN, SQL, fn

from robotoff.app import events
Expand Down Expand Up @@ -159,7 +159,9 @@ def get_insights(

if countries is not None:
where_clauses.append(
ProductInsight.countries.contains_any([c.value for c in countries])
ProductInsight.countries.contains_any(
[COUNTRY_CODE_TO_NAME[c] for c in countries]
)
)

if brands:
Expand Down Expand Up @@ -303,7 +305,6 @@ def get_image_predictions(
count: bool = False,
limit: Optional[int] = None,
) -> Iterable[ImagePrediction]:

query = ImagePrediction.select()

query = query.switch(ImagePrediction).join(ImageModel)
Expand Down
3 changes: 2 additions & 1 deletion robotoff/cli/insights.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
import dacite
import orjson
import tqdm
from openfoodfacts.ocr import OCRResult

from robotoff.insights.extraction import DEFAULT_OCR_PREDICTION_TYPES
from robotoff.off import get_barcode_from_path
from robotoff.prediction.ocr import OCRResult, extract_predictions
from robotoff.prediction.ocr import extract_predictions
from robotoff.prediction.ocr.core import ocr_content_iter
from robotoff.types import Prediction, PredictionType, ProductIdentifier, ServerType
from robotoff.utils import get_logger, jsonl_iter
Expand Down
5 changes: 2 additions & 3 deletions robotoff/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,9 +782,8 @@ def pprint_ocr_result(
import sys

import orjson
from openfoodfacts.ocr import OCRResult

from robotoff.prediction.ocr.core import get_ocr_result
from robotoff.prediction.ocr.dataclass import OCRResult
from robotoff.utils import get_logger, http_session

logger = get_logger()
Expand All @@ -795,7 +794,7 @@ def pprint_ocr_result(
logger.info("displaying OCR result %s", uri)

if uri.startswith("http"):
ocr_result = get_ocr_result(uri, http_session)
ocr_result = OCRResult.from_url(uri, http_session)
else:
with open(uri, "rb") as f:
data = orjson.loads(f.read())
Expand Down
4 changes: 2 additions & 2 deletions robotoff/insights/extraction.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import datetime
from typing import Iterable, Optional

from openfoodfacts.ocr import OCRResult
from PIL import Image

from robotoff.models import ImageModel, ImagePrediction
Expand All @@ -10,7 +11,6 @@
OBJECT_DETECTION_MODEL_VERSION,
ObjectDetectionModelRegistry,
)
from robotoff.prediction.ocr.core import get_ocr_result
from robotoff.types import (
ObjectDetectionModel,
Prediction,
Expand Down Expand Up @@ -124,7 +124,7 @@ def extract_ocr_predictions(

predictions_all: list[Prediction] = []
source_image = get_source_from_url(ocr_url)
ocr_result = get_ocr_result(ocr_url, http_session, error_raise=False)
ocr_result = OCRResult.from_url(ocr_url, http_session, error_raise=False)

if ocr_result is None:
return predictions_all
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from typing import Literal, Optional

import numpy as np
from openfoodfacts.ocr import OCRResult
from PIL import Image
from tritonclient.grpc import service_pb2

from robotoff.images import refresh_images_in_db
from robotoff.models import ImageEmbedding, ImageModel, with_db
from robotoff.off import generate_image_url, generate_json_ocr_url
from robotoff.prediction.ocr.core import get_ocr_result
from robotoff.taxonomy import Taxonomy
from robotoff.triton import (
deserialize_byte_tensor,
Expand Down Expand Up @@ -214,7 +214,7 @@ def fetch_ocr_texts(product: JSONType, product_id: ProductIdentifier) -> list[st
image_ids = (id_ for id_ in product.get("images", {}).keys() if id_.isdigit())
for image_id in image_ids:
ocr_url = generate_json_ocr_url(product_id, image_id)
ocr_result = get_ocr_result(ocr_url, http_session, error_raise=False)
ocr_result = OCRResult.from_url(ocr_url, http_session, error_raise=False)
if ocr_result:
ocr_texts.append(ocr_result.get_full_text_contiguous())

Expand Down
5 changes: 2 additions & 3 deletions robotoff/prediction/ingredient_list/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@
from typing import Optional, Union

import numpy as np
from openfoodfacts.ocr import OCRResult
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from tritonclient.grpc import service_pb2

from robotoff import settings
from robotoff.prediction.langid import LanguagePrediction, predict_lang_batch
from robotoff.prediction.ocr.core import get_ocr_result
from robotoff.prediction.ocr.dataclass import OCRResult
from robotoff.triton import get_triton_inference_stub
from robotoff.utils import http_session

Expand Down Expand Up @@ -85,7 +84,7 @@ def predict_from_ocr(
ocr_result: OCRResult
if isinstance(input_ocr, str):
# `input_ocr` is a URL, fetch OCR JSON and get OCRResult
ocr_result = get_ocr_result(input_ocr, http_session, error_raise=True) # type: ignore
ocr_result = OCRResult.from_url(input_ocr, http_session, error_raise=True) # type: ignore
else:
ocr_result = input_ocr

Expand Down
1 change: 0 additions & 1 deletion robotoff/prediction/ocr/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
# flake8: noqa
from .core import extract_predictions
from .dataclass import OCRResult
3 changes: 2 additions & 1 deletion robotoff/prediction/ocr/brand.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import functools
from typing import Iterable, Optional, Union

from openfoodfacts.ocr import OCRResult, get_match_bounding_box, get_text

from robotoff import settings
from robotoff.brands import get_brand_blacklist, keep_brand_from_taxonomy
from robotoff.types import Prediction, PredictionType
from robotoff.utils import get_logger, text_file_iter
from robotoff.utils.text import KeywordProcessor, get_tag

from .dataclass import OCRResult, get_match_bounding_box, get_text
from .utils import generate_keyword_processor

logger = get_logger(__name__)
Expand Down
10 changes: 8 additions & 2 deletions robotoff/prediction/ocr/category.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
import re
from typing import Optional, Union

from openfoodfacts.ocr import (
OCRField,
OCRRegex,
OCRResult,
get_match_bounding_box,
get_text,
)

from robotoff.off import normalize_tag
from robotoff.taxonomy import get_taxonomy
from robotoff.types import Prediction, PredictionType
from robotoff.utils import get_logger

from .dataclass import OCRField, OCRRegex, OCRResult, get_match_bounding_box, get_text

logger = get_logger(__name__)


Expand Down
54 changes: 1 addition & 53 deletions robotoff/prediction/ocr/core.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
import json
import pathlib
from typing import Callable, Iterable, Optional, TextIO, Union

import requests
from openfoodfacts.ocr import OCRResult

from robotoff.types import JSONType, Prediction, PredictionType, ProductIdentifier
from robotoff.utils import get_logger, jsonl_iter, jsonl_iter_fp

from .brand import find_brands
from .category import find_category
from .dataclass import OCRParsingException, OCRResult, OCRResultGenerationException
from .expiration_date import find_expiration_date
from .image_flag import flag_image
from .image_lang import get_image_lang
Expand Down Expand Up @@ -47,56 +45,6 @@
}


def get_ocr_result(
ocr_url: str, session: requests.Session, error_raise: bool = True
) -> Optional[OCRResult]:
"""Generate an OCRResult from the URL of an OCR JSON.

:param ocr_url: The URL of the JSON OCR
:param session: the requests Session to use to download the JSON file
:param error_raise: if True, raises an OCRResultGenerationException if an
error occured during download or analysis, defaults to True
:return: if `error_raise` is True, always return an OCRResult (or raises
an error). Otherwise return the OCRResult or None if an error occured.
"""
try:
r = session.get(ocr_url)
except requests.exceptions.RequestException as e:
error_message = "HTTP Error when fetching OCR URL"
if error_raise:
raise OCRResultGenerationException(error_message, ocr_url) from e

logger.warning(error_message + ": %s", ocr_url, exc_info=e)
return None

if not r.ok:
error_message = "Non-200 status code (%s) when fetching OCR URL"
if error_raise:
raise OCRResultGenerationException(error_message % r.status_code, ocr_url)

logger.warning(error_message + ": %s", r.status_code, ocr_url)
return None

try:
ocr_data: dict = r.json()
except json.JSONDecodeError as e:
error_message = "Error while decoding OCR JSON"
if error_raise:
raise OCRResultGenerationException(error_message, ocr_url) from e

logger.warning(error_message + ": %s", ocr_url, exc_info=e)
return None

try:
return OCRResult.from_json(ocr_data)
except OCRParsingException as e:
if error_raise:
raise OCRResultGenerationException(str(e), ocr_url) from e

logger.warning("Error while parsing OCR JSON from %s", ocr_url, exc_info=e)
return None


def extract_predictions(
content: Union[OCRResult, str],
prediction_type: PredictionType,
Expand Down
Loading