Skip to content

Commit

Permalink
fix: improve batch import
Browse files Browse the repository at this point in the history
  • Loading branch information
raphael0202 committed Dec 27, 2024
1 parent 8093e61 commit 906c8b9
Showing 1 changed file with 29 additions and 29 deletions.
58 changes: 29 additions & 29 deletions robotoff/batch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pathlib import Path

import duckdb
from more_itertools import chunked

from robotoff import settings
from robotoff.insights.importer import import_insights
Expand Down Expand Up @@ -44,38 +45,37 @@ def import_spellcheck_batch_predictions(batch_dir: str) -> None:
)
logger.info("Number of rows in the batch data: %s", len(df))

# Generate predictions
predictions = []
# We increment to allow import_insights to create a new version
predictor_version = "llm-v1-" + datetime.datetime.now().strftime("%Y%m%d%H%M%S")

for _, row in df.iterrows():
lang_predictions = predict_lang(row["text"], k=1)
lang, lang_confidence = lang_predictions[0].lang, (
lang_predictions[0].confidence if lang_predictions else None
)
predictions.append(
Prediction(
type=PredictionType.ingredient_spellcheck,
data={
"original": row["text"],
"correction": row["correction"],
"lang": lang,
"lang_confidence": lang_confidence,
},
value_tag=row["lang"],
barcode=row["code"],
predictor_version=predictor_version,
predictor="fine-tuned-mistral-7b",
automatic_processing=False,
for batch in chunked((row for _, row in df.iterrows()), 100):
predictions = []
for row in batch:
lang_predictions = predict_lang(row["text"], k=1)
lang, lang_confidence = lang_predictions[0].lang, (
lang_predictions[0].confidence if lang_predictions else None
)
)
# Store predictions and insights
with db:
import_results = import_insights(
predictions=predictions, server_type=ServerType.off
)
logger.info("Batch import results: %s", import_results)
predictions.append(
Prediction(
type=PredictionType.ingredient_spellcheck,
data={
"original": row["text"],
"correction": row["correction"],
"lang": lang,
"lang_confidence": lang_confidence,
},
value_tag=row["lang"],
barcode=row["code"],
predictor_version=predictor_version,
predictor="fine-tuned-mistral-7b",
automatic_processing=False,
)
)
# Store predictions and insights
with db:
import_results = import_insights(
predictions=predictions, server_type=ServerType.off
)
logger.info("Batch import results: %s", import_results)


def launch_spellcheck_batch_job(
Expand Down

0 comments on commit 906c8b9

Please sign in to comment.