Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ai/live: Configure parameters from auth callback #3264

Merged
merged 3 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions server/ai_live_video.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
if err != nil {
slog.Info("error publishing trickle", "err", err)
}
params.segmentReader.SwitchReader(func(reader io.Reader) {
params.liveParams.segmentReader.SwitchReader(func(reader io.Reader) {

Check warning on line 23 in server/ai_live_video.go

View check run for this annotation

Codecov / codecov/patch

server/ai_live_video.go#L23

Added line #L23 was not covered by tests
// check for end of stream
if _, eos := reader.(*media.EOSReader); eos {
if err := publisher.Close(); err != nil {
Expand Down Expand Up @@ -67,7 +67,7 @@
ffmpeg.Transcode3(&ffmpeg.TranscodeOptionsIn{
Fname: fmt.Sprintf("pipe:%d", r.Fd()),
}, []ffmpeg.TranscodeOptions{{
Oname: params.outputRTMPURL,
Oname: params.liveParams.outputRTMPURL,

Check warning on line 70 in server/ai_live_video.go

View check run for this annotation

Codecov / codecov/patch

server/ai_live_video.go#L70

Added line #L70 was not covered by tests
AudioEncoder: ffmpeg.ComponentOptions{Name: "copy"},
VideoEncoder: ffmpeg.ComponentOptions{Name: "copy"},
Muxer: ffmpeg.ComponentOptions{Name: "flv"},
Expand All @@ -88,15 +88,21 @@

func startControlPublish(control *url.URL, params aiRequestParams) {
controlPub, err := trickle.NewTricklePublisher(control.String())
stream := params.liveParams.stream

Check warning on line 91 in server/ai_live_video.go

View check run for this annotation

Codecov / codecov/patch

server/ai_live_video.go#L91

Added line #L91 was not covered by tests
if err != nil {
slog.Info("error starting control publisher", "stream", params.stream, "err", err)
slog.Info("error starting control publisher", "stream", stream, "err", err)

Check warning on line 93 in server/ai_live_video.go

View check run for this annotation

Codecov / codecov/patch

server/ai_live_video.go#L93

Added line #L93 was not covered by tests
return
}
params.node.LiveMu.Lock()
defer params.node.LiveMu.Unlock()
params.node.LivePipelines[params.stream] = &core.LivePipeline{ControlPub: controlPub}
params.node.LivePipelines[stream] = &core.LivePipeline{ControlPub: controlPub}

Check warning on line 98 in server/ai_live_video.go

View check run for this annotation

Codecov / codecov/patch

server/ai_live_video.go#L98

Added line #L98 was not covered by tests
}

const (
mediaMTXControlPort = "9997"
mediaMTXControlUser = "admin"
)

func (ls *LivepeerServer) kickInputConnection(mediaMTXHost, sourceID, sourceType string) error {
var apiPath string
switch sourceType {
Expand Down
73 changes: 49 additions & 24 deletions server/ai_mediaserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -420,19 +420,45 @@
return
}

err = authenticateAIStream(AuthWebhookURL, AIAuthRequest{
Stream: streamName,
Type: sourceTypeStr,
QueryParams: queryParams,
})
if err != nil {
kickErr := ls.kickInputConnection(remoteHost, sourceID, sourceType)
if kickErr != nil {
clog.Errorf(ctx, "failed to kick input connection: %s", kickErr.Error())
// if auth webhook returns pipeline config these will be replaced
pipeline := qp.Get("pipeline")
rawParams := qp.Get("params")
var pipelineParams map[string]interface{}
if rawParams != "" {
if err := json.Unmarshal([]byte(rawParams), &pipelineParams); err != nil {
clog.Errorf(ctx, "Invalid pipeline params: %s", err)
http.Error(w, "Invalid model params", http.StatusBadRequest)
return
}

Check warning on line 432 in server/ai_mediaserver.go

View check run for this annotation

Codecov / codecov/patch

server/ai_mediaserver.go#L424-L432

Added lines #L424 - L432 were not covered by tests
}

if AuthWebhookURL != nil {
authResp, err := authenticateAIStream(AuthWebhookURL, AIAuthRequest{
Stream: streamName,
Type: sourceTypeStr,
QueryParams: queryParams,
})
if err != nil {
kickErr := ls.kickInputConnection(remoteHost, sourceID, sourceType)
if kickErr != nil {
clog.Errorf(ctx, "failed to kick input connection: %s", kickErr.Error())
}
clog.Errorf(ctx, "Live AI auth failed: %s", err.Error())
http.Error(w, "Forbidden", http.StatusForbidden)
return

Check warning on line 448 in server/ai_mediaserver.go

View check run for this annotation

Codecov / codecov/patch

server/ai_mediaserver.go#L435-L448

Added lines #L435 - L448 were not covered by tests
}

if authResp.RTMPOutputURL != "" {
outputURL = authResp.RTMPOutputURL
}

Check warning on line 453 in server/ai_mediaserver.go

View check run for this annotation

Codecov / codecov/patch

server/ai_mediaserver.go#L451-L453

Added lines #L451 - L453 were not covered by tests

if authResp.Pipeline != "" {
pipeline = authResp.Pipeline
}

Check warning on line 457 in server/ai_mediaserver.go

View check run for this annotation

Codecov / codecov/patch

server/ai_mediaserver.go#L455-L457

Added lines #L455 - L457 were not covered by tests

if len(authResp.paramsMap) > 0 {
pipelineParams = authResp.paramsMap

Check warning on line 460 in server/ai_mediaserver.go

View check run for this annotation

Codecov / codecov/patch

server/ai_mediaserver.go#L459-L460

Added lines #L459 - L460 were not covered by tests
}
clog.Errorf(ctx, "Live AI auth failed: %s", err.Error())
http.Error(w, "Forbidden", http.StatusForbidden)
return
}

requestID := string(core.RandomManifestID())
Expand All @@ -449,16 +475,20 @@
}()

params := aiRequestParams{
node: ls.LivepeerNode,
os: drivers.NodeStorage.NewSession(requestID),
sessManager: ls.AISessionManager,
segmentReader: ssr,
outputRTMPURL: outputURL,
stream: streamName,
node: ls.LivepeerNode,
os: drivers.NodeStorage.NewSession(requestID),
sessManager: ls.AISessionManager,

liveParams: liveRequestParams{
segmentReader: ssr,
outputRTMPURL: outputURL,
stream: streamName,
},

Check warning on line 486 in server/ai_mediaserver.go

View check run for this annotation

Codecov / codecov/patch

server/ai_mediaserver.go#L478-L486

Added lines #L478 - L486 were not covered by tests
}

req := worker.GenLiveVideoToVideoJSONRequestBody{
// TODO set model and initial parameters here if necessary (eg, prompt)
ModelId: &pipeline,
Params: &pipelineParams,

Check warning on line 491 in server/ai_mediaserver.go

View check run for this annotation

Codecov / codecov/patch

server/ai_mediaserver.go#L490-L491

Added lines #L490 - L491 were not covered by tests
}
processAIRequest(ctx, params, req)
})
Expand Down Expand Up @@ -523,8 +553,3 @@
}
}
}

const (
mediaMTXControlPort = "9997"
mediaMTXControlUser = "admin"
)
8 changes: 6 additions & 2 deletions server/ai_process.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,11 @@
os drivers.OSSession
sessManager *AISessionManager

// For live video pipelines
liveParams liveRequestParams
}

// For live video pipelines
type liveRequestParams struct {
segmentReader *media.SwitchableSegmentReader
outputRTMPURL string
stream string
Expand Down Expand Up @@ -1402,7 +1406,7 @@
case worker.GenLiveVideoToVideoJSONRequestBody:
cap = core.Capability_LiveVideoToVideo
modelID = defaultLiveVideoToVideoModelID
if v.ModelId != nil {
if v.ModelId != nil && *v.ModelId != "" {

Check warning on line 1409 in server/ai_process.go

View check run for this annotation

Codecov / codecov/patch

server/ai_process.go#L1409

Added line #L1409 was not covered by tests
modelID = *v.ModelId
}
submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (interface{}, error) {
Expand Down
35 changes: 29 additions & 6 deletions server/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,26 +109,39 @@
// TODO not sure what params we need yet
}

func authenticateAIStream(authURL *url.URL, req AIAuthRequest) error {
// Contains the configuration parameters for this AI job
type AIAuthResponse struct {
// Where to send the output video
RTMPOutputURL string `json:"rtmp_output_url""`

// Name of the pipeline to run
Pipeline string `json:"pipeline"`

// Parameters for the pipeline
PipelineParams json.RawMessage `json:"pipeline_parameters"`
paramsMap map[string]interface{} // unmarshaled params
}

func authenticateAIStream(authURL *url.URL, req AIAuthRequest) (*AIAuthResponse, error) {
if authURL == nil {
return nil
return nil, fmt.Errorf("No auth URL configured")

Check warning on line 127 in server/auth.go

View check run for this annotation

Codecov / codecov/patch

server/auth.go#L127

Added line #L127 was not covered by tests
}
started := time.Now()

jsonValue, err := json.Marshal(req)
if err != nil {
return err
return nil, err

Check warning on line 133 in server/auth.go

View check run for this annotation

Codecov / codecov/patch

server/auth.go#L133

Added line #L133 was not covered by tests
}

resp, err := http.Post(authURL.String(), "application/json", bytes.NewBuffer(jsonValue))
if err != nil {
return err
return nil, err

Check warning on line 138 in server/auth.go

View check run for this annotation

Codecov / codecov/patch

server/auth.go#L138

Added line #L138 was not covered by tests
}

rbody, err := io.ReadAll(resp.Body)
resp.Body.Close()
if resp.StatusCode != 200 {
return fmt.Errorf("status=%d error=%s", resp.StatusCode, string(rbody))
return nil, fmt.Errorf("status=%d error=%s", resp.StatusCode, string(rbody))

Check warning on line 144 in server/auth.go

View check run for this annotation

Codecov / codecov/patch

server/auth.go#L144

Added line #L144 was not covered by tests
}

took := time.Since(started)
Expand All @@ -137,5 +150,15 @@
monitor.AuthWebhookFinished(took)
}

return nil
var authResp AIAuthResponse
if err := json.Unmarshal(rbody, &authResp); err != nil {
return nil, err
}

Check warning on line 156 in server/auth.go

View check run for this annotation

Codecov / codecov/patch

server/auth.go#L155-L156

Added lines #L155 - L156 were not covered by tests
if len(authResp.PipelineParams) > 0 {
if err := json.Unmarshal([]byte(authResp.PipelineParams), &authResp.paramsMap); err != nil {
return nil, err
}

Check warning on line 160 in server/auth.go

View check run for this annotation

Codecov / codecov/patch

server/auth.go#L158-L160

Added lines #L158 - L160 were not covered by tests
}

return &authResp, nil
}
3 changes: 2 additions & 1 deletion server/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,11 @@ func TestAILiveAuthSucceeds(t *testing.T) {
s, serverURL := stubAuthServer(t, http.StatusOK, `{}`)
defer s.Close()

err := authenticateAIStream(serverURL, AIAuthRequest{
resp, err := authenticateAIStream(serverURL, AIAuthRequest{
Stream: "stream",
})
require.NoError(t, err)
require.Equal(t, AIAuthResponse{}, *resp)
}

func TestNoErrorWhenTranscodeAuthHeaderNotPassed(t *testing.T) {
Expand Down
Loading