diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml index 1596a44..59f0635 100644 --- a/.github/workflows/python-app.yml +++ b/.github/workflows/python-app.yml @@ -2,16 +2,9 @@ # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python name: Python application - -on: - push: - branches: [ "main" ] - pull_request: - branches: [ "main" ] - +on: [push, pull_request] permissions: contents: read - jobs: build: runs-on: ubuntu-latest @@ -29,4 +22,4 @@ jobs: - name: Mypy Check uses: jpetrucciani/mypy-check@master with: - path: . \ No newline at end of file + path: . diff --git a/.streamlit/config.toml b/.streamlit/config.toml new file mode 100644 index 0000000..e38c84a --- /dev/null +++ b/.streamlit/config.toml @@ -0,0 +1,5 @@ +[global] +developmentMode = false + +[server] +port = 8501 diff --git a/README.md b/README.md index f47be17..07d6f11 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ - Python 3.10.4 - pip 22.0.4 - $ pip install -r requirements.txt -- $ python make-tags-with-wd-tagger.py --dir "IMAGE FILES CONTAINED DIR PATH" +- $ python tagging.py --dir "IMAGE FILES CONTAINED DIR PATH" - The script searches directory structure recursively :) - This takes quite a while... - About 0.5 sec/file at middle spec desktop PC (GPU is not used) @@ -28,18 +28,20 @@ - Plese see [here](https://onnxruntime.ai/docs/execution-providers/) - Performance key is processing speed of ONNX Runtime at your machine :) - Image files and tags of these are saved to tags-wd-tagger.txt -- $ python count-unique-tag-num.py +- $ python counttag.py - => for deciding appropriate dimension scale fitting your data -- $ python gen-lsi-model.py - - **Please edit [num_topics paramater](https://github.com/ryogrid/local-illust-image-searcher/blob/main/gen-lsi-model.py#L51) before execution** - - I think about 80% of unique tags which is counted with count-unique-tag-num.py is better - - EX: unique tags count is 1000 -> 0.8 * 1000 -> 800 num_topics (dimension) + - unique tag count is shown +- $ python genmodel.py --dim MODEL_DIMENSION + - MODEL_DIMENSION is integer which specify dimension of latent sementic representation + - Dimension after applying LSI + - I think that 80% of unique tags which is counted with counttag.py is better + - EX: unique tags count is 1000 -> 0.8 * 1000 -> 800 (dimension) - This takes quite a while... - LSI processing: dimension reduction with [Singular Value Decomposition (SVD)](https://en.wikipedia.org/wiki/Singular_value_decomposition) - Take several secs only for 1000 files and reduction from 800 to 700 dimension case (case of demo on later section) - But, in 340k files and from 7500 to 6000 dimension case, about 3.5 hour are taken - files are not for demo :) -- $ streamlit run web-ui-image-search-lsi.py +- $ streamlit run webui.py - Search app is opend on your web browser ## Tips (Attention) diff --git a/cmd_run.py b/cmd_run.py new file mode 100644 index 0000000..84f1c61 --- /dev/null +++ b/cmd_run.py @@ -0,0 +1,27 @@ +import argparse +import sys + +import tagging +import genmodel +import counttag + +def main() -> None: + parser: argparse.ArgumentParser = argparse.ArgumentParser() + parser.add_argument('command', nargs=1, help='command to run') + # dummy + parser.add_argument('--dir', nargs=1, help='') + # dummy + parser.add_argument('--dim', nargs=1, type=int, help='') + args: argparse.Namespace = parser.parse_args() + + if args.command[0] == 'tagging': + tagging.main(sys.argv[2:]) + elif args.command[0] == 'genmodel': + genmodel.main(sys.argv[2:]) + elif args.command[0] == 'counttag': + counttag.main() + else: + print('Invalid command') + exit(1) + +main() \ No newline at end of file diff --git a/count-unique-tag-num.py b/count-unique-tag-num.py deleted file mode 100644 index e65af74..0000000 --- a/count-unique-tag-num.py +++ /dev/null @@ -1,12 +0,0 @@ -# -- coding: utf-8 -- - -from typing import Dict, List - -tag_map: Dict[str, bool] = {} -with open('tags-wd-tagger.txt', 'r', encoding='utf-8') as f: - for line in f: - tags: List[str] = line.strip().split(',') - tags = tags[1:-1] - for tag in tags: - tag_map[tag] = True -print(f'{len(tag_map)} unique tags found') \ No newline at end of file diff --git a/counttag.py b/counttag.py new file mode 100644 index 0000000..87ec25c --- /dev/null +++ b/counttag.py @@ -0,0 +1,16 @@ +# -- coding: utf-8 -- + +from typing import Dict, List + +def main() -> None: + tag_map: Dict[str, bool] = {} + with open('tags-wd-tagger.txt', 'r', encoding='utf-8') as f: + for line in f: + tags: List[str] = line.strip().split(',') + tags = tags[1:-1] + for tag in tags: + tag_map[tag] = True + print(f'{len(tag_map)} unique tags found') + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/gen-lsi-model.py b/genmodel.py similarity index 84% rename from gen-lsi-model.py rename to genmodel.py index 32f19bc..2f9fe79 100644 --- a/gen-lsi-model.py +++ b/genmodel.py @@ -1,70 +1,74 @@ -from gensim import corpora -from gensim.models import LsiModel -from gensim.similarities import MatrixSimilarity -from gensim.utils import simple_preprocess -import pickle -from typing import List, Tuple -import logging - -# generate corpus for gensim and index text file for search tool -def read_documents_and_gen_idx_text(file_path: str) -> List[List[str]]: - corpus_base: List[List[str]] = [] - idx_text_fpath: str = file_path.split('.')[0] + '_lsi_idx.csv' - with open(idx_text_fpath, 'w', encoding='utf-8') as idx_f: - with open(file_path, 'r', encoding='utf-8') as f: - for line in f: - row: List[str] = line.split(",") - # remove file path element - row = row[1:] - # # remove last element - # row = row[:-1] - - # join tags with space for gensim - tags_line: str = ' '.join(row) - # tokens: List[str] = simple_preprocess(tags_line.strip()) - tokens: List[str] = row - # ignore simple_preprocess failure case and short tags image - if tokens and len(tokens) >= 3: - corpus_base.append(tokens) - idx_f.write(line) - idx_f.flush() - return corpus_base - -# read image files pathes from file -def read_documents(filename: str) -> List[str]: - with open(filename, 'r', encoding='utf-8') as file: - documents: List[str] = [line.strip() for line in file.readlines()] - return documents - -def main() -> None: - format_str = '%(asctime)s - %(name)s - %(levelname)s - %(message)s' - logging.basicConfig( - format=format_str, - level=logging.DEBUG - ) - - processed_docs: List[List[str]] = read_documents_and_gen_idx_text('tags-wd-tagger.txt') - - # image file => doc_id - dictionary: corpora.Dictionary = corpora.Dictionary(processed_docs) - # remove frequent tags - #dictionary.filter_n_most_frequent(500) - - with open('lsi_dictionary', 'wb') as f: - pickle.dump(dictionary, f) - - corpus: List[List[Tuple[int, int]]] = [dictionary.doc2bow(doc) for doc in processed_docs] - - # gen LSI model with specified number of topics (dimensions) - # ATTENTION: num_topics should be set to appropriate value!!! - lsi_model: LsiModel = LsiModel(corpus, id2word=dictionary, num_topics=800) - - lsi_model.save("lsi_model") - - # make similarity index - index: MatrixSimilarity = MatrixSimilarity(lsi_model[corpus]) - - index.save("lsi_index") - -if __name__ == "__main__": - main() \ No newline at end of file +import argparse +import sys + +from gensim import corpora +from gensim.models import LsiModel +from gensim.similarities import MatrixSimilarity +import pickle +from typing import List, Tuple +import logging + +# generate corpus for gensim and index text file for search tool +def read_documents_and_gen_idx_text(file_path: str) -> List[List[str]]: + corpus_base: List[List[str]] = [] + idx_text_fpath: str = file_path.split('.')[0] + '_lsi_idx.csv' + with open(idx_text_fpath, 'w', encoding='utf-8') as idx_f: + with open(file_path, 'r', encoding='utf-8') as f: + for line in f: + row: List[str] = line.split(",") + # remove file path element + row = row[1:] + + # tokens: List[str] = simple_preprocess(tags_line.strip()) + tokens: List[str] = row + # ignore simple_preprocess failure case and short tags image + if tokens and len(tokens) >= 3: + corpus_base.append(tokens) + idx_f.write(line) + idx_f.flush() + + return corpus_base + +# read image files pathes from file +def read_documents(filename: str) -> List[str]: + with open(filename, 'r', encoding='utf-8') as file: + documents: List[str] = [line.strip() for line in file.readlines()] + return documents + +def main(arg_str: list[str]) -> None: + format_str = '%(asctime)s - %(name)s - %(levelname)s - %(message)s' + logging.basicConfig( + format=format_str, + level=logging.DEBUG + ) + + parser: argparse.ArgumentParser = argparse.ArgumentParser() + parser.add_argument('--dim', nargs=1, type=int, required=True, help='number of dimensions at LSI model') + args: argparse.Namespace = parser.parse_args(arg_str) + + processed_docs: List[List[str]] = read_documents_and_gen_idx_text('tags-wd-tagger.txt') + + # image file => doc_id + dictionary: corpora.Dictionary = corpora.Dictionary(processed_docs) + # remove frequent tags + #dictionary.filter_n_most_frequent(500) + + with open('lsi_dictionary', 'wb') as f: + pickle.dump(dictionary, f) + + corpus: List[List[Tuple[int, int]]] = [dictionary.doc2bow(doc) for doc in processed_docs] + + # gen LSI model with specified number of topics (dimensions) + # ATTENTION: num_topics should be set to appropriate value!!! + # lsi_model: LsiModel = LsiModel(corpus, id2word=dictionary, num_topics=800) + lsi_model: LsiModel = LsiModel(corpus, id2word=dictionary, num_topics=args.dim[0]) + + lsi_model.save("lsi_model") + + # make similarity index + index: MatrixSimilarity = MatrixSimilarity(lsi_model[corpus]) + + index.save("lsi_index") + +if __name__ == "__main__": + main(sys.argv) \ No newline at end of file diff --git a/hooks/hook-streamlit.py b/hooks/hook-streamlit.py new file mode 100644 index 0000000..4689830 --- /dev/null +++ b/hooks/hook-streamlit.py @@ -0,0 +1,2 @@ +from PyInstaller.utils.hooks import copy_metadata +datas = copy_metadata('streamlit') diff --git a/packaging.bat b/packaging.bat index 772edac..f21e656 100644 --- a/packaging.bat +++ b/packaging.bat @@ -20,11 +20,8 @@ IF NOT EXIST "%output_dir%" ( MKDIR "%output_dir%" ) -REM For each Python script in the input directory -FOR %%f IN ("%input_dir%\*.py") DO ( - REM Package the script into an executable without the onefile option - pyinstaller --distpath "%output_dir%" --workpath "%output_dir%\build" --specpath "%output_dir%\spec" --noconfirm "%%f" -) +REM Package the script into an executable without the onefile option +pyinstaller --distpath "%output_dir%" --workpath "%output_dir%\build" --specpath "%output_dir%\spec" --noconfirm "%input_dir%\cmd_run.py REM Clean up build and spec files RMDIR /S /Q "%output_dir%\build" diff --git a/packaging_webui_step1.bat b/packaging_webui_step1.bat new file mode 100644 index 0000000..b1d2690 --- /dev/null +++ b/packaging_webui_step1.bat @@ -0,0 +1,12 @@ +@echo off +REM Batch file to package Python scripts into Windows executables +REM Usage: packaging_webui.bat + +REM Package the script into an executable without the onefile option +pyinstaller --additional-hooks-dir=./hooks --noconfirm run_webui.py --clean + +REM Clean up build and spec files +REM RMDIR /S /Q build +REM RMDIR /S /Q spec + +ECHO Packaging complete. diff --git a/packaging_webui_step2.bat b/packaging_webui_step2.bat new file mode 100644 index 0000000..a659ac8 --- /dev/null +++ b/packaging_webui_step2.bat @@ -0,0 +1,12 @@ +@echo off +REM Batch file to package Python scripts into Windows executables +REM Usage: packaging_webui.bat + +REM Package the script into an executable without the onefile option +pyinstaller --noconfirm run_webui.spec --clean + +REM Clean up build and spec files +REM RMDIR /S /Q build +REM RMDIR /S /Q __pycache__ + +ECHO Packaging complete. diff --git a/requirements_with_packager.txt b/requirements_with_packager.txt new file mode 100644 index 0000000..47d9440 Binary files /dev/null and b/requirements_with_packager.txt differ diff --git a/run_webui.py b/run_webui.py new file mode 100644 index 0000000..0b5a812 --- /dev/null +++ b/run_webui.py @@ -0,0 +1,11 @@ +import streamlit.web.cli as stcli +import os +import sys + +def streamlit_run(): + src = os.path.dirname(sys.executable) + '/webui.py' + sys.argv=['streamlit', 'run', src, '--global.developmentMode=false'] + sys.exit(stcli.main()) + +if __name__ == "__main__": + streamlit_run() \ No newline at end of file diff --git a/run_webui.spec b/run_webui.spec new file mode 100644 index 0000000..8b8987a --- /dev/null +++ b/run_webui.spec @@ -0,0 +1,69 @@ +# -*- mode: python ; coding: utf-8 -*- +import site +import os + +block_cipher = None + +assert len(site.getsitepackages()) > 0 + +# Choose path containing "site-packages" +package_path = site.getsitepackages()[0] +for p in site.getsitepackages(): + if "site-package" in p: + package_path = p + break + +a = Analysis( + ['run_webui.py'], + pathex=[], + binaries=[], + # copy streamlit files + datas=[(os.path.join(package_path, "altair/vegalite/v5/schema/vega-lite-schema.json"), + "./altair/vegalite/v5/schema/"), + (os.path.join(package_path, "streamlit/static"),"./streamlit/static"), + (os.path.join(package_path, "streamlit/runtime"),"./streamlit/runtime"), + ], + hiddenimports=[ + 'gensim', + 'gensim.models', + 'gensim.models.lsimodel', + 'gensim.similarities', + 'numpy', + ], + hookspath=['./hooks'], + hooksconfig={}, + runtime_hooks=[], + excludes=[], + win_no_prefer_redirects=False, + win_private_assemblies=False, + cipher=block_cipher, + noarchive=False, +) +pyz = PYZ(a.pure) + +exe = EXE( + pyz, + a.scripts, + [], + exclude_binaries=True, + name='run_webui', + debug=False, + bootloader_ignore_signals=False, + strip=False, + upx=True, + console=True, + disable_windowed_traceback=False, + argv_emulation=False, + target_arch=None, + codesign_identity=None, + entitlements_file=None, +) +coll = COLLECT( + exe, + a.binaries, + a.datas, + strip=False, + upx=True, + upx_exclude=[], + name='run_webui', +) diff --git a/make-tags-with-wd-tagger.py b/tagging.py similarity index 96% rename from make-tags-with-wd-tagger.py rename to tagging.py index e5f165f..4984260 100644 --- a/make-tags-with-wd-tagger.py +++ b/tagging.py @@ -1,249 +1,249 @@ -import os, time -import numpy as np -import onnxruntime as rt -from huggingface_hub import hf_hub_download -from PIL import Image -import pandas as pd -import argparse -import traceback, sys -import re -from typing import List, Tuple, Dict, Any, Optional, Callable, Protocol - -from numpy import signedinteger - -kaomojis: List[str] = [ - "0_0", - "(o)_(o)", - "+_+", - "+_-", - "._.", - "_", - "<|>_<|>", - "=_=", - ">_<", - "3_3", - "6_9", - ">_o", - "@_@", - "^_^", - "o_o", - "u_u", - "x_x", - "|_|", - "||_||", -] - -VIT_MODEL_DSV3_REPO: str = "SmilingWolf/wd-vit-tagger-v3" -MODEL_FILE_NAME: str = "model.onnx" -LABEL_FILENAME: str = "selected_tags.csv" - -EXTENSIONS: List[str] = ['.png', '.jpg', '.jpeg', ".PNG", ".JPG", ".JPEG"] - -def mcut_threshold(probs: np.ndarray) -> float: - sorted_probs: np.ndarray = probs[probs.argsort()[::-1]] - difs: np.ndarray = sorted_probs[:-1] - sorted_probs[1:] - t: signedinteger[Any] = difs.argmax() - thresh: float = (sorted_probs[t] + sorted_probs[t + 1]) / 2 - return thresh - -def load_labels(dataframe: pd.DataFrame) -> Tuple[List[str], List[int], List[int], List[int]]: - name_series: pd.Series = dataframe["name"] - name_series = name_series.map( - lambda x: x.replace("_", " ") if x not in kaomojis else x - ) - tag_names: List[str] = name_series.tolist() - - rating_indexes: List[int] = list(np.where(dataframe["category"] == 9)[0]) - general_indexes: List[int] = list(np.where(dataframe["category"] == 0)[0]) - character_indexes: List[int] = list(np.where(dataframe["category"] == 4)[0]) - return tag_names, rating_indexes, general_indexes, character_indexes - -def print_traceback() -> None: - tb: traceback.StackSummary = traceback.extract_tb(sys.exc_info()[2]) - trace: List[str] = traceback.format_list(tb) - print('---- traceback ----') - for line in trace: - if '~^~' in line: - print(line.rstrip()) - else: - text: str = re.sub(r'\n\s*', ' ', line.rstrip()) - print(text) - print('-------------------') - -def list_files_recursive(directory: str) -> List[str]: - file_list: List[str] = [] - for root, _, files in os.walk(directory): - for file in files: - file_path: str = os.path.join(root, file) - if any(file_path.endswith(ext) for ext in EXTENSIONS): - file_list.append(file_path) - return file_list - -class Predictor: - def __init__(self) -> None: - self.model_target_size: Optional[int] = None - self.last_loaded_repo: Optional[str] = None - self.tagger_model_path: Optional[str] = None - self.tagger_model: Optional[rt.InferenceSession] = None - self.tag_names: Optional[List[str]] = None - self.rating_indexes: Optional[List[int]] = None - self.general_indexes: Optional[List[int]] = None - self.character_indexes: Optional[List[int]] = None - - def prepare_image(self, image: Image.Image) -> np.ndarray: - target_size: int = self.model_target_size - - if image.mode in ('RGBA', 'LA'): - background: Image.Image = Image.new("RGB", image.size, (255, 255, 255)) - background.paste(image, mask=image.split()[-1]) - image = background - else: - image = image.convert("RGB") - - image_shape: Tuple[int, int] = image.size - max_dim: int = max(image_shape) - pad_left: int = (max_dim - image_shape[0]) // 2 - pad_top: int = (max_dim - image_shape[1]) // 2 - - padded_image: Image.Image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255)) - padded_image.paste(image, (pad_left, pad_top)) - - if max_dim != target_size: - padded_image = padded_image.resize( - (target_size, target_size), - Image.BICUBIC, - ) - - image_array: np.ndarray = np.asarray(padded_image, dtype=np.float32) - image_array = image_array[:, :, ::-1] - - return np.expand_dims(image_array, axis=0) - - def load_model(self) -> None: - if self.tagger_model is not None: - return - - self.tagger_model_path = hf_hub_download(repo_id=VIT_MODEL_DSV3_REPO, filename=MODEL_FILE_NAME) - self.tagger_model = rt.InferenceSession(self.tagger_model_path, providers=['CPUExecutionProvider']) - _, height, _, _ = self.tagger_model.get_inputs()[0].shape - - self.model_target_size = height - - csv_path: str = hf_hub_download( - VIT_MODEL_DSV3_REPO, - LABEL_FILENAME, - ) - tags_df: pd.DataFrame = pd.read_csv(csv_path) - sep_tags: Tuple[List[str], List[int], List[int], List[int]] = load_labels(tags_df) - - self.tag_names = sep_tags[0] - self.rating_indexes = sep_tags[1] - self.general_indexes = sep_tags[2] - self.character_indexes = sep_tags[3] - - def predict( - self, - image: Image.Image, - general_thresh: float, - general_mcut_enabled: bool, - character_thresh: float, - character_mcut_enabled: bool, - ) -> str: - img: np.ndarray = self.prepare_image(image) - - input_name: str = self.tagger_model.get_inputs()[0].name - label_name: str = self.tagger_model.get_outputs()[0].name - preds: np.ndarray = self.tagger_model.run([label_name], {input_name: img})[0] - - labels: List[Tuple[str, float]] = list(zip(self.tag_names, preds[0].astype(float))) - - general_names: List[Tuple[str, float]] = [labels[i] for i in self.general_indexes] - - if general_mcut_enabled: - general_probs: np.ndarray = np.array([x[1] for x in general_names]) - general_thresh = mcut_threshold(general_probs) - - general_res: Dict[str, float] = {x[0]: x[1] for x in general_names if x[1] > general_thresh} - - character_names: List[Tuple[str, float]] = [labels[i] for i in self.character_indexes] - - if character_mcut_enabled: - character_probs: np.ndarray = np.array([x[1] for x in character_names]) - character_thresh = mcut_threshold(character_probs) - character_thresh = max(0.15, character_thresh) - - character_res: Dict[str, float] = {x[0]: x[1] for x in character_names if x[1] > character_thresh} - - sorted_general_strings: List[Tuple[str, float]] = sorted( - general_res.items(), - key=lambda x: x[1], - reverse=True, - ) - sorted_general_strings_str : List[str] = [x[0] for x in sorted_general_strings] - sorted_general_strings_str = [x.replace(' ', '_') for x in sorted_general_strings_str] - ret_string: str = ( - ",".join(sorted_general_strings_str).replace("(", "\(").replace(")", "\)") - ) - - if len(character_res) > 0: - sorted_character_strings: List[Tuple[str, float]] = sorted( - character_res.items(), - key=lambda x: x[1], - reverse=True, - ) - sorted_character_strings_str: List[str] = [x[0] for x in sorted_character_strings] - sorted_character_strings_str = [x.replace(' ', '_') for x in sorted_character_strings_str] - ret_string += ",".join(sorted_character_strings_str).replace("(", "\(").replace(")", "\)") - - return ret_string - - def write_to_file(self, csv_line: str) -> None: - self.f.write(csv_line + '\n') - self.f.flush() - - def process_directory(self, directory: str) -> None: - file_list: List[str] = list_files_recursive(directory) - print(f'{len(file_list)} files found') - - self.f = open('tags-wd-tagger.txt', 'a', encoding='utf-8') - - self.load_model() - - start: float = time.perf_counter() - cnt: int = 0 - for file_path in file_list: - try: - img: Image.Image = Image.open(file_path) - results_in_csv_format: str = self.predict(img, 0.3, True, 0.3, True) - - self.write_to_file(file_path + ',' + results_in_csv_format) - - if cnt % 100 == 0: - now: float = time.perf_counter() - print(f'{cnt} files processed') - diff: float = now - start - print('{:.2f} seconds elapsed'.format(diff)) - if cnt > 0: - time_per_file: float = diff / cnt - print('{:.4f} seconds per file'.format(time_per_file)) - print("", flush=True) - - cnt += 1 - except Exception as e: - error_class: type = type(e) - error_description: str = str(e) - err_msg: str = '%s: %s' % (error_class, error_description) - print(err_msg) - print_traceback() - pass - -def main() -> None: - parser: argparse.ArgumentParser = argparse.ArgumentParser() - parser.add_argument('--dir', nargs=1, required=True, help='tagging target directory path') - args: argparse.Namespace = parser.parse_args() - - predictor: Predictor = Predictor() - predictor.process_directory(args.dir[0]) - -if __name__ == "__main__": - main() \ No newline at end of file +import os, time +import numpy as np +import onnxruntime as rt +from huggingface_hub import hf_hub_download +from PIL import Image +import pandas as pd +import argparse +import traceback, sys +import re +from typing import List, Tuple, Dict, Any, Optional, Callable, Protocol + +from numpy import signedinteger + +kaomojis: List[str] = [ + "0_0", + "(o)_(o)", + "+_+", + "+_-", + "._.", + "_", + "<|>_<|>", + "=_=", + ">_<", + "3_3", + "6_9", + ">_o", + "@_@", + "^_^", + "o_o", + "u_u", + "x_x", + "|_|", + "||_||", +] + +VIT_MODEL_DSV3_REPO: str = "SmilingWolf/wd-vit-tagger-v3" +MODEL_FILE_NAME: str = "model.onnx" +LABEL_FILENAME: str = "selected_tags.csv" + +EXTENSIONS: List[str] = ['.png', '.jpg', '.jpeg', ".PNG", ".JPG", ".JPEG"] + +def mcut_threshold(probs: np.ndarray) -> float: + sorted_probs: np.ndarray = probs[probs.argsort()[::-1]] + difs: np.ndarray = sorted_probs[:-1] - sorted_probs[1:] + t: signedinteger[Any] = difs.argmax() + thresh: float = (sorted_probs[t] + sorted_probs[t + 1]) / 2 + return thresh + +def load_labels(dataframe: pd.DataFrame) -> Tuple[List[str], List[int], List[int], List[int]]: + name_series: pd.Series = dataframe["name"] + name_series = name_series.map( + lambda x: x.replace("_", " ") if x not in kaomojis else x + ) + tag_names: List[str] = name_series.tolist() + + rating_indexes: List[int] = list(np.where(dataframe["category"] == 9)[0]) + general_indexes: List[int] = list(np.where(dataframe["category"] == 0)[0]) + character_indexes: List[int] = list(np.where(dataframe["category"] == 4)[0]) + return tag_names, rating_indexes, general_indexes, character_indexes + +def print_traceback() -> None: + tb: traceback.StackSummary = traceback.extract_tb(sys.exc_info()[2]) + trace: List[str] = traceback.format_list(tb) + print('---- traceback ----') + for line in trace: + if '~^~' in line: + print(line.rstrip()) + else: + text: str = re.sub(r'\n\s*', ' ', line.rstrip()) + print(text) + print('-------------------') + +def list_files_recursive(directory: str) -> List[str]: + file_list: List[str] = [] + for root, _, files in os.walk(directory): + for file in files: + file_path: str = os.path.join(root, file) + if any(file_path.endswith(ext) for ext in EXTENSIONS): + file_list.append(file_path) + return file_list + +class Predictor: + def __init__(self) -> None: + self.model_target_size: Optional[int] = None + self.last_loaded_repo: Optional[str] = None + self.tagger_model_path: Optional[str] = None + self.tagger_model: Optional[rt.InferenceSession] = None + self.tag_names: Optional[List[str]] = None + self.rating_indexes: Optional[List[int]] = None + self.general_indexes: Optional[List[int]] = None + self.character_indexes: Optional[List[int]] = None + + def prepare_image(self, image: Image.Image) -> np.ndarray: + target_size: int = self.model_target_size + + if image.mode in ('RGBA', 'LA'): + background: Image.Image = Image.new("RGB", image.size, (255, 255, 255)) + background.paste(image, mask=image.split()[-1]) + image = background + else: + image = image.convert("RGB") + + image_shape: Tuple[int, int] = image.size + max_dim: int = max(image_shape) + pad_left: int = (max_dim - image_shape[0]) // 2 + pad_top: int = (max_dim - image_shape[1]) // 2 + + padded_image: Image.Image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255)) + padded_image.paste(image, (pad_left, pad_top)) + + if max_dim != target_size: + padded_image = padded_image.resize( + (target_size, target_size), + Image.BICUBIC, + ) + + image_array: np.ndarray = np.asarray(padded_image, dtype=np.float32) + image_array = image_array[:, :, ::-1] + + return np.expand_dims(image_array, axis=0) + + def load_model(self) -> None: + if self.tagger_model is not None: + return + + self.tagger_model_path = hf_hub_download(repo_id=VIT_MODEL_DSV3_REPO, filename=MODEL_FILE_NAME) + self.tagger_model = rt.InferenceSession(self.tagger_model_path, providers=['CPUExecutionProvider']) + _, height, _, _ = self.tagger_model.get_inputs()[0].shape + + self.model_target_size = height + + csv_path: str = hf_hub_download( + VIT_MODEL_DSV3_REPO, + LABEL_FILENAME, + ) + tags_df: pd.DataFrame = pd.read_csv(csv_path) + sep_tags: Tuple[List[str], List[int], List[int], List[int]] = load_labels(tags_df) + + self.tag_names = sep_tags[0] + self.rating_indexes = sep_tags[1] + self.general_indexes = sep_tags[2] + self.character_indexes = sep_tags[3] + + def predict( + self, + image: Image.Image, + general_thresh: float, + general_mcut_enabled: bool, + character_thresh: float, + character_mcut_enabled: bool, + ) -> str: + img: np.ndarray = self.prepare_image(image) + + input_name: str = self.tagger_model.get_inputs()[0].name + label_name: str = self.tagger_model.get_outputs()[0].name + preds: np.ndarray = self.tagger_model.run([label_name], {input_name: img})[0] + + labels: List[Tuple[str, float]] = list(zip(self.tag_names, preds[0].astype(float))) + + general_names: List[Tuple[str, float]] = [labels[i] for i in self.general_indexes] + + if general_mcut_enabled: + general_probs: np.ndarray = np.array([x[1] for x in general_names]) + general_thresh = mcut_threshold(general_probs) + + general_res: Dict[str, float] = {x[0]: x[1] for x in general_names if x[1] > general_thresh} + + character_names: List[Tuple[str, float]] = [labels[i] for i in self.character_indexes] + + if character_mcut_enabled: + character_probs: np.ndarray = np.array([x[1] for x in character_names]) + character_thresh = mcut_threshold(character_probs) + character_thresh = max(0.15, character_thresh) + + character_res: Dict[str, float] = {x[0]: x[1] for x in character_names if x[1] > character_thresh} + + sorted_general_strings: List[Tuple[str, float]] = sorted( + general_res.items(), + key=lambda x: x[1], + reverse=True, + ) + sorted_general_strings_str : List[str] = [x[0] for x in sorted_general_strings] + sorted_general_strings_str = [x.replace(' ', '_') for x in sorted_general_strings_str] + ret_string: str = ( + ",".join(sorted_general_strings_str).replace("(", "\(").replace(")", "\)") + ) + + if len(character_res) > 0: + sorted_character_strings: List[Tuple[str, float]] = sorted( + character_res.items(), + key=lambda x: x[1], + reverse=True, + ) + sorted_character_strings_str: List[str] = [x[0] for x in sorted_character_strings] + sorted_character_strings_str = [x.replace(' ', '_') for x in sorted_character_strings_str] + ret_string += ",".join(sorted_character_strings_str).replace("(", "\(").replace(")", "\)") + + return ret_string + + def write_to_file(self, csv_line: str) -> None: + self.f.write(csv_line + '\n') + self.f.flush() + + def process_directory(self, directory: str) -> None: + file_list: List[str] = list_files_recursive(directory) + print(f'{len(file_list)} files found') + + self.f = open('tags-wd-tagger.txt', 'a', encoding='utf-8') + + self.load_model() + + start: float = time.perf_counter() + cnt: int = 0 + for file_path in file_list: + try: + img: Image.Image = Image.open(file_path) + results_in_csv_format: str = self.predict(img, 0.3, True, 0.3, True) + + self.write_to_file(file_path + ',' + results_in_csv_format) + + if cnt % 100 == 0: + now: float = time.perf_counter() + print(f'{cnt} files processed') + diff: float = now - start + print('{:.2f} seconds elapsed'.format(diff)) + if cnt > 0: + time_per_file: float = diff / cnt + print('{:.4f} seconds per file'.format(time_per_file)) + print("", flush=True) + + cnt += 1 + except Exception as e: + error_class: type = type(e) + error_description: str = str(e) + err_msg: str = '%s: %s' % (error_class, error_description) + print(err_msg) + print_traceback() + pass + +def main(arg_str: list[str]) -> None: + parser: argparse.ArgumentParser = argparse.ArgumentParser() + parser.add_argument('--dir', nargs=1, required=True, help='tagging target directory path') + args: argparse.Namespace = parser.parse_args(arg_str) + + predictor: Predictor = Predictor() + predictor.process_directory(args.dir[0]) + +if __name__ == "__main__": + main(sys.argv) \ No newline at end of file diff --git a/web-ui-image-search-lsi.py b/webui.py similarity index 94% rename from web-ui-image-search-lsi.py rename to webui.py index 54efaf2..7b287f4 100644 --- a/web-ui-image-search-lsi.py +++ b/webui.py @@ -1,394 +1,418 @@ -from gensim.models.lsimodel import LsiModel -from gensim.similarities import MatrixSimilarity -from numpy import ndarray -from streamlit.runtime.state import SessionStateProxy -import pickle - -import numpy as np -import argparse -import streamlit as st -import time -from typing import List, Tuple, Dict, Any, Optional, Protocol - -# $ streamlit run web-ui-image-search-lsi.py - -ss: SessionStateProxy = st.session_state -search_tags: str = '' -image_files_name_tags_arr: List[str] = [] -model: Optional[LsiModel] = None -index: Optional[MatrixSimilarity] = None -dictionary: Optional[Any] = None - -SIMILARITY_THRESHOLD: float = 0.1 - -NG_WORDS: List[str] = ['language', 'english_text', 'pixcel_art'] - -class Arguments(Protocol): - rep: List[str] - -args: Optional[Arguments] = None - -# sorted_scores: sorted_scores[N] >= sorted_scores[N+1] -def filter_searched_result(sorted_scores: List[Tuple[int, float]]) -> List[Tuple[int,float]]: - # sorted_scores: Any = scores[scores.argsort()[:-1]] - # difs: ndarray = sorted_scores[:-1] - sorted_scores[1:] - scores: List[float] = [sorted_scores[i][1] for i in range(len(sorted_scores))] - scores_ndarr: ndarray = np.array(scores) - max_val = scores_ndarr.max() - scores_ndarr = scores_ndarr / max_val - idxes_ndarr = np.where(scores_ndarr > SIMILARITY_THRESHOLD) - - return [(sorted_scores[idx][0], sorted_scores[idx][1] / float(max_val)) for idx in idxes_ndarr[0]] - -# # sorted_scores: sorted_scores[N] >= sorted_scores[N+1] -# def mcut_threshold(sorted_scores: List[Tuple[int, float]]) -> float: -# """ -# Maximum Cut Thresholding (MCut) -# Largeron, C., Moulin, C., & Gery, M. (2012). MCut: A Thresholding Strategy -# for Multi-label Classification. In 11th International Symposium, IDA 2012 -# (pp. 172-183). -# """ -# # sorted_scores: Any = scores[scores.argsort()[:-1]] -# # difs: ndarray = sorted_scores[:-1] - sorted_scores[1:] -# difs: List[float] = [sorted_scores[i + 1][1] - sorted_scores[i][1] for i in range(len(sorted_scores) - 1)] -# tmp_list : List[float] = [] -# # Replace 0 with -inf (same image files exist case) -# for idx, val in enumerate(difs): -# if val == 0: -# tmp_list.append(-np.inf) -# else: -# tmp_list.append(val) -# difs_ndarr: ndarray = np.array(difs) -# -# t: signedinteger = difs_ndarr.argmax() -# thresh: float = (sorted_scores[t][1] + sorted_scores[t + 1][1]) / 2 -# -# # score should be >= thresh -# return thresh - -def normalize_and_apply_weight_lsi(query_bow: List[Tuple[int, int]], new_doc: str) -> List[Tuple[int, float]]: - tags: List[str] = new_doc.split(" ") - - # parse tag:weight format - is_exist_negative_weight: bool = False - tag_and_weight_list: List[Tuple[str, int]] = [] - # all_weight: int = 0 - for tag in tags: - tag_splited: List[str] = tag.split(":") - if len(tag_splited) == 2: - # replace is for specific type of tags - tag_elem: str = tag_splited[0].replace('\\(', '(').replace('\\)', ')') - tag_and_weight_list.append((tag_elem.replace('(', '\\(').replace(')', '\\)'), int(tag_splited[1]))) - # all_weight += int(tag_splited[1]) - else: - # replace is for specific type of tags - tag_elem: str = tag_splited[0].replace('\\(', '(').replace('\\)', ')') - tag_and_weight_list.append((tag_elem.replace('(', '\\(').replace(')', '\\)'), 1)) - # all_weight += 1 - - query_bow_local: List[Tuple[int, int]] = [] - # apply weight to query_bow - for tag, weight in tag_and_weight_list: - tag_id: int = dictionary.token2id[tag] - for ii, _ in enumerate(query_bow): - if query_bow[ii][0] == tag_id: - if weight >= 1: - query_bow_local.append((query_bow[ii][0], query_bow[ii][1]*weight)) - elif weight < 0: - # ignore this elem weight here - query_bow_local.append((query_bow[ii][0], 0)) - is_exist_negative_weight = True - - break - - query_lsi: List[Tuple[int, float]] = model[query_bow_local] - # query_lsi: List[Tuple[int, float]] = model.__getitem__(query_bow_local, scaled=True) - - # reset - query_bow_local = [] - - if is_exist_negative_weight: - for tag, weight in tag_and_weight_list: - tag_id: int = dictionary.token2id[tag] - for ii, _ in enumerate(query_bow): - if query_bow[ii][0] == tag_id: - if weight >= 1: - query_bow_local.append((query_bow[ii][0], 0)) - elif weight < 0: - # negative weighted tags value is changed to positive and multiplied by weight - query_bow_local.append((query_bow[ii][0], -1*weight)) - - break - - query_lsi_neg: List[Tuple[int, float]] = model[query_bow_local] - # query_lsi_neg: List[Tuple[int, float]] = model.__getitem__(query_bow_local, scaled=True) - - # query_lsi - query_lsi_neg - query_lsi_tmp: List[Tuple[int, float]] = [] - for ii, _ in query_lsi: - query_lsi_tmp.append((query_lsi[ii][0], query_lsi[ii][1] - query_lsi_neg[ii][1])) - query_lsi = query_lsi_tmp - - # # normalize query with tag num - # if all_weight > 0: - # query_lsi = [(tag_id, tag_value / all_weight) for tag_id, tag_value in query_lsi] - return query_lsi - -def find_similar_documents(new_doc: str, topn: int = 50) -> List[Tuple[int, float]]: - # when getting bow presentaton, weight description is removed - # because without it, weighted tag is not found in the dictionary - splited_doc = [x.split(":")[0] for x in new_doc.split(' ')] - query_bow: List[Tuple[int, int]] = dictionary.doc2bow(splited_doc) - - query_lsi = normalize_and_apply_weight_lsi(query_bow, new_doc) - #query_lsi: List[Tuple[int, float]] = model[query_bow] - - sims: List[Tuple[int, float]] = index[query_lsi] - - sims = sorted(enumerate(sims), key=lambda item: -item[1]) - # sims = [x for x in sims if x[1] > 0.01] - - # thresh = mcut_threshold(sims) - # sims = [x for x in sims if x[1] >= thresh] - - sims = filter_searched_result(sims) - - ret_len: int = topn - if ret_len > len(sims): - ret_len = len(sims) - return sims[:ret_len] - -def is_include_ng_word(tags: List[str]) -> bool: - for ng_word in NG_WORDS: - if ng_word in tags: - return True - return False - -def init_session_state(data: List[Any] = []) -> None: - global ss - if 'data' not in ss: - ss['data'] = [] - ss['last_search_tags'] = '' - if 'selected_image_info' not in ss: - ss['selected_image_info'] = None - if len(data) > 0: - ss['data'] = data - ss['page_index'] = 0 - return - - if 'page_index' not in ss: - ss['page_index'] = 0 - -def update_index(session_key: str, num: int, max_val: Optional[int] = None) -> None: - global ss - - if max_val: - # to Last - if num == max_val: - ss[session_key] = max_val - 1 - st.rerun() - # Next - if ss[session_key] < max_val - num: - ss[session_key] += num - st.rerun() - else: - # to Top - if num == 0: - ss[session_key] = 0 - st.rerun() - # Prev - if ss[session_key] >= -num: - ss[session_key] += num - st.rerun() - -def convert_data_structure(image_info_list: List[Dict[str, Any]]) -> List[List[List[Dict[str, Any]]]]: - pages: List[List[List[Dict[str, Any]]]] = [] - rows: List[List[Dict[str, Any]]] = [] - cols: List[Dict[str, Any]] = [] - - for ii in range(len(image_info_list)): - cols.append(image_info_list[ii]) - if len(cols) >= 5: - rows.append(cols) - cols = [] - if len(rows) >= 5: - pages.append(rows) - rows = [] - - if cols: - rows.append(cols) - if rows: - pages.append(rows) - - return pages - -def get_all_images() -> List[str]: - images: List[str] = [] - for page in ss['data']: - for row in page: - for image_info in row: - images.append(image_info['file_path']) - return images - -def slideshow() -> None: - images: List[str] = get_all_images() - if len(images) == 0: - st.write("No images to display in slideshow.") - ss['slideshow_active'] = False - st.rerun() - if 'slideshow_index' not in ss: - ss['slideshow_index'] = 0 - cols: Any = st.columns([1]) - - try: - cols[0].image(images[ss['slideshow_index']], use_column_width=True) - except Exception as e: - print(f'Error: {e}') - ss['slideshow_index'] = (ss['slideshow_index'] + 1) % len(images) - st.rerun() - - if st.button('Stop'): - ss['slideshow_active'] = False - ss['slideshow_index'] = 0 - ss['text_input'] = ss['last_search_tags'] - else: - time.sleep(5) - ss['slideshow_index'] = (ss['slideshow_index'] + 1) % len(images) - st.rerun() - -def is_now_slideshow() -> bool: - return 'slideshow_active' in ss and ss['slideshow_active'] - -def display_images() -> None: - global ss - - if 'data' in ss and len(ss['data']) > 0: - cols: Any = st.columns([10]) - with cols[0]: - if st.button('Slideshow'): - ss['slideshow_active'] = True - ss['slideshow_index'] = 0 - st.rerun() - - for data_per_page in ss['data'][ss['page_index']]: - cols = st.columns(5) - for col_index, col_ph in enumerate(cols): - try: - image_info: Dict[str, Any] = data_per_page[col_index] - key: str = f"img_{ss['page_index']}_{image_info['doc_id']}_{col_index}" - if col_ph.button('info', key=key): - ss['selected_image_info'] = image_info - st.rerun() - col_ph.image(image_info['file_path'], use_column_width=True) - except Exception as e: - print(f'Error: {e}') - continue - pagination() - -def pagination() -> None: - col1, col2, col3, col4, col5 = st.columns([2, 2, 8, 2, 2]) - if col1.button('Top'): - update_index('page_index', 0) - if col2.button('Prev'): - update_index('page_index', -1) - if col4.button('Next'): - update_index('page_index', 1, len(ss['data'])) - if col5.button('Last'): - update_index('page_index', len(ss['data']), len(ss['data'])) - col3.markdown( - f''' -
- {ss['page_index'] + 1} / {len(ss['data'])} -
- ''', - unsafe_allow_html=True, - ) - -def display_selected_image() -> None: - global ss - image_info: Dict[str, Any] = ss['selected_image_info'] - col1, col2 = st.columns([3, 1]) - with col1: - st.image(image_info['file_path'], use_column_width=True) - with col2: - st.write("Matching Score:") - st.write("{:.2f}%".format(image_info['similarity'] * 100)) - st.write("File Path:") - st.code(image_info['file_path']) - st.write("Tags:") - st.write(' \n'.join(image_info['tags'])) - if st.button('Close'): - ss['selected_image_info'] = None - ss['text_input'] = ss['last_search_tags'] - st.rerun() - -def show_search_result() -> None: - global image_files_name_tags_arr - global args - - load_model() - similar_docs: List[Tuple[int, float]] = find_similar_documents(search_tags, topn=2000) - - found_docs_info: List[Dict[str, Any]] = [] - for doc_id, similarity in similar_docs: - try: - found_img_info_splited: List[str] = image_files_name_tags_arr[doc_id].split(',') - if is_include_ng_word(found_img_info_splited): - continue - found_fpath: str = found_img_info_splited[0] - if args is not None and args.rep: - found_fpath = found_fpath.replace(args.rep[0], args.rep[1]) - found_docs_info.append({ - 'file_path': found_fpath, - 'doc_id': doc_id, - 'similarity': similarity, - 'tags': found_img_info_splited[1:] - }) - except Exception as e: - print(f'Error: {e}') - continue - - pages: List[List[List[Dict[str, Any]]]] = convert_data_structure(found_docs_info) - init_session_state(pages) - -def load_model() -> None: - global model - global image_files_name_tags_arr - global index - global dictionary - - tag_file_path: str = 'tags-wd-tagger_lsi_idx.csv' - image_files_name_tags_arr = [] - with open(tag_file_path, 'r', encoding='utf-8') as f: - for line in f: - image_files_name_tags_arr.append(line.strip()) - - model = LsiModel.load("lsi_model") - index = MatrixSimilarity.load("lsi_index") - dictionary = pickle.load(open("lsi_dictionary", "rb")) - -def main() -> None: - global search_tags - global args - global ss - - parser: argparse.ArgumentParser = argparse.ArgumentParser() - parser.add_argument('--rep', nargs=2, required=False, help='replace the string in file path to one you want') - args = parser.parse_args() - - init_session_state() - - if is_now_slideshow(): - slideshow() - else: - if 'selected_image_info' in ss and ss['selected_image_info']: - display_selected_image() - else: - search_tags = st.text_input('Enter search tags', value='', key='text_input') - if search_tags and ss['last_search_tags'] != search_tags: - ss['last_search_tags'] = search_tags - show_search_result() - st.rerun() - display_images() - +import sys + +from gensim.models.lsimodel import LsiModel +from gensim.similarities import MatrixSimilarity +from numpy import ndarray +from streamlit.runtime.state import SessionStateProxy +import pickle + +import numpy as np +import argparse +import streamlit as st +import time +from typing import List, Tuple, Dict, Any, Optional, Protocol + +# $ streamlit run webui.py + +ss: SessionStateProxy = st.session_state +search_tags: str = '' +image_files_name_tags_arr: List[str] = [] +model: Optional[LsiModel] = None +index: Optional[MatrixSimilarity] = None +dictionary: Optional[Any] = None + +SIMILARITY_THRESHOLD: float = 0.1 + +NG_WORDS: List[str] = ['language', 'english_text', 'pixcel_art'] + +class Arguments(Protocol): + rep: List[str] + +args: Optional[Arguments] = None + +# sorted_scores: sorted_scores[N] >= sorted_scores[N+1] +def filter_searched_result(sorted_scores: List[Tuple[int, float]]) -> List[Tuple[int,float]]: + # sorted_scores: Any = scores[scores.argsort()[:-1]] + # difs: ndarray = sorted_scores[:-1] - sorted_scores[1:] + scores: List[float] = [sorted_scores[i][1] for i in range(len(sorted_scores))] + scores_ndarr: ndarray = np.array(scores) + max_val = scores_ndarr.max() + scores_ndarr = scores_ndarr / max_val + idxes_ndarr = np.where(scores_ndarr > SIMILARITY_THRESHOLD) + + return [(sorted_scores[idx][0], sorted_scores[idx][1] / float(max_val)) for idx in idxes_ndarr[0]] + +# # sorted_scores: sorted_scores[N] >= sorted_scores[N+1] +# def mcut_threshold(sorted_scores: List[Tuple[int, float]]) -> float: +# """ +# Maximum Cut Thresholding (MCut) +# Largeron, C., Moulin, C., & Gery, M. (2012). MCut: A Thresholding Strategy +# for Multi-label Classification. In 11th International Symposium, IDA 2012 +# (pp. 172-183). +# """ +# # sorted_scores: Any = scores[scores.argsort()[:-1]] +# # difs: ndarray = sorted_scores[:-1] - sorted_scores[1:] +# difs: List[float] = [sorted_scores[i + 1][1] - sorted_scores[i][1] for i in range(len(sorted_scores) - 1)] +# tmp_list : List[float] = [] +# # Replace 0 with -inf (same image files exist case) +# for idx, val in enumerate(difs): +# if val == 0: +# tmp_list.append(-np.inf) +# else: +# tmp_list.append(val) +# difs_ndarr: ndarray = np.array(difs) +# +# t: signedinteger = difs_ndarr.argmax() +# thresh: float = (sorted_scores[t][1] + sorted_scores[t + 1][1]) / 2 +# +# # score should be >= thresh +# return thresh + +def normalize_and_apply_weight_lsi(query_bow: List[Tuple[int, int]], new_doc: str) -> List[Tuple[int, float]]: + tags: List[str] = new_doc.split(" ") + + # parse tag:weight format + is_exist_negative_weight: bool = False + tag_and_weight_list: List[Tuple[str, int]] = [] + # all_weight: int = 0 + for tag in tags: + tag_splited: List[str] = tag.split(":") + if len(tag_splited) == 2: + # replace is for specific type of tags + tag_elem: str = tag_splited[0].replace('\\(', '(').replace('\\)', ')') + tag_and_weight_list.append((tag_elem.replace('(', '\\(').replace(')', '\\)'), int(tag_splited[1]))) + # all_weight += int(tag_splited[1]) + else: + # replace is for specific type of tags + tag_elem: str = tag_splited[0].replace('\\(', '(').replace('\\)', ')') + tag_and_weight_list.append((tag_elem.replace('(', '\\(').replace(')', '\\)'), 1)) + # all_weight += 1 + + query_bow_local: List[Tuple[int, int]] = [] + # apply weight to query_bow + for tag, weight in tag_and_weight_list: + tag_id: int = dictionary.token2id[tag] + for ii, _ in enumerate(query_bow): + if query_bow[ii][0] == tag_id: + if weight >= 1: + query_bow_local.append((query_bow[ii][0], query_bow[ii][1]*weight)) + elif weight < 0: + # ignore this elem weight here + query_bow_local.append((query_bow[ii][0], 0)) + is_exist_negative_weight = True + + break + + query_lsi: List[Tuple[int, float]] = model[query_bow_local] + # query_lsi: List[Tuple[int, float]] = model.__getitem__(query_bow_local, scaled=True) + + # reset + query_bow_local = [] + + if is_exist_negative_weight: + for tag, weight in tag_and_weight_list: + tag_id: int = dictionary.token2id[tag] + for ii, _ in enumerate(query_bow): + if query_bow[ii][0] == tag_id: + if weight >= 1: + query_bow_local.append((query_bow[ii][0], 0)) + elif weight < 0: + # negative weighted tags value is changed to positive and multiplied by weight + query_bow_local.append((query_bow[ii][0], -1*weight)) + + break + + query_lsi_neg: List[Tuple[int, float]] = model[query_bow_local] + # query_lsi_neg: List[Tuple[int, float]] = model.__getitem__(query_bow_local, scaled=True) + + # query_lsi - query_lsi_neg + query_lsi_tmp: List[Tuple[int, float]] = [] + for ii, _ in query_lsi: + query_lsi_tmp.append((query_lsi[ii][0], query_lsi[ii][1] - query_lsi_neg[ii][1])) + query_lsi = query_lsi_tmp + + # # normalize query with tag num + # if all_weight > 0: + # query_lsi = [(tag_id, tag_value / all_weight) for tag_id, tag_value in query_lsi] + return query_lsi + +def find_similar_documents(new_doc: str, topn: int = 50) -> List[Tuple[int, float]]: + # when getting bow presentaton, weight description is removed + # because without it, weighted tag is not found in the dictionary + splited_doc = [x.split(":")[0] for x in new_doc.split(' ')] + query_bow: List[Tuple[int, int]] = dictionary.doc2bow(splited_doc) + + query_lsi = normalize_and_apply_weight_lsi(query_bow, new_doc) + #query_lsi: List[Tuple[int, float]] = model[query_bow] + + sims: List[Tuple[int, float]] = index[query_lsi] + + sims = sorted(enumerate(sims), key=lambda item: -item[1]) + # sims = [x for x in sims if x[1] > 0.01] + + # thresh = mcut_threshold(sims) + # sims = [x for x in sims if x[1] >= thresh] + + sims = filter_searched_result(sims) + + ret_len: int = topn + if ret_len > len(sims): + ret_len = len(sims) + return sims[:ret_len] + +def is_include_ng_word(tags: List[str]) -> bool: + for ng_word in NG_WORDS: + if ng_word in tags: + return True + return False + +def init_session_state(data: List[Any] = []) -> None: + global ss + if 'data' not in ss: + ss['data'] = [] + ss['last_search_tags'] = '' + if 'selected_image_info' not in ss: + ss['selected_image_info'] = None + if len(data) > 0: + ss['data'] = data + ss['page_index'] = 0 + return + + if 'page_index' not in ss: + ss['page_index'] = 0 + +def update_index(session_key: str, num: int, max_val: Optional[int] = None) -> None: + global ss + + if max_val: + # to Last + if num == max_val: + ss[session_key] = max_val - 1 + st.rerun() + # Next + if ss[session_key] < max_val - num: + ss[session_key] += num + st.rerun() + else: + # to Top + if num == 0: + ss[session_key] = 0 + st.rerun() + # Prev + if ss[session_key] >= -num: + ss[session_key] += num + st.rerun() + +def convert_data_structure(image_info_list: List[Dict[str, Any]]) -> List[List[List[Dict[str, Any]]]]: + pages: List[List[List[Dict[str, Any]]]] = [] + rows: List[List[Dict[str, Any]]] = [] + cols: List[Dict[str, Any]] = [] + + for ii in range(len(image_info_list)): + cols.append(image_info_list[ii]) + if len(cols) >= 5: + rows.append(cols) + cols = [] + if len(rows) >= 5: + pages.append(rows) + rows = [] + + if cols: + rows.append(cols) + if rows: + pages.append(rows) + + return pages + +def get_all_images() -> List[str]: + images: List[str] = [] + for page in ss['data']: + for row in page: + for image_info in row: + images.append(image_info['file_path']) + return images + +def slideshow() -> None: + images: List[str] = get_all_images() + if len(images) == 0: + st.write("No images to display in slideshow.") + ss['slideshow_active'] = False + st.rerun() + if 'slideshow_index' not in ss: + ss['slideshow_index'] = 0 + cols: Any = st.columns([1]) + + try: + cols[0].image(images[ss['slideshow_index']], use_column_width=True) + except Exception as e: + print(f'Error: {e}') + ss['slideshow_index'] = (ss['slideshow_index'] + 1) % len(images) + st.rerun() + + if st.button('Stop'): + ss['slideshow_active'] = False + ss['slideshow_index'] = 0 + ss['text_input'] = ss['last_search_tags'] + else: + time.sleep(5) + ss['slideshow_index'] = (ss['slideshow_index'] + 1) % len(images) + st.rerun() + +def is_now_slideshow() -> bool: + return 'slideshow_active' in ss and ss['slideshow_active'] + +def export_result_to_file() -> None: + if sys.platform == 'win32': + encoding = 'shift_jis' + else: + encoding = 'utf-8' + + # name convention: "{search_tags}" + "_" + {timestamp} + ".txt" + output_file_path: str = f"{search_tags.replace(' ', '_').replace(':', '_') }_{int(time.time())}.txt" + + with open(output_file_path, 'w', encoding=encoding) as f: + for page in ss['data']: + for row in page: + for image_info in row: + try: + f.write(f"{image_info['file_path']}\n") + except Exception as e: + print(f'Error: {e}') + continue + +def display_images() -> None: + global ss + + if 'data' in ss and len(ss['data']) > 0: + cols: Any = st.columns([10]) + with cols[0]: + if st.button('Slideshow'): + ss['slideshow_active'] = True + ss['slideshow_index'] = 0 + st.rerun() + if st.button('Export'): + export_result_to_file() + st.rerun() + + for data_per_page in ss['data'][ss['page_index']]: + cols = st.columns(5) + for col_index, col_ph in enumerate(cols): + try: + image_info: Dict[str, Any] = data_per_page[col_index] + key: str = f"img_{ss['page_index']}_{image_info['doc_id']}_{col_index}" + if col_ph.button('info', key=key): + ss['selected_image_info'] = image_info + st.rerun() + col_ph.image(image_info['file_path'], use_column_width=True) + except Exception as e: + print(f'Error: {e}') + continue + pagination() + +def pagination() -> None: + col1, col2, col3, col4, col5 = st.columns([2, 2, 8, 2, 2]) + if col1.button('Top'): + update_index('page_index', 0) + if col2.button('Prev'): + update_index('page_index', -1) + if col4.button('Next'): + update_index('page_index', 1, len(ss['data'])) + if col5.button('Last'): + update_index('page_index', len(ss['data']), len(ss['data'])) + col3.markdown( + f''' +
+ {ss['page_index'] + 1} / {len(ss['data'])} +
+ ''', + unsafe_allow_html=True, + ) + +def display_selected_image() -> None: + global ss + image_info: Dict[str, Any] = ss['selected_image_info'] + col1, col2 = st.columns([3, 1]) + with col1: + st.image(image_info['file_path'], use_column_width=True) + with col2: + st.write("Matching Score:") + st.write("{:.2f}%".format(image_info['similarity'] * 100)) + st.write("File Path:") + st.code(image_info['file_path']) + st.write("Tags:") + st.write(' \n'.join(image_info['tags'])) + if st.button('Close'): + ss['selected_image_info'] = None + ss['text_input'] = ss['last_search_tags'] + st.rerun() + +def show_search_result() -> None: + global image_files_name_tags_arr + global args + + load_model() + similar_docs: List[Tuple[int, float]] = find_similar_documents(search_tags, topn=800) + + found_docs_info: List[Dict[str, Any]] = [] + for doc_id, similarity in similar_docs: + try: + found_img_info_splited: List[str] = image_files_name_tags_arr[doc_id].split(',') + if is_include_ng_word(found_img_info_splited): + continue + found_fpath: str = found_img_info_splited[0] + if args is not None and args.rep: + found_fpath = found_fpath.replace(args.rep[0], args.rep[1]) + found_docs_info.append({ + 'file_path': found_fpath, + 'doc_id': doc_id, + 'similarity': similarity, + 'tags': found_img_info_splited[1:] + }) + except Exception as e: + print(f'Error: {e}') + continue + + pages: List[List[List[Dict[str, Any]]]] = convert_data_structure(found_docs_info) + init_session_state(pages) + +def load_model() -> None: + global model + global image_files_name_tags_arr + global index + global dictionary + + tag_file_path: str = 'tags-wd-tagger_lsi_idx.csv' + image_files_name_tags_arr = [] + with open(tag_file_path, 'r', encoding='utf-8') as f: + for line in f: + image_files_name_tags_arr.append(line.strip()) + + model = LsiModel.load("lsi_model") + index = MatrixSimilarity.load("lsi_index") + dictionary = pickle.load(open("lsi_dictionary", "rb")) + +def main() -> None: + global search_tags + global args + global ss + + parser: argparse.ArgumentParser = argparse.ArgumentParser() + parser.add_argument('--rep', nargs=2, required=False, help='replace the string in file path to one you want') + args = parser.parse_args() + + init_session_state() + + if is_now_slideshow(): + slideshow() + else: + if 'selected_image_info' in ss and ss['selected_image_info']: + display_selected_image() + else: + search_tags = st.text_input('Enter search tags', value='', key='text_input') + if search_tags and ss['last_search_tags'] != search_tags: + ss['last_search_tags'] = search_tags + show_search_result() + st.rerun() + display_images() + main() \ No newline at end of file