Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: anonymised telemetry to track usage patterns #131

Merged
merged 2 commits into from
Apr 17, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified assets/bria.mp3
Binary file not shown.
33 changes: 30 additions & 3 deletions fam/llm/fast_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@
get_device,
normalize_text,
)
from fam.telemetry import TelemetryEvent
from fam.telemetry.posthog import PosthogClient

posthog = PosthogClient() # see fam/telemetry/README.md for more information


class TTS:
Expand Down Expand Up @@ -68,7 +72,7 @@ def __init__(
os.makedirs(self.output_dir, exist_ok=True)
if first_stage_path:
print(f"Overriding first stage checkpoint via provided model: {first_stage_path}")
first_stage_ckpt = first_stage_path or f"{self._model_dir}/first_stage.pt"
self._first_stage_ckpt = first_stage_path or f"{self._model_dir}/first_stage.pt"

second_stage_ckpt_path = f"{self._model_dir}/second_stage.pt"
config_second_stage = InferenceConfig(
Expand All @@ -90,13 +94,16 @@ def __init__(
self.precision = {"float16": torch.float16, "bfloat16": torch.bfloat16}[self._dtype]
self.model, self.tokenizer, self.smodel, self.model_size = build_model(
precision=self.precision,
checkpoint_path=Path(first_stage_ckpt),
checkpoint_path=Path(self._first_stage_ckpt),
spk_emb_ckpt_path=Path(f"{self._model_dir}/speaker_encoder.pt"),
device=self._device,
compile=True,
compile_prefill=True,
quantisation_mode=quantisation_mode,
)
self._seed = seed
self._quantisation_mode = quantisation_mode
self._model_name = model_name

def synthesise(self, text: str, spk_ref_path: str, top_p=0.95, guidance_scale=3.0, temperature=1.0) -> str:
"""
Expand Down Expand Up @@ -156,8 +163,28 @@ def synthesise(self, text: str, spk_ref_path: str, top_p=0.95, guidance_scale=3.
time_to_synth_s = time.time() - start
audio, sr = librosa.load(str(wav_file) + ".wav")
duration_s = librosa.get_duration(y=audio, sr=sr)
real_time_factor = time_to_synth_s / duration_s
print(f"\nTotal time to synth (s): {time_to_synth_s}")
print(f"Real-time factor: {time_to_synth_s / duration_s:.2f}")
print(f"Real-time factor: {real_time_factor:.2f}")

posthog.capture(
TelemetryEvent(
name="user_ran_tts",
properties={
sidroopdaska marked this conversation as resolved.
Show resolved Hide resolved
"text": text,
"temperature": temperature,
"guidance_scale": guidance_scale,
"top_p": top_p,
"spk_ref_path": spk_ref_path,
"speech_duration_s": duration_s,
"time_to_synth_s": time_to_synth_s,
"real_time_factor": round(real_time_factor, 2),
"quantisation_mode": self._quantisation_mode,
"seed": self._seed,
"first_stage_ckpt": self._first_stage_ckpt,
},
)
)

return str(wav_file) + ".wav"

Expand Down
69 changes: 45 additions & 24 deletions fam/llm/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

import itertools
import math
from pathlib import Path
import time
from pathlib import Path
from typing import Any, Dict, Optional

import click
Expand All @@ -19,7 +19,12 @@
from fam.llm.model import GPT, GPTConfig
from fam.llm.preprocessing.audio_token_mode import get_params_for_mode
from fam.llm.preprocessing.data_pipeline import get_training_tuple
from fam.llm.utils import hash_dictionary
from fam.telemetry import TelemetryEvent
from fam.telemetry.posthog import PosthogClient

# see fam/telemetry/README.md for more information
posthog = PosthogClient()

dtype: Literal["bfloat16", "float16", "tfloat32", "float32"] = (
"bfloat16" if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else "float16"
Expand Down Expand Up @@ -50,11 +55,13 @@
ckpts_save_dir = ckpts_base_dir / out_dir
os.makedirs(ckpts_save_dir, exist_ok=True)


def get_globals_state():
""" Return entirety of configuration global state which can be used for logging. """
"""Return entirety of configuration global state which can be used for logging."""
config_keys = [k for k, v in globals().items() if not k.startswith("_") and isinstance(v, (int, float, bool, str))]
return {k: globals()[k] for k in config_keys} # will be useful for logging


model_args: dict = dict(
n_layer=n_layer,
n_head=n_head,
Expand All @@ -72,6 +79,7 @@ def get_globals_state():
swiglu_multiple_of=swiglu_multiple_of,
) # start with model_args from command line


def strip_prefix(state_dict: Dict[str, Any], unwanted_prefix: str):
# TODO: this also appears in fast_inference_utils._load_model, it should be moved to a common place.
for k, v in list(state_dict.items()):
Expand Down Expand Up @@ -146,19 +154,13 @@ def main(train: Path, val: Path, model_id: str, ckpt: Optional[Path], spk_emb_ck
allow_ops_in_compiled_graph()
model = torch.compile(model) # type: ignore

def estimate_loss(dataset, iters: int=eval_iters):
""" Estimate loss on a dataset by running on `iters` batches. """
def estimate_loss(dataset, iters: int = eval_iters):
"""Estimate loss on a dataset by running on `iters` batches."""
if dataset is None:
return torch.nan
losses = []
for _, batch in zip(tqdm(range(iters)), dataset):
X, Y, SE = get_training_tuple(
batch,
causal,
num_codebooks,
speaker_cond,
device
)
X, Y, SE = get_training_tuple(batch, causal, num_codebooks, speaker_cond, device)
with ctx:
_, loss = model(X, Y, speaker_embs=SE, speaker_emb_mask=None)
losses.append(loss.item())
Expand Down Expand Up @@ -206,9 +208,7 @@ def get_lr(it):
mode_params["ctx_window"],
device,
)
train_dataloader = itertools.cycle(
DataLoader(train_dataset, batch_size, shuffle=True)
)
train_dataloader = itertools.cycle(DataLoader(train_dataset, batch_size, shuffle=True))
train_data = iter(train_dataloader)
# we do not perform any explicit checks for dataset overlap & leave it to the user
# to handle this
Expand All @@ -219,13 +219,7 @@ def get_lr(it):
eval_train_data = DataLoader(train_dataset, batch_size, shuffle=True)

batch = next(train_data)
X, Y, SE = get_training_tuple(
batch,
causal,
num_codebooks,
speaker_cond,
device
)
X, Y, SE = get_training_tuple(batch, causal, num_codebooks, speaker_cond, device)

t0 = time.time()
local_iter_num = 0 # number of iterations in the lifetime of this process
Expand All @@ -244,11 +238,29 @@ def get_lr(it):
for param in model.parameters():
param.requires_grad = False
for param in itertools.chain(
model.transformer.ln_f.parameters(), model.transformer.h[last_n_blocks_to_finetune*-1:].parameters()
model.transformer.ln_f.parameters(), model.transformer.h[last_n_blocks_to_finetune * -1 :].parameters()
):
param.requires_grad = True
print(f"After freezing excl. last {last_n_blocks_to_finetune} transformer blocks: {trainable_count(model)=}...")

# log start of finetuning event
properties = {
**config,
**model_args,
"train": str(train),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'd also be interesting to know:

  1. How much data (minutes ideal, row count ok) in datasets
  2. What % of the network is being finetuned - we're already somewhat tracking this via last_n_blocks_to_finetune being captured but wouldn't cover any custom logic changes - a quick sum of trainable params before finetuning starts vs their config settings would let us observe that custom network changes were made (interesting data point).
  3. Hardware being used

"val": str(val),
"model_id": model_id,
"ckpt": ckpt,
"spk_emb_ckpt": spk_emb_ckpt,
}
finetune_jobid = hash_dictionary(properties)
posthog.capture(
TelemetryEvent(
name="user_started_finetuning",
properties={"finetune_jobid": finetune_jobid, **properties},
)
)

while True:
lr = get_lr(iter_num) if decay_lr else learning_rate
for param_group in optimizer.param_groups:
Expand Down Expand Up @@ -278,7 +290,9 @@ def get_lr(it):
if losses["val"] < best_val_loss:
best_val_loss = losses["val"]
if iter_num > 0:
ckpt_save_name = ckpt_save_name.replace(".pt", f"_bestval_{best_val_loss}".replace(".", "_") + ".pt")
ckpt_save_name = ckpt_save_name.replace(
".pt", f"_bestval_{best_val_loss}".replace(".", "_") + ".pt"
)
save_checkpoint = True

save_checkpoint = save_checkpoint or iter_num % save_interval == 0
Expand Down Expand Up @@ -352,7 +366,14 @@ def get_lr(it):

# termination conditions
if iter_num > max_iters:
break
# log end of finetuning event
posthog.capture(
TelemetryEvent(
name="user_completed_finetuning",
properties={"finetune_jobid": finetune_jobid},
sidroopdaska marked this conversation as resolved.
Show resolved Hide resolved
)
)


if __name__ == "__main__":
main()
4 changes: 2 additions & 2 deletions fam/llm/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,8 +402,8 @@ def get_cached_file(file_or_uri: str):
# hash the file path to get the cache name
_cache_name = "audio_" + hashlib.md5(file_or_uri.encode("utf-8")).hexdigest() + ext

os.makedirs(os.path.expanduser("~/.cache/fam/"), exist_ok=True)
cache_path = os.path.expanduser(f"~/.cache/fam/{_cache_name}")
os.makedirs(os.path.expanduser("~/.cache/metavoice/"), exist_ok=True)
cache_path = os.path.expanduser(f"~/.cache/metavoice/{_cache_name}")

if not os.path.exists(cache_path):
command = f"curl -o {cache_path} {file_or_uri}"
Expand Down
14 changes: 14 additions & 0 deletions fam/llm/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import hashlib
import json
import os
import re
import subprocess
Expand Down Expand Up @@ -87,3 +89,15 @@ def get_default_dtype() -> str:

def get_device() -> str:
return "cuda" if torch.cuda.is_available() else "cpu"


def hash_dictionary(d: dict):
# Serialize the dictionary into JSON with sorted keys to ensure consistency
serialized = json.dumps(d, sort_keys=True)
# Encode the serialized string to bytes
encoded = serialized.encode()
# Create a hash object (you can also use sha1, sha512, etc.)
hash_object = hashlib.sha256(encoded)
# Get the hexadecimal digest of the hash
hash_digest = hash_object.hexdigest()
return hash_digest
5 changes: 5 additions & 0 deletions fam/telemetry/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Telemetry

This directory holds all the telemetry for MetaVoice. We, MetaVoice, capture anonymized telemetry to understand usage patterns.

If you prefer to opt out of telemetry, set `ANONYMIZED_TELEMETRY=False` in an .env file at the root level of this repo.
43 changes: 43 additions & 0 deletions fam/telemetry/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import abc
from abc import abstractmethod
from dataclasses import dataclass
import os
import uuid
from pathlib import Path


@dataclass(frozen=True)
class TelemetryEvent:
name: str
properties: dict


class TelemetryClient(abc.ABC):
USER_ID_PATH = str(Path.home() / ".cache" / "metavoice" / "telemetry_user_id")
UNKNOWN_USER_ID = "UNKNOWN"
_curr_user_id = None

@abstractmethod
def capture(self, event: TelemetryEvent) -> None:
pass

@property
def user_id(self) -> str:
if self._curr_user_id:
return self._curr_user_id

# File access may fail due to permissions or other reasons. We don't want to
# crash so we catch all exceptions.
try:
if not os.path.exists(self.USER_ID_PATH):
os.makedirs(os.path.dirname(self.USER_ID_PATH), exist_ok=True)
with open(self.USER_ID_PATH, "w") as f:
new_user_id = str(uuid.uuid4())
f.write(new_user_id)
self._curr_user_id = new_user_id
else:
with open(self.USER_ID_PATH, "r") as f:
self._curr_user_id = f.read()
except Exception:
self._curr_user_id = self.UNKNOWN_USER_ID
return self._curr_user_id
sidroopdaska marked this conversation as resolved.
Show resolved Hide resolved
40 changes: 40 additions & 0 deletions fam/telemetry/posthog.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import logging
import os
import sys

from dotenv import load_dotenv
from posthog import Posthog

from fam.telemetry import TelemetryClient, TelemetryEvent

load_dotenv()
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, handlers=[logging.StreamHandler(sys.stdout), logging.StreamHandler(sys.stderr)])


class PosthogClient(TelemetryClient):
def __init__(self):
self._posthog = Posthog(
project_api_key="phc_tk7IUlV7Q7lEa9LNbXxyC1sMWlCqiW6DkHyhJrbWMCS", host="https://eu.posthog.com"
vatsalaggarwal marked this conversation as resolved.
Show resolved Hide resolved
)

if not os.getenv("ANONYMIZED_TELEMETRY", True) or "pytest" in sys.modules:
self._posthog.disabled = True
logger.info("Anonymized telemetry disabled. See fam/telemetry/README.md for more information.")
else:
logger.info("Anonymized telemetry enabled. See fam/telemetry/README.md for more information.")

posthog_logger = logging.getLogger("posthog")
posthog_logger.disabled = True # Silence posthog's logging

super().__init__()

def capture(self, event: TelemetryEvent) -> None:
try:
self._posthog.capture(
self.user_id,
event.name,
{**event.properties},
)
except Exception as e:
logger.error(f"Failed to send telemetry event {event.name}: {e}")
Loading
Loading