Skip to content

Commit

Permalink
Repo struct change and fix undefined names
Browse files Browse the repository at this point in the history
  • Loading branch information
kmazrolina committed Jun 16, 2024
1 parent bc82135 commit 72906f7
Show file tree
Hide file tree
Showing 100 changed files with 86 additions and 88 deletions.
4 changes: 1 addition & 3 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,4 @@ jobs:
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Test with pytest
run: |
pytest
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Empty file.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -780,63 +780,63 @@ def get_audio_embedding(self, data):

return audio_embeds

def audio_infer(self, audio, hopsize=None, device=None):
"""Forward one audio and produce the audio embedding
Parameters
----------
audio: (audio_length)
the time-domain audio input, notice that it must be only one input
hopsize: int
the overlap hopsize as the sliding window
Returns
----------
output_dict: {
key: [n, (embedding_shape)] if "HTS-AT"
or
key: [(embedding_shape)] if "PANN"
}
the list of key values of the audio branch
"""

assert not self.training, "the inference mode must be run at eval stage"
output_dict = {}
# PANN
if self.audio_cfg.model_type == "PANN":
audio_input = audio.unsqueeze(dim=0)
output_dict[key] = self.encode_audio(audio_input, device=device)[
key
].squeeze(dim=0)
elif self.audio_cfg.model_type == "HTSAT":
# repeat
audio_len = len(audio)
k = self.audio_cfg.clip_samples // audio_len
if k > 1:
audio = audio.repeat(k)
audio_len = len(audio)

if hopsize is None:
hopsize = min(hopsize, audio_len)

if audio_len > self.audio_cfg.clip_samples:
audio_input = [
audio[pos : pos + self.audio_cfg.clip_samples].clone()
for pos in range(
0, audio_len - self.audio_cfg.clip_samples, hopsize
)
]
audio_input.append(audio[-self.audio_cfg.clip_samples :].clone())
audio_input = torch.stack(audio_input)
output_dict[key] = self.encode_audio(audio_input, device=device)[key]
else:
audio_input = audio.unsqueeze(dim=0)
output_dict[key] = self.encode_audio(audio_input, device=device)[
key
].squeeze(dim=0)

return output_dict
# def audio_infer(self, audio, hopsize=None, device=None):
# """Forward one audio and produce the audio embedding

# Parameters
# ----------
# audio: (audio_length)
# the time-domain audio input, notice that it must be only one input
# hopsize: int
# the overlap hopsize as the sliding window

# Returns
# ----------
# output_dict: {
# key: [n, (embedding_shape)] if "HTS-AT"
# or
# key: [(embedding_shape)] if "PANN"
# }
# the list of key values of the audio branch

# """

# assert not self.training, "the inference mode must be run at eval stage"
# output_dict = {}
# # PANN
# if self.audio_cfg.model_type == "PANN":
# audio_input = audio.unsqueeze(dim=0)
# output_dict[key] = self.encode_audio(audio_input, device=device)[
# key
# ].squeeze(dim=0)
# elif self.audio_cfg.model_type == "HTSAT":
# # repeat
# audio_len = len(audio)
# k = self.audio_cfg.clip_samples // audio_len
# if k > 1:
# audio = audio.repeat(k)
# audio_len = len(audio)

# if hopsize is None:
# hopsize = min(hopsize, audio_len)

# if audio_len > self.audio_cfg.clip_samples:
# audio_input = [
# audio[pos : pos + self.audio_cfg.clip_samples].clone()
# for pos in range(
# 0, audio_len - self.audio_cfg.clip_samples, hopsize
# )
# ]
# audio_input.append(audio[-self.audio_cfg.clip_samples :].clone())
# audio_input = torch.stack(audio_input)
# output_dict[key] = self.encode_audio(audio_input, device=device)[key]
# else:
# audio_input = audio.unsqueeze(dim=0)
# output_dict[key] = self.encode_audio(audio_input, device=device)[
# key
# ].squeeze(dim=0)

# return output_dict


def convert_weights_to_fp16(model: nn.Module):
Expand Down
File renamed without changes.
Empty file.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
56 changes: 28 additions & 28 deletions src/AudioSep/utils.py → src/v_audio_cc/AudioSep/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,43 +86,43 @@ def get_audioset632_id_to_lb(ontology_path: str) -> Dict:
return audioset632_id_to_lb


def load_pretrained_panns(
model_type: str,
checkpoint_path: str,
freeze: bool
) -> nn.Module:
r"""Load pretrained pretrained audio neural networks (PANNs).
# def load_pretrained_panns(
# model_type: str,
# checkpoint_path: str,
# freeze: bool
# ) -> nn.Module:
# r"""Load pretrained pretrained audio neural networks (PANNs).

Args:
model_type: str, e.g., "Cnn14"
checkpoint_path, str, e.g., "Cnn14_mAP=0.431.pth"
freeze: bool
# Args:
# model_type: str, e.g., "Cnn14"
# checkpoint_path, str, e.g., "Cnn14_mAP=0.431.pth"
# freeze: bool

Returns:
model: nn.Module
"""
# Returns:
# model: nn.Module
# """

if model_type == "Cnn14":
Model = Cnn14
# if model_type == "Cnn14":
# Model = Cnn14

elif model_type == "Cnn14_DecisionLevelMax":
Model = Cnn14_DecisionLevelMax
# elif model_type == "Cnn14_DecisionLevelMax":
# Model = Cnn14_DecisionLevelMax

else:
raise NotImplementedError
# else:
# raise NotImplementedError

model = Model(sample_rate=32000, window_size=1024, hop_size=320,
mel_bins=64, fmin=50, fmax=14000, classes_num=527)
# model = Model(sample_rate=32000, window_size=1024, hop_size=320,
# mel_bins=64, fmin=50, fmax=14000, classes_num=527)

if checkpoint_path:
checkpoint = torch.load(checkpoint_path, map_location="cpu")
model.load_state_dict(checkpoint["model"])
# if checkpoint_path:
# checkpoint = torch.load(checkpoint_path, map_location="cpu")
# model.load_state_dict(checkpoint["model"])

if freeze:
for param in model.parameters():
param.requires_grad = False
# if freeze:
# for param in model.parameters():
# param.requires_grad = False

return model
# return model


def energy(x):
Expand Down
Empty file added src/v_audio_cc/__init__.py
Empty file.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.

0 comments on commit 72906f7

Please sign in to comment.