Skip to content

Commit

Permalink
Working hashing for gaussian model, repeat for others
Browse files Browse the repository at this point in the history
  • Loading branch information
justincdavis committed Aug 6, 2024
1 parent a8047f2 commit 79b7a4f
Show file tree
Hide file tree
Showing 16 changed files with 146 additions and 31 deletions.
3 changes: 2 additions & 1 deletion src/oakutils/blobs/_find.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ def get_model_path(
if len(candidates) == 0:
err_msg = f"No model blob paths could be formed from the attributes {model_attributes} and shaves {shaves}."
raise ValueError(err_msg)
blobpath = Path(candidates[0])
_, _, path = candidates[0]
blobpath = Path(path)
if not blobpath.exists():
err_msg = f"The model blob path {blobpath} does not exists."
raise FileNotFoundError(err_msg)
Expand Down
34 changes: 29 additions & 5 deletions tests/blobs/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,39 @@
# MIT License
from __future__ import annotations

import pickle
from pathlib import Path

from .hashs import create_file_hash_table, create_bulk_hash_table
from .hashs import (
create_file_hash_table,
create_bulk_hash_table,
HASH_TABLE_PATH,
MODEL_RUN_TABLE_PATH,
BULK_HASH_TABLE_PATH,
BULK_RUN_TABLE_PATH,
)

# handle the creation of the hash files if they do not exists
hash_table_path = Path(__file__).parent / "hash_table.pkl"
if not hash_table_path.exists():
# hash table for model to hash and model to successful run
if not HASH_TABLE_PATH.exists():
create_file_hash_table()
if not MODEL_RUN_TABLE_PATH.exists():
model_run_table = {}
with Path.open(HASH_TABLE_PATH, "rb") as file:
hash_table = pickle.load(file)
for key in hash_table:
model_run_table[key] = False
with Path.open(MODEL_RUN_TABLE_PATH, "wb") as file:
pickle.dump(model_run_table, file, protocol=pickle.HIGHEST_PROTOCOL)

bulk_hash_table_path = Path(__file__).parent / "bulk_hash_table.pkl"
if not bulk_hash_table_path.exists():
# hash table for bulk model to hash and bulk model to successful run
if not BULK_HASH_TABLE_PATH.exists():
create_bulk_hash_table()
if not BULK_RUN_TABLE_PATH.exists():
bulk_run_table = {}
with Path.open(BULK_HASH_TABLE_PATH, "rb") as file:
bulk_hash_table = pickle.load(file)
for key in bulk_hash_table:
bulk_run_table[key] = False
with Path.open(BULK_RUN_TABLE_PATH, "wb") as file:
pickle.dump(bulk_run_table, file, protocol=pickle.HIGHEST_PROTOCOL)
24 changes: 19 additions & 5 deletions tests/blobs/models/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,26 @@
# MIT License
from __future__ import annotations

import sys
from collections.abc import Callable
from functools import partial
from pathlib import Path

from stdlib_list import stdlib_list
from oakutils.nodes import get_nn_frame
from oakutils.blobs import get_model_path
from oakutils.blobs.models import bulk
from oakutils.blobs.testing import BlobEvaluater

from .load import create_model, run_model
from ...device import get_device_count
try:
from ...device import get_device_count
from .load import create_model, run_model
except ImportError:
devicefile = Path(__file__).parent.parent.parent / "device.py"
sys.path.append(str(devicefile.parent))
from device import get_device_count

from load import create_model, run_model


def create_model_ghhs(createmodelfunc: Callable) -> None:
Expand All @@ -28,14 +37,17 @@ def create_model_ghhs(createmodelfunc: Callable) -> None:
use_blur=use_blur,
grayscale_out=use_gs,
)
assert create_model(modelfunc) == 0, f"Failed for {ks}, {shave}, {use_blur}, {use_gs}"
assert (
create_model(modelfunc) == 0
), f"Failed for {ks}, {shave}, {use_blur}, {use_gs}"


def run_model_ghhs(createmodelfunc: Callable) -> None:
def run_model_ghhs(createmodelfunc: Callable, modelname: str) -> None:
for use_blur in [True, False]:
for ks in [3, 5, 7, 9, 11, 13, 15]:
for shave in [1, 2, 3, 4, 5, 6]:
for use_gs in [True, False]:
# check if the model
modelfunc = partial(
createmodelfunc,
blur_kernel_size=ks,
Expand All @@ -48,7 +60,9 @@ def run_model_ghhs(createmodelfunc: Callable) -> None:
get_nn_frame,
channels=channels,
)
assert run_model(modelfunc, decodefunc) == 0, f"Failed for {ks}, {shave}, {use_blur}, {use_gs}"
assert (
run_model(modelfunc, decodefunc) == 0
), f"Failed for {ks}, {shave}, {use_blur}, {use_gs}"


