Skip to content

Commit

Permalink
udpate device.
Browse files Browse the repository at this point in the history
  • Loading branch information
shibing624 committed Sep 6, 2024
1 parent b2c4876 commit 533e7d5
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 8 deletions.
2 changes: 1 addition & 1 deletion examples/sentence_embeddings_zh_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import sys

sys.path.append('..')
from text2vec import SentenceModel, cos_sim, semantic_search, Similarity, EncoderType
from text2vec import SentenceModel, cos_sim, semantic_search, Similarity

if __name__ == '__main__':
m = SentenceModel("shibing624/text2vec-base-chinese")
Expand Down
2 changes: 1 addition & 1 deletion tests/test_qps.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
print("data:", data)
num_tokens = sum([len(i) for i in data])
use_cuda = torch.cuda.is_available()
repeat = 10 if use_cuda else 1
repeat = 10 if use_cuda else 4


class TransformersEncoder:
Expand Down
11 changes: 5 additions & 6 deletions text2vec/sentence_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,13 @@
import os
import queue
from enum import Enum
from typing import List, Union, Optional, Dict
from typing import List, Union, Optional, Dict, Literal

import numpy as np
import torch
import torch.multiprocessing as mp
from loguru import logger
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from tqdm.autonotebook import trange
from transformers import AutoTokenizer, AutoModel
Expand Down Expand Up @@ -48,9 +47,9 @@ class SentenceModel:
def __init__(
self,
model_name_or_path: str = "shibing624/text2vec-base-chinese",
encoder_type: Union[str, EncoderType] = "MEAN",
encoder_type: Literal["FIRST_LAST_AVG", "LAST_AVG", "CLS", "POOLER", "MEAN"] = "MEAN",
max_seq_length: int = 256,
device: Optional[str] = None,
device: Optional[Literal["cpu", "gpu", "mps"]] = None,
):
"""
Initializes the base sentence model.
Expand Down Expand Up @@ -159,7 +158,7 @@ def encode(
show_progress_bar: bool = False,
convert_to_numpy: bool = True,
convert_to_tensor: bool = False,
device: str = None,
device: Optional[Literal["cpu", "gpu", "mps"]] = None,
normalize_embeddings: bool = False,
max_seq_length: int = None,
):
Expand Down

0 comments on commit 533e7d5

Please sign in to comment.