// Package massert implements an assertion framework which is useful in tests. package massert import ( "bytes" "errors" "fmt" "path/filepath" "reflect" "runtime" "strings" "text/tabwriter" ) // AssertErr is an error returned by Assertions which have failed, containing // information about both the reason for failure and the Assertion itself. type AssertErr struct { Err error // The error which occurred Assertion Assertion // The Assertion which failed } func fmtBlock(str string) string { if strings.Index(str, "\n") == -1 { return str } return "\n\t" + strings.Replace(str, "\n", "\n\t", -1) + "\n" } func fmtMultiDescr(prefix string, aa ...Assertion) string { if len(aa) == 0 { return prefix + "()" } else if len(aa) == 1 { return prefix + "(" + fmtBlock(aa[0].Description()) + ")" } buf := new(bytes.Buffer) fmt.Fprintf(buf, "%s(\n", prefix) for _, a := range aa { descrStr := "\t" + strings.Replace(a.Description(), "\n", "\n\t", -1) fmt.Fprintf(buf, "%s,\n", descrStr) } fmt.Fprintf(buf, ")") return buf.String() } func fmtStack(frames []runtime.Frame) string { buf := new(bytes.Buffer) tw := tabwriter.NewWriter(buf, 0, 4, 2, ' ', 0) for _, frame := range frames { file := filepath.Base(frame.File) fmt.Fprintf(tw, "%s:%d\t%s\n", file, frame.Line, frame.Function) } if err := tw.Flush(); err != nil { panic(err) // fuck it } return buf.String() } func (ae AssertErr) Error() string { buf := new(bytes.Buffer) fmt.Fprintf(buf, "\n") fmt.Fprintf(buf, "Assertion: %s\n", fmtBlock(ae.Assertion.Description())) fmt.Fprintf(buf, "Error: %s\n", fmtBlock(ae.Err.Error())) fmt.Fprintf(buf, "Stack: %s\n", fmtBlock(fmtStack(ae.Assertion.Stack()))) return buf.String() } //////////////////////////////////////////////////////////////////////////////// // Assertion is an entity which will make some kind of assertion and produce an // error if that assertion does not hold true. The error returned will generally // be of type AssertErr. type Assertion interface { Assert() error Description() string // A description of the Assertion // Returns the callstack of where the Assertion was created, ordered from // closest to farthest. This may not necessarily contain the entire // callstack if that would be inconveniently cumbersome. Stack() []runtime.Frame } const maxStackLen = 8 type assertion struct { fn func() error descr string stack []runtime.Frame } func newAssertion(assertFn func() error, descr string, skip int) Assertion { pcs := make([]uintptr, maxStackLen) // first skip is for runtime.Callers, second is for newAssertion, third is // for whatever is calling newAssertion numPCs := runtime.Callers(skip+3, pcs) stack := make([]runtime.Frame, 0, maxStackLen) frames := runtime.CallersFrames(pcs[:numPCs]) for { frame, more := frames.Next() stack = append(stack, frame) if !more || len(stack) == maxStackLen { break } } a := &assertion{ descr: descr, stack: stack, } a.fn = func() error { err := assertFn() if err == nil { return nil } else if ae, ok := err.(AssertErr); ok { return ae } return AssertErr{ Err: err, Assertion: a, } } return a } func (a *assertion) Assert() error { return a.fn() } func (a *assertion) Description() string { return a.descr } func (a *assertion) Stack() []runtime.Frame { return a.stack } //////////////////////////////////////////////////////////////////////////////// // Assertion wrappers // if the Assertion is a wrapper for another, this makes sure that if the // underlying one returns an AssertErr that this Assertion is what ends up in // that AssertErr type wrap struct { Assertion } func (wa wrap) Assert() error { err := wa.Assertion.Assert() if err == nil { return nil } ae := err.(AssertErr) ae.Assertion = wa.Assertion return ae } type descrWrap struct { Assertion descr string } func (dw descrWrap) Description() string { return dw.descr } // Comment prepends a formatted string to the given Assertion's string // description. func Comment(a Assertion, msg string, args ...interface{}) Assertion { msg = strings.TrimSpace(msg) descr := fmt.Sprintf("/* "+msg+" */\n", args...) descr += a.Description() return wrap{descrWrap{Assertion: a, descr: descr}} } // Not negates an Assertion, so that it fails if the given Assertion does not, // and vice-versa. func Not(a Assertion) Assertion { fn := func() error { if err := a.Assert(); err == nil { return errors.New("assertion should have failed") } return nil } return newAssertion(fn, fmtMultiDescr("Not", a), 0) } // Any asserts that at least one of the given Assertions succeeds. func Any(aa ...Assertion) Assertion { fn := func() error { for _, a := range aa { if err := a.Assert(); err == nil { return nil } } return errors.New("no assertions succeeded") } return newAssertion(fn, fmtMultiDescr("Any", aa...), 0) } // AnyOne asserts that exactly one of the given Assertions succeeds. func AnyOne(aa ...Assertion) Assertion { fn := func() error { any := -1 for i, a := range aa { if err := a.Assert(); err == nil { if any >= 0 { return fmt.Errorf("assertions indices %d and %d both succeeded", any, i) } any = i } } if any == -1 { return errors.New("no assertions succeeded") } return nil } return newAssertion(fn, fmtMultiDescr("AnyOne", aa...), 0) } // All asserts that at all of the given Assertions succeed. Its Assert method // will return the error of whichever Assertion failed. func All(aa ...Assertion) Assertion { fn := func() error { for _, a := range aa { if err := a.Assert(); err != nil { // newAssertion will pass this error through, so that its // description and callstack is what gets displayed as the // error. This isn't totally consistent with Any's behavior, but // it's fine. return err } } return nil } return newAssertion(fn, fmtMultiDescr("All", aa...), 0) } // None asserts that all of the given Assertions fail. // // NOTE this is functionally equivalent to doing `Not(Any(aa...))`, but the // error returned is more helpful. func None(aa ...Assertion) Assertion { fn := func() error { for _, a := range aa { if err := a.Assert(); err == nil { return AssertErr{ Err: errors.New("assertion should not have succeeded"), Assertion: a, } } } return nil } return newAssertion(fn, fmtMultiDescr("None", aa...), 0) } //////////////////////////////////////////////////////////////////////////////// var typeOfInt64 = reflect.TypeOf(int64(0)) func toStr(i interface{}) string { return fmt.Sprintf("%T(%#v)", i, i) } // Equal asserts that the two values given are equal. The equality checking // done is to some degree fuzzy in the following ways: // // * All pointers are dereferenced. // * All ints and uints are converted to int64. // func Equal(a, b interface{}) Assertion { normalize := func(v reflect.Value) reflect.Value { v = reflect.Indirect(v) switch v.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: v = v.Convert(typeOfInt64) } return v } fn := func() error { aV, bV := reflect.ValueOf(a), reflect.ValueOf(b) aV, bV = normalize(aV), normalize(bV) if !reflect.DeepEqual(aV.Interface(), bV.Interface()) { return errors.New("not equal") } return nil } return newAssertion(fn, toStr(a)+" == "+toStr(b), 0) } // Exactly asserts that the two values are exactly equal, and uses the // reflect.DeepEquals function to determine if they are. func Exactly(a, b interface{}) Assertion { return newAssertion(func() error { if !reflect.DeepEqual(a, b) { return errors.New("not exactly equal") } return nil }, toStr(a)+" === "+toStr(b), 0) }