Skip to content

Commit

Permalink
Add column for text pairs
Browse files Browse the repository at this point in the history
  • Loading branch information
lukasgarbas committed Nov 11, 2024
1 parent 7f8d021 commit 6584960
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 93 deletions.
116 changes: 58 additions & 58 deletions transformer_ranker/datacleaner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,34 +17,34 @@ def __init__(
pre_tokenizer: Optional[Whitespace] = None,
exclude_test_split: bool = False,
merge_data_splits: bool = True,
change_ner_encoding_to_spans: bool = True,
remove_empty_sentences: bool = True,
dataset_downsample: Optional[float] = None,
task_type: Optional[str] = None,
text_column: Optional[str] = None,
label_column: Optional[str] = None,
label_map: Optional[Dict[str, int]] = None,
text_pair_column: Optional[str] = None,
remove_bio_notation: bool = True,
):
"""
Prepare huggingface dataset. Identify task type, find text and label columns, down-sample, merge data splits.
:param pre_tokenizer: Pre-tokenizer to use, such as Whitespace from huggingface pre-tokenizers.
:param exclude_test_split: Whether to exclude the test split.
:param merge_data_splits: Whether to merge train, dev, and test splits into one.
:param change_ner_encoding_to_spans: Whether to change BIO encoding to single class labels.
:param remove_bio_notation: Change BIO encoding to single class labels by removing B-, I-, O- prefixes
:param remove_empty_sentences: Whether to remove empty sentences.
:param dataset_downsample: Fraction to downsample the dataset to.
:param task_type: Type of task (e.g., 'sentence classification', 'word classification', 'sentence regression').
:param text_column: Column name where texts are stored.
:param label_column: Column name where labels are stored.
:param label_map: Mapping of labels to integers.
:param text_pair_column: Column name where the second text pair is stored. For entailment-type tasks.
:param dataset_downsample: Fraction to reduce the dataset size.
:param task_type: Task category (e.g., 'token classification', 'text classification', 'text regression').
:param text_column: Column name for texts.
:param label_column: Column name for labels.
:param label_map: A dictionary which maps label names to integers.
:param text_pair_column: Column name where the second text pair is stored (for entailment-like tasks)
"""
self.pre_tokenizer = pre_tokenizer
self.exclude_test_split = exclude_test_split
self.merge_data_splits = merge_data_splits
self.change_ner_encoding_to_spans = change_ner_encoding_to_spans
self.remove_bio_notation = remove_bio_notation
self.remove_empty_sentences = remove_empty_sentences
self.dataset_downsample = dataset_downsample
self.task_type = task_type
Expand Down Expand Up @@ -87,42 +87,42 @@ def prepare_dataset(self, dataset: Union[str, DatasetDict, Dataset]) -> Union[Da
self.text_column,
self.label_column)

# Determine task type based on label type if not specified
# Find task type based on label type
task_type = self._find_task_type(label_column, label_type) if not self.task_type else self.task_type

# Clean the dataset by removing empty sentences and negative labels
# Clean the dataset by removing empty sentences and empty/negative labels
if self.remove_empty_sentences:
dataset = self._remove_empty_rows(dataset, text_column, label_column)

# Down-sample the original dataset
if self.dataset_downsample:
dataset = self._downsample(dataset, ratio=self.dataset_downsample)

# Pre-tokenize sentences if pre-tokenizer is specified
if not task_type == "word classification" and self.pre_tokenizer:
# Pre-tokenize sentences if pre-tokenizer is given
if not task_type == "token classification" and self.pre_tokenizer:
dataset = self._tokenize(dataset, self.pre_tokenizer, text_column)

# Concatenate text columns for text-pair tasks
if self.text_pair_column:
dataset = self._merge_textpairs(dataset, text_column, self.text_pair_column)
dataset, text_column = self._merge_textpairs(dataset, text_column, self.text_pair_column)

# Convert string labels to integers
if label_type == str:
dataset, label_map = self._make_labels_categorical(dataset, label_column)
logger.info(f"Label map: {label_map}")

# Change NER encoding to spans if specified
if task_type == "word classification" and self.change_ner_encoding_to_spans:
dataset, self.label_map = self._change_to_span_encoding(dataset, label_column, self.label_map)
# Remove BIO prefixes for ner or chunking tasks
if task_type == "token classification" and self.remove_bio_notation:
dataset, self.label_map = self._remove_bio_notation(dataset, label_column, self.label_map)

