diff --git a/example/pkgname/userloader_gen.go b/example/pkgname/userloader_gen.go index 3495d73..72b3f38 100644 --- a/example/pkgname/userloader_gen.go +++ b/example/pkgname/userloader_gen.go @@ -3,6 +3,7 @@ package differentpkg import ( + "fmt" "sync" "time" @@ -19,6 +20,10 @@ type UserLoaderConfig struct { // MaxBatch will limit the maximum number of keys to send in one batch, 0 = not limit MaxBatch int + + // Recover is a function to transform a recovered value into an error. + // If a function is not supplied, the value is formatted with fmt.Errorf("%v", v). + Recover func(v interface{}) error } // NewUserLoader creates a new UserLoader given a fetch, wait, and maxBatch @@ -27,6 +32,7 @@ func NewUserLoader(config UserLoaderConfig) *UserLoader { fetch: config.Fetch, wait: config.Wait, maxBatch: config.MaxBatch, + recover: config.Recover, } } @@ -41,6 +47,9 @@ type UserLoader struct { // this will limit the maximum number of keys to send in one batch, 0 = no limit maxBatch int + // this transforms recovered panic values into errors + recover func(v interface{}) error + // INTERNAL // lazily created cache @@ -219,6 +228,15 @@ func (b *userLoaderBatch) startTimer(l *UserLoader) { } func (b *userLoaderBatch) end(l *UserLoader) { + defer func() { + if r := recover(); r != nil { + if l.recover != nil { + b.error = []error{l.recover(r)} + } else { + b.error = []error{fmt.Errorf("%v", r)} + } + } + close(b.done) + }() b.data, b.error = l.fetch(b.keys) - close(b.done) } diff --git a/example/slice/usersliceloader_gen.go b/example/slice/usersliceloader_gen.go index c2d6e83..56df6ae 100644 --- a/example/slice/usersliceloader_gen.go +++ b/example/slice/usersliceloader_gen.go @@ -3,6 +3,7 @@ package slice import ( + "fmt" "sync" "time" @@ -19,6 +20,10 @@ type UserSliceLoaderConfig struct { // MaxBatch will limit the maximum number of keys to send in one batch, 0 = not limit MaxBatch int + + // Recover is a function to transform a recovered value into an error. + // If a function is not supplied, the value is formatted with fmt.Errorf("%v", v). + Recover func(v interface{}) error } // NewUserSliceLoader creates a new UserSliceLoader given a fetch, wait, and maxBatch @@ -27,6 +32,7 @@ func NewUserSliceLoader(config UserSliceLoaderConfig) *UserSliceLoader { fetch: config.Fetch, wait: config.Wait, maxBatch: config.MaxBatch, + recover: config.Recover, } } @@ -41,6 +47,9 @@ type UserSliceLoader struct { // this will limit the maximum number of keys to send in one batch, 0 = no limit maxBatch int + // this transforms recovered panic values into errors + recover func(v interface{}) error + // INTERNAL // lazily created cache @@ -220,6 +229,15 @@ func (b *userSliceLoaderBatch) startTimer(l *UserSliceLoader) { } func (b *userSliceLoaderBatch) end(l *UserSliceLoader) { + defer func() { + if r := recover(); r != nil { + if l.recover != nil { + b.error = []error{l.recover(r)} + } else { + b.error = []error{fmt.Errorf("%v", r)} + } + } + close(b.done) + }() b.data, b.error = l.fetch(b.keys) - close(b.done) } diff --git a/example/user_test.go b/example/user_test.go index 342a4ad..b390361 100644 --- a/example/user_test.go +++ b/example/user_test.go @@ -1,6 +1,7 @@ package example import ( + "errors" "fmt" "strings" "sync" @@ -29,6 +30,8 @@ func TestUserLoader(t *testing.T) { for i, key := range keys { if strings.HasPrefix(key, "E") { errors[i] = fmt.Errorf("user not found") + } else if strings.HasPrefix(key, "P") { + panic("something bad happened") } else { users[i] = &User{ID: key, Name: "user " + key} } @@ -193,4 +196,21 @@ func TestUserLoader(t *testing.T) { require.Error(t, err2[1]) require.Equal(t, "user U6", users2[0].Name) }) + + t.Run("fetch panic with recover func", func(t *testing.T) { + expectedErr := errors.New("transformed") + dl.recover = func(interface{}) error { + return expectedErr + } + u, err := dl.Load("P1") + require.Nil(t, u) + require.Equal(t, err, expectedErr) + dl.recover = nil + }) + + t.Run("fetch panic with no recover func", func(t *testing.T) { + u, err := dl.Load("P1") + require.Nil(t, u) + require.Error(t, err) + }) } diff --git a/example/userloader_gen.go b/example/userloader_gen.go index 470ba6a..42bb1e9 100644 --- a/example/userloader_gen.go +++ b/example/userloader_gen.go @@ -3,6 +3,7 @@ package example import ( + "fmt" "sync" "time" ) @@ -17,6 +18,10 @@ type UserLoaderConfig struct { // MaxBatch will limit the maximum number of keys to send in one batch, 0 = not limit MaxBatch int + + // Recover is a function to transform a recovered value into an error. + // If a function is not supplied, the value is formatted with fmt.Errorf("%v", v). + Recover func(v interface{}) error } // NewUserLoader creates a new UserLoader given a fetch, wait, and maxBatch @@ -25,6 +30,7 @@ func NewUserLoader(config UserLoaderConfig) *UserLoader { fetch: config.Fetch, wait: config.Wait, maxBatch: config.MaxBatch, + recover: config.Recover, } } @@ -39,6 +45,9 @@ type UserLoader struct { // this will limit the maximum number of keys to send in one batch, 0 = no limit maxBatch int + // this transforms recovered panic values into errors + recover func(v interface{}) error + // INTERNAL // lazily created cache @@ -217,6 +226,15 @@ func (b *userLoaderBatch) startTimer(l *UserLoader) { } func (b *userLoaderBatch) end(l *UserLoader) { + defer func() { + if r := recover(); r != nil { + if l.recover != nil { + b.error = []error{l.recover(r)} + } else { + b.error = []error{fmt.Errorf("%v", r)} + } + } + close(b.done) + }() b.data, b.error = l.fetch(b.keys) - close(b.done) } diff --git a/pkg/generator/template.go b/pkg/generator/template.go index 48f5ba2..c0cc38f 100644 --- a/pkg/generator/template.go +++ b/pkg/generator/template.go @@ -12,6 +12,7 @@ var tpl = template.Must(template.New("generated"). package {{.Package}} import ( + "fmt" "sync" "time" @@ -21,7 +22,7 @@ import ( // {{.Name}}Config captures the config to create a new {{.Name}} type {{.Name}}Config struct { - // Fetch is a method that provides the data for the loader + // Fetch is a method that provides the data for the loader Fetch func(keys []{{.KeyType.String}}) ([]{{.ValType.String}}, []error) // Wait is how long wait before sending a batch @@ -29,6 +30,10 @@ type {{.Name}}Config struct { // MaxBatch will limit the maximum number of keys to send in one batch, 0 = not limit MaxBatch int + + // Recover is a function to transform a recovered value into an error. + // If a function is not supplied, the value is formatted with fmt.Errorf("%v", v). + Recover func(v interface{}) error } // New{{.Name}} creates a new {{.Name}} given a fetch, wait, and maxBatch @@ -37,10 +42,11 @@ func New{{.Name}}(config {{.Name}}Config) *{{.Name}} { fetch: config.Fetch, wait: config.Wait, maxBatch: config.MaxBatch, + recover: config.Recover, } } -// {{.Name}} batches and caches requests +// {{.Name}} batches and caches requests type {{.Name}} struct { // this method provides the data for the loader fetch func(keys []{{.KeyType.String}}) ([]{{.ValType.String}}, []error) @@ -51,6 +57,9 @@ type {{.Name}} struct { // this will limit the maximum number of keys to send in one batch, 0 = no limit maxBatch int + // this transforms recovered panic values into errors + recover func(v interface{}) error + // INTERNAL // lazily created cache @@ -239,7 +248,16 @@ func (b *{{.Name|lcFirst}}Batch) startTimer(l *{{.Name}}) { } func (b *{{.Name|lcFirst}}Batch) end(l *{{.Name}}) { + defer func() { + if r := recover(); r != nil { + if l.recover != nil { + b.error = []error{l.recover(r)} + } else { + b.error = []error{fmt.Errorf("%v", r)} + } + } + close(b.done) + }() b.data, b.error = l.fetch(b.keys) - close(b.done) } `))