Skip to content

Commit

Permalink
Model infrastructure complete, commence test running
Browse files Browse the repository at this point in the history
  • Loading branch information
justincdavis committed Aug 6, 2024
1 parent 79b7a4f commit 9886fe2
Show file tree
Hide file tree
Showing 11 changed files with 242 additions and 31 deletions.
13 changes: 12 additions & 1 deletion src/oakutils/blobs/testing/_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ def allclose(
data: list[np.ndarray | list[np.ndarray]] | None = None,
rdiff: float = 1e-4,
adiff: float = 1e-4,
percentage: float = 99.0,
*,
image_output: bool | None = None,
u8_input: bool | None = None,
Expand All @@ -249,6 +250,9 @@ def allclose(
The relative tolerance, by default 1e-4
adiff : float, optional
The absolute tolerance, by default 1e-4
percentage : float, optional
The percentage of data which should be close, by default 1.0
This will only be checked if the np.allclose call fails.
image_output : bool, optional
Whether the output is an image, by default None
If None, will assume image outputs with shape
Expand Down Expand Up @@ -310,7 +314,7 @@ def allclose(
else:
converted_data = data

compare_data = []
compare_data: list[tuple[tuple[int, np.ndarray], tuple[int, np.ndarray]]] = []
for idx1, idx2 in itertools.combinations(range(len(converted_data)), 2):
compare_data.append(
(
Expand All @@ -322,5 +326,12 @@ def allclose(
non_matches = []
for (idx1, d1), (idx2, d2) in compare_data:
if not np.allclose(d1, d2, rtol=rdiff, atol=adiff):
_log.debug(f"Data {idx1} and {idx2} are not close via np.allclose checking %.")
# assess if a percentage of the results are close
close_count = np.sum(np.isclose(d1, d2, rtol=rdiff, atol=adiff))
close_percentage = close_count / np.prod(d1.shape)
_log.debug(f"Data {idx1} and {idx2} are {close_percentage * 100:.3f}% close.")
if close_percentage > (percentage / 100.0):
continue
non_matches.append((idx1, idx2))
return len(non_matches) == 0, non_matches
46 changes: 41 additions & 5 deletions tests/blobs/models/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@
try:
from ...device import get_device_count
from .load import create_model, run_model
from .hashs import get_bulk_tables, write_bulk_tables, hash_file, hash_bulk_entry, get_run_tables, write_model_tables
except ImportError:
devicefile = Path(__file__).parent.parent.parent / "device.py"
sys.path.append(str(devicefile.parent))
from device import get_device_count

from load import create_model, run_model
from hashs import get_bulk_tables, write_bulk_tables, hash_file, hash_bulk_entry, get_run_tables, write_model_tables


def create_model_ghhs(createmodelfunc: Callable) -> None:
Expand All @@ -43,10 +44,29 @@ def create_model_ghhs(createmodelfunc: Callable) -> None:


def run_model_ghhs(createmodelfunc: Callable, modelname: str) -> None:
hash_table, run_table = get_run_tables()
for use_blur in [True, False]:
for ks in [3, 5, 7, 9, 11, 13, 15]:
for shave in [1, 2, 3, 4, 5, 6]:
for use_gs in [True, False]:
attributes = []
if use_blur:
modelname += "blur"
attributes.append(str(ks))
if use_gs:
modelname += "gray"
modelpath = get_model_path(modelname, attributes, shave)
model_hash = hash_file(modelpath)
modelkey = modelpath.stem
# if the hash is the same and we have already gotten a successful run, continue
if hash_table[modelkey] == model_hash and run_table[modelkey]:
continue
# if the hash is not the same update the hash and set the run to false
existing_hash = hash_table[modelkey]
if existing_hash != model_hash:
hash_table[modelkey] = model_hash
run_table[modelkey] = False

# check if the model
modelfunc = partial(
createmodelfunc,
Expand All @@ -60,9 +80,11 @@ def run_model_ghhs(createmodelfunc: Callable, modelname: str) -> None:
get_nn_frame,
channels=channels,
)
assert (
run_model(modelfunc, decodefunc) == 0
), f"Failed for {ks}, {shave}, {use_blur}, {use_gs}"
retcode = run_model(modelfunc, decodefunc)
tableval = retcode == 0
run_table[modelkey] = tableval
write_model_tables(hash_table, run_table)
assert retcode == 0, f"Failed for {ks}, {shave}, {use_blur}, {use_gs}"


def get_models(model_type: str) -> list[tuple[Path, ...]]:
Expand All @@ -86,9 +108,23 @@ def get_models(model_type: str) -> list[tuple[Path, ...]]:

def check_model_equivalence(model_type: str) -> None:
models = get_models(model_type)
hash_table, run_table = get_bulk_tables()
for model_paths in models:
if get_device_count() == 0:
return
modelkey = model_paths[0].stem[:-8]
entryhash = hash_bulk_entry(model_paths)
# if hash is the same and run_key is True, we can skip
existinghash = hash_table[modelkey]
if existinghash == entryhash and run_table[modelkey]:
continue
if existinghash != entryhash:
hash_table[modelkey] = entryhash
run_table[modelkey] = False
# check if the model has already been run
evaluator = BlobEvaluater([*model_paths])
evaluator.run()
assert evaluator.allclose()[0], f"Failed allclose check for {model_paths}"
success = evaluator.allclose()[0]
run_table[modelkey] = success
write_bulk_tables(hash_table, run_table)
assert success, f"Failed allclose check for {model_paths}"
Binary file modified tests/blobs/models/cache/bulk_hash_table.pkl
Binary file not shown.
28 changes: 24 additions & 4 deletions tests/blobs/models/hashs.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,19 @@ def compare_entry(entry: Path) -> bool:
return table[entry] == hash_file(entry)


def hash_bulk_entry(entry: tuple[Path, ...]) -> str:
hashes = [hash_file(bp) for bp in sorted(entry)]
hashstr = "".join(hashes)
return hashlib.md5(hashstr.encode()).hexdigest()


def create_bulk_hash_table() -> None:
hash_table: dict[str, int] = {}
for blob_tuple in ALL_MODELS:
# get the stem file path without the suffix
# then remove the _shavesN part at the end
key = blob_tuple[0].stem[:-8]
hashes = [hash_file(bp) for bp in blob_tuple]
hash_table[key] = hash(tuple(hashes))
hash_table[key] = hash_bulk_entry(blob_tuple)
with Path.open(BULK_HASH_TABLE_PATH, "wb") as file:
pickle.dump(hash_table, file, protocol=pickle.HIGHEST_PROTOCOL)

Expand All @@ -59,16 +64,31 @@ def compare_bulk_entry(entry: tuple[Path, ...]) -> bool:
return hash(tuple(hash_file(bp) for bp in entry)) == table[key]


def get_run_tables() -> tuple[dict[str, str], dict[str, str]]:
def get_run_tables() -> tuple[dict[str, str], dict[str, bool]]:
with Path.open(HASH_TABLE_PATH, "rb") as file:
hash_table = pickle.load(file)
with Path.open(MODEL_RUN_TABLE_PATH, "rb") as file:
model_run_table = pickle.load(file)
return hash_table, model_run_table


def write_model_tables(hash_table: dict[str, str], model_run_table: dict[str, str]) -> None:
def write_model_tables(hash_table: dict[str, str], model_run_table: dict[str, bool]) -> None:
with Path.open(HASH_TABLE_PATH, "wb") as file:
pickle.dump(hash_table, file, protocol=pickle.HIGHEST_PROTOCOL)
with Path.open(MODEL_RUN_TABLE_PATH, "wb") as file:
pickle.dump(model_run_table, file, protocol=pickle.HIGHEST_PROTOCOL)


def get_bulk_tables() -> tuple[dict[str, int], dict[str, bool]]:
with Path.open(BULK_HASH_TABLE_PATH, "rb") as file:
hash_table = pickle.load(file)
with Path.open(BULK_RUN_TABLE_PATH, "rb") as file:
run_table = pickle.load(file)
return hash_table, run_table


def write_bulk_tables(hash_table: dict[str, int], run_table: dict[str, bool]) -> None:
with Path.open(BULK_HASH_TABLE_PATH, "wb") as file:
pickle.dump(hash_table, file, protocol=pickle.HIGHEST_PROTOCOL)
with Path.open(BULK_RUN_TABLE_PATH, "wb") as file:
pickle.dump(run_table, file, protocol=pickle.HIGHEST_PROTOCOL)
13 changes: 11 additions & 2 deletions tests/blobs/models/test_gftt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,25 @@

from oakutils.nodes.models import create_gftt

from .basic import create_model_ghhs, run_model_ghhs, check_model_equivalence
try:
from basic import create_model_ghhs, run_model_ghhs, check_model_equivalence
except ModuleNotFoundError:
from .basic import create_model_ghhs, run_model_ghhs, check_model_equivalence


def test_create() -> None:
create_model_ghhs(create_gftt)


def test_run() -> None:
run_model_ghhs(create_gftt)
run_model_ghhs(create_gftt, "gftt")


def test_equivalence() -> None:
check_model_equivalence("gftt")


if __name__ == "__main__":
test_create()
test_run()
test_equivalence()
13 changes: 11 additions & 2 deletions tests/blobs/models/test_harris.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,25 @@

from oakutils.nodes.models import create_harris

from .basic import create_model_ghhs, run_model_ghhs, check_model_equivalence
try:
from basic import create_model_ghhs, run_model_ghhs, check_model_equivalence
except ModuleNotFoundError:
from .basic import create_model_ghhs, run_model_ghhs, check_model_equivalence


def test_create() -> None:
create_model_ghhs(create_harris)


def test_run() -> None:
run_model_ghhs(create_harris)
run_model_ghhs(create_harris, "harris")


def test_equivalence() -> None:
check_model_equivalence("harris")


if __name__ == "__main__":
test_create()
test_run()
test_equivalence()
13 changes: 11 additions & 2 deletions tests/blobs/models/test_hessian.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,25 @@

from oakutils.nodes.models import create_hessian

from .basic import create_model_ghhs, run_model_ghhs, check_model_equivalence
try:
from basic import create_model_ghhs, run_model_ghhs, check_model_equivalence
except ModuleNotFoundError:
from .basic import create_model_ghhs, run_model_ghhs, check_model_equivalence


def test_create() -> None:
create_model_ghhs(create_hessian)


def test_run() -> None:
run_model_ghhs(create_hessian)
run_model_ghhs(create_hessian, "hessian")


def test_equivalence() -> None:
check_model_equivalence("hessian")


if __name__ == "__main__":
test_create()
test_run()
test_equivalence()
46 changes: 41 additions & 5 deletions tests/blobs/models/test_laplacian.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,18 @@

from functools import partial

from oakutils.blobs import get_model_path
from oakutils.nodes import get_nn_frame
from oakutils.nodes.models import create_laplacian

from .basic import check_model_equivalence
from .load import create_model, run_model
try:
from basic import check_model_equivalence
from load import create_model, run_model
from hashs import get_run_tables, hash_file, write_model_tables
except ModuleNotFoundError:
from .basic import check_model_equivalence
from .load import create_model, run_model
from .hashs import get_run_tables, hash_file, write_model_tables


def test_create() -> None:
Expand All @@ -32,11 +39,32 @@ def test_create() -> None:


def test_run() -> None:
hash_table, run_table = get_run_tables()
for ks1 in [3, 5, 7, 9, 11, 13, 15]:
for ks2 in [3, 5, 7, 9, 11, 13, 15]:
for shave in [1, 2, 3, 4, 5, 6]:
for use_blur in [True, False]:
for use_gs in [True, False]:
modelname = "laplacian"
attributes = [str(ks1)]
if use_blur:
modelname += "blur"
attributes.append(str(ks2))
if use_gs:
modelname += "gray"
modelpath = get_model_path(modelname, attributes, shave)
model_hash = hash_file(modelpath)
modelkey = modelpath.stem
# if the hash is the same and we have already gotten a successful run, continue
if hash_table[modelkey] == model_hash and run_table[modelkey]:
continue
# if the hash is not the same update the hash and set the run to false
existing_hash = hash_table[modelkey]
if existing_hash != model_hash:
hash_table[modelkey] = model_hash
run_table[modelkey] = False

# perform the actual run
modelfunc = partial(
create_laplacian,
kernel_size=ks1,
Expand All @@ -50,10 +78,18 @@ def test_run() -> None:
get_nn_frame,
channels=channels,
)
assert (
run_model(modelfunc, decodefunc) == 0
), f"Failed for {ks1}, {ks2}, {shave}, {use_blur}, {use_gs}"
retcode = run_model(modelfunc, decodefunc)
tableval = retcode == 0
run_table[modelkey] = tableval
write_model_tables(hash_table, run_table)
assert retcode ==0, f"Failed for {ks1}, {ks2}, {shave}, {use_blur}, {use_gs}"


def test_equivalence() -> None:
check_model_equivalence("laplacian")


if __name__ == "__main__":
test_create()
test_run()
test_equivalence()
Loading

0 comments on commit 9886fe2

Please sign in to comment.