Skip to content

Commit

Permalink
Repo struct change and fix undefined names
Browse files Browse the repository at this point in the history
  • Loading branch information
kmazrolina committed Jun 16, 2024
1 parent bc82135 commit f97c677
Show file tree
Hide file tree
Showing 99 changed files with 85 additions and 85 deletions.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Empty file.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -780,63 +780,63 @@ def get_audio_embedding(self, data):

return audio_embeds

def audio_infer(self, audio, hopsize=None, device=None):
"""Forward one audio and produce the audio embedding
Parameters
----------
audio: (audio_length)
the time-domain audio input, notice that it must be only one input
hopsize: int
the overlap hopsize as the sliding window
Returns
----------
output_dict: {
key: [n, (embedding_shape)] if "HTS-AT"
or
key: [(embedding_shape)] if "PANN"
}
the list of key values of the audio branch
"""

assert not self.training, "the inference mode must be run at eval stage"
output_dict = {}
# PANN
if self.audio_cfg.model_type == "PANN":
audio_input = audio.unsqueeze(dim=0)
output_dict[key] = self.encode_audio(audio_input, device=device)[
key
].squeeze(dim=0)
elif self.audio_cfg.model_type == "HTSAT":
# repeat
audio_len = len(audio)
k = self.audio_cfg.clip_samples // audio_len
if k > 1:
audio = audio.repeat(k)
audio_len = len(audio)

if hopsize is None:
hopsize = min(hopsize, audio_len)

if audio_len > self.audio_cfg.clip_samples:
audio_input = [
audio[pos : pos + self.audio_cfg.clip_samples].clone()
for pos in range(
0, audio_len - self.audio_cfg.clip_samples, hopsize
)
]
audio_input.append(audio[-self.audio_cfg.clip_samples :].clone())
audio_input = torch.stack(audio_input)
output_dict[key] = self.encode_audio(audio_input, device=device)[key]
else:
audio_input = audio.unsqueeze(dim=0)
output_dict[key] = self.encode_audio(audio_input, device=device)[
key
].squeeze(dim=0)

return output_dict
# def audio_infer(self, audio, hopsize=None, device=None):
# """Forward one audio and produce the audio embedding

# Parameters
# ----------
# audio: (audio_length)
# the time-domain audio input, notice that it must be only one input
# hopsize: int
# the overlap hopsize as the sliding window

# Returns
# ----------
# output_dict: {
# key: [n, (embedding_shape)] if "HTS-AT"
# or
# key: [(embedding_shape)] if "PANN"
# }
# the list of key values of the audio branch

# """

# assert not self.training, "the inference mode must be run at eval stage"
# output_dict = {}
# # PANN
# if self.audio_cfg.model_type == "PANN":
# audio_input = audio.unsqueeze(dim=0)
# output_dict[key] = self.encode_audio(audio_input, device=device)[
# key
# ].squeeze(dim=0)
# elif self.audio_cfg.model_type == "HTSAT":
# # repeat
# audio_len = len(audio)
# k = self.audio_cfg.clip_samples // audio_len
# if k > 1:
# audio = audio.repeat(k)
# audio_len = len(audio)

# if hopsize is None:
# hopsize = min(hopsize, audio_len)

# if audio_len > self.audio_cfg.clip_samples:
# audio_input = [
# audio[pos : pos + self.audio_cfg.clip_samples].clone()
# for pos in range(
# 0, audio_len - self.audio_cfg.clip_samples, hopsize
# )
# ]
# audio_input.append(audio[-self.audio_cfg.clip_samples :].clone())
# audio_input = torch.stack(audio_input)
# output_dict[key] = self.encode_audio(audio_input, device=device)[key]
# else:
# audio_input = audio.unsqueeze(dim=0)
# output_dict[key] = self.encode_audio(audio_input, device=device)[
# key
# ].squeeze(dim=0)

# return output_dict


def convert_weights_to_fp16(model: nn.Module):
Expand Down
File renamed without changes.
Empty file.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
56 changes: 28 additions & 28 deletions src/AudioSep/utils.py → src/v_audio_cc/AudioSep/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,43 +86,43 @@ def get_audioset632_id_to_lb(ontology_path: str) -> Dict:
return audioset632_id_to_lb


def load_pretrained_panns(
model_type: str,
checkpoint_path: str,
freeze: bool
) -> nn.Module:
r"""Load pretrained pretrained audio neural networks (PANNs).
# def load_pretrained_panns(
# model_type: str,
# checkpoint_path: str,
# freeze: bool
# ) -> nn.Module:
# r"""Load pretrained pretrained audio neural networks (PANNs).

Args:
model_type: str, e.g., "Cnn14"
checkpoint_path, str, e.g., "Cnn14_mAP=0.431.pth"
freeze: bool
# Args:
# model_type: str, e.g., "Cnn14"
# checkpoint_path, str, e.g., "Cnn14_mAP=0.431.pth"
# freeze: bool

Returns:
model: nn.Module
"""
# Returns:
# model: nn.Module
# """

if model_type == "Cnn14":
Model = Cnn14
# if model_type == "Cnn14":
# Model = Cnn14

elif model_type == "Cnn14_DecisionLevelMax":
Model = Cnn14_DecisionLevelMax
# elif model_type == "Cnn14_DecisionLevelMax":
# Model = Cnn14_DecisionLevelMax

else:
raise NotImplementedError
# else:
# raise NotImplementedError

model = Model(sample_rate=32000, window_size=1024, hop_size=320,
mel_bins=64, fmin=50, fmax=14000, classes_num=527)
# model = Model(sample_rate=32000, window_size=1024, hop_size=320,
# mel_bins=64, fmin=50, fmax=14000, classes_num=527)

if checkpoint_path:
checkpoint = torch.load(checkpoint_path, map_location="cpu")
model.load_state_dict(checkpoint["model"])
# if checkpoint_path:
# checkpoint = torch.load(checkpoint_path, map_location="cpu")
# model.load_state_dict(checkpoint["model"])

if freeze:
for param in model.parameters():
param.requires_grad = False
# if freeze:
# for param in model.parameters():
# param.requires_grad = False

return model
# return model


def energy(x):
Expand Down
Empty file added src/v_audio_cc/__init__.py
Empty file.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.

0 comments on commit f97c677

Please sign in to comment.