Skip to content

Commit

Permalink
Chore: Docstrings partially done for analyse_vault.
Browse files Browse the repository at this point in the history
  • Loading branch information
lbeyers committed Aug 20, 2024
1 parent 329323d commit 94f3d4d
Showing 1 changed file with 56 additions and 1 deletion.
57 changes: 56 additions & 1 deletion og_marl/vault_utils/analyse_vault.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,18 @@ def get_saco(experience: Dict[str, Array]) -> Tuple[float, Array, Array]:
def plot_count_frequencies(
all_count_vals: Dict[str, Array], all_count_freq: Dict[str, Array], save_path: str = ""
) -> None:
"""Plots the frequencies of counts of state-action pairs.
Args:
Dict[str, Array]: for each uid (key), the counts of state-action pairs
Dict[str, Array]: for each uid (key), the number of times a state-action pair appears a specific number of times
string: path to save the plot to. If empty, the figure is unsaved.
Artefacts:
plt shows a log-log plot of state-action pair count frequencies per dataset
if save_plot is True, plt saves the figure as a pdf at location save_path
"""
vault_uids = list(all_count_vals.keys())
colors = sns.color_palette()

Expand Down Expand Up @@ -304,6 +316,22 @@ def describe_coverage(
plot_count_freq: bool = True,
save_plot: bool = False,
) -> None:
"""Provides coverage, structural and episode return descriptors of a Vault of datasets.
Args:
string: the name of the Vault, not containing the .vlt suffix
List[str]: a list of uids of datasets in the Vault, use if we only describe a subset of all datasets in the Vault
string: relative directory of the Vault
bool: True when the user wants to generate a plot of state-action count frequencies
bool: True when the user wants to save the generated plot
Artefacts:
A table is printed containing for each dataset in the list of uids:
- Joint SACo
if plot_count_freq is True, plt shows a log-log plot of state-action pair count frequencies per dataset
if save_plot is True, plt saves the figure as a pdf under the vault_name directory
"""
# get all uids if not specified
if len(vault_uids) == 0:
vault_uids = get_available_uids(f"./{rel_dir}/{vault_name}")
Expand Down Expand Up @@ -342,8 +370,35 @@ def descriptive_summary(
rel_dir: str = "vaults",
plot_hist: bool = True,
save_hist: bool = False,
n_bins=40,
n_bins: int = 40,
) -> Dict[str, Array]:
"""Provides coverage, structural and episode return descriptors of a Vault of datasets.
Args:
string: the name of the Vault, not containing the .vlt suffix
List[str]: a list of uids of datasets in the Vault, use if we only describe a subset of all datasets in the Vault
string: relative directory of the Vault
bool: True when the user wants to generate a histogram
bool: True when the user wants to save a generated histogram
integer: number of bins to use when generating a histogram
Returns:
Dict[str, Array]: for each uid (key), an Array of all episode returns in that dataset
Artefacts:
A table is printed containing for each dataset in the list of uids:
- Episode return:
-- mean
-- standard deviation
-- min
-- max
- Num of trajectories
- Num of transitions
- Joint SACo
if plot_hist is True, plt shows a histogram per dataset in the list of uids
if save_hist is True, plt saves the figure as a pdf under the vault_name directory
"""
if len(vault_uids) == 0:
vault_uids = get_available_uids(f"./{rel_dir}/{vault_name}")

Expand Down

0 comments on commit 94f3d4d

Please sign in to comment.