Skip to content

Commit

Permalink
Merge branch 'changed-to-doc2vec'
Browse files Browse the repository at this point in the history
  • Loading branch information
ryogrid committed Oct 29, 2024
2 parents 9d0cb89 + f578295 commit 8b5593f
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 8 deletions.
13 changes: 6 additions & 7 deletions genmodel.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import argparse
import sys

from gensim import corpora
Expand All @@ -13,7 +12,7 @@
TRAIN_EPOCHS = 100

# generate corpus for gensim and index text file for search tool
def read_documents_and_gen_idx_text(file_path: str) -> Tuple[List[str],List[TaggedDocument]]:
def read_documents_and_gen_idx_text(file_path: str) -> Tuple[List[List[str]],List[TaggedDocument]]:
processed_docs: List[List[str]] = []
tagged_docs: List[TaggedDocument] = []
idx_text_fpath: str = file_path.split('.')[0] + '_doc2vec_idx.csv'
Expand Down Expand Up @@ -98,9 +97,9 @@ def main(arg_str: list[str]) -> None:
level=logging.DEBUG
)

parser: argparse.ArgumentParser = argparse.ArgumentParser()
parser.add_argument('--dim', nargs=1, type=int, required=True, help='number of dimensions at LSI model')
args: argparse.Namespace = parser.parse_args(arg_str)
# parser: argparse.ArgumentParser = argparse.ArgumentParser()
# parser.add_argument('--dim', nargs=1, type=int, required=True, help='number of dimensions at LSI model')
# args: argparse.Namespace = parser.parse_args(arg_str)

tmp_tuple : [List[List[str]], List[TaggedDocument]] = read_documents_and_gen_idx_text('tags-wd-tagger.txt')
processed_docs: List[List[str]] = tmp_tuple[0]
Expand All @@ -115,7 +114,7 @@ def main(arg_str: list[str]) -> None:
pickle.dump(dictionary, f)

# gen Doc2Vec model with specified number of dimensions
doc2vec_model: Doc2Vec = Doc2Vec(vector_size=args.dim[0], window=50, min_count=1, workers=8, dm=0)
doc2vec_model: Doc2Vec = Doc2Vec(vector_size=300, window=50, min_count=1, workers=1, dm=0)
doc2vec_model.build_vocab(tagged_docs)
doc2vec_model.train(tagged_docs, total_examples=doc2vec_model.corpus_count, epochs=TRAIN_EPOCHS)
doc2vec_model.save("doc2vec_model")
Expand All @@ -127,7 +126,7 @@ def main(arg_str: list[str]) -> None:
for doc in processed_docs:
embed_vec = doc2vec_model.infer_vector(doc)
if index is None:
index = Similarity("doc2vec_index", [embed_vec], num_features=args.dim[0])
index = Similarity("doc2vec_index", [embed_vec], num_features=300)
else:
index.add_documents([embed_vec])

Expand Down
2 changes: 1 addition & 1 deletion webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def normalize_and_apply_weight_doc2vec(new_doc: str) -> List[Tuple[int, float]]:
tag_and_weight_list.append((tag_elem.replace('(', '\(').replace(')', '\)'), 1))
all_weight += 1

got_vector: ndarray = np.zeros(args.dim[0])
got_vector: ndarray = np.zeros(len(model.dv[0]))
for tag, weight in tag_and_weight_list:
got_vector += weight * model.infer_vector([tag])
got_vector = got_vector / all_weight
Expand Down

0 comments on commit 8b5593f

Please sign in to comment.