From a0b5acb71ceb548447bbd769d48cbe40a717cfa9 Mon Sep 17 00:00:00 2001 From: Michelangelo Mori Date: Tue, 17 Dec 2024 18:04:50 +0100 Subject: [PATCH] Add REGO debugger to Mindev. This change adds the possibility to start evaluate a REGO-based rule type in a debugger. The debugger allows setting breakpoints, stepping, printing source, and a few other simple utilities. The debugger is currently very, very, VERY rough around the edges and could use some love, especially in the reception of events from the debuggee, which is done inline and not asynchronously. --- cmd/dev/app/rule_type/rttst.go | 7 +- internal/engine/eval/rego/debug.go | 589 ++++++++++++++++++++++++++++ internal/engine/eval/rego/eval.go | 37 +- internal/engine/eval/rego/result.go | 17 +- internal/engine/options/options.go | 14 + internal/util/cli/styles.go | 3 +- 6 files changed, 654 insertions(+), 13 deletions(-) create mode 100644 internal/engine/eval/rego/debug.go diff --git a/cmd/dev/app/rule_type/rttst.go b/cmd/dev/app/rule_type/rttst.go index 7c263bb956..b1d7aae037 100644 --- a/cmd/dev/app/rule_type/rttst.go +++ b/cmd/dev/app/rule_type/rttst.go @@ -67,6 +67,7 @@ func CmdTest() *cobra.Command { testCmd.Flags().StringP("token", "t", "", "token to authenticate to the provider."+ "Can also be set via the TEST_AUTH_TOKEN environment variable.") testCmd.Flags().StringArrayP("data-source", "d", []string{}, "YAML file containing the data source to test the rule with") + testCmd.Flags().BoolP("debug", "", false, "Start REGO debugger (only works for REGO-based rules types)") if err := testCmd.MarkFlagRequired("rule-type"); err != nil { fmt.Fprintf(os.Stderr, "Error marking flag as required: %s\n", err) @@ -98,6 +99,7 @@ func testCmdRun(cmd *cobra.Command, _ []string) error { token := viper.GetString("test.auth.token") providerclass := cmd.Flag("provider") providerconfig := cmd.Flag("provider-config") + debug := cmd.Flag("debug").Value.String() == "true" dataSourceFileStrings, err := cmd.Flags().GetStringArray("data-source") if err != nil { @@ -197,7 +199,10 @@ func testCmdRun(cmd *cobra.Command, _ []string) error { // TODO: use cobra context here ctx := context.Background() - eng, err := rtengine.NewRuleTypeEngine(ctx, ruletype, prov, nil /*experiments*/, options.WithDataSources(dsRegistry)) + eng, err := rtengine.NewRuleTypeEngine(ctx, ruletype, prov, nil, /*experiments*/ + options.WithDataSources(dsRegistry), + options.WithRegoDebugger(debug), + ) if err != nil { return fmt.Errorf("cannot create rule type engine: %w", err) } diff --git a/internal/engine/eval/rego/debug.go b/internal/engine/eval/rego/debug.go new file mode 100644 index 0000000000..cb1f604e46 --- /dev/null +++ b/internal/engine/eval/rego/debug.go @@ -0,0 +1,589 @@ +// SPDX-FileCopyrightText: Copyright 2024 The Minder Authors +// SPDX-License-Identifier: Apache-2.0 + +// Package rego provides the rego rule evaluator +package rego + +import ( + "bufio" + "context" + "errors" + "fmt" + "math" + "os" + "strconv" + "strings" + "time" + + "github.com/open-policy-agent/opa/ast/location" + "github.com/open-policy-agent/opa/debug" + "github.com/open-policy-agent/opa/rego" + "google.golang.org/protobuf/reflect/protoreflect" + + "github.com/mindersec/minder/internal/util/cli" + "github.com/mindersec/minder/pkg/engine/v1/interfaces" +) + +type eventHandler struct { + ch chan *debug.Event +} + +func newEventHandler() *eventHandler { + return &eventHandler{ + ch: make(chan *debug.Event), + } +} + +func (eh *eventHandler) HandleEvent(event debug.Event) { + eh.ch <- &event +} + +// Actual client interface + +func (eh *eventHandler) NextBlocking() *debug.Event { + return <-eh.ch +} + +func (eh *eventHandler) WaitFor( + ctx context.Context, + eventType debug.EventType, +) *debug.Event { + for { + select { + case e := <-eh.ch: + if e.Type == eventType { + return e + } + case <-ctx.Done(): + return nil + } + } +} + +var ( + errInvalidInstr = errors.New("invalid instruction") + errInvalidBP = errors.New("invalid breakpoint") +) + +func (e *Evaluator) Debug( + ctx context.Context, + pol map[string]any, + entity protoreflect.ProtoMessage, + res *interfaces.Result, + input *Input, + funcs ...func(*rego.Rego), +) error { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + eh := newEventHandler() + debugger := debug.NewDebugger( + debug.SetEventHandler(eh.HandleEvent), + ) + launchProps := debug.LaunchEvalProperties{ + LaunchProperties: debug.LaunchProperties{ + StopOnResult: true, + StopOnEntry: false, + StopOnFail: false, + EnablePrint: true, + RuleIndexing: false, + }, + Input: input, + Query: e.reseval.getQueryString(), + } + + regoOpts := make([]debug.LaunchOption, 0, len(e.regoOpts)+len(funcs)) + for _, f := range e.regoOpts { + regoOpts = append(regoOpts, debug.RegoOption(f)) + } + for _, f := range funcs { + regoOpts = append(regoOpts, debug.RegoOption(f)) + } + + session, err := debugger.LaunchEval(ctx, launchProps, regoOpts...) + if err != nil { + return fmt.Errorf("error launching debugger: %w", err) + } + + // initial breakpoint + if _, err := session.AddBreakpoint(location.Location{File: "minder.rego", Row: 1}); err != nil { + return fmt.Errorf("error setting breakpoint: %w", err) + } + + thr := debug.ThreadID(1) + fmt.Print("(mindbg) ") + scanner := bufio.NewScanner(os.Stdin) + for scanner.Scan() { + line := scanner.Text() + + var b strings.Builder + switch { + case line == "": + + case line == "r", + line == "c": + if err := session.Resume(thr); err != nil { + return fmt.Errorf("error resuming execution: %w", err) + } + + EVENTS: + for { + evt := eh.NextBlocking() + switch evt.Type { + case debug.ExceptionEventType: + fmt.Fprintf(&b, "\nException %+v\n", evt) + printLocals(&b, session, thr) + break EVENTS + case debug.StoppedEventType: + fmt.Fprintf(&b, "\nStopped %+v\n", evt) + printLocals(&b, session, thr) + break EVENTS + case debug.StdoutEventType: + fmt.Fprintf(&b, "\nFinished %+v\n", evt) + printLocals(&b, session, thr) + break EVENTS + } + } + case line == "locals": + printLocals(&b, session, thr) + case line == "bp": + bps, err := session.Breakpoints() + if err != nil { + return fmt.Errorf("error getting breakpoints: %w", err) + } + printBreakpoints(&b, bps) + case line == "list", line == "l": + stack, err := session.StackTrace(thr) + if err != nil { + return fmt.Errorf("error getting stack trace: %w", err) + } + printStackTrace(&b, e.cfg.Def, stack) + case line == "trs": + threads, err := session.Threads() + if err != nil { + return fmt.Errorf("error getting threads: %w", err) + } + printThreads(&b, threads) + case line == "cla", + line == "clearall": + if err := session.ClearBreakpoints(); err != nil { + return fmt.Errorf("error clearing breakpoints: %w", err) + } + // "next" is a bit quirky, since it requires adding an + // internal breakpoint, running until it's reached, + // and finally removing the breakpoint. + case line == "n", + line == "next": + stack, err := session.StackTrace(thr) + if err != nil { + return fmt.Errorf("error getting stack trace: %w", err) + } + if loc := getCurrentLocation(stack); loc != nil { + loc.Row += 1 // let's hope it always exists... + loc.Col = 0 + + // add internal breakpoint + bp, err := session.AddBreakpoint(*loc) + if err != nil { + return fmt.Errorf("error setting breakpoint: %w", err) + } + + // resume execution + if err := session.Resume(thr); err != nil { + return fmt.Errorf("error resuming execution: %w", err) + } + + ctx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + evt := eh.WaitFor(ctx, debug.StoppedEventType) + stack, err := session.StackTrace(evt.Thread) + if err != nil { + return fmt.Errorf("error getting stack trace: %w", err) + } + + // clear internal breakpoint, even if + // we stopped for another reason. + session.RemoveBreakpoint(bp.ID()) + + printStackTrace(&b, e.cfg.Def, stack) + } + case line == "s", + line == "sv": + go func() { + if err := session.StepOver(thr); err != nil { + panic(err) + } + }() + ctx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + evt := eh.WaitFor(ctx, debug.StoppedEventType) + stack, err := session.StackTrace(evt.Thread) + if err != nil { + return fmt.Errorf("error getting stack trace: %w", err) + } + printStackTrace(&b, e.cfg.Def, stack) + case line == "si": + go func() { + if err := session.StepIn(thr); err != nil { + panic(err) + } + }() + ctx, cancel := context.WithTimeout(ctx, time.Second) + defer cancel() + evt := eh.WaitFor(ctx, debug.StoppedEventType) + stack, err := session.StackTrace(evt.Thread) + if err != nil { + return fmt.Errorf("error getting stack trace: %w", err) + } + printStackTrace(&b, e.cfg.Def, stack) + case line == "so": + go func() { + if err := session.StepOut(thr); err != nil { + panic(err) + } + }() + ctx, cancel := context.WithTimeout(ctx, time.Second) + defer cancel() + evt := eh.WaitFor(ctx, debug.StoppedEventType) + stack, err := session.StackTrace(evt.Thread) + if err != nil { + return fmt.Errorf("error getting stack trace: %w", err) + } + printStackTrace(&b, e.cfg.Def, stack) + case line == "q": + return fmt.Errorf("user abort") + case line == "h", + line == "help": + printHelp(&b) + case strings.HasPrefix(line, "p"): + varname, err := toVarName(line) + if err != nil { + fmt.Fprintln(&b, err) + continue + } + printVar(&b, varname, session, thr) + case strings.HasPrefix(line, "b"): + loc, err := toLocation(line) + if err != nil { + fmt.Fprintln(&b, err) + continue + } + bp, err := session.AddBreakpoint(*loc) + if err != nil { + return fmt.Errorf("error setting breakpoint: %w", err) + } + fmt.Fprintln(&b) + printBreakpoint(&b, bp) + case strings.HasPrefix(line, "cl "), + strings.HasPrefix(line, "clear "): + id, err := toBreakpointID(line) + if err != nil { + fmt.Fprintln(&b, err) + continue + } + session.RemoveBreakpoint(id) + } + + output := b.String() + if output != "" { + fmt.Printf("%s\n(mindbg) ", output) + } else { + fmt.Printf("(mindbg) ") + } + } + + return scanner.Err() +} + +func toLocation(line string) (*location.Location, error) { + num, ok := strings.CutPrefix(line, "b ") + if !ok { + return nil, fmt.Errorf(`%w: "%s"`, errInvalidInstr, line) + } + i, err := strconv.ParseUint(num, 10, 64) + if err != nil { + return nil, fmt.Errorf(`%w: invalid line %s`, errInvalidBP, num) + } + return &location.Location{File: "minder.rego", Row: int(i)}, nil +} + +func toBreakpointID(line string) (debug.BreakpointID, error) { + num1, ok1 := strings.CutPrefix(line, "cl ") + num2, ok2 := strings.CutPrefix(line, "clear ") + if !ok1 && !ok2 { + return debug.BreakpointID(-1), fmt.Errorf(`%w: "%s"`, errInvalidInstr, line) + } + + var num string + if !ok1 { + num = num2 + } + if !ok2 { + num = num1 + } + + i, err := strconv.ParseUint(num, 10, 64) + if err != nil { + return debug.BreakpointID(-1), fmt.Errorf(`%w: invalid breakpoint id %s`, errInvalidBP, num) + } + return debug.BreakpointID(i), nil +} + +func toVarName(line string) (string, error) { + varname, ok := strings.CutPrefix(line, "p ") + if !ok { + return "", fmt.Errorf(`%w: "%s"`, errInvalidInstr, line) + } + return varname, nil +} + +func printBreakpoints(b *strings.Builder, bps []debug.Breakpoint) { + fmt.Fprintln(b) + for _, bp := range bps { + printBreakpoint(b, bp) + } +} + +func printBreakpoint(b *strings.Builder, bp debug.Breakpoint) { + fmt.Fprintf(b, "Breakpoint %d set at %s:%d\n", bp.ID(), bp.Location().File, bp.Location().Row) +} + +func printThreads(b *strings.Builder, threads []debug.Thread) { + fmt.Fprintln(b) + for _, thread := range threads { + fmt.Fprintf(b, "Thread %d\n", thread.ID()) + } +} + +func getCurrentLocation(stack debug.StackTrace) *location.Location { + if len(stack) == 0 { + return nil + } + + frame := stack[0] + return frame.Location() +} + +func printStackTrace(b *strings.Builder, src string, stack debug.StackTrace) { + if len(stack) == 0 { + printSource(b, src) + return + } + + lines := strings.Split(src, "\n") + padding := int64(math.Floor(math.Log10(float64(len(lines)))) + 1) + + fmt.Fprintln(b) + frame := stack[0] + if loc := frame.Location(); loc != nil { + fmt.Fprintf(b, "Frame %d at %s:%d.%d\n", frame.ID(), loc.File, loc.Row, loc.Col) + + for idx, line := range strings.Split(src, "\n") { + fmt.Fprintf(b, "%*d: %s", padding, idx+1, line) + if idx+1 == loc.Row { + theline := strings.Split(string(loc.Text), "\n")[0] + fmt.Fprintf(b, "\n%s%s", + strings.Repeat(" ", loc.Col+int(padding)+2-1), + cli.SimpleBoldStyle.Render(strings.Repeat("^", len(theline))), + ) + } + fmt.Fprintln(b) + } + } +} + +func printSource(b *strings.Builder, source string) { + fmt.Fprintln(b) + lines := strings.Split(source, "\n") + padding := int64(math.Floor(math.Log10(float64(len(lines)))) + 1) + for idx, line := range lines { + fmt.Fprintf(b, "%*d: %s\n", padding, idx+1, line) + } +} + +func printLocals(b *strings.Builder, s debug.Session, thrID debug.ThreadID) error { + trace, err := s.StackTrace(thrID) + if err != nil { + return fmt.Errorf("error getting stacktrace: %w", err) + } + + if len(trace) == 0 { + return nil + } + + // The first trace in the list is the one related to the + // current stack frame. + scopes, err := s.Scopes(trace[0].ID()) + if err != nil { + return fmt.Errorf("error getting scopes: %w", err) + } + + for _, scope := range scopes { + vars, err := s.Variables(scope.VariablesReference()) + if err != nil { + return fmt.Errorf("error getting variables: %w", err) + } + for _, v := range vars { + fmt.Fprintf(b, "%s %s = %s\n", v.Type(), v.Name(), v.Value()) + } + } + + return nil +} + +func printResult(b *strings.Builder, s debug.Session, thrID debug.ThreadID) error { + trace, err := s.StackTrace(thrID) + if err != nil { + return fmt.Errorf("error getting stacktrace: %w", err) + } + + if len(trace) == 0 { + return nil + } + + // The first trace in the list is the one related to the + // current stack frame. + scopes, err := s.Scopes(trace[0].ID()) + if err != nil { + return fmt.Errorf("error getting scopes: %w", err) + } + + for _, scope := range scopes { + if scope.Name() == "Result Set" { + printVariablesInScope(b, "violations", s, scope.VariablesReference()) + } + } + + return nil +} + +func printVar( + b *strings.Builder, + varname string, + s debug.Session, + thrID debug.ThreadID, +) error { + trace, err := s.StackTrace(thrID) + if err != nil { + return fmt.Errorf("error getting stacktrace: %w", err) + } + + if len(trace) == 0 { + return nil + } + + // The first trace in the list is the one related to the + // current stack frame. + scopes, err := s.Scopes(trace[0].ID()) + if err != nil { + return fmt.Errorf("error getting scopes: %w", err) + } + + for _, scope := range scopes { + printVariablesInScope(b, varname, s, scope.VariablesReference()) + } + + return nil +} + +func printVariablesInScope( + b *strings.Builder, + varname string, + s debug.Session, + varRef debug.VarRef, +) error { + if varRef == 0 { + return nil + } + + vars, err := s.Variables(varRef) + if err != nil { + return fmt.Errorf("error getting variables: %w", err) + } + for _, v := range vars { + if v.Name() == varname { + var b1 strings.Builder + varToString(&b1, v, s) + fmt.Fprintf(b, "%s %s = %s\n", v.Type(), v.Name(), b1.String()) + + // We break early here despite the fact that + // multiple variables might match the given + // `varname`. This is done to honour lexical + // scope, showing just the only variable that + // is actually being used for evaluation in + // the given frame. + return nil + } + } + + return nil +} + +func varToString(b *strings.Builder, v debug.Variable, s debug.Session) error { + switch v.Type() { + case "array": + fmt.Fprint(b, "[\n ") + elems, err := s.Variables(v.VariablesReference()) + if err != nil { + return err + } + for i, elem := range elems { + varToString(b, elem, s) + if i < len(elems)-1 { + fmt.Fprintf(b, ",\n ") + } + } + fmt.Fprint(b, "\n]") + case "set": + fmt.Fprint(b, "{\n ") + elems, err := s.Variables(v.VariablesReference()) + if err != nil { + return err + } + for i, elem := range elems { + varToString(b, elem, s) + if i < len(elems)-1 { + fmt.Fprintf(b, ",\n ") + } + } + fmt.Fprint(b, "\n}") + case "object": + fmt.Fprint(b, "{\n ") + fields, err := s.Variables(v.VariablesReference()) + if err != nil { + return err + } + for i, field := range fields { + fmt.Fprintf(b, " %s: ", field.Name()) + varToString(b, field, s) + if i < len(fields)-1 { + fmt.Fprintf(b, ",\n ") + } + } + fmt.Fprint(b, "\n}") + default: + fmt.Fprint(b, v.Value()) + } + + return nil +} + +var helpMsg = ` +Available commands: + r/c ----------- continue + b ------- set breakpoint at line + bp ------------ show breakpoints + clear/cl - clear breakpoint with id + clearall/cla -- clear all breakpoints + trs ----------- print threads + s/sv ---------- step over + so ------------ step out + si ------------ step into + list/l -------- list source + locals -------- print local variables + q ------------- quit + help/h -------- print help +` + +func printHelp(b *strings.Builder) { + fmt.Fprintln(b, helpMsg) +} diff --git a/internal/engine/eval/rego/eval.go b/internal/engine/eval/rego/eval.go index fd2a597360..0832cb16cd 100644 --- a/internal/engine/eval/rego/eval.go +++ b/internal/engine/eval/rego/eval.go @@ -44,6 +44,14 @@ type Evaluator struct { regoOpts []func(*rego.Rego) reseval resultEvaluator datasources *v1datasources.DataSourceRegistry + debug bool +} + +var _ eoptions.RegoBased = (*Evaluator)(nil) + +func (e *Evaluator) SetDebugFlag(flag bool) error { + e.debug = flag + return nil } // Input is the input for the rego evaluator @@ -132,6 +140,28 @@ func (e *Evaluator) Eval( // If the evaluator has data sources defined, expose their functions regoFuncOptions = append(regoFuncOptions, buildDataSourceOptions(res, e.datasources)...) + input := &Input{ + Profile: pol, + Ingested: obj, + OutputFormat: e.cfg.ViolationFormat, + } + enrichInputWithEntityProps(input, entity) + + if e.debug { + err := e.Debug( + ctx, + pol, + entity, + res, + input, + regoFuncOptions..., + ) + if err != nil { + return nil, fmt.Errorf("error initializing debugger: %w", err) + } + return nil, nil + } + // Create the rego object r := e.newRegoFromOptions( regoFuncOptions..., @@ -142,13 +172,6 @@ func (e *Evaluator) Eval( return nil, fmt.Errorf("could not prepare Rego: %w", err) } - input := &Input{ - Profile: pol, - Ingested: obj, - OutputFormat: e.cfg.ViolationFormat, - } - - enrichInputWithEntityProps(input, entity) rs, err := pq.Eval(ctx, rego.EvalInput(input)) if err != nil { return nil, fmt.Errorf("error evaluating profile. Might be wrong input: %w", err) diff --git a/internal/engine/eval/rego/result.go b/internal/engine/eval/rego/result.go index d713aa2557..300b903b3e 100644 --- a/internal/engine/eval/rego/result.go +++ b/internal/engine/eval/rego/result.go @@ -53,6 +53,7 @@ func (c ConstraintsViolationsFormat) String() string { } type resultEvaluator interface { + getQueryString() string getQuery() func(*rego.Rego) parseResult(rego.ResultSet, protoreflect.ProtoMessage) (*interfaces.EvaluationResult, error) } @@ -60,8 +61,12 @@ type resultEvaluator interface { type denyByDefaultEvaluator struct { } -func (*denyByDefaultEvaluator) getQuery() func(r *rego.Rego) { - return rego.Query(RegoQueryPrefix) +func (*denyByDefaultEvaluator) getQueryString() string { + return RegoQueryPrefix +} + +func (d *denyByDefaultEvaluator) getQuery() func(r *rego.Rego) { + return rego.Query(d.getQueryString()) } func (*denyByDefaultEvaluator) parseResult(rs rego.ResultSet, entity protoreflect.ProtoMessage, @@ -168,8 +173,12 @@ type constraintsEvaluator struct { format ConstraintsViolationsFormat } -func (*constraintsEvaluator) getQuery() func(r *rego.Rego) { - return rego.Query(fmt.Sprintf("%s.violations[details]", RegoQueryPrefix)) +func (*constraintsEvaluator) getQueryString() string { + return fmt.Sprintf("%s.violations[details]", RegoQueryPrefix) +} + +func (c *constraintsEvaluator) getQuery() func(r *rego.Rego) { + return rego.Query(c.getQueryString()) } func (c *constraintsEvaluator) parseResult(rs rego.ResultSet, _ protoreflect.ProtoMessage) (*interfaces.EvaluationResult, error) { diff --git a/internal/engine/options/options.go b/internal/engine/options/options.go index 0da6223418..8fdc2ca767 100644 --- a/internal/engine/options/options.go +++ b/internal/engine/options/options.go @@ -35,6 +35,20 @@ func WithFlagsClient(client openfeature.IClient) Option { } } +type RegoBased interface { + SetDebugFlag(bool) error +} + +func WithRegoDebugger(flag bool) Option { + return func(e interfaces.Evaluator) error { + inner, ok := e.(RegoBased) + if !ok { + return nil + } + return inner.SetDebugFlag(flag) + } +} + // SupportsDataSources interface advertises the fact that the implementer // can register data sources with the evaluator. type SupportsDataSources interface { diff --git a/internal/util/cli/styles.go b/internal/util/cli/styles.go index 4973c54b85..453db8c72b 100644 --- a/internal/util/cli/styles.go +++ b/internal/util/cli/styles.go @@ -27,7 +27,8 @@ var ( // Common styles var ( - CursorStyle = lipgloss.NewStyle().Foreground(SecondaryColor) + CursorStyle = lipgloss.NewStyle().Foreground(SecondaryColor) + SimpleBoldStyle = lipgloss.NewStyle().Bold(true) ) // Banner styles