From 9e8603f53b83e326a2529b2c595b619e04f2b85b Mon Sep 17 00:00:00 2001 From: Andrey Smirnov Date: Wed, 22 Mar 2023 21:22:47 +0400 Subject: [PATCH] feat: implement new download URL variable `${code}` New variable value is coming from `META`, and it might be set using the interactive console (not implemented yet, but it will come soon). I had to refactor the URL expansion implementation: * simplify things where possible * provide more unit-tests for smaller units * handle expansion of all variables in parallel * allow parallel expansion on multiple variables Also I refactored download code to support proper passing of endpoint function with context. The end result: * Talos will try to download config for 3 hours before rebooting * Each attempt which includes URL expansion + download is limited to 3 minutes Signed-off-by: Andrey Smirnov --- .../platform/metal/internal/url/url.go | 118 +++++++++++ .../platform/metal/internal/url/url_test.go | 177 ++++++++++++++++ .../platform/metal/internal/url/value.go | 148 ++++++++++++++ .../platform/metal/internal/url/variable.go | 108 ++++++++++ .../metal/internal/url/variable_test.go | 160 +++++++++++++++ .../runtime/v1alpha1/platform/metal/match.go | 59 ------ .../runtime/v1alpha1/platform/metal/metal.go | 25 ++- .../runtime/v1alpha1/platform/metal/url.go | 193 ------------------ .../v1alpha1/platform/metal/url_test.go | 129 +----------- internal/pkg/meta/constants.go | 8 + pkg/download/download.go | 119 +++++++---- pkg/download/download_test.go | 151 ++++++++++++++ pkg/machinery/constants/constants.go | 6 +- 13 files changed, 974 insertions(+), 427 deletions(-) create mode 100644 internal/app/machined/pkg/runtime/v1alpha1/platform/metal/internal/url/url.go create mode 100644 internal/app/machined/pkg/runtime/v1alpha1/platform/metal/internal/url/url_test.go create mode 100644 internal/app/machined/pkg/runtime/v1alpha1/platform/metal/internal/url/value.go create mode 100644 internal/app/machined/pkg/runtime/v1alpha1/platform/metal/internal/url/variable.go create mode 100644 internal/app/machined/pkg/runtime/v1alpha1/platform/metal/internal/url/variable_test.go delete mode 100644 internal/app/machined/pkg/runtime/v1alpha1/platform/metal/match.go delete mode 100644 internal/app/machined/pkg/runtime/v1alpha1/platform/metal/url.go create mode 100644 pkg/download/download_test.go diff --git a/internal/app/machined/pkg/runtime/v1alpha1/platform/metal/internal/url/url.go b/internal/app/machined/pkg/runtime/v1alpha1/platform/metal/internal/url/url.go new file mode 100644 index 0000000000..5684f9ca71 --- /dev/null +++ b/internal/app/machined/pkg/runtime/v1alpha1/platform/metal/internal/url/url.go @@ -0,0 +1,118 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +// Package url handles expansion of the download URL for the config. +package url + +import ( + "context" + "fmt" + "log" + "net/url" + + "github.com/cosi-project/runtime/pkg/state" + "github.com/siderolabs/gen/maps" + "github.com/siderolabs/gen/slices" +) + +// Populate populates the config download URL with values replacing variables. +func Populate(ctx context.Context, downloadURL string, st state.State) (string, error) { + return PopulateVariables(ctx, downloadURL, st, AllVariables()) +} + +// PopulateVariables populates the config download URL with values replacing variables. +// +//nolint:gocyclo +func PopulateVariables(ctx context.Context, downloadURL string, st state.State, variables []*Variable) (string, error) { + u, err := url.Parse(downloadURL) + if err != nil { + return "", fmt.Errorf("failed to parse URL: %w", err) + } + + query := u.Query() + + var activeVariables []*Variable + + for _, variable := range variables { + if variable.Matches(query) { + activeVariables = append(activeVariables, variable) + } + } + + // happy path: no variables + if len(activeVariables) == 0 { + return downloadURL, nil + } + + // setup watches + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + watchCh := make(chan state.Event) + + for _, variable := range activeVariables { + if err = variable.Value.RegisterWatch(ctx, st, watchCh); err != nil { + return "", fmt.Errorf("error watching variable %q: %w", variable.Key, err) + } + } + + pendingVariables := slices.ToSet(activeVariables) + + // wait for all variables to be populated + for len(pendingVariables) > 0 { + log.Printf("waiting for URL variables: %v", slices.Map(maps.Keys(pendingVariables), func(v *Variable) string { return v.Key })) + + var ev state.Event + + select { + case <-ctx.Done(): + // context was canceled, return the URL as is + u.RawQuery = query.Encode() + + return u.String(), ctx.Err() + case ev = <-watchCh: + } + + switch ev.Type { + case state.Errored: + return "", fmt.Errorf("error watching variables: %w", ev.Error) + case state.Bootstrapped: + // ignored + case state.Created, state.Updated, state.Destroyed: + anyHandled := false + + for _, variable := range activeVariables { + handled, err := variable.Value.EventHandler(ev) + if err != nil { + return "", fmt.Errorf("error handling variable %q: %w", variable.Key, err) + } + + if handled { + delete(pendingVariables, variable) + + anyHandled = true + } + } + + if !anyHandled { + continue + } + + // perform another round of replacing + query = u.Query() + + for _, variable := range activeVariables { + if _, pending := pendingVariables[variable]; pending { + continue + } + + variable.Replace(query) + } + } + } + + u.RawQuery = query.Encode() + + return u.String(), nil +} diff --git a/internal/app/machined/pkg/runtime/v1alpha1/platform/metal/internal/url/url_test.go b/internal/app/machined/pkg/runtime/v1alpha1/platform/metal/internal/url/url_test.go new file mode 100644 index 0000000000..1bbbca29b4 --- /dev/null +++ b/internal/app/machined/pkg/runtime/v1alpha1/platform/metal/internal/url/url_test.go @@ -0,0 +1,177 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package url_test + +import ( + "context" + "net" + "testing" + "time" + + "github.com/cosi-project/runtime/pkg/safe" + "github.com/cosi-project/runtime/pkg/state" + "github.com/cosi-project/runtime/pkg/state/impl/inmem" + "github.com/cosi-project/runtime/pkg/state/impl/namespaced" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/siderolabs/talos/internal/app/machined/pkg/runtime/v1alpha1/platform/metal/internal/url" + "github.com/siderolabs/talos/internal/pkg/meta" + "github.com/siderolabs/talos/pkg/machinery/nethelpers" + "github.com/siderolabs/talos/pkg/machinery/resources/hardware" + "github.com/siderolabs/talos/pkg/machinery/resources/network" + "github.com/siderolabs/talos/pkg/machinery/resources/runtime" +) + +type setupFunc func(context.Context, *testing.T, state.State) + +func TestPopulate(t *testing.T) { + t.Parallel() + + for _, test := range []struct { + name string + url string + + preSetup []setupFunc + parallelSetup []setupFunc + + expected string + }{ + { + name: "no variables", + url: "https://example.com?foo=bar", + expected: "https://example.com?foo=bar", + }, + { + name: "legacy UUID", + url: "https://example.com?uuid=", + expected: "https://example.com?uuid=0000-0000", + preSetup: []setupFunc{ + createSysInfo("0000-0000", ""), + }, + }, + { + name: "sys info", + url: "https://example.com?uuid=${uuid}&no=${serial}", + expected: "https://example.com?no=12345&uuid=0000-0000", + preSetup: []setupFunc{ + createSysInfo("0000-0000", "12345"), + }, + }, + { + name: "multiple variables", + url: "https://example.com?uuid=${uuid}&mac=${mac}&hostname=${hostname}&code=${code}", + expected: "https://example.com?code=top-secret&hostname=example-node&mac=12%3A34%3A56%3A78%3A90%3Aab&uuid=0000-0000", + preSetup: []setupFunc{ + createSysInfo("0000-0000", "12345"), + createMac("12:34:56:78:90:ab"), + createHostname("example-node"), + createCode("top-secret"), + }, + }, + { + name: "mixed wait variables", + url: "https://example.com?uuid=${uuid}&mac=${mac}&hostname=${hostname}&code=${code}", + expected: "https://example.com?code=top-secret&hostname=another-node&mac=12%3A34%3A56%3A78%3A90%3Aab&uuid=0000-1234", + preSetup: []setupFunc{ + createSysInfo("0000-1234", "12345"), + createMac("12:34:56:78:90:ab"), + createHostname("example-node"), + }, + parallelSetup: []setupFunc{ + sleep(time.Second), + updateHostname("another-node"), + sleep(time.Second), + createCode("top-secret"), + }, + }, + } { + test := test + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + st := state.WrapCore(namespaced.NewState(inmem.Build)) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + for _, f := range test.preSetup { + f(ctx, t, st) + } + + errCh := make(chan error) + + var result string + + go func() { + var e error + + result, e = url.Populate(ctx, test.url, st) + errCh <- e + }() + + for _, f := range test.parallelSetup { + f(ctx, t, st) + } + + err := <-errCh + require.NoError(t, err) + + assert.Equal(t, test.expected, result) + }) + } +} + +func createSysInfo(uuid, serial string) setupFunc { + return func(ctx context.Context, t *testing.T, st state.State) { + sysInfo := hardware.NewSystemInformation(hardware.SystemInformationID) + sysInfo.TypedSpec().UUID = uuid + sysInfo.TypedSpec().SerialNumber = serial + require.NoError(t, st.Create(ctx, sysInfo)) + } +} + +func createMac(mac string) setupFunc { + return func(ctx context.Context, t *testing.T, st state.State) { + addr, err := net.ParseMAC(mac) + require.NoError(t, err) + + hwAddr := network.NewHardwareAddr(network.NamespaceName, network.FirstHardwareAddr) + hwAddr.TypedSpec().HardwareAddr = nethelpers.HardwareAddr(addr) + require.NoError(t, st.Create(ctx, hwAddr)) + } +} + +func createHostname(hostname string) setupFunc { + return func(ctx context.Context, t *testing.T, st state.State) { + hn := network.NewHostnameStatus(network.NamespaceName, network.HostnameID) + hn.TypedSpec().Hostname = hostname + require.NoError(t, st.Create(ctx, hn)) + } +} + +func updateHostname(hostname string) setupFunc { + return func(ctx context.Context, t *testing.T, st state.State) { + hn, err := safe.StateGet[*network.HostnameStatus](ctx, st, network.NewHostnameStatus(network.NamespaceName, network.HostnameID).Metadata()) + require.NoError(t, err) + + hn.TypedSpec().Hostname = hostname + require.NoError(t, st.Update(ctx, hn)) + } +} + +func createCode(code string) setupFunc { + return func(ctx context.Context, t *testing.T, st state.State) { + mk := runtime.NewMetaKey(runtime.NamespaceName, runtime.MetaKeyTagToID(meta.DownloadURLCode)) + mk.TypedSpec().Value = code + require.NoError(t, st.Create(ctx, mk)) + } +} + +func sleep(d time.Duration) setupFunc { + return func(ctx context.Context, t *testing.T, st state.State) { + time.Sleep(d) + } +} diff --git a/internal/app/machined/pkg/runtime/v1alpha1/platform/metal/internal/url/value.go b/internal/app/machined/pkg/runtime/v1alpha1/platform/metal/internal/url/value.go new file mode 100644 index 0000000000..182a7bede3 --- /dev/null +++ b/internal/app/machined/pkg/runtime/v1alpha1/platform/metal/internal/url/value.go @@ -0,0 +1,148 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package url + +import ( + "context" + "sync" + + "github.com/cosi-project/runtime/pkg/state" + + "github.com/siderolabs/talos/internal/pkg/meta" + "github.com/siderolabs/talos/pkg/machinery/resources/hardware" + "github.com/siderolabs/talos/pkg/machinery/resources/network" + "github.com/siderolabs/talos/pkg/machinery/resources/runtime" +) + +// Value of a variable. +type Value interface { + // Get the value. + Get() string + // RegisterWatch handles registering a watch for the variable. + RegisterWatch(ctx context.Context, st state.State, ch chan<- state.Event) error + // EventHandler is called for each watch event, returns when the variable value is ready. + EventHandler(event state.Event) (bool, error) +} + +type value struct { + mu sync.Mutex + val string + + registerWatch func(ctx context.Context, st state.State, ch chan<- state.Event) error + eventHandler func(event state.Event) (string, error) +} + +func (v *value) Get() string { + v.mu.Lock() + defer v.mu.Unlock() + + return v.val +} + +func (v *value) RegisterWatch(ctx context.Context, st state.State, ch chan<- state.Event) error { + return v.registerWatch(ctx, st, ch) +} + +func (v *value) EventHandler(event state.Event) (bool, error) { + val, err := v.eventHandler(event) + if err != nil { + return false, err + } + + if val == "" { + return false, nil + } + + v.mu.Lock() + v.val = val + v.mu.Unlock() + + return true, nil +} + +// UUIDValue is a value for UUID variable. +func UUIDValue() Value { + return &value{ + registerWatch: func(ctx context.Context, st state.State, ch chan<- state.Event) error { + return st.Watch(ctx, hardware.NewSystemInformation(hardware.SystemInformationID).Metadata(), ch) + }, + eventHandler: func(event state.Event) (string, error) { + sysInfo, ok := event.Resource.(*hardware.SystemInformation) + if !ok { + return "", nil + } + + return sysInfo.TypedSpec().UUID, nil + }, + } +} + +// SerialNumberValue is a value for SerialNumber variable. +func SerialNumberValue() Value { + return &value{ + registerWatch: func(ctx context.Context, st state.State, ch chan<- state.Event) error { + return st.Watch(ctx, hardware.NewSystemInformation(hardware.SystemInformationID).Metadata(), ch) + }, + eventHandler: func(event state.Event) (string, error) { + sysInfo, ok := event.Resource.(*hardware.SystemInformation) + if !ok { + return "", nil + } + + return sysInfo.TypedSpec().SerialNumber, nil + }, + } +} + +// MACValue is a value for MAC variable. +func MACValue() Value { + return &value{ + registerWatch: func(ctx context.Context, st state.State, ch chan<- state.Event) error { + return st.Watch(ctx, network.NewHardwareAddr(network.NamespaceName, network.FirstHardwareAddr).Metadata(), ch) + }, + eventHandler: func(event state.Event) (string, error) { + hwAddr, ok := event.Resource.(*network.HardwareAddr) + if !ok { + return "", nil + } + + return hwAddr.TypedSpec().HardwareAddr.String(), nil + }, + } +} + +// HostnameValue is a value for Hostname variable. +func HostnameValue() Value { + return &value{ + registerWatch: func(ctx context.Context, st state.State, ch chan<- state.Event) error { + return st.Watch(ctx, network.NewHostnameStatus(network.NamespaceName, network.HostnameID).Metadata(), ch) + }, + eventHandler: func(event state.Event) (string, error) { + hostname, ok := event.Resource.(*network.HostnameStatus) + if !ok { + return "", nil + } + + return hostname.TypedSpec().Hostname, nil + }, + } +} + +// CodeValue is a value for Code variable. +func CodeValue() Value { + return &value{ + registerWatch: func(ctx context.Context, st state.State, ch chan<- state.Event) error { + return st.Watch(ctx, runtime.NewMetaKey(runtime.NamespaceName, runtime.MetaKeyTagToID(meta.DownloadURLCode)).Metadata(), ch) + }, + eventHandler: func(event state.Event) (string, error) { + code, ok := event.Resource.(*runtime.MetaKey) + if !ok { + return "", nil + } + + return code.TypedSpec().Value, nil + }, + } +} diff --git a/internal/app/machined/pkg/runtime/v1alpha1/platform/metal/internal/url/variable.go b/internal/app/machined/pkg/runtime/v1alpha1/platform/metal/internal/url/variable.go new file mode 100644 index 0000000000..f836841cc3 --- /dev/null +++ b/internal/app/machined/pkg/runtime/v1alpha1/platform/metal/internal/url/variable.go @@ -0,0 +1,108 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package url + +import ( + "net/url" + "regexp" + "strings" + "sync" + + "github.com/siderolabs/talos/pkg/machinery/constants" +) + +// Variable represents a variable substitution in the download URL. +type Variable struct { + // Key is the variable name. + Key string + // MatchOnArg is set for variables which are match on the arg name with empty value. + // + // Required to support legacy `?uuid=` style of the download URL. + MatchOnArg bool + // Value is the variable value. + Value Value + + rOnce sync.Once + r *regexp.Regexp +} + +// AllVariables is a list of all supported variables. +func AllVariables() []*Variable { + return []*Variable{ + { + Key: constants.UUIDKey, + MatchOnArg: true, + Value: UUIDValue(), + }, + { + Key: constants.SerialNumberKey, + Value: SerialNumberValue(), + }, + { + Key: constants.MacKey, + Value: MACValue(), + }, + { + Key: constants.HostnameKey, + Value: HostnameValue(), + }, + { + Key: constants.CodeKey, + Value: CodeValue(), + }, + } +} + +func keyToVar(key string) string { + return `${` + key + `}` +} + +func (v *Variable) init() { + v.rOnce.Do(func() { + v.r = regexp.MustCompile(`(?i)` + regexp.QuoteMeta(keyToVar(v.Key))) + }) +} + +// Matches checks if the variable is present in the URL. +func (v *Variable) Matches(query url.Values) bool { + v.init() + + for arg, values := range query { + if v.MatchOnArg { + if arg == v.Key && !(len(values) == 1 && strings.TrimSpace(values[0]) != "") { + return true + } + } + + for _, value := range values { + if v.r.MatchString(value) { + return true + } + } + } + + return false +} + +// Replace modifies the URL query replacing the variable with the value. +func (v *Variable) Replace(query url.Values) { + v.init() + + for arg, values := range query { + if v.MatchOnArg { + if arg == v.Key && !(len(values) == 1 && strings.TrimSpace(values[0]) != "") { + query.Set(arg, v.Value.Get()) + + continue + } + } + + for idx, value := range values { + values[idx] = v.r.ReplaceAllString(value, v.Value.Get()) + } + + query[arg] = values + } +} diff --git a/internal/app/machined/pkg/runtime/v1alpha1/platform/metal/internal/url/variable_test.go b/internal/app/machined/pkg/runtime/v1alpha1/platform/metal/internal/url/variable_test.go new file mode 100644 index 0000000000..1ef9644b01 --- /dev/null +++ b/internal/app/machined/pkg/runtime/v1alpha1/platform/metal/internal/url/variable_test.go @@ -0,0 +1,160 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package url_test + +import ( + "context" + neturl "net/url" + "testing" + + "github.com/cosi-project/runtime/pkg/state" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/siderolabs/talos/internal/app/machined/pkg/runtime/v1alpha1/platform/metal/internal/url" + "github.com/siderolabs/talos/pkg/machinery/constants" +) + +func TestVariableMatches(t *testing.T) { + t.Parallel() + + for _, test := range []struct { + name string + url string + shouldMatch map[string]struct{} + }{ + { + name: "no matches", + url: "https://example.com?foo=bar", + }, + { + name: "legacy UUID", + url: "https://example.com?uuid=&foo=bar", + shouldMatch: map[string]struct{}{ + constants.UUIDKey: {}, + }, + }, + { + name: "UUID static", + url: "https://example.com?uuid=0000-0000&foo=bar", + }, + { + name: "more variables", + url: "https://example.com?uuid=${uuid}&foo=bar&serial=${serial}&mac=${mac}&hostname=fixed&hostname=${hostname}", + shouldMatch: map[string]struct{}{ + constants.UUIDKey: {}, + constants.SerialNumberKey: {}, + constants.MacKey: {}, + constants.HostnameKey: {}, + }, + }, + { + name: "case insensitive", + url: "https://example.com?uuid=${UUId}&foo=bar&serial=${SeRiaL}", + shouldMatch: map[string]struct{}{ + constants.UUIDKey: {}, + constants.SerialNumberKey: {}, + }, + }, + } { + test := test + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + u, err := neturl.Parse(test.url) + require.NoError(t, err) + + for _, variable := range url.AllVariables() { + if _, ok := test.shouldMatch[variable.Key]; ok { + assert.True(t, variable.Matches(u.Query())) + } else { + assert.False(t, variable.Matches(u.Query())) + } + } + }) + } +} + +type mockValue struct { + value string +} + +func (v mockValue) Get() string { + return v.value +} + +func (v mockValue) RegisterWatch(context.Context, state.State, chan<- state.Event) error { + return nil +} + +func (v mockValue) EventHandler(state.Event) (bool, error) { + return true, nil +} + +func TestVariableReplace(t *testing.T) { + t.Parallel() + + var1 := &url.Variable{ + Key: "var1", + MatchOnArg: true, + Value: mockValue{ + value: "value1", + }, + } + + var2 := &url.Variable{ + Key: "var2", + Value: mockValue{ + value: "value2", + }, + } + + for _, test := range []struct { + name string + url string + expected string + }{ + { + name: "no matches", + url: "https://example.com?foo=bar", + expected: "https://example.com?foo=bar", + }, + { + name: "legacy match", + url: "https://example.com?var1=&foo=bar", + expected: "https://example.com?foo=bar&var1=value1", + }, + { + name: "variable match", + url: "https://example.com?a=${var1}-suffix&foo=bar&b=${var2}&b=xyz&b=${var2}", + expected: "https://example.com?a=value1-suffix&b=value2&b=xyz&b=value2&foo=bar", + }, + { + name: "case insensitive", + url: "https://example.com?a=${VAR1}", + expected: "https://example.com?a=value1", + }, + } { + test := test + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + u, err := neturl.Parse(test.url) + require.NoError(t, err) + + query := u.Query() + + for _, variable := range []*url.Variable{var1, var2} { + variable.Replace(query) + } + + u.RawQuery = query.Encode() + + assert.Equal(t, test.expected, u.String()) + }) + } +} diff --git a/internal/app/machined/pkg/runtime/v1alpha1/platform/metal/match.go b/internal/app/machined/pkg/runtime/v1alpha1/platform/metal/match.go deleted file mode 100644 index 2bb0283474..0000000000 --- a/internal/app/machined/pkg/runtime/v1alpha1/platform/metal/match.go +++ /dev/null @@ -1,59 +0,0 @@ -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package metal - -import "regexp" - -func keyToVar(key string) string { - return `${` + key + `}` -} - -type matcher struct { - Key string - Regexp *regexp.Regexp -} - -func newMatcher(key string) *matcher { - return &matcher{ - Key: keyToVar(key), - Regexp: regexp.MustCompile(`(?i)` + regexp.QuoteMeta(keyToVar(key))), - } -} - -type replacer struct { - original string - Regexp *regexp.Regexp - Matches [][]int -} - -func (m *matcher) process(original string) *replacer { - var r replacer - r.Regexp = m.Regexp - r.original = original - - r.Matches = m.Regexp.FindAllStringIndex(original, -1) - - return &r -} - -func (r *replacer) ReplaceMatches(replacement string) string { - var res string - - if len(r.Matches) < 1 { - return res - } - - res += r.original[:r.Matches[0][0]] - res += replacement - - for i := 0; i < len(r.Matches)-1; i++ { - res += r.original[r.Matches[i][1]:r.Matches[i+1][0]] - res += replacement - } - - res += r.original[r.Matches[len(r.Matches)-1][1]:] - - return res -} diff --git a/internal/app/machined/pkg/runtime/v1alpha1/platform/metal/metal.go b/internal/app/machined/pkg/runtime/v1alpha1/platform/metal/metal.go index fbd08b2fef..bfaefea185 100644 --- a/internal/app/machined/pkg/runtime/v1alpha1/platform/metal/metal.go +++ b/internal/app/machined/pkg/runtime/v1alpha1/platform/metal/metal.go @@ -18,12 +18,14 @@ import ( "github.com/siderolabs/go-blockdevice/blockdevice/probe" "github.com/siderolabs/go-pointer" "github.com/siderolabs/go-procfs/procfs" + "github.com/siderolabs/go-retry/retry" "golang.org/x/sys/unix" "gopkg.in/yaml.v3" "github.com/siderolabs/talos/internal/app/machined/pkg/runtime" "github.com/siderolabs/talos/internal/app/machined/pkg/runtime/v1alpha1/platform/errors" "github.com/siderolabs/talos/internal/app/machined/pkg/runtime/v1alpha1/platform/internal/netutils" + "github.com/siderolabs/talos/internal/app/machined/pkg/runtime/v1alpha1/platform/metal/internal/url" "github.com/siderolabs/talos/internal/pkg/meta" "github.com/siderolabs/talos/pkg/download" "github.com/siderolabs/talos/pkg/machinery/constants" @@ -54,15 +56,19 @@ func (m *Metal) Configuration(ctx context.Context, r state.State) ([]byte, error return nil, errors.ErrNoConfigSource } - getURL := func() string { - downloadEndpoint, err := PopulateURLParameters(ctx, *option, r) + getURL := func(ctx context.Context) (string, error) { + // give a shorter timeout to populate the URL, leave the rest of the time to the actual download + ctx, cancel := context.WithTimeout(ctx, constants.ConfigLoadAttemptTimeout/2) + defer cancel() + + downloadEndpoint, err := url.Populate(ctx, *option, r) if err != nil { - log.Printf("failed to populate talos.config fetch URL: %q ; %s", *option, err.Error()) + log.Printf("failed to populate talos.config fetch URL %q: %s", *option, err.Error()) } log.Printf("fetching machine config from: %q", downloadEndpoint) - return downloadEndpoint + return downloadEndpoint, nil } switch *option { @@ -73,7 +79,16 @@ func (m *Metal) Configuration(ctx context.Context, r state.State) ([]byte, error return nil, err } - return download.Download(ctx, *option, download.WithEndpointFunc(getURL)) + return download.Download( + ctx, + *option, + download.WithEndpointFunc(getURL), + download.WithTimeout(constants.ConfigLoadTimeout), + download.WithRetryOptions( + // give a timeout per attempt, max 50% of that is dedicated for URL interpolation, the rest is for the actual download + retry.WithAttemptTimeout(constants.ConfigLoadAttemptTimeout), + ), + ) } } diff --git a/internal/app/machined/pkg/runtime/v1alpha1/platform/metal/url.go b/internal/app/machined/pkg/runtime/v1alpha1/platform/metal/url.go deleted file mode 100644 index 9ed04a2a59..0000000000 --- a/internal/app/machined/pkg/runtime/v1alpha1/platform/metal/url.go +++ /dev/null @@ -1,193 +0,0 @@ -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package metal - -import ( - "context" - "fmt" - "net/url" - "strings" - "time" - - "github.com/cosi-project/runtime/pkg/resource" - "github.com/cosi-project/runtime/pkg/safe" - "github.com/cosi-project/runtime/pkg/state" - "github.com/hashicorp/go-multierror" - - "github.com/siderolabs/talos/pkg/machinery/constants" - hardwareResource "github.com/siderolabs/talos/pkg/machinery/resources/hardware" - "github.com/siderolabs/talos/pkg/machinery/resources/network" -) - -// PopulateURLParameters fills in empty parameters in the download URL. -// -//nolint:gocyclo -func PopulateURLParameters(ctx context.Context, downloadURL string, r state.State) (string, error) { - populatedURL := downloadURL - - genErr := func(varOfKey string, errToWrap error) error { - return fmt.Errorf("error while substituting %s: %w", varOfKey, errToWrap) - } - - u, err := url.Parse(populatedURL) - if err != nil { - return "", fmt.Errorf("failed to parse %s: %w", populatedURL, err) - } - - values := u.Query() - - substitute := func(varToSubstitute string, getFunc func(ctx context.Context, r state.State) (string, error)) error { - m := newMatcher(varToSubstitute) - - for qKey, qValues := range values { - if len(qValues) == 0 { - continue - } - - qVal := qValues[0] - - // backwards compatible behavior for the uuid key - if qKey == constants.UUIDKey && !(len(qValues) == 1 && len(strings.TrimSpace(qVal)) > 0) { - uid, err := getSystemUUID(ctx, r) - if err != nil { - return fmt.Errorf("error while substituting UUID: %w", err) - } - - values.Set(constants.UUIDKey, uid) - - continue - } - - replacer := m.process(qVal) - - if len(replacer.Matches) < 1 { - continue - } - - val, err := getFunc(ctx, r) - if err != nil { - return genErr(m.Key, err) - } - - qVal = replacer.ReplaceMatches(val) - - values.Set(qKey, qVal) - } - - return nil - } - - var multiErr *multierror.Error - - if err := substitute(constants.UUIDKey, getSystemUUID); err != nil { - multiErr = multierror.Append(multiErr, err) - } - - if err := substitute(constants.SerialNumberKey, getSystemSerialNumber); err != nil { - multiErr = multierror.Append(multiErr, err) - } - - if err := substitute(constants.MacKey, getMACAddress); err != nil { - multiErr = multierror.Append(multiErr, err) - } - - if err := substitute(constants.HostnameKey, getHostname); err != nil { - multiErr = multierror.Append(multiErr, err) - } - - u.RawQuery = values.Encode() - - return u.String(), multiErr.ErrorOrNil() -} - -//nolint:gocyclo -func getResource[T resource.Resource](ctx context.Context, r state.State, namespace, typ, valName string, isReadyFunc func(T) bool, checkAndGetFunc func(T) string) (string, error) { - metadata := resource.NewMetadata(namespace, typ, "", resource.VersionUndefined) - - watchCtx, cancel := context.WithTimeout(ctx, 1*time.Minute) - defer cancel() - - events := make(chan safe.WrappedStateEvent[T]) - - err := safe.StateWatchKind(watchCtx, r, metadata, events, state.WithBootstrapContents(true)) - if err != nil { - return "", fmt.Errorf("failed to watch %s resources: %w", typ, err) - } - - for { - select { - case <-watchCtx.Done(): - return "", fmt.Errorf("failed to determine %s of %s: %w", valName, typ, watchCtx.Err()) - case event := <-events: - switch event.Type() { - case state.Created, state.Updated: - // ok, proceed - case state.Destroyed, state.Bootstrapped: - // ignore - continue - case state.Errored: - return "", event.Error() - } - - eventResource, err := event.Resource() - if err != nil { - return "", fmt.Errorf("incorrect resource: %w", err) - } - - if !isReadyFunc(eventResource) { - continue - } - - val := checkAndGetFunc(eventResource) - if val == "" { - return "", fmt.Errorf("%s property of resource %s is empty", valName, typ) - } - - return val, nil - } - } -} - -func getUUIDProperty(r *hardwareResource.SystemInformation) string { - return r.TypedSpec().UUID -} - -func getSerialNumberProperty(r *hardwareResource.SystemInformation) string { - return r.TypedSpec().SerialNumber -} - -func getSystemUUID(ctx context.Context, r state.State) (string, error) { - return getResource(ctx, r, hardwareResource.NamespaceName, hardwareResource.SystemInformationType, "UUID", func(*hardwareResource.SystemInformation) bool { return true }, getUUIDProperty) -} - -func getSystemSerialNumber(ctx context.Context, r state.State) (string, error) { - return getResource(ctx, - r, - hardwareResource.NamespaceName, - hardwareResource.SystemInformationType, - "Serial Number", - func(*hardwareResource.SystemInformation) bool { return true }, - getSerialNumberProperty) -} - -func getMACAddressProperty(r *network.LinkStatus) string { - return r.TypedSpec().HardwareAddr.String() -} - -func checkLinkUp(r *network.LinkStatus) bool { - return r.TypedSpec().LinkState -} - -func getMACAddress(ctx context.Context, r state.State) (string, error) { - return getResource(ctx, r, network.NamespaceName, network.LinkStatusType, "MAC Address", checkLinkUp, getMACAddressProperty) -} - -func getHostnameProperty(r *network.HostnameSpec) string { - return r.TypedSpec().Hostname -} - -func getHostname(ctx context.Context, r state.State) (string, error) { - return getResource(ctx, r, network.NamespaceName, network.HostnameSpecType, "Hostname", func(*network.HostnameSpec) bool { return true }, getHostnameProperty) -} diff --git a/internal/app/machined/pkg/runtime/v1alpha1/platform/metal/url_test.go b/internal/app/machined/pkg/runtime/v1alpha1/platform/metal/url_test.go index 3cc929e22f..6101ea7ac7 100644 --- a/internal/app/machined/pkg/runtime/v1alpha1/platform/metal/url_test.go +++ b/internal/app/machined/pkg/runtime/v1alpha1/platform/metal/url_test.go @@ -6,11 +6,9 @@ package metal_test import ( "context" - "fmt" "net" "net/http" "net/http/httptest" - "net/url" "testing" "time" @@ -52,22 +50,20 @@ func createOrUpdate(ctx context.Context, st state.State, r resource.Resource) er } func setup(ctx context.Context, t *testing.T, st state.State, mockUUID, mockSerialNumber, mockHostname, mockMAC string) { - testID := "testID" - sysInfo := hardware.NewSystemInformation(testID) + sysInfo := hardware.NewSystemInformation(hardware.SystemInformationID) sysInfo.TypedSpec().UUID = mockUUID sysInfo.TypedSpec().SerialNumber = mockSerialNumber assert.NoError(t, createOrUpdate(ctx, st, sysInfo)) - hostnameSpec := network.NewHostnameSpec(network.NamespaceName, testID) + hostnameSpec := network.NewHostnameStatus(network.NamespaceName, network.HostnameID) hostnameSpec.TypedSpec().Hostname = mockHostname assert.NoError(t, createOrUpdate(ctx, st, hostnameSpec)) - linkStatusSpec := network.NewLinkStatus(network.NamespaceName, testID) + linkStatusSpec := network.NewHardwareAddr(network.NamespaceName, network.FirstHardwareAddr) parsedMockMAC, err := net.ParseMAC(mockMAC) assert.NoError(t, err) linkStatusSpec.TypedSpec().HardwareAddr = nethelpers.HardwareAddr(parsedMockMAC) - linkStatusSpec.TypedSpec().LinkState = true assert.NoError(t, createOrUpdate(ctx, st, linkStatusSpec)) netStatus := network.NewStatus(network.NamespaceName, network.StatusID) @@ -75,125 +71,6 @@ func setup(ctx context.Context, t *testing.T, st state.State, mockUUID, mockSeri assert.NoError(t, createOrUpdate(ctx, st, netStatus)) } -func TestPopulateURLParameters(t *testing.T) { - mockUUID := "40dcbd19-3b10-444e-bfff-aaee44a51fda" - - mockMAC := "52:2f:fd:df:fc:c0" - - mockSerialNumber := "0OCZJ19N65" - - mockHostname := "myTestHostname" - - for _, tt := range []struct { - name string - url string - expectedURL string - expectedError string - }{ - { - name: "no uuid", - url: "http://example.com/metadata", - expectedURL: "http://example.com/metadata", - }, - { - name: "empty uuid", - url: "http://example.com/metadata?uuid=", - expectedURL: fmt.Sprintf("http://example.com/metadata?uuid=%s", mockUUID), - }, - { - name: "uuid present", - url: "http://example.com/metadata?uuid=xyz", - expectedURL: "http://example.com/metadata?uuid=xyz", - }, - { - name: "multiple uuids in one query parameter", - url: "http://example.com/metadata?u=this-${uuid}-equals-${uuid}-exactly", - expectedURL: fmt.Sprintf("http://example.com/metadata?u=this-%s-equals-%s-exactly", mockUUID, mockUUID), - }, - { - name: "uuid and mac in one query parameter", - url: "http://example.com/metadata?u=this-${uuid}-and-${mac}-together", - expectedURL: fmt.Sprintf("http://example.com/metadata?u=this-%s-and-%s-together", mockUUID, mockMAC), - }, - { - name: "other parameters", - url: "http://example.com/metadata?foo=a", - expectedURL: "http://example.com/metadata?foo=a", - }, - { - name: "multiple uuids", - url: "http://example.com/metadata?uuid=xyz&uuid=foo", - expectedURL: fmt.Sprintf("http://example.com/metadata?uuid=%s", mockUUID), - }, - { - name: "single serial number", - url: "http://example.com/metadata?serial=${serial}", - expectedURL: fmt.Sprintf("http://example.com/metadata?serial=%s", mockSerialNumber), - }, - { - name: "single MAC", - url: "http://example.com/metadata?mac=${mac}", - expectedURL: fmt.Sprintf("http://example.com/metadata?mac=%s", mockMAC), - }, - { - name: "single hostname", - url: "http://example.com/metadata?host=${hostname}", - expectedURL: fmt.Sprintf("http://example.com/metadata?host=%s", mockHostname), - }, - { - name: "serial number, MAC and hostname", - url: "http://example.com/metadata?h=${hostname}&m=${mac}&s=${serial}", - expectedURL: fmt.Sprintf("http://example.com/metadata?h=%s&m=%s&s=%s", mockHostname, mockMAC, mockSerialNumber), - }, - { - name: "uuid, serial number, MAC and hostname; case-insensitive", - url: "http://example.com/metadata?h=${HOSTname}&m=${mAC}&s=${SERIAL}&u=${uUid}", - expectedURL: fmt.Sprintf("http://example.com/metadata?h=%s&m=%s&s=%s&u=%s", mockHostname, mockMAC, mockSerialNumber, mockUUID), - }, - { - name: "MAC and UUID without variable", - url: "http://example.com/metadata?macaddr=${mac}&uuid=", - expectedURL: fmt.Sprintf("http://example.com/metadata?macaddr=%s&uuid=%s", mockMAC, mockUUID), - }, - { - name: "serial number and UUID without variable, order is not preserved", - url: "http://example.com/metadata?uuid=&ser=${serial}", - expectedURL: fmt.Sprintf("http://example.com/metadata?ser=%s&uuid=%s", mockSerialNumber, mockUUID), - }, - { - name: "UUID variable", - url: "http://example.com/metadata?uuid=${uuid}", - expectedURL: fmt.Sprintf("http://example.com/metadata?uuid=%s", mockUUID), - }, - { - name: "serial number and UUID with variable, order is not preserved", - url: "http://example.com/metadata?uuid=${uuid}&ser=${serial}", - expectedURL: fmt.Sprintf("http://example.com/metadata?ser=%s&uuid=%s", mockSerialNumber, mockUUID), - }, - } { - tt := tt - - t.Run(tt.name, func(t *testing.T) { - ctx := context.Background() - - st := state.WrapCore(namespaced.NewState(inmem.Build)) - - setup(ctx, t, st, mockUUID, mockSerialNumber, mockHostname, mockMAC) - - output, err := metal.PopulateURLParameters(ctx, tt.url, st) - - if tt.expectedError != "" { - assert.EqualError(t, err, tt.expectedError) - } else { - u, err := url.Parse(tt.expectedURL) - assert.NoError(t, err) - u.RawQuery = u.Query().Encode() - assert.Equal(t, u.String(), output) - } - }) - } -} - func TestRepopulateOnRetry(t *testing.T) { st := state.WrapCore(namespaced.NewState(inmem.Build)) diff --git a/internal/pkg/meta/constants.go b/internal/pkg/meta/constants.go index f97b9db5f3..adb5b088f1 100644 --- a/internal/pkg/meta/constants.go +++ b/internal/pkg/meta/constants.go @@ -15,4 +15,12 @@ const ( StateEncryptionConfig // MetalNetworkPlatformConfig stores serialized NetworkPlatformConfig for the `metal` platform. MetalNetworkPlatformConfig + // DownloadURLCode stores the value of the `${code}` variable in the download URL for talos.config= URL. + DownloadURLCode + // UserReserved1 is reserved for user-defined metadata. + UserReserved1 + // UserReserved2 is reserved for user-defined metadata. + UserReserved2 + // UserReserved3 is reserved for user-defined metadata. + UserReserved3 ) diff --git a/pkg/download/download.go b/pkg/download/download.go index cdcf2231e6..b8bb59f40e 100644 --- a/pkg/download/download.go +++ b/pkg/download/download.go @@ -2,6 +2,7 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. +// Package download provides a download with retries for machine configuration and userdata. package download import ( @@ -29,18 +30,26 @@ type downloadOptions struct { Headers map[string]string Format string LowSrcPort bool - Endpoint string + + EndpointFunc func(context.Context) (string, error) ErrorOnNotFound error ErrorOnEmptyResponse error + + Timeout time.Duration + RetryOptions []retry.Option } // Option configures the download options. type Option func(*downloadOptions) -func downloadDefaults() *downloadOptions { +func downloadDefaults(endpoint string) *downloadOptions { return &downloadOptions{ + EndpointFunc: func(context.Context) (string, error) { + return endpoint, nil + }, Headers: make(map[string]string), + Timeout: 3 * time.Minute, } } @@ -89,9 +98,23 @@ func WithErrorOnEmptyResponse(e error) Option { } // WithEndpointFunc provides a function that sets the endpoint of the download options. -func WithEndpointFunc(endpointFunc func() string) Option { +func WithEndpointFunc(endpointFunc func(context.Context) (string, error)) Option { + return func(d *downloadOptions) { + d.EndpointFunc = endpointFunc + } +} + +// WithTimeout sets the timeout for the download. +func WithTimeout(timeout time.Duration) Option { + return func(d *downloadOptions) { + d.Timeout = timeout + } +} + +// WithRetryOptions sets the retry options for the download. +func WithRetryOptions(opts ...retry.Option) Option { return func(d *downloadOptions) { - d.Endpoint = endpointFunc() + d.RetryOptions = append(d.RetryOptions, opts...) } } @@ -99,61 +122,71 @@ func WithEndpointFunc(endpointFunc func() string) Option { // //nolint:gocyclo func Download(ctx context.Context, endpoint string, opts ...Option) (b []byte, err error) { - dlOpts := downloadDefaults() + options := downloadDefaults(endpoint) - dlOpts.Endpoint = endpoint + for _, opt := range opts { + opt(options) + } err = retry.Exponential( - 180*time.Second, - retry.WithUnits(time.Second), - retry.WithJitter(time.Second), - retry.WithErrorLogging(true), + options.Timeout, + append([]retry.Option{ + retry.WithUnits(time.Second), + retry.WithJitter(time.Second), + retry.WithErrorLogging(true), + }, + options.RetryOptions..., + )..., ).RetryWithContext(ctx, func(ctx context.Context) error { - dlOpts = downloadDefaults() - - dlOpts.Endpoint = endpoint - - for _, opt := range opts { - opt(dlOpts) - } - - var u *url.URL - u, err = url.Parse(dlOpts.Endpoint) + var attemptEndpoint string + attemptEndpoint, err = options.EndpointFunc(ctx) if err != nil { return err } - if u.Scheme == "file" { - var fileContent []byte - fileContent, err = os.ReadFile(u.Path) + if err = func() error { + var u *url.URL + u, err = url.Parse(attemptEndpoint) if err != nil { return err } - b = fileContent + if u.Scheme == "file" { + var fileContent []byte + fileContent, err = os.ReadFile(u.Path) + if err != nil { + return err + } - return nil - } + b = fileContent - var req *http.Request + return nil + } - if req, err = http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil); err != nil { - return err - } + var req *http.Request - for k, v := range dlOpts.Headers { - req.Header.Set(k, v) - } + if req, err = http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil); err != nil { + return err + } + + for k, v := range options.Headers { + req.Header.Set(k, v) + } - b, err = download(req, dlOpts) + b, err = download(req, options) + + return err + }(); err != nil { + return fmt.Errorf("failed to download config from %q: %w", endpoint, err) + } - return err + return nil }) if err != nil { - return nil, fmt.Errorf("failed to download config from %q: %w", dlOpts.Endpoint, err) + return nil, err } - if dlOpts.Format == b64 { + if options.Format == b64 { var b64 []byte b64, err = base64.StdEncoding.DecodeString(string(b)) @@ -167,11 +200,11 @@ func Download(ctx context.Context, endpoint string, opts ...Option) (b []byte, e return b, nil } -func download(req *http.Request, dlOpts *downloadOptions) (data []byte, err error) { +func download(req *http.Request, options *downloadOptions) (data []byte, err error) { transport := httpdefaults.PatchTransport(cleanhttp.DefaultTransport()) transport.RegisterProtocol("tftp", NewTFTPTransport()) - if dlOpts.LowSrcPort { + if options.LowSrcPort { port := 100 + rand.Intn(512) localTCPAddr, tcperr := net.ResolveTCPAddr("tcp", ":"+strconv.Itoa(port)) @@ -200,8 +233,8 @@ func download(req *http.Request, dlOpts *downloadOptions) (data []byte, err erro //nolint:errcheck defer resp.Body.Close() - if resp.StatusCode == http.StatusNotFound && dlOpts.ErrorOnNotFound != nil { - return data, dlOpts.ErrorOnNotFound + if resp.StatusCode == http.StatusNotFound && options.ErrorOnNotFound != nil { + return data, options.ErrorOnNotFound } if resp.StatusCode != http.StatusOK { @@ -213,8 +246,8 @@ func download(req *http.Request, dlOpts *downloadOptions) (data []byte, err erro return data, retry.ExpectedError(fmt.Errorf("read config: %s", err.Error())) } - if len(data) == 0 && dlOpts.ErrorOnEmptyResponse != nil { - return data, dlOpts.ErrorOnEmptyResponse + if len(data) == 0 && options.ErrorOnEmptyResponse != nil { + return data, options.ErrorOnEmptyResponse } return data, nil diff --git a/pkg/download/download_test.go b/pkg/download/download_test.go new file mode 100644 index 0000000000..3b0040e210 --- /dev/null +++ b/pkg/download/download_test.go @@ -0,0 +1,151 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package download_test + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/siderolabs/go-retry/retry" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/siderolabs/talos/pkg/download" +) + +type flipper struct { + srv *httptest.Server + val int + sleep time.Duration +} + +func (f *flipper) EndpointFunc() func(context.Context) (string, error) { + return func(context.Context) (string, error) { + time.Sleep(f.sleep) + + f.val++ + + if f.val%2 == 1 { + return f.srv.URL + "/404", nil + } + + return f.srv.URL + "/data", nil + } +} + +func TestDownload(t *testing.T) { + t.Parallel() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/empty": + w.WriteHeader(http.StatusOK) + case "/data": + w.WriteHeader(http.StatusOK) + w.Write([]byte("data")) //nolint:errcheck + case "/base64": + w.WriteHeader(http.StatusOK) + w.Write([]byte("ZGF0YQ==")) //nolint:errcheck + case "/404": + w.WriteHeader(http.StatusNotFound) + default: + w.WriteHeader(http.StatusInternalServerError) + } + })) + t.Cleanup(srv.Close) + + flip := flipper{ + srv: srv, + } + + sleepingFlip := flipper{ + srv: srv, + sleep: 100 * time.Millisecond, + } + + for _, test := range []struct { + name string + path string + opts []download.Option + + expected string + expectedError string + }{ + { + name: "empty download", + path: "/empty", + expected: "", + }, + { + name: "some data", + path: "/data", + expected: "data", + }, + { + name: "base64", + path: "/base64", + opts: []download.Option{download.WithFormat("base64")}, + expected: "data", + }, + { + name: "empty error", + path: "/empty", + opts: []download.Option{download.WithErrorOnEmptyResponse(fmt.Errorf("empty response"))}, + expectedError: "empty response", + }, + { + name: "not found error", + path: "/404", + opts: []download.Option{download.WithErrorOnNotFound(fmt.Errorf("gone forever"))}, + expectedError: "gone forever", + }, + { + name: "failure 404", + path: "/404", + opts: []download.Option{download.WithTimeout(2 * time.Second)}, + expectedError: "failed to download config, received 404", + }, + { + name: "retry endpoint change", + opts: []download.Option{ + download.WithTimeout(2 * time.Second), + download.WithEndpointFunc(flip.EndpointFunc()), + }, + expected: "data", + }, + { + name: "retry with attempt timeout", + opts: []download.Option{ + download.WithTimeout(2 * time.Second), + download.WithEndpointFunc(sleepingFlip.EndpointFunc()), + download.WithRetryOptions(retry.WithAttemptTimeout(200 * time.Millisecond)), + }, + expected: "data", + }, + } { + test := test + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + b, err := download.Download(ctx, srv.URL+test.path, test.opts...) + + if test.expectedError != "" { + assert.ErrorContains(t, err, test.expectedError) + } else { + require.NoError(t, err) + + assert.Equal(t, test.expected, string(b)) + } + }) + } +} diff --git a/pkg/machinery/constants/constants.go b/pkg/machinery/constants/constants.go index 797202f28c..e30198922b 100644 --- a/pkg/machinery/constants/constants.go +++ b/pkg/machinery/constants/constants.go @@ -602,7 +602,10 @@ const ( DefaultDNSDomain = "cluster.local" // ConfigLoadTimeout is the timeout to wait for the config to be loaded from an external source. - ConfigLoadTimeout = 3 * time.Minute + ConfigLoadTimeout = 3 * time.Hour + + // ConfigLoadAttemptTimeout is the timeout for a single attempt to download config. + ConfigLoadAttemptTimeout = 3 * time.Minute // BootTimeout is the timeout to run all services. BootTimeout = 70 * time.Minute @@ -844,6 +847,7 @@ const ( SerialNumberKey = "serial" HostnameKey = "hostname" MacKey = "mac" + CodeKey = "code" ) // Overlays is the set of paths to create overlay mounts for.