# Store updated attributes and log them
# Set updated attributes and log them
self.text_column = text_column
self.label_column = label_column
self.task_type = task_type
self.dataset_size = len(dataset)
self.log_dataset_info()

# Simplify the dataset: keep only relevant columns
# Keep only text and label columns
keep_columns = [col for col in (self.text_column, self.text_pair_column, self.label_column) if col is not None]
dataset = self._remove_columns(dataset, keep_columns=keep_columns)

Expand Down Expand Up @@ -177,13 +177,16 @@ def _find_text_and_label_columns(dataset: Dataset, text_column: Optional[str] =
return text_column, label_column, label_type

@staticmethod
def _merge_textpairs(dataset: Dataset, text_column: str, text_pair_column: str) -> Dataset:
def _merge_textpairs(dataset: Dataset, text_column: str, text_pair_column: str) -> Tuple[Dataset, str]:
"""Concatenate text pairs into a single text using separator token"""
new_text_column_name = text_column + '+' + text_pair_column

def merge_texts(example: Dict[str, str]) -> Dict[str, str]:
example[text_column] = example[text_column] + " [SEP] " + example[text_pair_column]
example[new_text_column_name] = example.pop(text_column)
return example
dataset = dataset.map(merge_texts, num_proc=None, desc="Merge sentence pair columns")
return dataset
dataset = dataset.map(merge_texts, num_proc=None, desc="Merging text pair columns")
return dataset, new_text_column_name

@staticmethod
def _find_task_type(label_column: str, label_type: type) -> str:
Expand Down Expand Up @@ -261,60 +264,49 @@ def map_labels(example):
return dataset, label_map

@staticmethod
def _change_to_span_encoding(
def _remove_bio_notation(
dataset: Dataset,
label_column: str,
label_map: Optional[Dict[str, int]] = None,
) -> Tuple[Dataset, Dict[str, int]]:
"""Remove BIO prefixes for NER labels and create a new label map.
Example: ['B-PER', 'I-PER', 'O'] -> ['PER', 'PER', 'O']
Original label map: {'B-PER': 0, 'I-PER': 1, 'O': 2}
Converted span label map: {'PER': 0, 'O': 1}
Example: ['O', 'B-PER', 'I-PER'] -> ['O', 'PER', 'PER']
Original label map: {'O': 0, 'B-PER': 1, 'I-PER': 2}
Label map without BIO notation: {'O': 0, 'PER': 1}
:param dataset: The dataset containing BIO labels.
:param label_column: The name of the label column.
:param label_map: Optional dictionary to map BIO labels to integers. If not provided, a new one will be created.
:return: A tuple with the dataset containing new labels and the updated label map.
"""
# Attempt to get the label map from dataset features information
if not label_map:
features = dataset.features
if label_column in features and hasattr(features[label_column], 'feature') and hasattr(
features[label_column].feature, 'names'):
label_map = {name: idx for idx, name in enumerate(features[label_column].feature.names)}
else:
# Create label map manually if not found
try:
# Attempt to get the label map from dataset feature information
label_map = {label: idx for idx, label in enumerate(dataset.features[label_column].feature.names)}
except AttributeError:
# Try to create label map manually
logger.info('Label map not found. Creating manually...')
unique_labels: Set[str] = set()
label_data = dataset[label_column] if isinstance(dataset, Dataset) else [dataset[split][label_column]
for split in dataset]
for label_list in label_data:

for label_list in dataset[label_column]:
unique_labels.update(
label.split('-')[-1] if isinstance(label, str) else str(label) for label in label_list)
label_map = {label: idx for idx, label in enumerate(sorted(unique_labels, key=int))}

logger.info(f"Label map: {label_map}")

# Remove BIO encoding from the label map
span_label_map: Dict[str, int] = {}
# Remove BIO encoding and create a new label map
new_label_map: Dict[str, int] = {}
for label in label_map:
main_label = label.split('-')[-1] if isinstance(label, str) else label
if main_label not in span_label_map:
span_label_map[main_label] = len(span_label_map)

logger.info(f"Simplified label map: {span_label_map}")
if main_label not in new_label_map:
new_label_map[main_label] = len(new_label_map)

if label_map == span_label_map:
logger.warning("Could not convert BIO labels to span labels. "
"Please add the label map as parameter label_map: Dict[str, int] = ... manually.")

# Create a reverse map from the original integer labels to the simplified span labels
# Create a reverse map from original integer labels to labels without BIO prefixes
reverse_map = {}
for original_label, index in label_map.items():
main_label = original_label.split('-')[-1] if isinstance(original_label, str) else original_label
reverse_map[index] = span_label_map[main_label]
reverse_map[index] = new_label_map[main_label]

# Map labels to their corresponding span encoding
# Map labels to their class labels without BIO
def map_to_spans(example):
example_labels = example[label_column]
new_labels = [reverse_map[bio_label] for bio_label in example_labels]
Expand All @@ -323,15 +315,23 @@ def map_to_spans(example):

if isinstance(dataset, DatasetDict):
for split in dataset.keys():
dataset[split] = dataset[split].map(map_to_spans, num_proc=None, desc="Mapping BIO to span encoding")
dataset[split] = dataset[split].map(map_to_spans, num_proc=None, desc="Removing BIO encoding")
else:
dataset = dataset.map(map_to_spans, num_proc=None, desc="Mapping BIO to span encoding")
dataset = dataset.map(map_to_spans, num_proc=None, desc="Removing BIO encoding")

if label_map == new_label_map:
logger.warning("Could not remove BIO prefixes for this tagging dataset. "
"Please add the label map as parameter label_map: Dict[str, int] = ... manually.")
else:
logger.info(f"Label map: {label_map}")
logger.info(f"New label map: {new_label_map}")

return dataset, span_label_map
return dataset, new_label_map

def log_dataset_info(self) -> None:
"""Log information about dataset"""
logger.info(f"Text and label columns: '{self.text_column}', '{self.label_column}'")
logger.info(f"Task type identified: '{self.task_type}'")
downsample_info = f"(down-sampled to {self.dataset_downsample})" if self.dataset_downsample else ""
logger.info(f"Dataset size: {self.dataset_size} {downsample_info}")
logger.info(f"Texts and labels: '{self.text_column}', '{self.label_column}'")
logger.info(f"Task category: '{self.task_type}'")
is_downsampled = self.dataset_downsample and self.dataset_downsample < 1.0
downsample_info = f"(down-sampled to {self.dataset_downsample})" if is_downsampled else ""
logger.info(f"Dataset size: {self.dataset_size} texts {downsample_info}")
41 changes: 19 additions & 22 deletions transformer_ranker/embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,17 @@ def __init__(
device: Optional[str] = None,
):
"""
Embed texts using a pre-trained transformer model. This embedder works at the word level, representing each
text as a list of word vectors. It supports various sub-word pooling and effective sentence pooling options.
♻️ Feel free to use it if you ever need a simple implementation for transformer embeddings.
Embed texts using a pre-trained transformer model. This embedder works at the word level, where
each text is a list of word vectors. It supports different sub-word pooling and sentence pooling options.
♻️ Feel free to use it if you need a simple implementation for word or text embeddings.
:param model: The model to use, either by name (e.g., 'bert-base-uncased') or a loaded model instance.
:param tokenizer: Optional tokenizer, either by name or a loaded tokenizer instance.
:param subword_pooling: Method for pooling sub-word embeddings into word-level embeddings.
:param model: Model name 'bert-base-uncased' or a model instance e.g. AutoModel.from_pretrained(...)
:param tokenizer: Optional tokenizer, either a string name or a tokenizer instance.
:param subword_pooling: Method for pooling sub-words into word embeddings.
:param layer_ids: Layers to use e.g., '-1' for the top layer, '-1,-2' for multiple, or 'all'. Default is 'all'.
:param layer_pooling: Optional method for pooling across selected layers.
:param use_pretokenizer: Whether to pre-tokenize texts using whitespace.
:param device: Device for computations, either 'cpu' or 'cuda'. Defaults to the available device.
:param device: Device for computations, either 'cpu' or 'cuda:0'. Defaults to the available device.
"""
# Load transformer model
if isinstance(model, torch.nn.Module):
Expand All @@ -42,12 +42,10 @@ def __init__(

# Load a model-specific tokenizer
self.tokenizer: PreTrainedTokenizerFast
tokenizer_source = tokenizer if isinstance(tokenizer, str) else self.model_name

# Assign or load tokenizer
if isinstance(tokenizer, PreTrainedTokenizerFast):
self.tokenizer = tokenizer
else:
tokenizer_source = tokenizer if isinstance(tokenizer, str) else self.model_name
self.tokenizer = AutoTokenizer.from_pretrained(
tokenizer_source,
add_prefix_space=True,
Expand All @@ -68,19 +66,16 @@ def __init__(
# Set relevant layers that will be used for embeddings
self.layer_ids = self._filter_layer_ids(layer_ids)

# Set pooling operations for sub-words and layers
# Set pooling options
self.subword_pooling = subword_pooling
self.layer_pooling = layer_pooling

# Set sentence-pooling to get embedding for the full text if specified
self.sentence_pooling = sentence_pooling

# Set cpu or gpu device
if device is None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
self.device = torch.device(device)

self.model = self.model.to(self.device)

def tokenize(self, sentences):
Expand Down Expand Up @@ -141,7 +136,7 @@ def embed_batch(self, sentences, move_embeddings_to_cpu: bool = True) -> List[to
attention_mask = tokenized_input["attention_mask"].to(self.device)
word_ids = [tokenized_input.word_ids(i) for i in range(len(sentences))]

# Embedd using a transformer: forward pass to get all hidden states of the model
# Embedd: forward pass to get all hidden states of the model
with torch.no_grad():
hidden_states = self.model(
input_ids, attention_mask=attention_mask, output_hidden_states=True
Expand All @@ -157,7 +152,7 @@ def embed_batch(self, sentences, move_embeddings_to_cpu: bool = True) -> List[to
# Extract layers defined by layer_ids, average all layers for a batch of sentences if specified
embeddings = self._extract_relevant_layers(embeddings)

# Process each sentence separately and gather word or sentence embeddings
# Go through each sentence separately
sentence_embeddings = []
for subword_embeddings, word_ids in zip(embeddings, word_ids):

Expand All @@ -167,13 +162,11 @@ def embed_batch(self, sentences, move_embeddings_to_cpu: bool = True) -> List[to
# Stack all word-level embeddings that represent a sentence
word_embeddings = torch.stack(word_embedding_list, dim=0)

# Pool word-level embeddings into a single sentence vector if specified
# Pool word-level embeddings into a sentence embedding
sentence_embedding = self._pool_words(word_embeddings) if self.sentence_pooling else word_embeddings

# Store sentence-embedding tensors in a python list
sentence_embeddings.append(sentence_embedding)

# Move embedding batch to cpu
if move_embeddings_to_cpu:
sentence_embeddings = [sentence_embedding.cpu() for sentence_embedding in sentence_embeddings]

Expand All @@ -188,18 +181,22 @@ def _filter_layer_ids(self, layer_ids) -> List[int]:
layer_ids = [int(number) for number in layer_ids.split(",")]
layer_ids = [layer_id for layer_id in layer_ids if self.num_transformer_layers >= abs(layer_id)]

if not layer_ids:
raise ValueError(f"\"layer_ids\" are out of bounds for the model size. "
f"Num layers in model {self.model_name}: {self.num_transformer_layers}")

return layer_ids

def _extract_relevant_layers(self, batched_embeddings: torch.Tensor) -> torch.Tensor:
"""Keep only relevant layers in each embedding and apply layer-wise pooling if required"""
# To maintain original layer order, map negative layer IDs to positive indices,
# Use positive layer ids ('-1 -> 23' is the last layer in a 24 layer model)
layer_ids = sorted((layer_id if layer_id >= 0 else self.num_transformer_layers + layer_id)
for layer_id in self.layer_ids)

# A batch of raw embeddings is in this shape (batch_size, sequence_length, num_layers - 1, hidden_size)
# A batch of embeddings is in this shape (batch_size, sequence_length, num_layers, hidden_size)
batched_embeddings = batched_embeddings[:, :, layer_ids, :] # keep only selected layers

# Apply mean pooling over the layer dimension if specified
# average all layers
if self.layer_pooling == "mean":
batched_embeddings = torch.mean(batched_embeddings, dim=2, keepdim=True)

Expand Down
Loading

0 comments on commit 6584960

Please sign in to comment.