def get_models(model_type: str) -> list[tuple[Path, ...]]:
Expand Down
Binary file removed tests/blobs/models/bulk_hash_table.pkl
Binary file not shown.
Binary file added tests/blobs/models/cache/bulk_hash_table.pkl
Binary file not shown.
Binary file added tests/blobs/models/cache/bulk_run_table.pkl
Binary file not shown.
Binary file added tests/blobs/models/cache/hash_table.pkl
Binary file not shown.
Binary file added tests/blobs/models/cache/model_run_table.pkl
Binary file not shown.
Binary file removed tests/blobs/models/hash_table.pkl
Binary file not shown.
35 changes: 27 additions & 8 deletions tests/blobs/models/hashs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@
from oakutils.blobs.models.bulk import ALL_MODELS


HASH_TABLE_PATH = Path(__file__).parent / "cache" / "hash_table.pkl"
MODEL_RUN_TABLE_PATH = Path(__file__).parent / "cache" / "model_run_table.pkl"
BULK_HASH_TABLE_PATH = Path(__file__).parent / "cache" / "bulk_hash_table.pkl"
BULK_RUN_TABLE_PATH = Path(__file__).parent / "cache" / "bulk_run_table.pkl"


def hash_file(file_path: Path) -> str:
hasher = hashlib.md5()
with file_path.open("rb") as file:
Expand All @@ -23,33 +29,46 @@ def create_file_hash_table() -> None:
hash_table: dict[str, str] = {}
for blob_tuple in ALL_MODELS:
for blob_path in blob_tuple:
hash_table[blob_path] = hash_file(blob_path)
table_file = Path(__file__).parent / "hash_table.pkl"
with Path.open(table_file, "wb") as file:
hash_table[blob_path.stem] = hash_file(blob_path)
with Path.open(HASH_TABLE_PATH, "wb") as file:
pickle.dump(hash_table, file, protocol=pickle.HIGHEST_PROTOCOL)


def compare_entry(entry: Path) -> bool:
with Path.open(Path(__file__).parent / "hash_table.pkl", "rb") as file:
with Path.open(HASH_TABLE_PATH, "rb") as file:
table = pickle.load(file)
return table[entry] == hash_file(entry)


def create_bulk_hash_table() -> None:
hash_table: dict[str, str] = {}
hash_table: dict[str, int] = {}
for blob_tuple in ALL_MODELS:
# get the stem file path without the suffix
# then remove the _shavesN part at the end
key = blob_tuple[0].stem[:-8]
hashes = [hash_file(bp) for bp in blob_tuple]
hash_table[key] = hash(tuple(hashes))
table_file = Path(__file__).parent / "bulk_hash_table.pkl"
with Path.open(table_file, "wb") as file:
with Path.open(BULK_HASH_TABLE_PATH, "wb") as file:
pickle.dump(hash_table, file, protocol=pickle.HIGHEST_PROTOCOL)


def compare_bulk_entry(entry: tuple[Path, ...]) -> bool:
key = entry[0].stem[:-8]
with Path.open(Path(__file__).parent / "bulk_hash_table.pkl", "rb") as file:
with Path.open(BULK_HASH_TABLE_PATH, "rb") as file:
table = pickle.load(file)
return hash(tuple(hash_file(bp) for bp in entry)) == table[key]


def get_run_tables() -> tuple[dict[str, str], dict[str, str]]:
with Path.open(HASH_TABLE_PATH, "rb") as file:
hash_table = pickle.load(file)
with Path.open(MODEL_RUN_TABLE_PATH, "rb") as file:
model_run_table = pickle.load(file)
return hash_table, model_run_table


def write_model_tables(hash_table: dict[str, str], model_run_table: dict[str, str]) -> None:
with Path.open(HASH_TABLE_PATH, "wb") as file:
pickle.dump(hash_table, file, protocol=pickle.HIGHEST_PROTOCOL)
with Path.open(MODEL_RUN_TABLE_PATH, "wb") as file:
pickle.dump(model_run_table, file, protocol=pickle.HIGHEST_PROTOCOL)
11 changes: 9 additions & 2 deletions tests/blobs/models/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,19 @@
# MIT License
from __future__ import annotations

import sys
from collections.abc import Callable
from pathlib import Path

import depthai as dai
from oakutils.nodes import create_color_camera, create_xout

from ...device import get_device_count
try:
from ...device import get_device_count
except ImportError:
devicefile = Path(__file__).parent.parent.parent / "device.py"
sys.path.append(str(devicefile.parent))
from device import get_device_count


def create_model(modelfunc: Callable) -> int:
Expand Down Expand Up @@ -45,7 +52,7 @@ def run_model(modelfunc: Callable, decodefunc: Callable) -> int:

if get_device_count() == 0:
return 0

with dai.Device(pipeline) as device:
queue: dai.DataOutputQueue = device.getOutputQueue("model_out")

Expand Down
45 changes: 41 additions & 4 deletions tests/blobs/models/test_gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,18 @@

from functools import partial

from oakutils.blobs import get_model_path
from oakutils.nodes import get_nn_frame
from oakutils.nodes.models import create_gaussian

