From 6584960493460a9b468a3bc006521d8a6eaf48dd Mon Sep 17 00:00:00 2001 From: Lukas Garbas Date: Mon, 11 Nov 2024 19:36:06 +0100 Subject: [PATCH] Add column for text pairs --- transformer_ranker/datacleaner.py | 116 +++++++++++++++--------------- transformer_ranker/embedder.py | 41 +++++------ transformer_ranker/ranker.py | 23 +++--- 3 files changed, 87 insertions(+), 93 deletions(-) diff --git a/transformer_ranker/datacleaner.py b/transformer_ranker/datacleaner.py index b817616..47c468f 100644 --- a/transformer_ranker/datacleaner.py +++ b/transformer_ranker/datacleaner.py @@ -17,7 +17,6 @@ def __init__( pre_tokenizer: Optional[Whitespace] = None, exclude_test_split: bool = False, merge_data_splits: bool = True, - change_ner_encoding_to_spans: bool = True, remove_empty_sentences: bool = True, dataset_downsample: Optional[float] = None, task_type: Optional[str] = None, @@ -25,6 +24,7 @@ def __init__( label_column: Optional[str] = None, label_map: Optional[Dict[str, int]] = None, text_pair_column: Optional[str] = None, + remove_bio_notation: bool = True, ): """ Prepare huggingface dataset. Identify task type, find text and label columns, down-sample, merge data splits. @@ -32,19 +32,19 @@ def __init__( :param pre_tokenizer: Pre-tokenizer to use, such as Whitespace from huggingface pre-tokenizers. :param exclude_test_split: Whether to exclude the test split. :param merge_data_splits: Whether to merge train, dev, and test splits into one. - :param change_ner_encoding_to_spans: Whether to change BIO encoding to single class labels. + :param remove_bio_notation: Change BIO encoding to single class labels by removing B-, I-, O- prefixes :param remove_empty_sentences: Whether to remove empty sentences. - :param dataset_downsample: Fraction to downsample the dataset to. - :param task_type: Type of task (e.g., 'sentence classification', 'word classification', 'sentence regression'). - :param text_column: Column name where texts are stored. - :param label_column: Column name where labels are stored. - :param label_map: Mapping of labels to integers. - :param text_pair_column: Column name where the second text pair is stored. For entailment-type tasks. + :param dataset_downsample: Fraction to reduce the dataset size. + :param task_type: Task category (e.g., 'token classification', 'text classification', 'text regression'). + :param text_column: Column name for texts. + :param label_column: Column name for labels. + :param label_map: A dictionary which maps label names to integers. + :param text_pair_column: Column name where the second text pair is stored (for entailment-like tasks) """ self.pre_tokenizer = pre_tokenizer self.exclude_test_split = exclude_test_split self.merge_data_splits = merge_data_splits - self.change_ner_encoding_to_spans = change_ner_encoding_to_spans + self.remove_bio_notation = remove_bio_notation self.remove_empty_sentences = remove_empty_sentences self.dataset_downsample = dataset_downsample self.task_type = task_type @@ -87,10 +87,10 @@ def prepare_dataset(self, dataset: Union[str, DatasetDict, Dataset]) -> Union[Da self.text_column, self.label_column) - # Determine task type based on label type if not specified + # Find task type based on label type task_type = self._find_task_type(label_column, label_type) if not self.task_type else self.task_type - # Clean the dataset by removing empty sentences and negative labels + # Clean the dataset by removing empty sentences and empty/negative labels if self.remove_empty_sentences: dataset = self._remove_empty_rows(dataset, text_column, label_column) @@ -98,31 +98,31 @@ def prepare_dataset(self, dataset: Union[str, DatasetDict, Dataset]) -> Union[Da if self.dataset_downsample: dataset = self._downsample(dataset, ratio=self.dataset_downsample) - # Pre-tokenize sentences if pre-tokenizer is specified - if not task_type == "word classification" and self.pre_tokenizer: + # Pre-tokenize sentences if pre-tokenizer is given + if not task_type == "token classification" and self.pre_tokenizer: dataset = self._tokenize(dataset, self.pre_tokenizer, text_column) # Concatenate text columns for text-pair tasks if self.text_pair_column: - dataset = self._merge_textpairs(dataset, text_column, self.text_pair_column) + dataset, text_column = self._merge_textpairs(dataset, text_column, self.text_pair_column) # Convert string labels to integers if label_type == str: dataset, label_map = self._make_labels_categorical(dataset, label_column) logger.info(f"Label map: {label_map}") - # Change NER encoding to spans if specified - if task_type == "word classification" and self.change_ner_encoding_to_spans: - dataset, self.label_map = self._change_to_span_encoding(dataset, label_column, self.label_map) + # Remove BIO prefixes for ner or chunking tasks + if task_type == "token classification" and self.remove_bio_notation: + dataset, self.label_map = self._remove_bio_notation(dataset, label_column, self.label_map) - # Store updated attributes and log them + # Set updated attributes and log them self.text_column = text_column self.label_column = label_column self.task_type = task_type self.dataset_size = len(dataset) self.log_dataset_info() - # Simplify the dataset: keep only relevant columns + # Keep only text and label columns keep_columns = [col for col in (self.text_column, self.text_pair_column, self.label_column) if col is not None] dataset = self._remove_columns(dataset, keep_columns=keep_columns) @@ -177,13 +177,16 @@ def _find_text_and_label_columns(dataset: Dataset, text_column: Optional[str] = return text_column, label_column, label_type @staticmethod - def _merge_textpairs(dataset: Dataset, text_column: str, text_pair_column: str) -> Dataset: + def _merge_textpairs(dataset: Dataset, text_column: str, text_pair_column: str) -> Tuple[Dataset, str]: """Concatenate text pairs into a single text using separator token""" + new_text_column_name = text_column + '+' + text_pair_column + def merge_texts(example: Dict[str, str]) -> Dict[str, str]: example[text_column] = example[text_column] + " [SEP] " + example[text_pair_column] + example[new_text_column_name] = example.pop(text_column) return example - dataset = dataset.map(merge_texts, num_proc=None, desc="Merge sentence pair columns") - return dataset + dataset = dataset.map(merge_texts, num_proc=None, desc="Merging text pair columns") + return dataset, new_text_column_name @staticmethod def _find_task_type(label_column: str, label_type: type) -> str: @@ -261,60 +264,49 @@ def map_labels(example): return dataset, label_map @staticmethod - def _change_to_span_encoding( + def _remove_bio_notation( dataset: Dataset, label_column: str, label_map: Optional[Dict[str, int]] = None, ) -> Tuple[Dataset, Dict[str, int]]: """Remove BIO prefixes for NER labels and create a new label map. - Example: ['B-PER', 'I-PER', 'O'] -> ['PER', 'PER', 'O'] - Original label map: {'B-PER': 0, 'I-PER': 1, 'O': 2} - Converted span label map: {'PER': 0, 'O': 1} + Example: ['O', 'B-PER', 'I-PER'] -> ['O', 'PER', 'PER'] + Original label map: {'O': 0, 'B-PER': 1, 'I-PER': 2} + Label map without BIO notation: {'O': 0, 'PER': 1} :param dataset: The dataset containing BIO labels. :param label_column: The name of the label column. :param label_map: Optional dictionary to map BIO labels to integers. If not provided, a new one will be created. :return: A tuple with the dataset containing new labels and the updated label map. """ - # Attempt to get the label map from dataset features information if not label_map: - features = dataset.features - if label_column in features and hasattr(features[label_column], 'feature') and hasattr( - features[label_column].feature, 'names'): - label_map = {name: idx for idx, name in enumerate(features[label_column].feature.names)} - else: - # Create label map manually if not found + try: + # Attempt to get the label map from dataset feature information + label_map = {label: idx for idx, label in enumerate(dataset.features[label_column].feature.names)} + except AttributeError: + # Try to create label map manually logger.info('Label map not found. Creating manually...') unique_labels: Set[str] = set() - label_data = dataset[label_column] if isinstance(dataset, Dataset) else [dataset[split][label_column] - for split in dataset] - for label_list in label_data: + + for label_list in dataset[label_column]: unique_labels.update( label.split('-')[-1] if isinstance(label, str) else str(label) for label in label_list) label_map = {label: idx for idx, label in enumerate(sorted(unique_labels, key=int))} - logger.info(f"Label map: {label_map}") - - # Remove BIO encoding from the label map - span_label_map: Dict[str, int] = {} + # Remove BIO encoding and create a new label map + new_label_map: Dict[str, int] = {} for label in label_map: main_label = label.split('-')[-1] if isinstance(label, str) else label - if main_label not in span_label_map: - span_label_map[main_label] = len(span_label_map) - - logger.info(f"Simplified label map: {span_label_map}") + if main_label not in new_label_map: + new_label_map[main_label] = len(new_label_map) - if label_map == span_label_map: - logger.warning("Could not convert BIO labels to span labels. " - "Please add the label map as parameter label_map: Dict[str, int] = ... manually.") - - # Create a reverse map from the original integer labels to the simplified span labels + # Create a reverse map from original integer labels to labels without BIO prefixes reverse_map = {} for original_label, index in label_map.items(): main_label = original_label.split('-')[-1] if isinstance(original_label, str) else original_label - reverse_map[index] = span_label_map[main_label] + reverse_map[index] = new_label_map[main_label] - # Map labels to their corresponding span encoding + # Map labels to their class labels without BIO def map_to_spans(example): example_labels = example[label_column] new_labels = [reverse_map[bio_label] for bio_label in example_labels] @@ -323,15 +315,23 @@ def map_to_spans(example): if isinstance(dataset, DatasetDict): for split in dataset.keys(): - dataset[split] = dataset[split].map(map_to_spans, num_proc=None, desc="Mapping BIO to span encoding") + dataset[split] = dataset[split].map(map_to_spans, num_proc=None, desc="Removing BIO encoding") else: - dataset = dataset.map(map_to_spans, num_proc=None, desc="Mapping BIO to span encoding") + dataset = dataset.map(map_to_spans, num_proc=None, desc="Removing BIO encoding") + + if label_map == new_label_map: + logger.warning("Could not remove BIO prefixes for this tagging dataset. " + "Please add the label map as parameter label_map: Dict[str, int] = ... manually.") + else: + logger.info(f"Label map: {label_map}") + logger.info(f"New label map: {new_label_map}") - return dataset, span_label_map + return dataset, new_label_map def log_dataset_info(self) -> None: """Log information about dataset""" - logger.info(f"Text and label columns: '{self.text_column}', '{self.label_column}'") - logger.info(f"Task type identified: '{self.task_type}'") - downsample_info = f"(down-sampled to {self.dataset_downsample})" if self.dataset_downsample else "" - logger.info(f"Dataset size: {self.dataset_size} {downsample_info}") + logger.info(f"Texts and labels: '{self.text_column}', '{self.label_column}'") + logger.info(f"Task category: '{self.task_type}'") + is_downsampled = self.dataset_downsample and self.dataset_downsample < 1.0 + downsample_info = f"(down-sampled to {self.dataset_downsample})" if is_downsampled else "" + logger.info(f"Dataset size: {self.dataset_size} texts {downsample_info}") diff --git a/transformer_ranker/embedder.py b/transformer_ranker/embedder.py index 1dc4569..bb9db49 100644 --- a/transformer_ranker/embedder.py +++ b/transformer_ranker/embedder.py @@ -20,17 +20,17 @@ def __init__( device: Optional[str] = None, ): """ - Embed texts using a pre-trained transformer model. This embedder works at the word level, representing each - text as a list of word vectors. It supports various sub-word pooling and effective sentence pooling options. - ♻️ Feel free to use it if you ever need a simple implementation for transformer embeddings. + Embed texts using a pre-trained transformer model. This embedder works at the word level, where + each text is a list of word vectors. It supports different sub-word pooling and sentence pooling options. + ♻️ Feel free to use it if you need a simple implementation for word or text embeddings. - :param model: The model to use, either by name (e.g., 'bert-base-uncased') or a loaded model instance. - :param tokenizer: Optional tokenizer, either by name or a loaded tokenizer instance. - :param subword_pooling: Method for pooling sub-word embeddings into word-level embeddings. + :param model: Model name 'bert-base-uncased' or a model instance e.g. AutoModel.from_pretrained(...) + :param tokenizer: Optional tokenizer, either a string name or a tokenizer instance. + :param subword_pooling: Method for pooling sub-words into word embeddings. :param layer_ids: Layers to use e.g., '-1' for the top layer, '-1,-2' for multiple, or 'all'. Default is 'all'. :param layer_pooling: Optional method for pooling across selected layers. :param use_pretokenizer: Whether to pre-tokenize texts using whitespace. - :param device: Device for computations, either 'cpu' or 'cuda'. Defaults to the available device. + :param device: Device for computations, either 'cpu' or 'cuda:0'. Defaults to the available device. """ # Load transformer model if isinstance(model, torch.nn.Module): @@ -42,12 +42,10 @@ def __init__( # Load a model-specific tokenizer self.tokenizer: PreTrainedTokenizerFast - tokenizer_source = tokenizer if isinstance(tokenizer, str) else self.model_name - - # Assign or load tokenizer if isinstance(tokenizer, PreTrainedTokenizerFast): self.tokenizer = tokenizer else: + tokenizer_source = tokenizer if isinstance(tokenizer, str) else self.model_name self.tokenizer = AutoTokenizer.from_pretrained( tokenizer_source, add_prefix_space=True, @@ -68,11 +66,9 @@ def __init__( # Set relevant layers that will be used for embeddings self.layer_ids = self._filter_layer_ids(layer_ids) - # Set pooling operations for sub-words and layers + # Set pooling options self.subword_pooling = subword_pooling self.layer_pooling = layer_pooling - - # Set sentence-pooling to get embedding for the full text if specified self.sentence_pooling = sentence_pooling # Set cpu or gpu device @@ -80,7 +76,6 @@ def __init__( self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: self.device = torch.device(device) - self.model = self.model.to(self.device) def tokenize(self, sentences): @@ -141,7 +136,7 @@ def embed_batch(self, sentences, move_embeddings_to_cpu: bool = True) -> List[to attention_mask = tokenized_input["attention_mask"].to(self.device) word_ids = [tokenized_input.word_ids(i) for i in range(len(sentences))] - # Embedd using a transformer: forward pass to get all hidden states of the model + # Embedd: forward pass to get all hidden states of the model with torch.no_grad(): hidden_states = self.model( input_ids, attention_mask=attention_mask, output_hidden_states=True @@ -157,7 +152,7 @@ def embed_batch(self, sentences, move_embeddings_to_cpu: bool = True) -> List[to # Extract layers defined by layer_ids, average all layers for a batch of sentences if specified embeddings = self._extract_relevant_layers(embeddings) - # Process each sentence separately and gather word or sentence embeddings + # Go through each sentence separately sentence_embeddings = [] for subword_embeddings, word_ids in zip(embeddings, word_ids): @@ -167,13 +162,11 @@ def embed_batch(self, sentences, move_embeddings_to_cpu: bool = True) -> List[to # Stack all word-level embeddings that represent a sentence word_embeddings = torch.stack(word_embedding_list, dim=0) - # Pool word-level embeddings into a single sentence vector if specified + # Pool word-level embeddings into a sentence embedding sentence_embedding = self._pool_words(word_embeddings) if self.sentence_pooling else word_embeddings - # Store sentence-embedding tensors in a python list sentence_embeddings.append(sentence_embedding) - # Move embedding batch to cpu if move_embeddings_to_cpu: sentence_embeddings = [sentence_embedding.cpu() for sentence_embedding in sentence_embeddings] @@ -188,18 +181,22 @@ def _filter_layer_ids(self, layer_ids) -> List[int]: layer_ids = [int(number) for number in layer_ids.split(",")] layer_ids = [layer_id for layer_id in layer_ids if self.num_transformer_layers >= abs(layer_id)] + if not layer_ids: + raise ValueError(f"\"layer_ids\" are out of bounds for the model size. " + f"Num layers in model {self.model_name}: {self.num_transformer_layers}") + return layer_ids def _extract_relevant_layers(self, batched_embeddings: torch.Tensor) -> torch.Tensor: """Keep only relevant layers in each embedding and apply layer-wise pooling if required""" - # To maintain original layer order, map negative layer IDs to positive indices, + # Use positive layer ids ('-1 -> 23' is the last layer in a 24 layer model) layer_ids = sorted((layer_id if layer_id >= 0 else self.num_transformer_layers + layer_id) for layer_id in self.layer_ids) - # A batch of raw embeddings is in this shape (batch_size, sequence_length, num_layers - 1, hidden_size) + # A batch of embeddings is in this shape (batch_size, sequence_length, num_layers, hidden_size) batched_embeddings = batched_embeddings[:, :, layer_ids, :] # keep only selected layers - # Apply mean pooling over the layer dimension if specified + # average all layers if self.layer_pooling == "mean": batched_embeddings = torch.mean(batched_embeddings, dim=2, keepdim=True) diff --git a/transformer_ranker/ranker.py b/transformer_ranker/ranker.py index 6d8bc25..ff94ff2 100644 --- a/transformer_ranker/ranker.py +++ b/transformer_ranker/ranker.py @@ -20,7 +20,6 @@ def __init__( dataset_downsample: Optional[float] = None, text_column: Optional[str] = None, label_column: Optional[str] = None, - task_type: Optional[str] = None, **kwargs ): """ @@ -36,7 +35,6 @@ def __init__( self.data_handler = DatasetCleaner(dataset_downsample=dataset_downsample, text_column=text_column, label_column=label_column, - task_type=task_type, **kwargs, ) @@ -76,7 +74,7 @@ def run( self._confirm_ranker_setup(estimator=estimator, layer_aggregator=layer_aggregator) # Load all transformers into hf cache - self._preload_transformers(models) + self._preload_transformers(models, device) labels = self.data_handler.prepare_labels(self.dataset) @@ -157,13 +155,12 @@ def run( return ranking_results @staticmethod - def _preload_transformers(models: List[Union[str, torch.nn.Module]]) -> None: + def _preload_transformers(models: List[Union[str, torch.nn.Module]], device: Optional[str] = None) -> None: """Loads all models into HuggingFace cache""" cached_models, download_models = [], [] - for model_name in models: try: - Embedder(model_name, local_files_only=True) + Embedder(model_name, local_files_only=True, device=device) cached_models.append(model_name) except OSError: download_models.append(model_name) @@ -172,7 +169,7 @@ def _preload_transformers(models: List[Union[str, torch.nn.Module]]) -> None: logger.info(f"Downloading models: {download_models}") if download_models else None for model_name in models: - Embedder(model_name) + Embedder(model_name, device=device) def _confirm_ranker_setup(self, estimator, layer_aggregator) -> None: """Validate estimator and layer selection setup""" @@ -192,15 +189,15 @@ def _confirm_ranker_setup(self, estimator, layer_aggregator) -> None: "task_type= \"text classification\", \"token classification\", or " "\"text regression\"") + if self.task_type == 'text regression' and estimator == 'hscore': + raise ValueError(f"\"{estimator}\" does not support text regression. " + f"Use one of the following estimators: {valid_estimators.remove('hscore')}") + def _estimate_score(self, estimator, embeddings: torch.Tensor, labels: torch.Tensor) -> float: """Use an estimator to score a transformer""" - regression = self.task_type == "text regression" - if estimator in ['hscore'] and regression: - logger.warning(f'Specified estimator="{estimator}" does not support regression tasks.') - estimator_classes = { - "knn": KNN(k=3, regression=regression), - "logme": LogME(regression=regression), + "knn": KNN(k=3, regression=(self.task_type == "text regression")), + "logme": LogME(regression=(self.task_type == "text regression")), "hscore": HScore(), }