From 533e7d550d1e6da99905acfcc932427389591611 Mon Sep 17 00:00:00 2001 From: shibing624 Date: Fri, 6 Sep 2024 09:36:09 +0800 Subject: [PATCH] udpate device. --- examples/sentence_embeddings_zh_demo.py | 2 +- tests/test_qps.py | 2 +- text2vec/sentence_model.py | 11 +++++------ 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/examples/sentence_embeddings_zh_demo.py b/examples/sentence_embeddings_zh_demo.py index 7386588..8240652 100644 --- a/examples/sentence_embeddings_zh_demo.py +++ b/examples/sentence_embeddings_zh_demo.py @@ -9,7 +9,7 @@ import sys sys.path.append('..') -from text2vec import SentenceModel, cos_sim, semantic_search, Similarity, EncoderType +from text2vec import SentenceModel, cos_sim, semantic_search, Similarity if __name__ == '__main__': m = SentenceModel("shibing624/text2vec-base-chinese") diff --git a/tests/test_qps.py b/tests/test_qps.py index 732fd47..dbc20c5 100644 --- a/tests/test_qps.py +++ b/tests/test_qps.py @@ -27,7 +27,7 @@ print("data:", data) num_tokens = sum([len(i) for i in data]) use_cuda = torch.cuda.is_available() -repeat = 10 if use_cuda else 1 +repeat = 10 if use_cuda else 4 class TransformersEncoder: diff --git a/text2vec/sentence_model.py b/text2vec/sentence_model.py index 605aee5..27c98c2 100644 --- a/text2vec/sentence_model.py +++ b/text2vec/sentence_model.py @@ -8,14 +8,13 @@ import os import queue from enum import Enum -from typing import List, Union, Optional, Dict +from typing import List, Union, Optional, Dict, Literal import numpy as np import torch import torch.multiprocessing as mp from loguru import logger -from torch.utils.data import DataLoader -from torch.utils.data import Dataset +from torch.utils.data import DataLoader, Dataset from tqdm import tqdm from tqdm.autonotebook import trange from transformers import AutoTokenizer, AutoModel @@ -48,9 +47,9 @@ class SentenceModel: def __init__( self, model_name_or_path: str = "shibing624/text2vec-base-chinese", - encoder_type: Union[str, EncoderType] = "MEAN", + encoder_type: Literal["FIRST_LAST_AVG", "LAST_AVG", "CLS", "POOLER", "MEAN"] = "MEAN", max_seq_length: int = 256, - device: Optional[str] = None, + device: Optional[Literal["cpu", "gpu", "mps"]] = None, ): """ Initializes the base sentence model. @@ -159,7 +158,7 @@ def encode( show_progress_bar: bool = False, convert_to_numpy: bool = True, convert_to_tensor: bool = False, - device: str = None, + device: Optional[Literal["cpu", "gpu", "mps"]] = None, normalize_embeddings: bool = False, max_seq_length: int = None, ):