Skip to content
23 changes: 15 additions & 8 deletions databuilder.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
*/

type builder struct {
Builder any
fnValue reflect.Value // cached reflect.ValueOf(builder func) to avoid repeated reflection
In []string
Out string
Name string
Expand Down Expand Up @@ -88,11 +88,17 @@ func (d *db) Compile(init ...any) (Plan, error) {

// IsValidBuilder checks if the given function is valid or not
func IsValidBuilder(builder any) error {
if builder == nil {
return ErrInvalidBuilder
}
t := reflect.TypeOf(builder)
if t.Kind() != reflect.Func {
// Input can only be a function
return ErrInvalidBuilderKind
}
if reflect.ValueOf(builder).IsNil() {
return ErrInvalidBuilder
}
if t.NumOut() != 2 {
// should return a struct and an error
return ErrInvalidBuilderNumOutput
Expand Down Expand Up @@ -139,13 +145,18 @@ func getBuilder(bldr any) (*builder, error) {
return nil, err
}

t := reflect.TypeOf(bldr)
fnValue := reflect.ValueOf(bldr)
if fnValue.IsNil() {
return nil, ErrInvalidBuilder
}

t := fnValue.Type()
out := getStructName(t.Out(0))
name := getFuncName(bldr)
name := runtime.FuncForPC(fnValue.Pointer()).Name()

b := &builder{
Out: out,
Builder: bldr,
fnValue: fnValue,
Name: name,
}
// first in context.Context so we start from second
Expand All @@ -155,10 +166,6 @@ func getBuilder(bldr any) (*builder, error) {
return b, nil
}

func getFuncName(bldr any) string {
return runtime.FuncForPC(reflect.ValueOf(bldr).Pointer()).Name()
}

func getStructName(t reflect.Type) string {
return t.PkgPath() + "." + t.Name()
}
Expand Down
46 changes: 46 additions & 0 deletions databuilder_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package databuilder

import (
"context"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -61,3 +62,48 @@ func TestCompileCyclic(t *testing.T) {
_, err = d.Compile()
assert.Error(t, err, "cyclic dependency should return an error")
}

func TestTypedNilBuilderRejected(t *testing.T) {
var nilFunc func(context.Context) (TestStruct1, error)
d := testNew(t)
err := d.AddBuilders(nilFunc)
assert.Error(t, err, "typed-nil func should be rejected")
assert.ErrorIs(t, err, ErrInvalidBuilder)
Comment thread
ankurs marked this conversation as resolved.

// IsValidBuilder should also reject typed-nil builders directly
err = IsValidBuilder(nilFunc)
assert.Error(t, err, "IsValidBuilder should reject typed-nil func")
assert.ErrorIs(t, err, ErrInvalidBuilder)
}

func TestContextCancellation(t *testing.T) {
d := testNew(t)
err := d.AddBuilders(DBTestFunc, DBTestFunc4)
assert.NoError(t, err)
plan, err := d.Compile(TestStruct1{})
assert.NoError(t, err)

ctx, cancel := context.WithCancel(context.Background())
cancel() // cancel immediately

_, err = plan.Run(ctx, TestStruct1{Value: "test"})
assert.Error(t, err, "cancelled context should return an error")
assert.ErrorIs(t, err, context.Canceled)
}

func TestJoinErrors(t *testing.T) {
// Single error should be returned unwrapped
sentinel := ErrWTF
err := joinErrors([]error{sentinel})
assert.Equal(t, sentinel, err, "single error should be returned as-is, not wrapped")

// No errors
err = joinErrors(nil)
assert.NoError(t, err)

// Multiple errors
err = joinErrors([]error{ErrWTF, ErrInvalidBuilder})
assert.Error(t, err)
assert.ErrorIs(t, err, ErrWTF)
assert.ErrorIs(t, err, ErrInvalidBuilder)
}
32 changes: 21 additions & 11 deletions plan.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ func processWork(ctx context.Context, w work) {
w.out <- o
}
}()
fn := reflect.ValueOf(w.builder.Builder)
fn := w.builder.fnValue
// allow builders to access already built data
ctx = AddResultToCtx(ctx, w.dataMap)
args := make([]reflect.Value, 1)
Expand Down Expand Up @@ -183,12 +183,21 @@ func doWorkAndGetResult(ctx context.Context, builders []*builder, dataMap map[st
name := getStructName(outputs[0].Type())
dataMap[name] = outputs[0].Interface()
}
if len(errs) > 0 {
// we only return the first error
// only the first error is returned; aggregate if needed
return joinErrors(errs)
}

// joinErrors returns nil for no errors, the error itself for a single error,
// or a joined error for multiple errors. This avoids wrapping single errors
// which would break sentinel checks like err == context.Canceled.
func joinErrors(errs []error) error {
switch len(errs) {
case 0:
return nil
case 1:
return errs[0]
default:
return errors.Join(errs...)
}
Comment thread
ankurs marked this conversation as resolved.
return nil
}

func (p *plan) run(ctx context.Context, workers uint, dataMap map[string]any) error {
Expand All @@ -205,17 +214,18 @@ func (p *plan) run(ctx context.Context, workers uint, dataMap map[string]any) er

errs := make([]error, 0)
for i := range p.order {
if err := ctx.Err(); err != nil {
if len(errs) == 0 {
return err
}
return joinErrors(append(errs, err))
}
err := doWorkAndGetResult(ctx, p.order[i], dataMap, wChan)
if err != nil {
errs = append(errs, err)
}
}
if len(errs) > 0 {
// we only return the first error
// only the first error is returned; aggregate if needed
return errs[0]
}
return nil
return joinErrors(errs)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
}

// Result.Get returns the value of the struct from the result
Expand Down
Loading