Skip to content

Commit

Permalink
refactor: move random sampling to outside caller of visualiser
Browse files Browse the repository at this point in the history
  • Loading branch information
paluchasz committed Dec 13, 2024
1 parent 28db0e5 commit fba3ecb
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 9 deletions.
2 changes: 2 additions & 0 deletions kazu/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ class TrainingConfig:
architecture: str = "bert"
#: fraction of epoch to complete before evaluations begin
epoch_completion_fraction_before_evals: float = 0.75
#: The random seed to use
seed: int = 42


@dataclass
Expand Down
13 changes: 5 additions & 8 deletions kazu/training/modelling_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import copy
import json
import logging
import random
from collections.abc import Iterable
from pathlib import Path
from typing import Any, Optional, Union
Expand Down Expand Up @@ -102,26 +101,24 @@ def get_gold_ents_for_side_by_side_view(self, docs: list[Document]) -> list[list
return result

def update(
self, test_docs: list[Document], global_step: Union[int, str], has_gs: bool = True
self, docs: list[Document], global_step: Union[int, str], has_gs: bool = True
) -> None:
ls_manager = LabelStudioManager(
headers=self.ls_manager.headers,
project_name=f"{self.ls_manager.project_name}_test_{global_step}",
)

ls_manager.delete_project_if_exists()
ls_manager.create_linking_project()
docs_subset = random.sample(test_docs, min([len(test_docs), 100]))
if not docs_subset:
if not docs:
logger.info("no results to represent yet")
return
if has_gs:
side_by_side = self.get_gold_ents_for_side_by_side_view(docs_subset)
side_by_side = self.get_gold_ents_for_side_by_side_view(docs)
ls_manager.update_view(self.view, side_by_side)
ls_manager.update_tasks(side_by_side)
else:
ls_manager.update_view(self.view, docs_subset)
ls_manager.update_tasks(docs_subset)
ls_manager.update_view(self.view, docs)
ls_manager.update_tasks(docs)


def create_wrapper(cfg: DictConfig, label_list: list[str]) -> Optional[LSManagerViewWrapper]:
Expand Down
5 changes: 4 additions & 1 deletion kazu/training/train_multilabel_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
import math
import pickle
import random
import shutil
import tempfile
from collections import defaultdict
Expand Down Expand Up @@ -337,6 +338,7 @@ def __init__(
self.label_list = label_list
self.pretrained_model_name_or_path = pretrained_model_name_or_path
self.keys_to_use = _select_keys_to_use(self.training_config.architecture)
random.seed(training_config.seed)

def _write_to_tensorboard(
self, global_step: int, main_tag: str, tag_scalar_dict: dict[str, NumericMetric]
Expand All @@ -360,7 +362,8 @@ def evaluate_model(

model_test_docs = self._process_docs(model)
if self.ls_wrapper:
self.ls_wrapper.update(model_test_docs, global_step)
sample_test_docs = random.sample(model_test_docs, min([len(model_test_docs), 100]))
self.ls_wrapper.update(sample_test_docs, global_step)

all_results, tensorboad_loggables = calculate_metrics(
epoch_loss, model_test_docs, self.label_list
Expand Down

0 comments on commit fba3ecb

Please sign in to comment.