Skip to content

Commit

Permalink
tweak + test destroy (#65)
Browse files Browse the repository at this point in the history
* tweak + test destroy

* bump versions
  • Loading branch information
ricky0123 authored Dec 2, 2023
1 parent 7e9ee5f commit 60b0b6c
Show file tree
Hide file tree
Showing 13 changed files with 230 additions and 107 deletions.
7 changes: 4 additions & 3 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion packages/_common/src/messages.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ export enum Message {
SpeechStart = "SPEECH_START",
VADMisfire = "VAD_MISFIRE",
SpeechEnd = "SPEECH_END",
SpeechStop = "SPEECH_STOP"
SpeechStop = "SPEECH_STOP",
}
2 changes: 1 addition & 1 deletion packages/react/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
},
"dependencies": {
"onnxruntime-web": "^1.14.0",
"@ricky0123/vad-web": "^0.0.13"
"@ricky0123/vad-web": "^0.0.14"
},
"peerDependencies": {
"react": "^18",
Expand Down
4 changes: 2 additions & 2 deletions packages/react/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ export function useMicVAD(options: Partial<ReactRealTimeVADOptions>) {
vadOptions.onVADMisfire = _onVADMisfire

useEffect(() => {
let myvad: MicVAD | null
const setup = async (): Promise<void> => {
let myvad: MicVAD | null
try {
myvad = await MicVAD.new(vadOptions)
} catch (e) {
Expand All @@ -105,8 +105,8 @@ export function useMicVAD(options: Partial<ReactRealTimeVADOptions>) {
console.log("Well that didn't work")
})
return function cleanUp() {
myvad?.destroy()
if (!loading && !errored) {
vad?.destroy()
setListening(false)
}
}
Expand Down
2 changes: 1 addition & 1 deletion packages/web/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
"offline-speech-recognition"
],
"homepage": "https://github.com/ricky0123/vad",
"version": "0.0.13",
"version": "0.0.14",
"license": "ISC",
"main": "dist/index.js",
"unpkg": "dist/bundle.min.js",
Expand Down
10 changes: 5 additions & 5 deletions packages/web/src/asset-path.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
// nextjs@14 bundler may attempt to execute this during SSR and crash
const isWeb = typeof window !== 'undefined' && typeof window.document !== 'undefined';
const currentScript =
isWeb
? window.document.currentScript as HTMLScriptElement
: null
const isWeb =
typeof window !== "undefined" && typeof window.document !== "undefined"
const currentScript = isWeb
? (window.document.currentScript as HTMLScriptElement)
: null

let basePath = ""
if (currentScript) {
Expand Down
5 changes: 2 additions & 3 deletions packages/web/src/default-model-fetcher.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
export const defaultModelFetcher = (path: string) => {
return fetch(path)
.then(model=>model.arrayBuffer())
};
return fetch(path).then((model) => model.arrayBuffer())
}
20 changes: 9 additions & 11 deletions packages/web/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,25 @@ import { audioFileToArray } from "./utils"
import { defaultModelFetcher } from "./default-model-fetcher"
import { assetPath } from "./asset-path"


export interface NonRealTimeVADOptionsWeb extends NonRealTimeVADOptions {
modelURL: string,
modelFetcher: (path: string) => Promise<ArrayBuffer>,
}
modelURL: string
modelFetcher: (path: string) => Promise<ArrayBuffer>
}

export const defaultNonRealTimeVADOptions = {
modelURL: assetPath("silero_vad.onnx"),
modelFetcher: defaultModelFetcher
modelFetcher: defaultModelFetcher,
}

class NonRealTimeVAD extends PlatformAgnosticNonRealTimeVAD {
static async new(
options: Partial<NonRealTimeVADOptionsWeb> = {}
): Promise<NonRealTimeVAD> {
const {modelURL, modelFetcher} = {...defaultNonRealTimeVADOptions, ...options};
return await this._new(
() => modelFetcher(modelURL),
ort,
options
)
const { modelURL, modelFetcher } = {
...defaultNonRealTimeVADOptions,
...options,
}
return await this._new(() => modelFetcher(modelURL), ort, options)
}
}

Expand Down
174 changes: 96 additions & 78 deletions packages/web/src/real-time-vad.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ type AudioConstraints = Omit<

type AssetOptions = {
workletURL: string
modelURL: string,
modelURL: string
modelFetcher: (path: string) => Promise<ArrayBuffer>
}

Expand All @@ -66,7 +66,6 @@ export type RealTimeVADOptions =
| RealTimeVADOptionsWithStream
| RealTimeVADOptionsWithoutStream


export const defaultRealTimeVADOptions: RealTimeVADOptions = {
...defaultFrameProcessorOptions,
onFrameProcessed: (probabilities) => {},
Expand All @@ -79,52 +78,59 @@ export const defaultRealTimeVADOptions: RealTimeVADOptions = {
onSpeechEnd: () => {
log.debug("Detected speech end")
},
workletURL: assetPath("vad.worklet.bundle.min.js"),
workletURL: assetPath("vad.worklet.bundle.min.js"),
modelURL: assetPath("silero_vad.onnx"),
modelFetcher: defaultModelFetcher,
stream: undefined,
}

export class MicVAD {
audioContext: AudioContext | null = null
// @ts-ignore
stream: MediaStream
// @ts-ignore
audioNodeVAD: AudioNodeVAD
listening = false

static async new(options: Partial<RealTimeVADOptions> = {}) {
const vad = new MicVAD({ ...defaultRealTimeVADOptions, ...options })
await vad.init()
return vad
}

constructor(public options: RealTimeVADOptions) {
validateOptions(options)
}
const fullOptions: RealTimeVADOptions = {
...defaultRealTimeVADOptions,
...options,
}
validateOptions(fullOptions)

init = async () => {
if (this.options.stream === undefined)
this.stream = await navigator.mediaDevices.getUserMedia({
let stream: MediaStream
if (fullOptions.stream === undefined)
stream = await navigator.mediaDevices.getUserMedia({
audio: {
...this.options.additionalAudioConstraints,
...fullOptions.additionalAudioConstraints,
channelCount: 1,
echoCancellation: true,
autoGainControl: true,
noiseSuppression: true,
},
})
else this.stream = this.options.stream
else stream = fullOptions.stream

this.audioContext = new AudioContext()
const source = new MediaStreamAudioSourceNode(this.audioContext, {
mediaStream: this.stream,
const audioContext = new AudioContext()
const sourceNode = new MediaStreamAudioSourceNode(audioContext, {
mediaStream: stream,
})

this.audioNodeVAD = await AudioNodeVAD.new(this.audioContext, this.options)
this.audioNodeVAD.receive(source)
const audioNodeVAD = await AudioNodeVAD.new(audioContext, fullOptions)
audioNodeVAD.receive(sourceNode)

return new MicVAD(
fullOptions,
audioContext,
stream,
audioNodeVAD,
sourceNode
)
}

private constructor(
public options: RealTimeVADOptions,
private audioContext: AudioContext,
private stream: MediaStream,
private audioNodeVAD: AudioNodeVAD,
private sourceNode: MediaStreamAudioSourceNode,
private listening = false
) {}

pause = () => {
this.audioNodeVAD.pause()
this.listening = false
Expand All @@ -139,37 +145,77 @@ export class MicVAD {
if (this.listening) {
this.pause()
}
this.stream.getTracks().forEach((t) => t.stop())
this.audioContext?.close()
this.audioContext = null
this.audioNodeVAD.entryNode.port.postMessage({
message: Message.SpeechStop,
})
this.sourceNode.disconnect()
this.audioNodeVAD.destroy()
this.audioContext.close()
}
}

export class AudioNodeVAD {
// @ts-ignore
frameProcessor: FrameProcessor
// @ts-ignore
entryNode: AudioWorkletNode

static async new(
ctx: AudioContext,
options: Partial<RealTimeVADOptions> = {}
) {
const vad = new AudioNodeVAD(ctx, {
const fullOptions: RealTimeVADOptions = {
...defaultRealTimeVADOptions,
...options,
}
validateOptions(fullOptions)

await ctx.audioWorklet.addModule(fullOptions.workletURL)
const vadNode = new AudioWorkletNode(ctx, "vad-helper-worklet", {
processorOptions: {
frameSamples: fullOptions.frameSamples,
},
})
await vad.init()
return vad
}

constructor(public ctx: AudioContext, public options: RealTimeVADOptions) {
validateOptions(options)
const model = await Silero.new(ort, () =>
fullOptions.modelFetcher(fullOptions.modelURL)
)

const frameProcessor = new FrameProcessor(
model.process,
model.reset_state,
{
frameSamples: fullOptions.frameSamples,
positiveSpeechThreshold: fullOptions.positiveSpeechThreshold,
negativeSpeechThreshold: fullOptions.negativeSpeechThreshold,
redemptionFrames: fullOptions.redemptionFrames,
preSpeechPadFrames: fullOptions.preSpeechPadFrames,
minSpeechFrames: fullOptions.minSpeechFrames,
}
)

const audioNodeVAD = new AudioNodeVAD(
ctx,
fullOptions,
frameProcessor,
vadNode
)

vadNode.port.onmessage = async (ev: MessageEvent) => {
switch (ev.data?.message) {
case Message.AudioFrame:
const buffer: ArrayBuffer = ev.data.data
const frame = new Float32Array(buffer)
await audioNodeVAD.processFrame(frame)
break

default:
break
}
}

return audioNodeVAD
}

constructor(
public ctx: AudioContext,
public options: RealTimeVADOptions,
private frameProcessor: FrameProcessor,
private entryNode: AudioWorkletNode
) {}

pause = () => {
this.frameProcessor.pause()
}
Expand Down Expand Up @@ -197,46 +243,18 @@ export class AudioNodeVAD {
break

case Message.SpeechEnd:
// @ts-ignore
this.options.onSpeechEnd(audio)
this.options.onSpeechEnd(audio as Float32Array)
break

default:
break
}
}

init = async () => {
await this.ctx.audioWorklet.addModule(this.options.workletURL)
const vadNode = new AudioWorkletNode(this.ctx, "vad-helper-worklet", {
processorOptions: {
frameSamples: this.options.frameSamples,
},
})
this.entryNode = vadNode

const model = await Silero.new(ort, () => this.options.modelFetcher(this.options.modelURL))

this.frameProcessor = new FrameProcessor(model.process, model.reset_state, {
frameSamples: this.options.frameSamples,
positiveSpeechThreshold: this.options.positiveSpeechThreshold,
negativeSpeechThreshold: this.options.negativeSpeechThreshold,
redemptionFrames: this.options.redemptionFrames,
preSpeechPadFrames: this.options.preSpeechPadFrames,
minSpeechFrames: this.options.minSpeechFrames,
destroy = () => {
this.entryNode.port.postMessage({
message: Message.SpeechStop,
})

vadNode.port.onmessage = async (ev: MessageEvent) => {
switch (ev.data?.message) {
case Message.AudioFrame:
const buffer: ArrayBuffer = ev.data.data
const frame = new Float32Array(buffer)
await this.processFrame(frame)
break

default:
break
}
}
this.entryNode.disconnect()
}
}
Loading

0 comments on commit 60b0b6c

Please sign in to comment.