from .basic import check_model_equivalence
from .load import create_model, run_model
try:
from basic import check_model_equivalence
from hashs import get_run_tables, hash_file, write_model_tables
from load import create_model, run_model
except ModuleNotFoundError:
from .basic import check_model_equivalence
from .hashs import get_run_tables, hash_file, write_model_tables
from .load import create_model, run_model


def test_create() -> None:
Expand All @@ -22,13 +29,33 @@ def test_create() -> None:
shaves=shave,
grayscale_out=use_gs,
)
assert create_model(modelfunc) == 0, f"Failed for {ks}, {shave}, {use_gs}"
assert (
create_model(modelfunc) == 0
), f"Failed for {ks}, {shave}, {use_gs}"


def test_run() -> None:
hash_table, run_table = get_run_tables()
for ks in [3, 5, 7, 9, 11, 13, 15]:
for shave in [1, 2, 3, 4, 5, 6]:
for use_gs in [True, False]:
# assess if the model has already been run
modelname = "gaussian"
if use_gs:
modelname += "gray"
modelpath = get_model_path(modelname, [str(ks)], shave)
model_hash = hash_file(modelpath)
modelkey = modelpath.stem
# if the hash is the same and we have already gotten a successful run, continue
if hash_table[modelkey] == model_hash and run_table[modelkey]:
continue
# if the hash is not the same update the hash and set the run to false
existing_hash = hash_table[modelkey]
if existing_hash != model_hash:
hash_table[modelkey] = model_hash
run_table[modelkey] = False

# perform the actual run
modelfunc = partial(
create_gaussian,
kernel_size=ks,
Expand All @@ -40,8 +67,18 @@ def test_run() -> None:
get_nn_frame,
channels=channels,
)
assert run_model(modelfunc, decodefunc) == 0, f"Failed for {ks}, {shave}, {use_gs}"
retcode = run_model(modelfunc, decodefunc)
tableval = retcode == 0
run_table[modelkey] = tableval
write_model_tables(hash_table, run_table)
assert retcode == 0, f"Failed for {ks}, {shave}, {use_gs}"


def test_equivalence() -> None:
check_model_equivalence("gaussian")


if __name__ == "__main__":
test_create()
test_run()
test_equivalence()
8 changes: 6 additions & 2 deletions tests/blobs/models/test_laplacian.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ def test_create() -> None:
grayscale_out=use_gs,
use_blur=use_blur,
)
assert create_model(modelfunc) == 0, f"Failed for {ks1}, {ks2}, {shave}, {use_blur}, {use_gs}"
assert (
create_model(modelfunc) == 0
), f"Failed for {ks1}, {ks2}, {shave}, {use_blur}, {use_gs}"


def test_run() -> None:
Expand All @@ -48,7 +50,9 @@ def test_run() -> None:
get_nn_frame,
channels=channels,
)
assert run_model(modelfunc, decodefunc) == 0, f"Failed for {ks1}, {ks2}, {shave}, {use_blur}, {use_gs}"
assert (
run_model(modelfunc, decodefunc) == 0
), f"Failed for {ks1}, {ks2}, {shave}, {use_blur}, {use_gs}"


def test_equivalence() -> None:
Expand Down
8 changes: 6 additions & 2 deletions tests/blobs/models/test_laserscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ def test_create() -> None:
scans=scan,
shaves=shave,
)
assert create_model(modelfunc) == 0, f"Failed for {width}, {scan}, {shave}"
assert (
create_model(modelfunc) == 0
), f"Failed for {width}, {scan}, {shave}"


def test_run() -> None:
Expand All @@ -35,7 +37,9 @@ def test_run() -> None:
shaves=shave,
)
decodefunc = get_laserscan
assert run_model(modelfunc, decodefunc) == 0, f"Failed for {width}, {scan}, {shave}"
assert (
run_model(modelfunc, decodefunc) == 0
), f"Failed for {width}, {scan}, {shave}"


def test_equivalence() -> None:
Expand Down
4 changes: 3 additions & 1 deletion tests/blobs/models/test_pointcloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ def test_create_and_run() -> None:
for shave in [1, 2, 3, 4, 5, 6]:
pipeline = dai.Pipeline()
stereo, left, right = create_stereo_depth(pipeline)
pcl, xin_pcl, device_call = create_point_cloud(pipeline, stereo.depth, calib_data, shaves=shave)
pcl, xin_pcl, device_call = create_point_cloud(
pipeline, stereo.depth, calib_data, shaves=shave
)
xout_pcl = create_xout(pipeline, pcl.out, "pcl_out")

all_nodes = [
Expand Down
5 changes: 4 additions & 1 deletion tests/blobs/test_dirs.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,12 @@ def test_model_shave_dirs_equal() -> None:
continue
for file in module_contents1:
if file not in module_contents2:
print(f"File {file} from shave {idx1+1} not in other module shave {idx2+1}")
print(
f"File {file} from shave {idx1+1} not in other module shave {idx2+1}"
)
raise err


def test_model_shave_dirs_equivalent() -> None:
"""Tests all the shave modules have the same models"""
assert os.path.exists(models.__file__)
Expand Down

0 comments on commit 79b7a4f

Please sign in to comment.