parent
3e2713a850
commit
360d41e2b8
@ -1 +0,0 @@ |
||||
Dockerfile |
@ -1,11 +0,0 @@ |
||||
FROM golang:1.12 AS builder |
||||
WORKDIR /app |
||||
COPY . . |
||||
RUN GOBIN=$(pwd)/bin CGO_ENABLED=0 GOOS=linux go install -a -installsuffix cgo ./cmd/... |
||||
|
||||
FROM alpine:latest |
||||
RUN apk --no-cache add ca-certificates |
||||
WORKDIR /app/bin |
||||
COPY --from=builder /app/bin /app/bin |
||||
ENV PATH="/app/bin:${PATH}" |
||||
CMD echo "Available commands:" && ls |
@ -1,2 +0,0 @@ |
||||
- read through all docs, especially package docs, make sure they make sense |
||||
- write examples |
@ -1,122 +0,0 @@ |
||||
package main |
||||
|
||||
/* |
||||
totp-proxy is a reverse proxy which implements basic time-based one-time |
||||
password (totp) authentication for any website. |
||||
|
||||
It takes in a JSON object which maps usernames to totp secrets (generated at |
||||
a site like https://freeotp.github.io/qrcode.html), as well as a url to
|
||||
proxy requests to. Users are prompted with a basic-auth prompt, and if they |
||||
succeed their totp challenge a cookie is set and requests are proxied to the |
||||
destination. |
||||
*/ |
||||
|
||||
import ( |
||||
"context" |
||||
"net/http" |
||||
"net/url" |
||||
"time" |
||||
|
||||
"github.com/mediocregopher/mediocre-go-lib/m" |
||||
"github.com/mediocregopher/mediocre-go-lib/mcfg" |
||||
"github.com/mediocregopher/mediocre-go-lib/mcrypto" |
||||
"github.com/mediocregopher/mediocre-go-lib/mctx" |
||||
"github.com/mediocregopher/mediocre-go-lib/merr" |
||||
"github.com/mediocregopher/mediocre-go-lib/mhttp" |
||||
"github.com/mediocregopher/mediocre-go-lib/mlog" |
||||
"github.com/mediocregopher/mediocre-go-lib/mrand" |
||||
"github.com/mediocregopher/mediocre-go-lib/mrun" |
||||
"github.com/mediocregopher/mediocre-go-lib/mtime" |
||||
"github.com/pquerna/otp/totp" |
||||
) |
||||
|
||||
func main() { |
||||
cmp := m.RootServiceComponent() |
||||
cookieName := mcfg.String(cmp, "cookie-name", |
||||
mcfg.ParamDefault("_totp_proxy"), |
||||
mcfg.ParamUsage("String to use as the name for cookies")) |
||||
cookieTimeout := mcfg.Duration(cmp, "cookie-timeout", |
||||
mcfg.ParamDefault(mtime.Duration{1 * time.Hour}), |
||||
mcfg.ParamUsage("Timeout for cookies")) |
||||
|
||||
var userSecrets map[string]string |
||||
mcfg.JSON(cmp, "users", &userSecrets, |
||||
mcfg.ParamRequired(), |
||||
mcfg.ParamUsage("JSON object which maps usernames to their TOTP secret strings")) |
||||
|
||||
var secret mcrypto.Secret |
||||
secretStr := mcfg.String(cmp, "secret", |
||||
mcfg.ParamUsage("String used to sign authentication tokens. If one isn't given a new one will be generated on each startup, invalidating all previous tokens.")) |
||||
mrun.InitHook(cmp, func(context.Context) error { |
||||
if *secretStr == "" { |
||||
*secretStr = mrand.Hex(32) |
||||
} |
||||
mlog.From(cmp).Info("generating secret") |
||||
secret = mcrypto.NewSecret([]byte(*secretStr)) |
||||
return nil |
||||
}) |
||||
|
||||
proxyHandler := new(struct{ http.Handler }) |
||||
proxyURL := mcfg.String(cmp, "dst-url", |
||||
mcfg.ParamRequired(), |
||||
mcfg.ParamUsage("URL to proxy requests to. Only the scheme and host should be set.")) |
||||
mrun.InitHook(cmp, func(context.Context) error { |
||||
u, err := url.Parse(*proxyURL) |
||||
if err != nil { |
||||
return merr.Wrap(err, cmp.Context()) |
||||
} |
||||
proxyHandler.Handler = mhttp.ReverseProxy(u) |
||||
return nil |
||||
}) |
||||
|
||||
authHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
||||
// TODO mlog.FromHTTP?
|
||||
ctx := r.Context() |
||||
|
||||
unauthorized := func() { |
||||
mlog.From(cmp).Debug("connection is unauthorized") |
||||
w.Header().Add("WWW-Authenticate", "Basic") |
||||
w.WriteHeader(http.StatusUnauthorized) |
||||
} |
||||
|
||||
authorized := func() { |
||||
mlog.From(cmp).Debug("connection is authorized, rewriting cookies") |
||||
sig := mcrypto.SignString(secret, "") |
||||
http.SetCookie(w, &http.Cookie{ |
||||
Name: *cookieName, |
||||
Value: sig.String(), |
||||
MaxAge: int((*cookieTimeout).Seconds()), |
||||
}) |
||||
proxyHandler.ServeHTTP(w, r) |
||||
} |
||||
|
||||
if cookie, _ := r.Cookie(*cookieName); cookie != nil { |
||||
mlog.From(cmp).Debug("authenticating with cookie", |
||||
mctx.Annotate(ctx, "cookie", cookie.String())) |
||||
var sig mcrypto.Signature |
||||
if err := sig.UnmarshalText([]byte(cookie.Value)); err == nil { |
||||
err := mcrypto.VerifyString(secret, sig, "") |
||||
if err == nil && time.Since(sig.Time()) < (*cookieTimeout).Duration { |
||||
authorized() |
||||
return |
||||
} |
||||
} |
||||
} |
||||
|
||||
if user, pass, ok := r.BasicAuth(); ok && pass != "" { |
||||
mlog.From(cmp).Debug("authenticating with user", |
||||
mctx.Annotate(ctx, "user", user)) |
||||
if userSecret, ok := userSecrets[user]; ok { |
||||
if totp.Validate(pass, userSecret) { |
||||
authorized() |
||||
return |
||||
} |
||||
} |
||||
} |
||||
|
||||
unauthorized() |
||||
}) |
||||
|
||||
mhttp.InstListeningServer(cmp, authHandler) |
||||
m.Exec(cmp) |
||||
} |
@ -1,29 +0,0 @@ |
||||
export CLOUDSDK_CORE_PROJECT="test" |
||||
|
||||
if [ "$(ps aux | grep '[p]ubsub-emulator')" = "" ]; then |
||||
echo "starting pubsub emulator" |
||||
yes | gcloud beta emulators pubsub start >/dev/null 2>&1 & |
||||
fi |
||||
$(gcloud beta emulators pubsub env-init) |
||||
|
||||
if [ "$(ps aux | grep '[c]loud-datastore-emulator')" = "" ]; then |
||||
echo "starting datastore emulator" |
||||
yes | gcloud beta emulators datastore start >/dev/null 2>&1 & |
||||
fi |
||||
$(gcloud beta emulators datastore env-init) |
||||
|
||||
if [ "$(ps aux | grep '[b]igtable-emulator')" = "" ]; then |
||||
echo "starting bigtable emulator" |
||||
yes | gcloud beta emulators bigtable start --host-port 127.0.0.1:8086 >/dev/null 2>&1 & |
||||
fi |
||||
$(gcloud beta emulators bigtable env-init) |
||||
|
||||
if ! (sudo systemctl status mysqld 1>/dev/null); then |
||||
echo "starting mysqld" |
||||
sudo systemctl start mysqld |
||||
fi |
||||
|
||||
if ! (sudo systemctl status redis 1>/dev/null); then |
||||
echo "starting redis" |
||||
sudo systemctl start redis |
||||
fi |
@ -1,176 +0,0 @@ |
||||
cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= |
||||
cloud.google.com/go v0.31.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= |
||||
cloud.google.com/go v0.36.0 h1:+aCSj7tOo2LODWVEuZDZeGCckdt6MlSF+X/rB3wUiS8= |
||||
cloud.google.com/go v0.36.0/go.mod h1:RUoy9p/M4ge0HzT8L+SDZ8jg+Q6fth0CiBuhFJpSV40= |
||||
dmitri.shuralyov.com/app/changes v0.0.0-20180602232624-0a106ad413e3/go.mod h1:Yl+fi1br7+Rr3LqpNJf1/uxUdtRUV+Tnj0o93V2B9MU= |
||||
dmitri.shuralyov.com/html/belt v0.0.0-20180602232347-f7d459c86be0/go.mod h1:JLBrvjyP0v+ecvNYvCpyZgu5/xkfAUhi6wJj28eUfSU= |
||||
dmitri.shuralyov.com/service/change v0.0.0-20181023043359-a85b471d5412/go.mod h1:a1inKt/atXimZ4Mv927x+r7UpyzRUf4emIoiiSC2TN4= |
||||
dmitri.shuralyov.com/state v0.0.0-20180228185332-28bcc343414c/go.mod h1:0PRwlb0D6DFvNNtx+9ybjezNCa8XF0xaYcETyp6rHWU= |
||||
git.apache.org/thrift.git v0.0.0-20180902110319-2566ecd5d999/go.mod h1:fPE2ZNJGynbRyZ4dJvy6G277gSllfV2HJqblrnkyeyg= |
||||
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= |
||||
github.com/anmitsu/go-shlex v0.0.0-20161002113705-648efa622239/go.mod h1:2FmKhYUyUczH0OGQWaF5ceTx0UBShxjsH6f8oGKYe2c= |
||||
github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= |
||||
github.com/boombuler/barcode v1.0.0 h1:s1TvRnXwL2xJRaccrdcBQMZxq6X7DvsMogtmJeHDdrc= |
||||
github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= |
||||
github.com/bradfitz/go-smtpd v0.0.0-20170404230938-deb6d6237625/go.mod h1:HYsPBTaaSFSlLx/70C2HPIMNZpVV8+vt/A+FMnYP11g= |
||||
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= |
||||
github.com/coreos/go-systemd v0.0.0-20181012123002-c6f51f82210d/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= |
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= |
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= |
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= |
||||
github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= |
||||
github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568/go.mod h1:xEzjJPgXI435gkrCt3MPfRiAkVrwSbHsst4LCFVfpJc= |
||||
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= |
||||
github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= |
||||
github.com/gliderlabs/ssh v0.1.1/go.mod h1:U7qILu1NlMHj9FlMhZLlkCdDnU1DBEAqr0aevW3Awn0= |
||||
github.com/go-sql-driver/mysql v1.4.0 h1:7LxgVwFb2hIQtMm87NdgAVfXjnt4OePseqT1tKx+opk= |
||||
github.com/go-sql-driver/mysql v1.4.0/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= |
||||
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= |
||||
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b h1:VKtxabqXZkF25pY9ekfRL6a582T4P37/31XEstQ5p58= |
||||
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= |
||||
github.com/golang/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:tluoj9z5200jBnyusfRPU2LqT6J+DAorxEvtC7LHB+E= |
||||
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= |
||||
github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= |
||||
github.com/golang/protobuf v1.2.0 h1:P3YflyNX/ehuJFLhxviNdFxQPkGK5cDcApsge1SqnvM= |
||||
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= |
||||
github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c h1:964Od4U6p2jUkFxvCydnIczKteheJEzHRToSGK3Bnlw= |
||||
github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= |
||||
github.com/google/go-cmp v0.2.0 h1:+dTQ8DZQJz0Mb/HjFlkptS1FeQ4cWSnN941F8aEG4SQ= |
||||
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= |
||||
github.com/google/go-github v17.0.0+incompatible/go.mod h1:zLgOLi98H3fifZn+44m+umXrS52loVEgC2AApnigrVQ= |
||||
github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= |
||||
github.com/google/martian v2.1.0+incompatible h1:/CP5g8u/VJHijgedC/Legn3BAbAaWPgecwXBIDzw5no= |
||||
github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= |
||||
github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= |
||||
github.com/googleapis/gax-go v2.0.0+incompatible h1:j0GKcs05QVmm7yesiZq2+9cxHkNK9YM6zKx4D2qucQU= |
||||
github.com/googleapis/gax-go v2.0.0+incompatible/go.mod h1:SFVmujtThgffbyetf+mdk2eWhX2bMyUtNHzFKcPA9HY= |
||||
github.com/googleapis/gax-go/v2 v2.0.3 h1:siORttZ36U2R/WjiJuDz8znElWBiAlO9rVt+mqJt0Cc= |
||||
github.com/googleapis/gax-go/v2 v2.0.3/go.mod h1:LLvjysVCY1JZeum8Z6l8qUty8fiNwE08qbEPm1M08qg= |
||||
github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= |
||||
github.com/gregjones/httpcache v0.0.0-20180305231024-9cad4c3443a7/go.mod h1:FecbI9+v66THATjSRHfNgh1IVFe/9kFxbXtjV0ctIMA= |
||||
github.com/grpc-ecosystem/grpc-gateway v1.5.0/go.mod h1:RSKVYQBd5MCa4OVpNdGskqpgL2+G+NZTnrVHpWWfpdw= |
||||
github.com/jellevandenhooff/dkim v0.0.0-20150330215556-f50fe3d243e1/go.mod h1:E0B/fFc00Y+Rasa88328GlI/XbtyysCtTHZS8h7IrBU= |
||||
github.com/jmoiron/sqlx v1.2.0 h1:41Ip0zITnmWNR/vHV+S4m+VoUivnWY5E4OJfLZjCJMA= |
||||
github.com/jmoiron/sqlx v1.2.0/go.mod h1:1FEQNm3xlJgrMD+FBdI9+xvCksHtbpVBBw5dYhBSsks= |
||||
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= |
||||
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= |
||||
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= |
||||
github.com/kr/pty v1.1.3/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= |
||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= |
||||
github.com/lib/pq v1.0.0 h1:X5PMW56eZitiTeO7tKzZxFCSpbFZJtkMMooicw2us9A= |
||||
github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= |
||||
github.com/mattn/go-sqlite3 v1.9.0 h1:pDRiWfl+++eC2FEFRy6jXmQlvp4Yh3z1MJKg4UeYM/4= |
||||
github.com/mattn/go-sqlite3 v1.9.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= |
||||
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= |
||||
github.com/mediocregopher/radix/v3 v3.3.2 h1:2gAC5aDBWQr1LBgaNQiVLb2LGX4lvkARDkfjsuonKJE= |
||||
github.com/mediocregopher/radix/v3 v3.3.2/go.mod h1:RsC7cELtyL4TGkg0nwRPTa+J2TXZ0dh/ruohD3rnjMk= |
||||
github.com/microcosm-cc/bluemonday v1.0.1/go.mod h1:hsXNsILzKxV+sX77C5b8FSuKF00vh2OMYv+xgHpAMF4= |
||||
github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo= |
||||
github.com/neelance/sourcemap v0.0.0-20151028013722-8c68805598ab/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM= |
||||
github.com/openzipkin/zipkin-go v0.1.1/go.mod h1:NtoC/o8u3JlF1lSlyPNswIbeQH9bJTmOf0Erfk+hxe8= |
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= |
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= |
||||
github.com/pquerna/otp v1.1.0 h1:q2gMsMuMl3JzneUaAX1MRGxLvOG6bzXV51hivBaStf0= |
||||
github.com/pquerna/otp v1.1.0/go.mod h1:Zad1CMQfSQZI5KLpahDiSUX4tMMREnXw98IvL1nhgMk= |
||||
github.com/prometheus/client_golang v0.8.0/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= |
||||
github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= |
||||
github.com/prometheus/common v0.0.0-20180801064454-c7de2306084e/go.mod h1:daVV7qP5qjZbuso7PdcryaAu0sAZbrN9i7WWcTMWvro= |
||||
github.com/prometheus/procfs v0.0.0-20180725123919-05ee40e3a273/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= |
||||
github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= |
||||
github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= |
||||
github.com/shurcooL/component v0.0.0-20170202220835-f88ec8f54cc4/go.mod h1:XhFIlyj5a1fBNx5aJTbKoIq0mNaPvOagO+HjB3EtxrY= |
||||
github.com/shurcooL/events v0.0.0-20181021180414-410e4ca65f48/go.mod h1:5u70Mqkb5O5cxEA8nxTsgrgLehJeAw6Oc4Ab1c/P1HM= |
||||
github.com/shurcooL/github_flavored_markdown v0.0.0-20181002035957-2122de532470/go.mod h1:2dOwnU2uBioM+SGy2aZoq1f/Sd1l9OkAeAUvjSyvgU0= |
||||
github.com/shurcooL/go v0.0.0-20180423040247-9e1955d9fb6e/go.mod h1:TDJrrUr11Vxrven61rcy3hJMUqaf/CLWYhHNPmT14Lk= |
||||
github.com/shurcooL/go-goon v0.0.0-20170922171312-37c2f522c041/go.mod h1:N5mDOmsrJOB+vfqUK+7DmDyjhSLIIBnXo9lvZJj3MWQ= |
||||
github.com/shurcooL/gofontwoff v0.0.0-20180329035133-29b52fc0a18d/go.mod h1:05UtEgK5zq39gLST6uB0cf3NEHjETfB4Fgr3Gx5R9Vw= |
||||
github.com/shurcooL/gopherjslib v0.0.0-20160914041154-feb6d3990c2c/go.mod h1:8d3azKNyqcHP1GaQE/c6dDgjkgSx2BZ4IoEi4F1reUI= |
||||
github.com/shurcooL/highlight_diff v0.0.0-20170515013008-09bb4053de1b/go.mod h1:ZpfEhSmds4ytuByIcDnOLkTHGUI6KNqRNPDLHDk+mUU= |
||||
github.com/shurcooL/highlight_go v0.0.0-20181028180052-98c3abbbae20/go.mod h1:UDKB5a1T23gOMUJrI+uSuH0VRDStOiUVSjBTRDVBVag= |
||||
github.com/shurcooL/home v0.0.0-20181020052607-80b7ffcb30f9/go.mod h1:+rgNQw2P9ARFAs37qieuu7ohDNQ3gds9msbT2yn85sg= |
||||
github.com/shurcooL/htmlg v0.0.0-20170918183704-d01228ac9e50/go.mod h1:zPn1wHpTIePGnXSHpsVPWEktKXHr6+SS6x/IKRb7cpw= |
||||
github.com/shurcooL/httperror v0.0.0-20170206035902-86b7830d14cc/go.mod h1:aYMfkZ6DWSJPJ6c4Wwz3QtW22G7mf/PEgaB9k/ik5+Y= |
||||
github.com/shurcooL/httpfs v0.0.0-20171119174359-809beceb2371/go.mod h1:ZY1cvUeJuFPAdZ/B6v7RHavJWZn2YPVFQ1OSXhCGOkg= |
||||
github.com/shurcooL/httpgzip v0.0.0-20180522190206-b1c53ac65af9/go.mod h1:919LwcH0M7/W4fcZ0/jy0qGght1GIhqyS/EgWGH2j5Q= |
||||
github.com/shurcooL/issues v0.0.0-20181008053335-6292fdc1e191/go.mod h1:e2qWDig5bLteJ4fwvDAc2NHzqFEthkqn7aOZAOpj+PQ= |
||||
github.com/shurcooL/issuesapp v0.0.0-20180602232740-048589ce2241/go.mod h1:NPpHK2TI7iSaM0buivtFUc9offApnI0Alt/K8hcHy0I= |
||||
github.com/shurcooL/notifications v0.0.0-20181007000457-627ab5aea122/go.mod h1:b5uSkrEVM1jQUspwbixRBhaIjIzL2xazXp6kntxYle0= |
||||
github.com/shurcooL/octicon v0.0.0-20181028054416-fa4f57f9efb2/go.mod h1:eWdoE5JD4R5UVWDucdOPg1g2fqQRq78IQa9zlOV1vpQ= |
||||
github.com/shurcooL/reactions v0.0.0-20181006231557-f2e0b4ca5b82/go.mod h1:TCR1lToEk4d2s07G3XGfz2QrgHXg4RJBvjrOozvoWfk= |
||||
github.com/shurcooL/sanitized_anchor_name v0.0.0-20170918181015-86672fcb3f95/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= |
||||
github.com/shurcooL/users v0.0.0-20180125191416-49c67e49c537/go.mod h1:QJTqeLYEDaXHZDBsXlPCDqdhQuJkuw4NOtaxYe3xii4= |
||||
github.com/shurcooL/webdavfs v0.0.0-20170829043945-18c3829fa133/go.mod h1:hKmq5kWdCj2z2KEozexVbfEZIWiTjhE0+UjmZgPqehw= |
||||
github.com/sourcegraph/annotate v0.0.0-20160123013949-f4cad6c6324d/go.mod h1:UdhH50NIW0fCiwBSr0co2m7BnFLdv4fQTgdqdJTHFeE= |
||||
github.com/sourcegraph/syntaxhighlight v0.0.0-20170531221838-bd320f5d308e/go.mod h1:HuIsMU8RRBOtsCgI77wP899iHVBQpCmg4ErYMZB+2IA= |
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= |
||||
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= |
||||
github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= |
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= |
||||
github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA= |
||||
go.opencensus.io v0.18.0 h1:Mk5rgZcggtbvtAun5aJzAtjKKN/t0R3jJPlWILlv938= |
||||
go.opencensus.io v0.18.0/go.mod h1:vKdFvxhtzZ9onBp9VKHK8z/sRpBMnKAsufL7wlDrCOA= |
||||
go4.org v0.0.0-20180809161055-417644f6feb5/go.mod h1:MkTOUMDaeVYJUOUsaDXIhWPZYa1yOyC1qaOBpL57BhE= |
||||
golang.org/x/build v0.0.0-20190111050920-041ab4dc3f9d/go.mod h1:OWs+y06UdEOHN4y+MfF/py+xQ/tYqIWW03b70/CG9Rw= |
||||
golang.org/x/crypto v0.0.0-20181030102418-4d3f4d9ffa16/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= |
||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= |
||||
golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= |
||||
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= |
||||
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= |
||||
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= |
||||
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= |
||||
golang.org/x/net v0.0.0-20181029044818-c44066c5c816/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= |
||||
golang.org/x/net v0.0.0-20181106065722-10aee1819953/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= |
||||
golang.org/x/net v0.0.0-20190213061140-3a22650c66bd h1:HuTn7WObtcDo9uEEU7rEqL0jYthdXAmZ6PP+meazmaU= |
||||
golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= |
||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= |
||||
golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= |
||||
golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890 h1:uESlIz09WIHT2I+pasSXcpLYqYK8wHcdCetU3VuMBJE= |
||||
golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= |
||||
golang.org/x/perf v0.0.0-20180704124530-6e6d33e29852/go.mod h1:JLpeXjPJfIyPr5TlbXLkXWLhP8nz10XfvxElABhCtcw= |
||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= |
||||
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f h1:Bl/8QSvNqXvPGPGXa2z5xUTmV7VDcZyvRZ+QQXkXTZQ= |
||||
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= |
||||
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= |
||||
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= |
||||
golang.org/x/sys v0.0.0-20181029174526-d69651ed3497 h1:GXMDsk4xWZCVzkAWCabrabzCCVmfiYSw72f1K/S9QIY= |
||||
golang.org/x/sys v0.0.0-20181029174526-d69651ed3497/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= |
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= |
||||
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2 h1:z99zHgr7hKfrUcX/KsoJk5FJfjTceCKIp96+biqP4To= |
||||
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= |
||||
golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= |
||||
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= |
||||
golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= |
||||
golang.org/x/tools v0.0.0-20181030000716-a0a13e073c7b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= |
||||
golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522 h1:bhOzK9QyoD0ogCnFro1m2mz41+Ib0oOhfJnBp5MR4K4= |
||||
golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= |
||||
google.golang.org/api v0.0.0-20180910000450-7ca32eb868bf/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0= |
||||
google.golang.org/api v0.0.0-20181030000543-1d582fd0359e/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0= |
||||
google.golang.org/api v0.1.0 h1:K6z2u68e86TPdSdefXdzvXgR1zEMa+459vBSfWYAZkI= |
||||
google.golang.org/api v0.1.0/go.mod h1:UGEZY7KEX120AnNLIHFMKIo4obdJhkp2tPbaPlQx13Y= |
||||
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= |
||||
google.golang.org/appengine v1.2.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= |
||||
google.golang.org/appengine v1.3.0 h1:FBSsiFRMz3LBeXIomRnVzrQwSDj4ibvcRexLG0LZGQk= |
||||
google.golang.org/appengine v1.3.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= |
||||
google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= |
||||
google.golang.org/genproto v0.0.0-20180831171423-11092d34479b/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= |
||||
google.golang.org/genproto v0.0.0-20181029155118-b69ba1387ce2/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= |
||||
google.golang.org/genproto v0.0.0-20181202183823-bd91e49a0898/go.mod h1:7Ep/1NZk928CDR8SjdVbjWNpdIf6nzjE3BTgJDr2Atg= |
||||
google.golang.org/genproto v0.0.0-20190201180003-4b09977fb922 h1:mBVYJnbrXLA/ZCBTCe7PtEgAUP+1bg92qTaFoPHdz+8= |
||||
google.golang.org/genproto v0.0.0-20190201180003-4b09977fb922/go.mod h1:L3J43x8/uS+qIUoksaLKe6OS3nUKxOKuIFz1sl2/jx4= |
||||
google.golang.org/grpc v1.14.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw= |
||||
google.golang.org/grpc v1.16.0/go.mod h1:0JHn/cJsOMiMfNA9+DeHDlAU7KAAB5GDlYFpa9MZMio= |
||||
google.golang.org/grpc v1.17.0/go.mod h1:6QZJwpn2B+Zp71q/5VxRsJ6NXXVCE5NRUHRo+f3cWCs= |
||||
google.golang.org/grpc v1.18.0 h1:IZl7mfBGfbhYx2p2rKRtYgDFw6SBz+kclmxYrCksPPA= |
||||
google.golang.org/grpc v1.18.0/go.mod h1:6QZJwpn2B+Zp71q/5VxRsJ6NXXVCE5NRUHRo+f3cWCs= |
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= |
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= |
||||
gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw= |
||||
gopkg.in/mgo.v2 v2.0.0-20180705113604-9856a29383ce h1:xcEWjVhvbDy+nHP67nPDDpbYrY+ILlfndk4bRioVHaU= |
||||
gopkg.in/mgo.v2 v2.0.0-20180705113604-9856a29383ce/go.mod h1:yeKp02qBN3iKW1OzL3MGk2IdtZzaj7SFntXj72NppTA= |
||||
gopkg.in/yaml.v2 v2.2.1 h1:mUhvW9EsL+naU5Q3cakzfE91YhliOondGd6ZrsDBHQE= |
||||
gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= |
||||
grpc.go4.org v0.0.0-20170609214715-11d0a25b4919/go.mod h1:77eQGdRu53HpSqPFJFmuJdjuHRquDANNeA4x7B8WQ9o= |
||||
honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= |
||||
honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= |
||||
sourcegraph.com/sourcegraph/go-diff v0.5.0/go.mod h1:kuch7UrkMzY0X+p9CRK03kfuPQ2zzQcaEFbx8wA8rck= |
||||
sourcegraph.com/sqs/pbtypes v0.0.0-20180604144634-d3ebe8f20ae4/go.mod h1:ketZ/q3QxT9HOBeFhu6RdvsftgpsbFHBF5Cas6cDKZ0= |
@ -1,55 +0,0 @@ |
||||
package jstream |
||||
|
||||
import ( |
||||
"bytes" |
||||
"encoding/base64" |
||||
"io" |
||||
) |
||||
|
||||
type delimReader struct { |
||||
r io.Reader |
||||
delim byte |
||||
rest []byte |
||||
} |
||||
|
||||
func (dr *delimReader) Read(b []byte) (int, error) { |
||||
if dr.delim != 0 { |
||||
return 0, io.EOF |
||||
} |
||||
n, err := dr.r.Read(b) |
||||
if i := bytes.IndexAny(b[:n], bbDelims); i >= 0 { |
||||
dr.delim = b[i] |
||||
dr.rest = append([]byte(nil), b[i+1:n]...) |
||||
return i, err |
||||
} |
||||
return n, err |
||||
} |
||||
|
||||
type byteBlobReader struct { |
||||
dr *delimReader |
||||
dec io.Reader |
||||
} |
||||
|
||||
func newByteBlobReader(r io.Reader) *byteBlobReader { |
||||
dr := &delimReader{r: r} |
||||
return &byteBlobReader{ |
||||
dr: dr, |
||||
dec: base64.NewDecoder(base64.StdEncoding, dr), |
||||
} |
||||
} |
||||
|
||||
func (bbr *byteBlobReader) Read(into []byte) (int, error) { |
||||
n, err := bbr.dec.Read(into) |
||||
if bbr.dr.delim == bbEnd { |
||||
return n, io.EOF |
||||
} else if bbr.dr.delim == bbCancel { |
||||
return n, ErrCanceled |
||||
} |
||||
return n, err |
||||
} |
||||
|
||||
// returns the bytes which were read off the underlying io.Reader but which
|
||||
// haven't been consumed yet.
|
||||
func (bbr *byteBlobReader) buffered() io.Reader { |
||||
return bytes.NewBuffer(bbr.dr.rest) |
||||
} |
@ -1,186 +0,0 @@ |
||||
package jstream |
||||
|
||||
import ( |
||||
"bytes" |
||||
"encoding/base64" |
||||
"io" |
||||
"io/ioutil" |
||||
. "testing" |
||||
|
||||
"github.com/mediocregopher/mediocre-go-lib/mrand" |
||||
"github.com/stretchr/testify/assert" |
||||
) |
||||
|
||||
type bbrTest struct { |
||||
wsSuffix []byte // whitespace
|
||||
body []byte |
||||
shouldCancel bool |
||||
intoSize int |
||||
} |
||||
|
||||
func randBBRTest(minBodySize, maxBodySize int) bbrTest { |
||||
var whitespace = []byte{' ', '\n', '\t', '\r'} |
||||
genWhitespace := func(n int) []byte { |
||||
ws := make([]byte, n) |
||||
for i := range ws { |
||||
ws[i] = whitespace[mrand.Intn(len(whitespace))] |
||||
} |
||||
return ws |
||||
} |
||||
|
||||
body := mrand.Bytes(minBodySize + mrand.Intn(maxBodySize-minBodySize)) |
||||
return bbrTest{ |
||||
wsSuffix: genWhitespace(mrand.Intn(10)), |
||||
body: body, |
||||
intoSize: 1 + mrand.Intn(len(body)+1), |
||||
} |
||||
} |
||||
|
||||
func (bt bbrTest) msgAndArgs() []interface{} { |
||||
return []interface{}{"bt:%#v len(body):%d", bt, len(bt.body)} |
||||
} |
||||
|
||||
func (bt bbrTest) mkBytes() []byte { |
||||
buf := new(bytes.Buffer) |
||||
enc := base64.NewEncoder(base64.StdEncoding, buf) |
||||
|
||||
if bt.shouldCancel { |
||||
enc.Write(bt.body[:len(bt.body)/2]) |
||||
enc.Close() |
||||
buf.WriteByte(bbCancel) |
||||
} else { |
||||
enc.Write(bt.body) |
||||
enc.Close() |
||||
buf.WriteByte(bbEnd) |
||||
} |
||||
|
||||
buf.Write(bt.wsSuffix) |
||||
return buf.Bytes() |
||||
} |
||||
|
||||
func (bt bbrTest) do(t *T) bool { |
||||
buf := bytes.NewBuffer(bt.mkBytes()) |
||||
bbr := newByteBlobReader(buf) |
||||
|
||||
into := make([]byte, bt.intoSize) |
||||
outBuf := new(bytes.Buffer) |
||||
_, err := io.CopyBuffer(outBuf, bbr, into) |
||||
if bt.shouldCancel { |
||||
return assert.Equal(t, ErrCanceled, err, bt.msgAndArgs()...) |
||||
} |
||||
if !assert.NoError(t, err, bt.msgAndArgs()...) { |
||||
return false |
||||
} |
||||
if !assert.Equal(t, bt.body, outBuf.Bytes(), bt.msgAndArgs()...) { |
||||
return false |
||||
} |
||||
fullRest := append(bbr.dr.rest, buf.Bytes()...) |
||||
if len(bt.wsSuffix) == 0 { |
||||
return assert.Empty(t, fullRest, bt.msgAndArgs()...) |
||||
} |
||||
return assert.Equal(t, bt.wsSuffix, fullRest, bt.msgAndArgs()...) |
||||
} |
||||
|
||||
func TestByteBlobReader(t *T) { |
||||
// some sanity tests
|
||||
bbrTest{ |
||||
body: []byte{2, 3, 4, 5}, |
||||
intoSize: 4, |
||||
}.do(t) |
||||
bbrTest{ |
||||
body: []byte{2, 3, 4, 5}, |
||||
intoSize: 3, |
||||
}.do(t) |
||||
bbrTest{ |
||||
body: []byte{2, 3, 4, 5}, |
||||
shouldCancel: true, |
||||
intoSize: 3, |
||||
}.do(t) |
||||
|
||||
// fuzz this bitch
|
||||
for i := 0; i < 50000; i++ { |
||||
bt := randBBRTest(0, 1000) |
||||
if !bt.do(t) { |
||||
return |
||||
} |
||||
bt.shouldCancel = true |
||||
if !bt.do(t) { |
||||
return |
||||
} |
||||
} |
||||
} |
||||
|
||||
func BenchmarkByteBlobReader(b *B) { |
||||
type bench struct { |
||||
bt bbrTest |
||||
body []byte |
||||
buf *bytes.Reader |
||||
cpBuf []byte |
||||
} |
||||
|
||||
mkTestSet := func(minBodySize, maxBodySize int) []bench { |
||||
n := 100 |
||||
benches := make([]bench, n) |
||||
for i := range benches { |
||||
bt := randBBRTest(minBodySize, maxBodySize) |
||||
body := bt.mkBytes() |
||||
benches[i] = bench{ |
||||
bt: bt, |
||||
body: body, |
||||
buf: bytes.NewReader(nil), |
||||
cpBuf: make([]byte, bt.intoSize), |
||||
} |
||||
} |
||||
return benches |
||||
} |
||||
|
||||
testRaw := func(b *B, benches []bench) { |
||||
j := 0 |
||||
for i := 0; i < b.N; i++ { |
||||
if j >= len(benches) { |
||||
j = 0 |
||||
} |
||||
benches[j].buf.Reset(benches[j].body) |
||||
io.CopyBuffer(ioutil.Discard, benches[j].buf, benches[j].cpBuf) |
||||
j++ |
||||
} |
||||
} |
||||
|
||||
testBBR := func(b *B, benches []bench) { |
||||
j := 0 |
||||
for i := 0; i < b.N; i++ { |
||||
if j >= len(benches) { |
||||
j = 0 |
||||
} |
||||
benches[j].buf.Reset(benches[j].body) |
||||
bbr := newByteBlobReader(benches[j].buf) |
||||
io.CopyBuffer(ioutil.Discard, bbr, benches[j].cpBuf) |
||||
j++ |
||||
} |
||||
} |
||||
|
||||
benches := []struct { |
||||
name string |
||||
minBodySize, maxBodySize int |
||||
}{ |
||||
{"small", 0, 1000}, |
||||
{"medium", 1000, 10000}, |
||||
{"large", 10000, 100000}, |
||||
{"xlarge", 100000, 1000000}, |
||||
} |
||||
|
||||
b.StopTimer() |
||||
for i := range benches { |
||||
b.Run(benches[i].name, func(b *B) { |
||||
set := mkTestSet(benches[i].minBodySize, benches[i].maxBodySize) |
||||
b.StartTimer() |
||||
b.Run("raw", func(b *B) { |
||||
testRaw(b, set) |
||||
}) |
||||
b.Run("bbr", func(b *B) { |
||||
testBBR(b, set) |
||||
}) |
||||
b.StopTimer() |
||||
}) |
||||
} |
||||
} |
@ -1,410 +0,0 @@ |
||||
// Package jstream defines and implements the JSON Stream protocol
|
||||
//
|
||||
// Purpose
|
||||
//
|
||||
// The purpose of the jstream protocol is to provide a very simple layer on top
|
||||
// of an existing JSON implementation to allow for streaming arbitrary numbers
|
||||
// of JSON objects and byte blobs of arbitrary size in a standard way, and to
|
||||
// allow for embedding streams within each other.
|
||||
//
|
||||
// The order of priorities when designing jstream is as follows:
|
||||
// 1) Protocol simplicity
|
||||
// 2) Implementation simplicity
|
||||
// 3) Efficiency, both in parsing speed and bandwidth
|
||||
//
|
||||
// The justification for this is that protocol simplicity generally spills into
|
||||
// implementation simplicity anyway, and accounts for future languages which
|
||||
// have different properties than current ones. Parsing speed isn't much of a
|
||||
// concern when reading data off a network (the primary use-case here), as RTT
|
||||
// is always going to be the main blocker. Bandwidth can be a concern, but it's
|
||||
// one better solved by wrapping the byte stream with a compressor.
|
||||
//
|
||||
// jstream protocol
|
||||
//
|
||||
// The jstream protocol is carried over a byte stream (in go: an io.Reader). To
|
||||
// read the protocol a JSON object is read off the byte stream and inspected to
|
||||
// determine what kind of jstream element it is.
|
||||
//
|
||||
// Multiple jstream elements are sequentially read off the same byte stream.
|
||||
// Each element may be separated from the other by any amount of whitespace,
|
||||
// with whitespace being defined as spaces, tabs, carriage returns, and/or
|
||||
// newlines.
|
||||
//
|
||||
// jstream elements
|
||||
//
|
||||
// There are three jstream element types:
|
||||
//
|
||||
// * JSON Value: Any JSON value
|
||||
// * Byte Blob: A stream of bytes of unknown, and possibly infinite, size
|
||||
// * Stream: A heterogenous sequence of jstream elements of unknown, and
|
||||
// possibly infinite, size
|
||||
//
|
||||
// JSON Value elements are defined as being JSON objects with a `val` field. The
|
||||
// value of that field is the JSON Value.
|
||||
//
|
||||
// { "val":{"foo":"bar"} }
|
||||
//
|
||||
// Byte Blob elements are defined as being a JSON object with a `bytesStart`
|
||||
// field with a value of `true`. Immediately following the JSON object are the
|
||||
// bytes which are the Byte Blob, encoded using standard base64. Immediately
|
||||
// following the encoded bytes is the character `$`, to indicate the bytes have
|
||||
// been completely written. Alternatively the character `!` may be written
|
||||
// immediately after the bytes to indicate writing was canceled prematurely by
|
||||
// the writer.
|
||||
//
|
||||
// { "bytesStart":true }wXnxQHgUO8g=$
|
||||
// { "bytesStart":true }WGYcTI8=!
|
||||
//
|
||||
// The JSON object may also contain a `sizeHint` field, which gives the
|
||||
// estimated number of bytes in the Byte Blob (excluding the trailing
|
||||
// delimiter). The hint is neither required to exist or be accurate if it does.
|
||||
// The trailing delimeter (`$` or `!`) is required to be sent even if the hint
|
||||
// is sent.
|
||||
//
|
||||
// Stream elements are defined as being a JSON object with a `streamStart` field
|
||||
// with a value of `true`. Immediately following the JSON object will be zero
|
||||
// more jstream elements of any type, possibly separated by whitespace. Finally
|
||||
// the Stream is ended with another JSON object with a `streamEnd` field with a
|
||||
// value of `true`.
|
||||
//
|
||||
// { "streamStart":true }
|
||||
// { "val":{"foo":"bar"} }
|
||||
// { "bytesStart":true }7TdlDQOnA6isxD9C$
|
||||
// { "streamEnd":true }
|
||||
//
|
||||
// A Stream may also be prematurely canceled by the sending of a JSON object
|
||||
// with the `streamCancel` field set to `true` (in place of one with `streamEnd`
|
||||
// set to `true`).
|
||||
//
|
||||
// The Stream's original JSON object (the "head") may also have a `sizeHint`
|
||||
// field, which gives the estimated number of jstream elements in the Stream.
|
||||
// The hint is neither required to exist or be accurate if it does. The tail
|
||||
// JSON object (with the `streamEnd` field) is required even if `sizeHint` is
|
||||
// given.
|
||||
//
|
||||
// One of the elements in a Stream may itself be a Stream. In this way Streams
|
||||
// may be embedded within each other.
|
||||
//
|
||||
// Here's an example of a complex Stream, which carries within it two different
|
||||
// streams and some other elements:
|
||||
//
|
||||
// { "streamStart":true }
|
||||
// { "val":{"foo":"bar" }
|
||||
// { "streamStart":true, "sizeHint":2 }
|
||||
// { "val":{"foo":"baz"} }
|
||||
// { "val":{"foo":"biz"} }
|
||||
// { "streamEnd":true }
|
||||
// { "bytesStart":true }X7KCpLIjqIBJt9vA$
|
||||
// { "streamStart":true }
|
||||
// { "bytesStart":true }0jT+kNCuxHywUYy0$
|
||||
// { "bytesStart":true }LUqjR6OACB2p1BG4$
|
||||
// { "streamEnd":true }
|
||||
// { "streamEnd":true }
|
||||
//
|
||||
// Finally, the byte stream off of which the jstream is based (i.e. the
|
||||
// io.Reader) is implicitly treated as a Stream, with the Stream ending when the
|
||||
// byte stream is closed.
|
||||
//
|
||||
package jstream |
||||
|
||||
// TODO figure out how to expose the json.Encoder/Decoders so that users can set
|
||||
// custom options on them (like UseNumber and whatnot)
|
||||
|
||||
import ( |
||||
"encoding/base64" |
||||
"encoding/json" |
||||
"errors" |
||||
"fmt" |
||||
"io" |
||||
) |
||||
|
||||
// byte blob constants
|
||||
const ( |
||||
bbEnd = '$' |
||||
bbCancel = '!' |
||||
bbDelims = string(bbEnd) + string(bbCancel) |
||||
) |
||||
|
||||
// Type is used to enumerate the types of jstream elements
|
||||
type Type string |
||||
|
||||
// The jstream element types
|
||||
const ( |
||||
TypeJSONValue Type = "jsonValue" |
||||
TypeByteBlob Type = "byteBlob" |
||||
TypeStream Type = "stream" |
||||
) |
||||
|
||||
// ErrWrongType is an error returned by the Decode* methods on Decoder when the
|
||||
// wrong decoding method has been called for the element which was read. The
|
||||
// error contains the actual type of the element.
|
||||
type ErrWrongType struct { |
||||
Actual Type |
||||
} |
||||
|
||||
func (err ErrWrongType) Error() string { |
||||
return fmt.Sprintf("wrong type, actual type is %q", err.Actual) |
||||
} |
||||
|
||||
var ( |
||||
// ErrCanceled is returned when reading either a Byte Blob or a Stream,
|
||||
// indicating that the writer has prematurely canceled the element.
|
||||
ErrCanceled = errors.New("canceled by writer") |
||||
|
||||
// ErrStreamEnded is returned from Next when the Stream being read has been
|
||||
// ended by the writer.
|
||||
ErrStreamEnded = errors.New("stream ended") |
||||
) |
||||
|
||||
type element struct { |
||||
Value json.RawMessage `json:"val,omitempty"` |
||||
|
||||
BytesStart bool `json:"bytesStart,omitempty"` |
||||
|
||||
StreamStart bool `json:"streamStart,omitempty"` |
||||
StreamEnd bool `json:"streamEnd,omitempty"` |
||||
StreamCancel bool `json:"streamCancel,omitempty"` |
||||
|
||||
SizeHint uint `json:"sizeHint,omitempty"` |
||||
} |
||||
|
||||
// Element is a single jstream element which is read off a StreamReader.
|
||||
//
|
||||
// If a method is called which expects a particular Element type (e.g.
|
||||
// DecodeValue, which expects a JSONValue Element) but the Element is not of
|
||||
// that type then an ErrWrongType will be returned.
|
||||
//
|
||||
// If there was an error reading the Element off the StreamReader that error is
|
||||
// kept in the Element and returned from any method call.
|
||||
type Element struct { |
||||
element |
||||
|
||||
// Err will be set if the StreamReader encountered an error while reading
|
||||
// the next Element. If set then the Element is otherwise unusable.
|
||||
//
|
||||
// Err may be ErrCanceled or ErrStreamEnded, which would indicate the end of
|
||||
// the stream but would not indicate the StreamReader is no longer usable,
|
||||
// depending on the behavior of the writer on the other end.
|
||||
Err error |
||||
|
||||
// needed for byte blobs and streams
|
||||
sr *StreamReader |
||||
} |
||||
|
||||
// Type returns the Element's Type, or an error
|
||||
func (el Element) Type() (Type, error) { |
||||
if el.Err != nil { |
||||
return "", el.Err |
||||
} else if el.element.StreamStart { |
||||
return TypeStream, nil |
||||
} else if el.element.BytesStart { |
||||
return TypeByteBlob, nil |
||||
} else if len(el.element.Value) > 0 { |
||||
return TypeJSONValue, nil |
||||
} |
||||
return "", errors.New("malformed Element, can't determine type") |
||||
} |
||||
|
||||
func (el Element) assertType(is Type) error { |
||||
typ, err := el.Type() |
||||
if err != nil { |
||||
return err |
||||
} else if typ != is { |
||||
return ErrWrongType{Actual: typ} |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
// Value attempts to unmarshal a JSON Value Element's value into the given
|
||||
// receiver.
|
||||
//
|
||||
// This method should not be called more than once.
|
||||
func (el Element) Value(i interface{}) error { |
||||
if err := el.assertType(TypeJSONValue); err != nil { |
||||
return err |
||||
} |
||||
return json.Unmarshal(el.element.Value, i) |
||||
} |
||||
|
||||
// SizeHint returns the size hint which may have been optionally sent for
|
||||
// ByteBlob and Stream elements, or zero. The hint is never required to be
|
||||
// sent or to be accurate.
|
||||
func (el Element) SizeHint() uint { |
||||
return el.element.SizeHint |
||||
} |
||||
|
||||
// Bytes returns an io.Reader which will contain the contents of a ByteBlob
|
||||
// element. The io.Reader _must_ be read till io.EOF or ErrCanceled before the
|
||||
// StreamReader may be used again.
|
||||
//
|
||||
// This method should not be called more than once.
|
||||
func (el Element) Bytes() (io.Reader, error) { |
||||
if err := el.assertType(TypeByteBlob); err != nil { |
||||
return nil, err |
||||
} |
||||
return el.sr.readBytes(), nil |
||||
} |
||||
|
||||
// Stream returns the embedded stream represented by this Element as a
|
||||
// StreamReader. The returned StreamReader _must_ be iterated (via the Next
|
||||
// method) till ErrStreamEnded or ErrCanceled is returned before the original
|
||||
// StreamReader may be used again.
|
||||
//
|
||||
// This method should not be called more than once.
|
||||
func (el Element) Stream() (*StreamReader, error) { |
||||
if err := el.assertType(TypeStream); err != nil { |
||||
return nil, err |
||||
} |
||||
return el.sr, nil |
||||
} |
||||
|
||||
// StreamReader represents a Stream from which Elements may be read using the
|
||||
// Next method.
|
||||
type StreamReader struct { |
||||
orig io.Reader |
||||
|
||||
// only one of these can be set at a time
|
||||
dec *json.Decoder |
||||
bbr *byteBlobReader |
||||
} |
||||
|
||||
// NewStreamReader takes an io.Reader and interprets it as a jstream Stream.
|
||||
func NewStreamReader(r io.Reader) *StreamReader { |
||||
return &StreamReader{orig: r} |
||||
} |
||||
|
||||
// pulls buffered bytes out of either the json.Decoder or byteBlobReader, if
|
||||
// possible, and returns an io.MultiReader of those and orig. Will also set the
|
||||
// json.Decoder/byteBlobReader to nil if that's where the bytes came from.
|
||||
func (sr *StreamReader) multiReader() io.Reader { |
||||
if sr.dec != nil { |
||||
buf := sr.dec.Buffered() |
||||
sr.dec = nil |
||||
return io.MultiReader(buf, sr.orig) |
||||
} else if sr.bbr != nil { |
||||
buf := sr.bbr.buffered() |
||||
sr.bbr = nil |
||||
return io.MultiReader(buf, sr.orig) |
||||
} |
||||
return sr.orig |
||||
} |
||||
|
||||
// Next reads, decodes, and returns the next Element off the StreamReader. If
|
||||
// the Element is a ByteBlob or embedded Stream then it _must_ be fully consumed
|
||||
// before Next is called on this StreamReader again.
|
||||
//
|
||||
// The returned Element's Err field will be ErrStreamEnd if the Stream was
|
||||
// ended, or ErrCanceled if it was canceled, and this StreamReader should not be
|
||||
// used again in those cases.
|
||||
//
|
||||
// If the underlying io.Reader is closed the returned Err field will be io.EOF.
|
||||
func (sr *StreamReader) Next() Element { |
||||
if sr.dec == nil { |
||||
sr.dec = json.NewDecoder(sr.multiReader()) |
||||
} |
||||
|
||||
var el element |
||||
var err error |
||||
if err = sr.dec.Decode(&el); err != nil { |
||||
// welp
|
||||
} else if el.StreamEnd { |
||||
err = ErrStreamEnded |
||||
} else if el.StreamCancel { |
||||
err = ErrCanceled |
||||
} |
||||
if err != nil { |
||||
return Element{Err: err} |
||||
} |
||||
return Element{sr: sr, element: el} |
||||
} |
||||
|
||||
func (sr *StreamReader) readBytes() *byteBlobReader { |
||||
sr.bbr = newByteBlobReader(sr.multiReader()) |
||||
return sr.bbr |
||||
} |
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// StreamWriter represents a Stream to which Elements may be written using any
|
||||
// of the Encode methods.
|
||||
type StreamWriter struct { |
||||
w io.Writer |
||||
enc *json.Encoder |
||||
} |
||||
|
||||
// NewStreamWriter takes an io.Writer and returns a StreamWriter which will
|
||||
// write to it.
|
||||
func NewStreamWriter(w io.Writer) *StreamWriter { |
||||
return &StreamWriter{w: w, enc: json.NewEncoder(w)} |
||||
} |
||||
|
||||
// EncodeValue marshals the given value and writes it to the Stream as a
|
||||
// JSONValue element.
|
||||
func (sw *StreamWriter) EncodeValue(i interface{}) error { |
||||
b, err := json.Marshal(i) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
return sw.enc.Encode(element{Value: b}) |
||||
} |
||||
|
||||
// EncodeBytes copies the given io.Reader, until io.EOF, onto the Stream as a
|
||||
// ByteBlob element. This method will block until copying is completed or an
|
||||
// error is encountered.
|
||||
//
|
||||
// If the io.Reader returns any error which isn't io.EOF then the ByteBlob is
|
||||
// canceled and that error is returned from this method. Otherwise nil is
|
||||
// returned.
|
||||
//
|
||||
// sizeHint may be given if it's known or can be guessed how many bytes the
|
||||
// io.Reader will read out.
|
||||
func (sw *StreamWriter) EncodeBytes(sizeHint uint, r io.Reader) error { |
||||
if err := sw.enc.Encode(element{ |
||||
BytesStart: true, |
||||
SizeHint: sizeHint, |
||||
}); err != nil { |
||||
return err |
||||
|
||||
} |
||||
|
||||
enc := base64.NewEncoder(base64.StdEncoding, sw.w) |
||||
if _, err := io.Copy(enc, r); err != nil { |
||||
// if canceling doesn't work then the whole connection is broken and
|
||||
// it's not worth doing anything about. if nothing else the brokeness of
|
||||
// it will be discovered the next time it is used.
|
||||
sw.w.Write([]byte{bbCancel}) |
||||
return err |
||||
} else if err := enc.Close(); err != nil { |
||||
return err |
||||
} else if _, err := sw.w.Write([]byte{bbEnd}); err != nil { |
||||
return err |
||||
} |
||||
|
||||
return nil |
||||
} |
||||
|
||||
// EncodeStream encodes an embedded Stream element onto the Stream. The callback
|
||||
// is given a new StreamWriter which represents the embedded Stream and to which
|
||||
// any elemens may be written. This methods blocks until the callback has
|
||||
// returned.
|
||||
//
|
||||
// If the callback returns nil the Stream is ended normally. If it returns
|
||||
// anything else the embedded Stream is canceled and that error is returned from
|
||||
// this method.
|
||||
//
|
||||
// sizeHint may be given if it's known or can be guessed how many elements will
|
||||
// be in the embedded Stream.
|
||||
func (sw *StreamWriter) EncodeStream(sizeHint uint, fn func(*StreamWriter) error) error { |
||||
if err := sw.enc.Encode(element{ |
||||
StreamStart: true, |
||||
SizeHint: sizeHint, |
||||
}); err != nil { |
||||
return err |
||||
|
||||
} else if err := fn(sw); err != nil { |
||||
// as when canceling a byte blob, we don't really care if this errors
|
||||
sw.enc.Encode(element{StreamCancel: true}) |
||||
return err |
||||
} |
||||
return sw.enc.Encode(element{StreamEnd: true}) |
||||
} |
@ -1,246 +0,0 @@ |
||||
package jstream |
||||
|
||||
import ( |
||||
"bytes" |
||||
"errors" |
||||
"io" |
||||
"io/ioutil" |
||||
"sync" |
||||
. "testing" |
||||
|
||||
"github.com/mediocregopher/mediocre-go-lib/mrand" |
||||
"github.com/stretchr/testify/assert" |
||||
) |
||||
|
||||
type cancelBuffer struct { |
||||
lr *io.LimitedReader |
||||
} |
||||
|
||||
func newCancelBuffer(b []byte) io.Reader { |
||||
return &cancelBuffer{ |
||||
lr: &io.LimitedReader{ |
||||
R: bytes.NewBuffer(b), |
||||
N: int64(len(b) / 2), |
||||
}, |
||||
} |
||||
} |
||||
|
||||
func (cb *cancelBuffer) Read(p []byte) (int, error) { |
||||
if cb.lr.N == 0 { |
||||
return 0, ErrCanceled |
||||
} |
||||
return cb.lr.Read(p) |
||||
} |
||||
|
||||
func TestEncoderDecoder(t *T) { |
||||
type testCase struct { |
||||
typ Type |
||||
val interface{} |
||||
bytes []byte |
||||
stream []testCase |
||||
cancel bool |
||||
} |
||||
|
||||
var randTestCase func(Type, bool) testCase |
||||
randTestCase = func(typ Type, cancelable bool) testCase { |
||||
// if typ isn't given then use a random one
|
||||
if typ == "" { |
||||
pick := mrand.Intn(5) |
||||
switch { |
||||
case pick == 0: |
||||
typ = TypeStream |
||||
case pick < 4: |
||||
typ = TypeJSONValue |
||||
default: |
||||
typ = TypeByteBlob |
||||
} |
||||
} |
||||
|
||||
tc := testCase{ |
||||
typ: typ, |
||||
cancel: cancelable && mrand.Intn(10) == 0, |
||||
} |
||||
|
||||
switch typ { |
||||
case TypeJSONValue: |
||||
tc.val = map[string]interface{}{ |
||||
mrand.Hex(8): mrand.Hex(8), |
||||
mrand.Hex(8): mrand.Hex(8), |
||||
mrand.Hex(8): mrand.Hex(8), |
||||
mrand.Hex(8): mrand.Hex(8), |
||||
mrand.Hex(8): mrand.Hex(8), |
||||
} |
||||
return tc |
||||
case TypeByteBlob: |
||||
tc.bytes = mrand.Bytes(mrand.Intn(256)) |
||||
return tc |
||||
case TypeStream: |
||||
for i := mrand.Intn(10); i > 0; i-- { |
||||
tc.stream = append(tc.stream, randTestCase("", true)) |
||||
} |
||||
return tc |
||||
} |
||||
panic("shouldn't get here") |
||||
} |
||||
|
||||
tcLog := func(tcs ...testCase) []interface{} { |
||||
return []interface{}{"%#v", tcs} |
||||
} |
||||
|
||||
var assertRead func(*StreamReader, Element, testCase) bool |
||||
assertRead = func(r *StreamReader, el Element, tc testCase) bool { |
||||
l, success := tcLog(tc), true |
||||
typ, err := el.Type() |
||||
success = success && assert.NoError(t, err, l...) |
||||
success = success && assert.Equal(t, tc.typ, typ, l...) |
||||
|
||||
switch typ { |
||||
case TypeJSONValue: |
||||
var val interface{} |
||||
success = success && assert.NoError(t, el.Value(&val), l...) |
||||
success = success && assert.Equal(t, tc.val, val, l...) |
||||
case TypeByteBlob: |
||||
br, err := el.Bytes() |
||||
success = success && assert.NoError(t, err, l...) |
||||
success = success && assert.Equal(t, uint(len(tc.bytes)), el.SizeHint(), l...) |
||||
all, err := ioutil.ReadAll(br) |
||||
if tc.cancel { |
||||
success = success && assert.Equal(t, ErrCanceled, err, l...) |
||||
} else { |
||||
success = success && assert.NoError(t, err, l...) |
||||
success = success && assert.Equal(t, tc.bytes, all, l...) |
||||
} |
||||
case TypeStream: |
||||
innerR, err := el.Stream() |
||||
success = success && assert.NoError(t, err, l...) |
||||
success = success && assert.Equal(t, uint(len(tc.stream)), el.SizeHint(), l...) |
||||
n := 0 |
||||
for { |
||||
el := innerR.Next() |
||||
if tc.cancel && el.Err == ErrCanceled { |
||||
break |
||||
} else if n == len(tc.stream) { |
||||
success = success && assert.Equal(t, ErrStreamEnded, el.Err, l...) |
||||
break |
||||
} |
||||
success = success && assertRead(innerR, el, tc.stream[n]) |
||||
n++ |
||||
} |
||||
} |
||||
return success |
||||
} |
||||
|
||||
var assertWrite func(*StreamWriter, testCase) bool |
||||
assertWrite = func(w *StreamWriter, tc testCase) bool { |
||||
l, success := tcLog(tc), true |
||||
switch tc.typ { |
||||
case TypeJSONValue: |
||||
success = success && assert.NoError(t, w.EncodeValue(tc.val), l...) |
||||
case TypeByteBlob: |
||||
if tc.cancel { |
||||
r := newCancelBuffer(tc.bytes) |
||||
err := w.EncodeBytes(uint(len(tc.bytes)), r) |
||||
success = success && assert.Equal(t, ErrCanceled, err, l...) |
||||
} else { |
||||
r := bytes.NewBuffer(tc.bytes) |
||||
err := w.EncodeBytes(uint(len(tc.bytes)), r) |
||||
success = success && assert.NoError(t, err, l...) |
||||
} |
||||
case TypeStream: |
||||
err := w.EncodeStream(uint(len(tc.stream)), func(innerW *StreamWriter) error { |
||||
if len(tc.stream) == 0 && tc.cancel { |
||||
return ErrCanceled |
||||
} |
||||
for i := range tc.stream { |
||||
if tc.cancel && i == len(tc.stream)/2 { |
||||
return ErrCanceled |
||||
} else if !assertWrite(w, tc.stream[i]) { |
||||
return errors.New("we got problems") |
||||
} |
||||
} |
||||
return nil |
||||
}) |
||||
if tc.cancel { |
||||
success = success && assert.Equal(t, ErrCanceled, err, l...) |
||||
} else { |
||||
success = success && assert.NoError(t, err, l...) |
||||
} |
||||
} |
||||
return success |
||||
} |
||||
|
||||
do := func(tcs ...testCase) bool { |
||||
// we keep a copy of all read/written bytes for debugging, but generally
|
||||
// don't actually log them
|
||||
ioR, ioW := io.Pipe() |
||||
cpR, cpW := new(bytes.Buffer), new(bytes.Buffer) |
||||
r, w := NewStreamReader(io.TeeReader(ioR, cpR)), NewStreamWriter(io.MultiWriter(ioW, cpW)) |
||||
|
||||
readCh, writeCh := make(chan bool, 1), make(chan bool, 1) |
||||
wg := new(sync.WaitGroup) |
||||
wg.Add(2) |
||||
go func() { |
||||
success := true |
||||
for _, tc := range tcs { |
||||
success = success && assertRead(r, r.Next(), tc) |
||||
} |
||||
success = success && assert.Equal(t, io.EOF, r.Next().Err) |
||||
readCh <- success |
||||
ioR.Close() |
||||
wg.Done() |
||||
}() |
||||
go func() { |
||||
success := true |
||||
for _, tc := range tcs { |
||||
success = success && assertWrite(w, tc) |
||||
} |
||||
writeCh <- success |
||||
ioW.Close() |
||||
wg.Done() |
||||
}() |
||||
wg.Wait() |
||||
|
||||
//log.Printf("data written:%q", cpW.Bytes())
|
||||
//log.Printf("data read: %q", cpR.Bytes())
|
||||
|
||||
if !(<-readCh && <-writeCh) { |
||||
assert.FailNow(t, "test case failed", tcLog(tcs...)...) |
||||
return false |
||||
} |
||||
return true |
||||
} |
||||
|
||||
// some basic test cases
|
||||
do() // empty stream
|
||||
do(randTestCase(TypeJSONValue, false)) |
||||
do(randTestCase(TypeByteBlob, false)) |
||||
do( |
||||
randTestCase(TypeJSONValue, false), |
||||
randTestCase(TypeJSONValue, false), |
||||
randTestCase(TypeJSONValue, false), |
||||
) |
||||
do( |
||||
randTestCase(TypeJSONValue, false), |
||||
randTestCase(TypeByteBlob, false), |
||||
randTestCase(TypeJSONValue, false), |
||||
) |
||||
do( |
||||
randTestCase(TypeByteBlob, false), |
||||
randTestCase(TypeByteBlob, false), |
||||
randTestCase(TypeByteBlob, false), |
||||
) |
||||
do( |
||||
randTestCase(TypeJSONValue, false), |
||||
randTestCase(TypeStream, false), |
||||
randTestCase(TypeJSONValue, false), |
||||
) |
||||
|
||||
// some special cases, empty elements which are canceled
|
||||
do(testCase{typ: TypeStream, cancel: true}) |
||||
do(testCase{typ: TypeByteBlob, cancel: true}) |
||||
|
||||
for i := 0; i < 1000; i++ { |
||||
tc := randTestCase(TypeStream, false) |
||||
do(tc.stream...) |
||||
} |
||||
} |
@ -1,152 +0,0 @@ |
||||
// Package m implements functionality specific to how I like my programs to
|
||||
// work. It acts as glue between many of the other packages in this framework,
|
||||
// putting them together in the way I find most useful.
|
||||
package m |
||||
|
||||
import ( |
||||
"context" |
||||
"os" |
||||
"os/signal" |
||||
"time" |
||||
|
||||
"github.com/mediocregopher/mediocre-go-lib/mcfg" |
||||
"github.com/mediocregopher/mediocre-go-lib/mcmp" |
||||
"github.com/mediocregopher/mediocre-go-lib/mctx" |
||||
"github.com/mediocregopher/mediocre-go-lib/merr" |
||||
"github.com/mediocregopher/mediocre-go-lib/mlog" |
||||
"github.com/mediocregopher/mediocre-go-lib/mrun" |
||||
) |
||||
|
||||
type cmpKey int |
||||
|
||||
const ( |
||||
cmpKeyCfgSrc cmpKey = iota |
||||
cmpKeyInfoLog |
||||
) |
||||
|
||||
func debugLog(cmp *mcmp.Component, msg string, ctxs ...context.Context) { |
||||
level := mlog.DebugLevel |
||||
if len(ctxs) > 0 { |
||||
if ok, _ := ctxs[0].Value(cmpKeyInfoLog).(bool); ok { |
||||
level = mlog.InfoLevel |
||||
} |
||||
} |
||||
|
||||
mlog.From(cmp).Log(mlog.Message{ |
||||
Level: level, |
||||
Description: msg, |
||||
Contexts: ctxs, |
||||
}) |
||||
} |
||||
|
||||
// RootComponent returns a Component which should be used as the root Component
|
||||
// when implementing most programs.
|
||||
//
|
||||
// The returned Component will automatically handle setting up global
|
||||
// configuration parameters like "log-level", as well as parsing those
|
||||
// and all other parameters when the Init even is triggered on it.
|
||||
func RootComponent() *mcmp.Component { |
||||
cmp := new(mcmp.Component) |
||||
|
||||
// embed confuration source which should be used into the context.
|
||||
cmp.SetValue(cmpKeyCfgSrc, mcfg.Source(new(mcfg.SourceCLI))) |
||||
|
||||
// set up log level handling
|
||||
logger := mlog.NewLogger() |
||||
mlog.SetLogger(cmp, logger) |
||||
|
||||
// set up parameter parsing
|
||||
mrun.InitHook(cmp, func(context.Context) error { |
||||
src, _ := cmp.Value(cmpKeyCfgSrc).(mcfg.Source) |
||||
if src == nil { |
||||
return merr.New("Component not sourced from m package", cmp.Context()) |
||||
} else if err := mcfg.Populate(cmp, src); err != nil { |
||||
return merr.Wrap(err, cmp.Context()) |
||||
} |
||||
return nil |
||||
}) |
||||
|
||||
logLevelStr := mcfg.String(cmp, "log-level", |
||||
mcfg.ParamDefault("info"), |
||||
mcfg.ParamUsage("Maximum log level which will be printed.")) |
||||
mrun.InitHook(cmp, func(context.Context) error { |
||||
logLevel := mlog.LevelFromString(*logLevelStr) |
||||
if logLevel == nil { |
||||
return merr.New("invalid log level", cmp.Context(), |
||||
mctx.Annotated("log-level", *logLevelStr)) |
||||
} |
||||
logger.SetMaxLevel(logLevel) |
||||
mlog.SetLogger(cmp, logger) |
||||
return nil |
||||
}) |
||||
|
||||
return cmp |
||||
} |
||||
|
||||
// RootServiceComponent extends RootComponent so that it better supports long
|
||||
// running processes which are expected to handle requests from outside clients.
|
||||
//
|
||||
// Additional behavior it adds includes setting up an http endpoint where debug
|
||||
// information about the running process can be accessed.
|
||||
func RootServiceComponent() *mcmp.Component { |
||||
cmp := RootComponent() |
||||
|
||||
// services expect to use many different configuration sources
|
||||
cmp.SetValue(cmpKeyCfgSrc, mcfg.Source(mcfg.Sources{ |
||||
new(mcfg.SourceEnv), |
||||
new(mcfg.SourceCLI), |
||||
})) |
||||
|
||||
// it's useful to show debug entries (from this package specifically) as
|
||||
// info logs for long-running services.
|
||||
cmp.SetValue(cmpKeyInfoLog, true) |
||||
|
||||
// TODO set up the debug endpoint.
|
||||
return cmp |
||||
} |
||||
|
||||
// MustInit will call mrun.Init on the given Component, which must have been
|
||||
// created in this package, and exit the process if mrun.Init does not complete
|
||||
// successfully.
|
||||
func MustInit(cmp *mcmp.Component) { |
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) |
||||
defer cancel() |
||||
|
||||
debugLog(cmp, "initializing") |
||||
if err := mrun.Init(ctx, cmp); err != nil { |
||||
mlog.From(cmp).Fatal("initialization failed", merr.Context(err)) |
||||
} |
||||
debugLog(cmp, "initialization completed successfully") |
||||
} |
||||
|
||||
// MustShutdown is like MustInit, except that it triggers the Shutdown event on
|
||||
// the Component.
|
||||
func MustShutdown(cmp *mcmp.Component) { |
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) |
||||
defer cancel() |
||||
|
||||
debugLog(cmp, "shutting down") |
||||
if err := mrun.Shutdown(ctx, cmp); err != nil { |
||||
mlog.From(cmp).Fatal("shutdown failed", merr.Context(err)) |
||||
} |
||||
debugLog(cmp, "shutting down completed successfully") |
||||
} |
||||
|
||||
// Exec calls MustInit on the given Component, then blocks until an interrupt
|
||||
// signal is received, then calls MustShutdown on the Component, until finally
|
||||
// exiting the process.
|
||||
func Exec(cmp *mcmp.Component) { |
||||
MustInit(cmp) |
||||
{ |
||||
ch := make(chan os.Signal, 1) |
||||
signal.Notify(ch, os.Interrupt) |
||||
s := <-ch |
||||
debugLog(cmp, "signal received, stopping", mctx.Annotated("signal", s)) |
||||
} |
||||
MustShutdown(cmp) |
||||
|
||||
debugLog(cmp, "exiting process") |
||||
os.Stdout.Sync() |
||||
os.Stderr.Sync() |
||||
os.Exit(0) |
||||
} |
@ -1,50 +0,0 @@ |
||||
package m |
||||
|
||||
import ( |
||||
"context" |
||||
"encoding/json" |
||||
. "testing" |
||||
|
||||
"github.com/mediocregopher/mediocre-go-lib/mcfg" |
||||
"github.com/mediocregopher/mediocre-go-lib/mlog" |
||||
"github.com/mediocregopher/mediocre-go-lib/mtest/massert" |
||||
) |
||||
|
||||
func TestServiceCtx(t *T) { |
||||
t.Run("log-level", func(t *T) { |
||||
cmp := RootComponent() |
||||
|
||||
// pull the Logger out of the component and set the Handler on it, so we
|
||||
// can check the log level
|
||||
var msgs []mlog.Message |
||||
logger := mlog.GetLogger(cmp) |
||||
logger.SetHandler(func(msg mlog.Message) error { |
||||
msgs = append(msgs, msg) |
||||
return nil |
||||
}) |
||||
mlog.SetLogger(cmp, logger) |
||||
|
||||
// create a child Component before running to ensure it the change propagates
|
||||
// correctly.
|
||||
cmpA := cmp.Child("A") |
||||
|
||||
params := mcfg.ParamValues{{Name: "log-level", Value: json.RawMessage(`"DEBUG"`)}} |
||||
cmp.SetValue(cmpKeyCfgSrc, params) |
||||
MustInit(cmp) |
||||
|
||||
mlog.From(cmpA).Info("foo") |
||||
mlog.From(cmpA).Debug("bar") |
||||
massert.Require(t, |
||||
massert.Length(msgs, 3), |
||||
massert.Equal(msgs[0].Level.String(), "DEBUG"), |
||||
massert.Equal(msgs[0].Description, "initialization completed successfully"), |
||||
massert.Equal(msgs[0].Contexts, []context.Context{cmp.Context()}), |
||||
massert.Equal(msgs[1].Level.String(), "INFO"), |
||||
massert.Equal(msgs[1].Description, "foo"), |
||||
massert.Equal(msgs[1].Contexts, []context.Context{cmpA.Context()}), |
||||
massert.Equal(msgs[2].Level.String(), "DEBUG"), |
||||
massert.Equal(msgs[2].Description, "bar"), |
||||
massert.Equal(msgs[2].Contexts, []context.Context{cmpA.Context()}), |
||||
) |
||||
}) |
||||
} |
@ -1,29 +0,0 @@ |
||||
// Package mcrypto contains general purpose functionality related to
|
||||
// cryptography, notably related to unique identifiers, signing/verifying data,
|
||||
// and encrypting/decrypting data
|
||||
package mcrypto |
||||
|
||||
import ( |
||||
"strings" |
||||
) |
||||
|
||||
// Instead of outputing opaque hex garbage, this package opts to add a prefix to
|
||||
// the garbage. Each "type" of string returned has its own character which is
|
||||
// not found in the hex range (0-9, a-f), and in addition each also has a
|
||||
// version character prefixed as well, in case something wants to be changed
|
||||
// going forward.
|
||||
//
|
||||
// We keep the constant prefices here to ensure there's no conflicts across
|
||||
// string types in this package.
|
||||
const ( |
||||
uuidV0 = "0u" // u for uuid
|
||||
sigV0 = "0s" // s for signature
|
||||
encryptedV0 = "0n" // n for "n"-crypted, harharhar
|
||||
pubKeyV0 = "0l" // b for pub"l"ic key
|
||||
privKeyV0 = "0v" // v for pri"v"ate key
|
||||
) |
||||
|
||||
func stripPrefix(s, prefix string) (string, bool) { |
||||
trimmed := strings.TrimPrefix(s, prefix) |
||||
return trimmed, len(trimmed) < len(s) |
||||
} |
@ -1,246 +0,0 @@ |
||||
package mcrypto |
||||
|
||||
import ( |
||||
"context" |
||||
"crypto" |
||||
"crypto/rand" |
||||
"crypto/rsa" |
||||
"crypto/sha256" |
||||
"encoding/binary" |
||||
"encoding/hex" |
||||
"encoding/json" |
||||
"errors" |
||||
"io" |
||||
"math/big" |
||||
"time" |
||||
|
||||
"github.com/mediocregopher/mediocre-go-lib/mctx" |
||||
"github.com/mediocregopher/mediocre-go-lib/merr" |
||||
) |
||||
|
||||
var ( |
||||
errMalformedPublicKey = errors.New("malformed public key") |
||||
errMalformedPrivateKey = errors.New("malformed private key") |
||||
) |
||||
|
||||
// NewKeyPair generates and returns a complementary public/private key pair
|
||||
func NewKeyPair() (PublicKey, PrivateKey) { |
||||
return newKeyPair(2048) |
||||
} |
||||
|
||||
// NewWeakKeyPair is like NewKeyPair but the returned pair uses fewer bits
|
||||
// (though still a reasonably secure amount for data that doesn't need security
|
||||
// guarantees into the year 3000 whatever).
|
||||
func NewWeakKeyPair() (PublicKey, PrivateKey) { |
||||
return newKeyPair(1024) |
||||
} |
||||
|
||||
func newKeyPair(bits int) (PublicKey, PrivateKey) { |
||||
priv, err := rsa.GenerateKey(rand.Reader, bits) |
||||
if err != nil { |
||||
panic(err) |
||||
} |
||||
return PublicKey{priv.PublicKey}, PrivateKey{priv} |
||||
} |
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// PublicKey is a wrapper around an rsa.PublicKey which simplifies using it and
|
||||
// adds marshaling/unmarshaling methods.
|
||||
//
|
||||
// A PublicKey automatically implements the Verifier interface.
|
||||
type PublicKey struct { |
||||
rsa.PublicKey |
||||
} |
||||
|
||||
func (pk PublicKey) verify(s Signature, r io.Reader) error { |
||||
h := sha256.New() |
||||
r = sigPrefixReader(r, 32, s.salt, s.t) |
||||
if _, err := io.Copy(h, r); err != nil { |
||||
return err |
||||
} |
||||
if err := rsa.VerifyPSS(&pk.PublicKey, crypto.SHA256, h.Sum(nil), s.sig, nil); err != nil { |
||||
ctx := mctx.Annotate(context.Background(), "sig", s) |
||||
return merr.Wrap(ErrInvalidSig, ctx) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func (pk PublicKey) String() string { |
||||
nB := pk.N.Bytes() |
||||
b := make([]byte, 8+len(nB)) |
||||
// the exponent is never negative so this is fine
|
||||
binary.BigEndian.PutUint64(b, uint64(pk.E)) |
||||
copy(b[8:], nB) |
||||
return pubKeyV0 + hex.EncodeToString(b) |
||||
} |
||||
|
||||
// KV implements the method for the mlog.KVer interface
|
||||
func (pk PublicKey) KV() map[string]interface{} { |
||||
return map[string]interface{}{"publicKey": pk.String()} |
||||
} |
||||
|
||||
// MarshalText implements the method for the encoding.TextMarshaler interface
|
||||
func (pk PublicKey) MarshalText() ([]byte, error) { |
||||
return []byte(pk.String()), nil |
||||
} |
||||
|
||||
// UnmarshalText implements the method for the encoding.TextUnmarshaler
|
||||
// interface
|
||||
func (pk *PublicKey) UnmarshalText(b []byte) error { |
||||
str := string(b) |
||||
strEnc, ok := stripPrefix(str, pubKeyV0) |
||||
if !ok || len(strEnc) <= hex.EncodedLen(8) { |
||||
ctx := mctx.Annotate(context.Background(), "pubKeyStr", str) |
||||
return merr.Wrap(errMalformedPublicKey, ctx) |
||||
} |
||||
|
||||
b, err := hex.DecodeString(strEnc) |
||||
if err != nil { |
||||
ctx := mctx.Annotate(context.Background(), "pubKeyStr", str) |
||||
return merr.Wrap(err, ctx) |
||||
} |
||||
|
||||
pk.E = int(binary.BigEndian.Uint64(b)) |
||||
pk.N = new(big.Int) |
||||
pk.N.SetBytes(b[8:]) |
||||
return nil |
||||
} |
||||
|
||||
// MarshalJSON implements the method for the json.Marshaler interface
|
||||
func (pk PublicKey) MarshalJSON() ([]byte, error) { |
||||
return json.Marshal(pk.String()) |
||||
} |
||||
|
||||
// UnmarshalJSON implements the method for the json.Unmarshaler interface
|
||||
func (pk *PublicKey) UnmarshalJSON(b []byte) error { |
||||
var s string |
||||
if err := json.Unmarshal(b, &s); err != nil { |
||||
return err |
||||
} |
||||
return pk.UnmarshalText([]byte(s)) |
||||
} |
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// PrivateKey is a wrapper around an rsa.PrivateKey which simplifies using it
|
||||
// and adds marshaling/unmarshaling methods.
|
||||
//
|
||||
// A PrivateKey automatically implements the Signer interface.
|
||||
type PrivateKey struct { |
||||
*rsa.PrivateKey |
||||
} |
||||
|
||||
func (pk PrivateKey) sign(r io.Reader) (Signature, error) { |
||||
salt := make([]byte, 8) |
||||
if _, err := rand.Read(salt); err != nil { |
||||
panic(err) |
||||
} |
||||
t := time.Now() |
||||
h := sha256.New() |
||||
// sigLen has to be 32 here (bytes returned by sha256) cause of the way the
|
||||
// VerifyPSS function is
|
||||
if _, err := io.Copy(h, sigPrefixReader(r, 32, salt, t)); err != nil { |
||||
return Signature{}, err |
||||
} |
||||
sig, err := rsa.SignPSS(rand.Reader, pk.PrivateKey, crypto.SHA256, h.Sum(nil), nil) |
||||
return Signature{sig: sig, salt: salt, t: t}, err |
||||
} |
||||
|
||||
func (pk PrivateKey) String() string { |
||||
numBytes := binary.MaxVarintLen64 * 3 // public exponent, N, and D
|
||||
nB, dB := pk.PublicKey.N.Bytes(), pk.D.Bytes() |
||||
numBytes += len(nB) + len(dB) |
||||
|
||||
primes := make([][]byte, len(pk.Primes)) |
||||
for i, prime := range pk.Primes { |
||||
primes[i] = prime.Bytes() |
||||
numBytes += binary.MaxVarintLen64 + len(primes[i]) |
||||
} |
||||
|
||||
b, ptr := make([]byte, numBytes), 0 |
||||
ptr += binary.PutUvarint(b[ptr:], uint64(pk.E)) |
||||
ptr += binary.PutUvarint(b[ptr:], uint64(len(nB))) |
||||
ptr += copy(b[ptr:], nB) |
||||
ptr += binary.PutUvarint(b[ptr:], uint64(len(dB))) |
||||
ptr += copy(b[ptr:], dB) |
||||
|
||||
for _, prime := range primes { |
||||
ptr += binary.PutUvarint(b[ptr:], uint64(len(prime))) |
||||
ptr += copy(b[ptr:], prime) |
||||
} |
||||
|
||||
return privKeyV0 + hex.EncodeToString(b[:ptr]) |
||||
} |
||||
|
||||
// KV implements the method for the mlog.KVer interface
|
||||
func (pk PrivateKey) KV() map[string]interface{} { |
||||
return map[string]interface{}{"privateKey": pk.String()} |
||||
} |
||||
|
||||
// MarshalText implements the method for the encoding.TextMarshaler interface
|
||||
func (pk PrivateKey) MarshalText() ([]byte, error) { |
||||
return []byte(pk.String()), nil |
||||
} |
||||
|
||||
// UnmarshalText implements the method for the encoding.TextUnmarshaler
|
||||
// interface
|
||||
func (pk *PrivateKey) UnmarshalText(b []byte) error { |
||||
str := string(b) |
||||
strEnc, ok := stripPrefix(str, privKeyV0) |
||||
if !ok { |
||||
return merr.Wrap(errMalformedPrivateKey) |
||||
} |
||||
|
||||
b, err := hex.DecodeString(strEnc) |
||||
if err != nil { |
||||
return merr.Wrap(errMalformedPrivateKey) |
||||
} |
||||
|
||||
e, n := binary.Uvarint(b) |
||||
if n <= 0 { |
||||
return merr.Wrap(errMalformedPrivateKey) |
||||
} |
||||
pk.PublicKey.E = int(e) |
||||
b = b[n:] |
||||
|
||||
bigInt := func() *big.Int { |
||||
if err != nil { |
||||
return nil |
||||
} |
||||
l, n := binary.Uvarint(b) |
||||
if n <= 0 { |
||||
err = merr.Wrap(errMalformedPrivateKey) |
||||
} |
||||
b = b[n:] |
||||
i := new(big.Int) |
||||
i.SetBytes(b[:l]) |
||||
b = b[l:] |
||||
return i |
||||
} |
||||
|
||||
pk.PublicKey.N = bigInt() |
||||
pk.D = bigInt() |
||||
for len(b) > 0 && err == nil { |
||||
pk.Primes = append(pk.Primes, bigInt()) |
||||
} |
||||
|
||||
if err != nil { |
||||
return err |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
// MarshalJSON implements the method for the json.Marshaler interface
|
||||
func (pk PrivateKey) MarshalJSON() ([]byte, error) { |
||||
return json.Marshal(pk.String()) |
||||
} |
||||
|
||||
// UnmarshalJSON implements the method for the json.Unmarshaler interface
|
||||
func (pk *PrivateKey) UnmarshalJSON(b []byte) error { |
||||
var s string |
||||
if err := json.Unmarshal(b, &s); err != nil { |
||||
return err |
||||
} |
||||
return pk.UnmarshalText([]byte(s)) |
||||
} |
@ -1,17 +0,0 @@ |
||||
package mcrypto |
||||
|
||||
import ( |
||||
. "testing" |
||||
|
||||
"github.com/mediocregopher/mediocre-go-lib/mrand" |
||||
"github.com/stretchr/testify/assert" |
||||
) |
||||
|
||||
func TestKeyPair(t *T) { |
||||
pub, priv := NewWeakKeyPair() |
||||
|
||||
// test signing/verifying
|
||||
str := mrand.Hex(512) |
||||
sig := SignString(priv, str) |
||||
assert.NoError(t, VerifyString(pub, sig, str)) |
||||
} |
@ -1,84 +0,0 @@ |
||||
package mcrypto |
||||
|
||||
import ( |
||||
"crypto/hmac" |
||||
"crypto/rand" |
||||
"crypto/sha256" |
||||
"io" |
||||
"time" |
||||
|
||||
"github.com/mediocregopher/mediocre-go-lib/merr" |
||||
) |
||||
|
||||
// Secret contains a set of bytes which are inteded to remain secret within some
|
||||
// context (e.g. a backend application keeping a secret from the frontend).
|
||||
//
|
||||
// Secret inherently implements the Signer and Verifier interfaces.
|
||||
//
|
||||
// Secret can be initialized with NewSecret or NewWeakSecret. The Signatures
|
||||
// produced by these will be of differing lengths, but either can Verify a
|
||||
// Signature made by the other as long as the secret bytes they are initialized
|
||||
// with are the same.
|
||||
type Secret struct { |
||||
sigSize uint8 // in bytes, shouldn't be more than 32, cause sha256
|
||||
secret []byte |
||||
|
||||
// only used during tests
|
||||
testNow time.Time |
||||
} |
||||
|
||||
// NewSecret initializes and returns an instance of Secret which uses the given
|
||||
// bytes as the underlying secret.
|
||||
func NewSecret(secret []byte) Secret { |
||||
return Secret{sigSize: 20, secret: secret} |
||||
} |
||||
|
||||
// NewWeakSecret is like NewSecret but the Signatures it produces will be
|
||||
// shorter and weaker (though still secure enough for most applications).
|
||||
// Signatures produced by either normal or weak Secrets can be Verified by the
|
||||
// other.
|
||||
func NewWeakSecret(secret []byte) Secret { |
||||
return Secret{sigSize: 8, secret: secret} |
||||
} |
||||
|
||||
func (s Secret) now() time.Time { |
||||
if !s.testNow.IsZero() { |
||||
return s.testNow |
||||
} |
||||
return time.Now() |
||||
} |
||||
|
||||
func (s Secret) signRaw( |
||||
r io.Reader, |
||||
sigLen uint8, salt []byte, t time.Time, |
||||
) ( |
||||
[]byte, error, |
||||
) { |
||||
h := hmac.New(sha256.New, s.secret) |
||||
r = sigPrefixReader(r, sigLen, salt, t) |
||||
if _, err := io.Copy(h, r); err != nil { |
||||
return nil, err |
||||
} |
||||
return h.Sum(nil)[:sigLen], nil |
||||
} |
||||
|
||||
func (s Secret) sign(r io.Reader) (Signature, error) { |
||||
salt := make([]byte, 8) |
||||
if _, err := rand.Read(salt); err != nil { |
||||
panic(err) |
||||
} |
||||
|
||||
t := s.now() |
||||
sig, err := s.signRaw(r, s.sigSize, salt, t) |
||||
return Signature{sig: sig, salt: salt, t: t}, err |
||||
} |
||||
|
||||
func (s Secret) verify(sig Signature, r io.Reader) error { |
||||
sigB, err := s.signRaw(r, uint8(len(sig.sig)), sig.salt, sig.t) |
||||
if err != nil { |
||||
return merr.Wrap(err) |
||||
} else if !hmac.Equal(sigB, sig.sig) { |
||||
return merr.Wrap(ErrInvalidSig) |
||||
} |
||||
return nil |
||||
} |
@ -1,54 +0,0 @@ |
||||
package mcrypto |
||||
|
||||
import ( |
||||
. "testing" |
||||
"time" |
||||
|
||||
"github.com/mediocregopher/mediocre-go-lib/merr" |
||||
"github.com/mediocregopher/mediocre-go-lib/mrand" |
||||
"github.com/stretchr/testify/assert" |
||||
) |
||||
|
||||
func TestSecretSignVerify(t *T) { |
||||
secretRaw := mrand.Bytes(16) |
||||
secret := NewSecret(secretRaw) |
||||
weakSecret := NewWeakSecret(secretRaw) |
||||
var prevStr string |
||||
var prevSig, prevWeakSig Signature |
||||
for i := 0; i < 10000; i++ { |
||||
now := time.Now().Round(0) |
||||
secret.testNow = now |
||||
weakSecret.testNow = now |
||||
|
||||
thisStr := mrand.Hex(512) |
||||
thisSig := SignString(secret, thisStr) |
||||
thisWeakSig := SignString(weakSecret, thisStr) |
||||
thisSigStr, thisWeakSigStr := thisSig.String(), thisWeakSig.String() |
||||
|
||||
// sanity checks
|
||||
assert.Equal(t, now, thisSig.Time()) |
||||
assert.Equal(t, now, thisWeakSig.Time()) |
||||
assert.NotEmpty(t, thisSigStr) |
||||
assert.NotEmpty(t, thisWeakSigStr) |
||||
assert.NotEqual(t, thisSigStr, thisWeakSigStr) |
||||
assert.True(t, len(thisSigStr) > len(thisWeakSigStr)) |
||||
|
||||
// Either secret should be able to verify either signature
|
||||
assert.NoError(t, VerifyString(secret, thisSig, thisStr)) |
||||
assert.NoError(t, VerifyString(weakSecret, thisWeakSig, thisStr)) |
||||
assert.NoError(t, VerifyString(secret, thisWeakSig, thisStr)) |
||||
assert.NoError(t, VerifyString(weakSecret, thisSig, thisStr)) |
||||
|
||||
if prevStr != "" { |
||||
assert.NotEqual(t, prevSig.String(), thisSigStr) |
||||
assert.NotEqual(t, prevWeakSig.String(), thisWeakSigStr) |
||||
err := VerifyString(secret, prevSig, thisStr) |
||||
assert.True(t, merr.Equal(err, ErrInvalidSig)) |
||||
err = VerifyString(secret, prevWeakSig, thisStr) |
||||
assert.True(t, merr.Equal(err, ErrInvalidSig)) |
||||
} |
||||
prevStr = thisStr |
||||
prevSig = thisSig |
||||
prevWeakSig = thisWeakSig |
||||
} |
||||
} |
@ -1,184 +0,0 @@ |
||||
package mcrypto |
||||
|
||||
import ( |
||||
"bytes" |
||||
"encoding/binary" |
||||
"encoding/hex" |
||||
"encoding/json" |
||||
"errors" |
||||
"io" |
||||
"time" |
||||
|
||||
"github.com/mediocregopher/mediocre-go-lib/merr" |
||||
) |
||||
|
||||
var ( |
||||
errMalformedSig = errors.New("malformed signature") |
||||
|
||||
// ErrInvalidSig is returned by Signer related functions when an invalid
|
||||
// signature is used, e.g. it is a signature for different data, or uses a
|
||||
// different secret key, or has expired
|
||||
ErrInvalidSig = errors.New("invalid signature") |
||||
) |
||||
|
||||
// Signature marshals/unmarshals an actual signature, produced internally by a
|
||||
// Signer, along with the timestamp the signing took place and a random salt.
|
||||
//
|
||||
// All signatures produced in this package will have had the timestamp and salt
|
||||
// included in the signature's input data, and so are also checked by the
|
||||
// Verifier.
|
||||
type Signature struct { |
||||
sig, salt []byte // neither of these should ever be more than 255 bytes long
|
||||
t time.Time |
||||
} |
||||
|
||||
// Time returns the timestamp the Signature was generated at
|
||||
func (s Signature) Time() time.Time { |
||||
return s.t |
||||
} |
||||
|
||||
func (s Signature) String() string { |
||||
// ts:8 + saltHeader:1 + salt + sigHeader:1 + sig
|
||||
b := make([]byte, 10+len(s.salt)+len(s.sig)) |
||||
// It will be year 2286 before the nano doesn't fit in uint64
|
||||
binary.BigEndian.PutUint64(b, uint64(s.t.UnixNano())) |
||||
ptr := 8 |
||||
b[ptr], ptr = uint8(len(s.salt)), ptr+1 |
||||
ptr += copy(b[ptr:], s.salt) |
||||
b[ptr], ptr = uint8(len(s.sig)), ptr+1 |
||||
copy(b[ptr:], s.sig) |
||||
return sigV0 + hex.EncodeToString(b) |
||||
} |
||||
|
||||
// KV implements the method for the mlog.KVer interface
|
||||
func (s Signature) KV() map[string]interface{} { |
||||
return map[string]interface{}{"sig": s.String()} |
||||
} |
||||
|
||||
// MarshalText implements the method for the encoding.TextMarshaler interface
|
||||
func (s Signature) MarshalText() ([]byte, error) { |
||||
return []byte(s.String()), nil |
||||
} |
||||
|
||||
// UnmarshalText implements the method for the encoding.TextUnmarshaler
|
||||
// interface
|
||||
func (s *Signature) UnmarshalText(b []byte) error { |
||||
str := string(b) |
||||
strEnc, ok := stripPrefix(str, sigV0) |
||||
if !ok || len(strEnc) < hex.EncodedLen(10) { |
||||
return merr.Wrap(errMalformedSig) |
||||
} |
||||
|
||||
b, err := hex.DecodeString(strEnc) |
||||
if err != nil { |
||||
return merr.Wrap(err) |
||||
} |
||||
|
||||
unixNano, b := int64(binary.BigEndian.Uint64(b[:8])), b[8:] |
||||
s.t = time.Unix(0, unixNano).Local() |
||||
|
||||
readBytes := func() []byte { |
||||
if err != nil { |
||||
return nil |
||||
} else if len(b) < 1+int(b[0]) { |
||||
err = merr.Wrap(errMalformedSig) |
||||
return nil |
||||
} |
||||
out := b[1 : 1+b[0]] |
||||
b = b[1+b[0]:] |
||||
return out |
||||
} |
||||
|
||||
s.salt = readBytes() |
||||
s.sig = readBytes() |
||||
return err |
||||
} |
||||
|
||||
// MarshalJSON implements the method for the json.Marshaler interface
|
||||
func (s Signature) MarshalJSON() ([]byte, error) { |
||||
return json.Marshal(s.String()) |
||||
} |
||||
|
||||
// UnmarshalJSON implements the method for the json.Unmarshaler interface
|
||||
func (s *Signature) UnmarshalJSON(b []byte) error { |
||||
var str string |
||||
if err := json.Unmarshal(b, &str); err != nil { |
||||
return err |
||||
} |
||||
return s.UnmarshalText([]byte(str)) |
||||
} |
||||
|
||||
// returns an io.Reader which will first read out information about the
|
||||
// Signature which is going to be generated for the data, and then the data from
|
||||
// the io.Reader itself. When used in conjunction with the Signer/Verifier's
|
||||
// hashing algorithm this ensures that the other data encoded in the Signature
|
||||
// (the time and salt) are also encompassed in the sig.
|
||||
func sigPrefixReader(r io.Reader, sigLen uint8, salt []byte, t time.Time) io.Reader { |
||||
// ts:8 + saltHeader:1 + salt + sigLen:1
|
||||
b := make([]byte, 10+len(salt)) |
||||
binary.BigEndian.PutUint64(b, uint64(t.UnixNano())) |
||||
b[9] = uint8(len(salt)) |
||||
copy(b[9:9+len(salt)], salt) |
||||
b[9+len(salt)] = sigLen |
||||
return io.MultiReader(bytes.NewBuffer(b), r) |
||||
} |
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Signer is some entity which can generate signatures for arbitrary data and
|
||||
// can later verify those signatures
|
||||
type Signer interface { |
||||
sign(io.Reader) (Signature, error) |
||||
} |
||||
|
||||
// Verifier is some entity which can verify Signatures produced by a Signer for
|
||||
// some arbitrary data
|
||||
type Verifier interface { |
||||
// returns an error if io.Reader returns one ever, or if the Signature
|
||||
// couldn't be verified
|
||||
verify(Signature, io.Reader) error |
||||
} |
||||
|
||||
// Sign reads all data from the io.Reader and signs it using the given Signer
|
||||
func Sign(s Signer, r io.Reader) (Signature, error) { |
||||
return s.sign(r) |
||||
} |
||||
|
||||
// SignBytes uses the Signer to generate a Signature for the given []bytes
|
||||
func SignBytes(s Signer, b []byte) Signature { |
||||
sig, err := s.sign(bytes.NewBuffer(b)) |
||||
if err != nil { |
||||
panic(err) |
||||
} |
||||
return sig |
||||
} |
||||
|
||||
// SignString uses the Signer to generate a Signature for the given string
|
||||
func SignString(s Signer, in string) Signature { |
||||
return SignBytes(s, []byte(in)) |
||||
} |
||||
|
||||
// Verify reads all data from the io.Reader and uses the Verifier to verify that
|
||||
// the Signature is for that data.
|
||||
//
|
||||
// Returns any errors from io.Reader, or ErrInvalidSig (use merr.Equal(err,
|
||||
// mcrypto.ErrInvalidSig) to check).
|
||||
func Verify(v Verifier, s Signature, r io.Reader) error { |
||||
return v.verify(s, r) |
||||
} |
||||
|
||||
// VerifyBytes uses the Verifier to verify that the Signature is for the given
|
||||
// []bytes.
|
||||
//
|
||||
// Returns ErrInvalidSig (use merr.Equal(err, mcrypto.ErrInvalidSig) to check).
|
||||
func VerifyBytes(v Verifier, s Signature, b []byte) error { |
||||
return v.verify(s, bytes.NewBuffer(b)) |
||||
} |
||||
|
||||
// VerifyString uses the Verifier to verify that the Signature is for the given
|
||||
// string.
|
||||
//
|
||||
// Returns ErrInvalidSig (use merr.Equal(err, mcrypto.ErrInvalidSig) to check).
|
||||
func VerifyString(v Verifier, s Signature, in string) error { |
||||
return VerifyBytes(v, s, []byte(in)) |
||||
} |
@ -1,44 +0,0 @@ |
||||
package mcrypto |
||||
|
||||
import ( |
||||
. "testing" |
||||
"time" |
||||
|
||||
"github.com/mediocregopher/mediocre-go-lib/merr" |
||||
"github.com/mediocregopher/mediocre-go-lib/mrand" |
||||
"github.com/stretchr/testify/assert" |
||||
) |
||||
|
||||
func TestSignerVerifier(t *T) { |
||||
secret := NewSecret(mrand.Bytes(16)) |
||||
var prevStr string |
||||
var prevSig Signature |
||||
for i := 0; i < 10000; i++ { |
||||
now := time.Now().Round(0) |
||||
secret.testNow = now |
||||
|
||||
thisStr := mrand.Hex(512) |
||||
thisSig := SignString(secret, thisStr) |
||||
thisSigStr := thisSig.String() |
||||
|
||||
// sanity checks
|
||||
assert.NotEmpty(t, thisSigStr) |
||||
assert.Equal(t, now, thisSig.Time()) |
||||
assert.NoError(t, VerifyString(secret, thisSig, thisStr)) |
||||
|
||||
// marshaling/unmarshaling
|
||||
var thisSig2 Signature |
||||
assert.NoError(t, thisSig2.UnmarshalText([]byte(thisSigStr))) |
||||
assert.Equal(t, thisSigStr, thisSig2.String()) |
||||
assert.Equal(t, now, thisSig2.Time()) |
||||
assert.NoError(t, VerifyString(secret, thisSig2, thisStr)) |
||||
|
||||
if prevStr != "" { |
||||
assert.NotEqual(t, prevSig.String(), thisSigStr) |
||||
err := VerifyString(secret, prevSig, thisStr) |
||||
assert.True(t, merr.Equal(err, ErrInvalidSig)) |
||||
} |
||||
prevStr = thisStr |
||||
prevSig = thisSig |
||||
} |
||||
} |
@ -1,97 +0,0 @@ |
||||
package mcrypto |
||||
|
||||
import ( |
||||
"bytes" |
||||
"context" |
||||
"crypto/rand" |
||||
"encoding/binary" |
||||
"encoding/hex" |
||||
"encoding/json" |
||||
"errors" |
||||
"time" |
||||
|
||||
"github.com/mediocregopher/mediocre-go-lib/mctx" |
||||
"github.com/mediocregopher/mediocre-go-lib/merr" |
||||
) |
||||
|
||||
var errMalformedUUID = errors.New("malformed UUID string") |
||||
|
||||
// UUID is a universally unique identifier which embeds within it a timestamp.
|
||||
//
|
||||
// Only Unmarshal methods should be called on the zero UUID value.
|
||||
//
|
||||
// Comparing the equality of two UUID's should always be done using the Equal
|
||||
// method, or by comparing their string forms.
|
||||
//
|
||||
// The string form of UUIDs (returned by String or MarshalText) are
|
||||
// lexigraphically order-able by their embedded timestamp.
|
||||
type UUID struct { |
||||
b []byte |
||||
} |
||||
|
||||
// NewUUID populates and returns a new UUID instance which embeds the given time
|
||||
func NewUUID(t time.Time) UUID { |
||||
b := make([]byte, 16) |
||||
binary.BigEndian.PutUint64(b[:8], uint64(t.UnixNano())) |
||||
if _, err := rand.Read(b[8:]); err != nil { |
||||
panic(err) |
||||
} |
||||
return UUID{b: b} |
||||
} |
||||
|
||||
func (u UUID) String() string { |
||||
return uuidV0 + hex.EncodeToString(u.b) |
||||
} |
||||
|
||||
// Equal returns whether or not the two UUID's are the same value
|
||||
func (u UUID) Equal(u2 UUID) bool { |
||||
return bytes.Equal(u.b, u2.b) |
||||
} |
||||
|
||||
// Time unpacks and returns the timestamp embedded in the UUID
|
||||
func (u UUID) Time() time.Time { |
||||
unixNano := int64(binary.BigEndian.Uint64(u.b[:8])) |
||||
return time.Unix(0, unixNano).Local() |
||||
} |
||||
|
||||
// KV implements the method for the mlog.KVer interface
|
||||
func (u UUID) KV() map[string]interface{} { |
||||
return map[string]interface{}{"uuid": u.String()} |
||||
} |
||||
|
||||
// MarshalText implements the method for the encoding.TextMarshaler interface
|
||||
func (u UUID) MarshalText() ([]byte, error) { |
||||
return []byte(u.String()), nil |
||||
} |
||||
|
||||
// UnmarshalText implements the method for the encoding.TextUnmarshaler
|
||||
// interface
|
||||
func (u *UUID) UnmarshalText(b []byte) error { |
||||
str := string(b) |
||||
strEnc, ok := stripPrefix(str, uuidV0) |
||||
if !ok || len(strEnc) != hex.EncodedLen(16) { |
||||
ctx := mctx.Annotate(context.Background(), "uuidStr", str) |
||||
return merr.Wrap(errMalformedUUID, ctx) |
||||
} |
||||
b, err := hex.DecodeString(strEnc) |
||||
if err != nil { |
||||
ctx := mctx.Annotate(context.Background(), "uuidStr", str) |
||||
return merr.Wrap(errMalformedUUID, ctx) |
||||
} |
||||
u.b = b |
||||
return nil |
||||
} |
||||
|
||||
// MarshalJSON implements the method for the json.Marshaler interface
|
||||
func (u UUID) MarshalJSON() ([]byte, error) { |
||||
return json.Marshal(u.String()) |
||||
} |
||||
|
||||
// UnmarshalJSON implements the method for the json.Unmarshaler interface
|
||||
func (u *UUID) UnmarshalJSON(b []byte) error { |
||||
var s string |
||||
if err := json.Unmarshal(b, &s); err != nil { |
||||
return err |
||||
} |
||||
return u.UnmarshalText([]byte(s)) |
||||
} |
@ -1,39 +0,0 @@ |
||||
package mcrypto |
||||
|
||||
import ( |
||||
"strings" |
||||
. "testing" |
||||
"time" |
||||
|
||||
"github.com/stretchr/testify/assert" |
||||
"github.com/stretchr/testify/require" |
||||
) |
||||
|
||||
func TestUUID(t *T) { |
||||
var prevT time.Time |
||||
var prev UUID |
||||
for i := 0; i < 10000; i++ { |
||||
thisT := time.Now().Round(0) // strip monotonic clock
|
||||
require.True(t, thisT.After(prevT)) // sanity check
|
||||
this := NewUUID(thisT) |
||||
|
||||
// basic
|
||||
assert.True(t, strings.HasPrefix(this.String(), uuidV0)) |
||||
|
||||
// comparisons with prev
|
||||
assert.False(t, prev.Equal(this)) |
||||
assert.NotEqual(t, prev.String(), this.String()) |
||||
assert.True(t, this.String() > prev.String()) |
||||
prev = this |
||||
|
||||
// check time unpacking
|
||||
assert.Equal(t, thisT, this.Time()) |
||||
|
||||
// check marshal/unmarshal
|
||||
thisStr, err := this.MarshalText() |
||||
require.NoError(t, err) |
||||
var this2 UUID |
||||
require.NoError(t, this2.UnmarshalText(thisStr)) |
||||
assert.True(t, this.Equal(this2), "this:%q this2:%q", this, this2) |
||||
} |
||||
} |
@ -1,168 +0,0 @@ |
||||
// Package mbigquery implements connecting to Google's BigQuery service and
|
||||
// simplifying a number of interactions with it.
|
||||
package mbigquery |
||||
|
||||
import ( |
||||
"context" |
||||
"encoding/json" |
||||
"sync" |
||||
"time" |
||||
|
||||
"github.com/mediocregopher/mediocre-go-lib/mctx" |
||||
"github.com/mediocregopher/mediocre-go-lib/mdb" |
||||
"github.com/mediocregopher/mediocre-go-lib/merr" |
||||
"github.com/mediocregopher/mediocre-go-lib/mlog" |
||||
"github.com/mediocregopher/mediocre-go-lib/mrun" |
||||
|
||||
"cloud.google.com/go/bigquery" |
||||
"google.golang.org/api/googleapi" |
||||
) |
||||
|
||||
// TODO this file needs tests
|
||||
|
||||
func isErrAlreadyExists(err error) bool { |
||||
if err == nil { |
||||
return false |
||||
} |
||||
if gerr, ok := err.(*googleapi.Error); ok && gerr.Code == 409 { |
||||
return true |
||||
} |
||||
return false |
||||
} |
||||
|
||||
// BigQuery is a wrapper around a bigquery client providing more functionality.
|
||||
type BigQuery struct { |
||||
*bigquery.Client |
||||
gce *mdb.GCE |
||||
ctx context.Context |
||||
|
||||
// key is dataset/tableName
|
||||
tablesL sync.Mutex |
||||
tables map[[2]string]*bigquery.Table |
||||
tableUploaders map[[2]string]*bigquery.Uploader |
||||
} |
||||
|
||||
// WithBigQuery returns a BigQuery instance which will be initialized and
|
||||
// configured when the start event is triggered on the returned (see
|
||||
// mrun.Start). The BigQuery instance will have Close called on it when the stop
|
||||
// event is triggered on the returned Context (see mrun.Stop).
|
||||
//
|
||||
// gce is optional and can be passed in if there's an existing gce object which
|
||||
// should be used, otherwise a new one will be created with mdb.MGCE.
|
||||
func WithBigQuery(parent context.Context, gce *mdb.GCE) (context.Context, *BigQuery) { |
||||
ctx := mctx.NewChild(parent, "mbigquery") |
||||
if gce == nil { |
||||
ctx, gce = mdb.WithGCE(ctx, "") |
||||
} |
||||
|
||||
bq := &BigQuery{ |
||||
gce: gce, |
||||
tables: map[[2]string]*bigquery.Table{}, |
||||
tableUploaders: map[[2]string]*bigquery.Uploader{}, |
||||
} |
||||
|
||||
ctx = mrun.WithStartHook(ctx, func(innerCtx context.Context) error { |
||||
bq.ctx = mctx.MergeAnnotations(bq.ctx, bq.gce.Context()) |
||||
mlog.Info("connecting to bigquery", bq.ctx) |
||||
var err error |
||||
bq.Client, err = bigquery.NewClient(innerCtx, bq.gce.Project, bq.gce.ClientOptions()...) |
||||
return merr.Wrap(err, bq.ctx) |
||||
}) |
||||
ctx = mrun.WithStopHook(ctx, func(context.Context) error { |
||||
return bq.Client.Close() |
||||
}) |
||||
bq.ctx = ctx |
||||
return mctx.WithChild(parent, ctx), bq |
||||
} |
||||
|
||||
// Table initializes and returns the table instance with the given dataset and
|
||||
// schema information. This method caches the Table/Uploader instances it
|
||||
// returns, so multiple calls with the same dataset/tableName will only actually
|
||||
// create those instances on the first call.
|
||||
func (bq *BigQuery) Table( |
||||
ctx context.Context, |
||||
dataset, tableName string, |
||||
schemaObj interface{}, |
||||
) ( |
||||
*bigquery.Table, *bigquery.Uploader, error, |
||||
) { |
||||
bq.tablesL.Lock() |
||||
defer bq.tablesL.Unlock() |
||||
|
||||
key := [2]string{dataset, tableName} |
||||
if table, ok := bq.tables[key]; ok { |
||||
return table, bq.tableUploaders[key], nil |
||||
} |
||||
|
||||
ctx = mctx.MergeAnnotations(ctx, bq.ctx) |
||||
ctx = mctx.Annotate(ctx, "dataset", dataset, "table", tableName) |
||||
|
||||
mlog.Debug("creating/grabbing table", bq.ctx) |
||||
schema, err := bigquery.InferSchema(schemaObj) |
||||
if err != nil { |
||||
return nil, nil, merr.Wrap(err, ctx) |
||||
} |
||||
|
||||
ds := bq.Dataset(dataset) |
||||
if err := ds.Create(ctx, nil); err != nil && !isErrAlreadyExists(err) { |
||||
return nil, nil, merr.Wrap(err, ctx) |
||||
} |
||||
|
||||
table := ds.Table(tableName) |
||||
meta := &bigquery.TableMetadata{ |
||||
Name: tableName, |
||||
Schema: schema, |
||||
} |
||||
if err := table.Create(ctx, meta); err != nil && !isErrAlreadyExists(err) { |
||||
return nil, nil, merr.Wrap(err, ctx) |
||||
} |
||||
uploader := table.Uploader() |
||||
|
||||
bq.tables[key] = table |
||||
bq.tableUploaders[key] = uploader |
||||
return table, uploader, nil |
||||
} |
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
const timeFormat = "2006-01-02 15:04:05 MST" |
||||
|
||||
// Time wraps a time.Time object and provides marshaling/unmarshaling for
|
||||
// bigquery's time format.
|
||||
type Time struct { |
||||
time.Time |
||||
} |
||||
|
||||
// MarshalText implements the encoding.TextMarshaler interface.
|
||||
func (t Time) MarshalText() ([]byte, error) { |
||||
str := t.Time.Format(timeFormat) |
||||
return []byte(str), nil |
||||
} |
||||
|
||||
// UnmarshalText implements the encoding.TextUnmarshaler interface.
|
||||
func (t *Time) UnmarshalText(b []byte) error { |
||||
tt, err := time.Parse(timeFormat, string(b)) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
t.Time = tt |
||||
return nil |
||||
} |
||||
|
||||
// MarshalJSON implements the json.Marshaler interface.
|
||||
func (t *Time) MarshalJSON() ([]byte, error) { |
||||
b, err := t.MarshalText() |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
return json.Marshal(string(b)) |
||||
} |
||||
|
||||
// UnmarshalJSON implements the json.Unmarshaler interface.
|
||||
func (t *Time) UnmarshalJSON(b []byte) error { |
||||
var str string |
||||
if err := json.Unmarshal(b, &str); err != nil { |
||||
return err |
||||
} |
||||
return t.UnmarshalText([]byte(str)) |
||||
} |
@ -1,124 +0,0 @@ |
||||
// Package mbigtable implements connecting to Google's Bigtable service and
|
||||
// simplifying a number of interactions with it.
|
||||
package mbigtable |
||||
|
||||
import ( |
||||
"context" |
||||
"strings" |
||||
|
||||
"cloud.google.com/go/bigtable" |
||||
"github.com/mediocregopher/mediocre-go-lib/mcfg" |
||||
"github.com/mediocregopher/mediocre-go-lib/mctx" |
||||
"github.com/mediocregopher/mediocre-go-lib/mdb" |
||||
"github.com/mediocregopher/mediocre-go-lib/merr" |
||||
"github.com/mediocregopher/mediocre-go-lib/mlog" |
||||
"github.com/mediocregopher/mediocre-go-lib/mrun" |
||||
) |
||||
|
||||
func isErrAlreadyExists(err error) bool { |
||||
if err == nil { |
||||
return false |
||||
} |
||||
return strings.HasSuffix(err.Error(), " already exists") |
||||
} |
||||
|
||||
// Bigtable is a wrapper around a bigtable client providing more functionality.
|
||||
type Bigtable struct { |
||||
*bigtable.Client |
||||
Instance string |
||||
|
||||
gce *mdb.GCE |
||||
ctx context.Context |
||||
} |
||||
|
||||
// WithBigTable returns a Bigtable instance which will be initialized and
|
||||
// configured when the start event is triggered on the returned Context (see
|
||||
// mrun.Start). The Bigtable instance will have Close called on it when the
|
||||
// stop event is triggered on the returned Context (see mrun.Stop).
|
||||
//
|
||||
// gce is optional and can be passed in if there's an existing gce object which
|
||||
// should be used, otherwise a new one will be created with mdb.MGCE.
|
||||
//
|
||||
// defaultInstance can be given as the instance name to use as the default
|
||||
// parameter value. If empty the parameter will be required to be set.
|
||||
func WithBigTable(parent context.Context, gce *mdb.GCE, defaultInstance string) (context.Context, *Bigtable) { |
||||
ctx := mctx.NewChild(parent, "bigtable") |
||||
if gce == nil { |
||||
ctx, gce = mdb.WithGCE(ctx, "") |
||||
} |
||||
|
||||
bt := &Bigtable{ |
||||
gce: gce, |
||||
} |
||||
|
||||
var inst *string |
||||
{ |
||||
const name, descr = "instance", "name of the bigtable instance in the project to connect to" |
||||
if defaultInstance != "" { |
||||
ctx, inst = mcfg.WithString(ctx, name, defaultInstance, descr) |
||||
} else { |
||||
ctx, inst = mcfg.WithRequiredString(ctx, name, descr) |
||||
} |
||||
} |
||||
|
||||
ctx = mrun.WithStartHook(ctx, func(innerCtx context.Context) error { |
||||
bt.Instance = *inst |
||||
|
||||
bt.ctx = mctx.MergeAnnotations(bt.ctx, bt.gce.Context()) |
||||
bt.ctx = mctx.Annotate(bt.ctx, "instance", bt.Instance) |
||||
|
||||
mlog.Info("connecting to bigtable", bt.ctx) |
||||
var err error |
||||
bt.Client, err = bigtable.NewClient( |
||||
innerCtx, |
||||
bt.gce.Project, bt.Instance, |
||||
bt.gce.ClientOptions()..., |
||||
) |
||||
return merr.Wrap(err, bt.ctx) |
||||
}) |
||||
ctx = mrun.WithStopHook(ctx, func(context.Context) error { |
||||
return bt.Client.Close() |
||||
}) |
||||
bt.ctx = ctx |
||||
return mctx.WithChild(parent, ctx), bt |
||||
} |
||||
|
||||
// EnsureTable ensures that the given table exists and has (at least) the given
|
||||
// column families.
|
||||
//
|
||||
// This method requires admin privileges on the bigtable instance.
|
||||
func (bt *Bigtable) EnsureTable(ctx context.Context, name string, colFams ...string) error { |
||||
ctx = mctx.MergeAnnotations(ctx, bt.ctx) |
||||
ctx = mctx.Annotate(ctx, "table", name) |
||||
mlog.Info("ensuring table", ctx) |
||||
|
||||
mlog.Debug("creating admin client", ctx) |
||||
adminClient, err := bigtable.NewAdminClient(ctx, bt.gce.Project, bt.Instance) |
||||
if err != nil { |
||||
return merr.Wrap(err, ctx) |
||||
} |
||||
defer adminClient.Close() |
||||
|
||||
mlog.Debug("creating bigtable table (if needed)", ctx) |
||||
err = adminClient.CreateTable(ctx, name) |
||||
if err != nil && !isErrAlreadyExists(err) { |
||||
return merr.Wrap(err, ctx) |
||||
} |
||||
|
||||
for _, colFam := range colFams { |
||||
ctx := mctx.Annotate(ctx, "family", colFam) |
||||
mlog.Debug("creating bigtable column family (if needed)", ctx) |
||||
err := adminClient.CreateColumnFamily(ctx, name, colFam) |
||||
if err != nil && !isErrAlreadyExists(err) { |
||||
return merr.Wrap(err, ctx) |
||||
} |
||||
} |
||||
|
||||
return nil |
||||
} |
||||
|
||||
// Table returns the bigtable.Table instance which can be used to write/query
|
||||
// the given table.
|
||||
func (bt *Bigtable) Table(tableName string) *bigtable.Table { |
||||
return bt.Open(tableName) |
||||
} |
@ -1,44 +0,0 @@ |
||||
package mbigtable |
||||
|
||||
import ( |
||||
. "testing" |
||||
"time" |
||||
|
||||
"cloud.google.com/go/bigtable" |
||||
"github.com/mediocregopher/mediocre-go-lib/mrand" |
||||
"github.com/mediocregopher/mediocre-go-lib/mtest" |
||||
"github.com/mediocregopher/mediocre-go-lib/mtest/massert" |
||||
) |
||||
|
||||
func TestBasic(t *T) { |
||||
ctx := mtest.Context() |
||||
ctx = mtest.WithEnv(ctx, "BIGTABLE_GCE_PROJECT", "testProject") |
||||
ctx, bt := WithBigTable(ctx, nil, "testInstance") |
||||
|
||||
mtest.Run(ctx, t, func() { |
||||
tableName := "test-" + mrand.Hex(8) |
||||
colFam := "colFam-" + mrand.Hex(8) |
||||
if err := bt.EnsureTable(ctx, tableName, colFam); err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
|
||||
table := bt.Table(tableName) |
||||
row := "row-" + mrand.Hex(8) |
||||
mut := bigtable.NewMutation() |
||||
mut.Set(colFam, "col", bigtable.Time(time.Now()), []byte("bar")) |
||||
if err := table.Apply(ctx, row, mut); err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
|
||||
readRow, err := table.ReadRow(ctx, row) |
||||
if err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
readColFam := readRow[colFam] |
||||
massert.Require(t, |
||||
massert.Length(readColFam, 1), |
||||
massert.Equal(colFam+":col", readColFam[0].Column), |
||||
massert.Equal([]byte("bar"), readColFam[0].Value), |
||||
) |
||||
}) |
||||
} |
@ -1,54 +0,0 @@ |
||||
// Package mdatastore implements connecting to Google's Datastore service and
|
||||
// simplifying a number of interactions with it.
|
||||
package mdatastore |
||||
|
||||
import ( |
||||
"context" |
||||
|
||||
"cloud.google.com/go/datastore" |
||||
"github.com/mediocregopher/mediocre-go-lib/mctx" |
||||
"github.com/mediocregopher/mediocre-go-lib/mdb" |
||||
"github.com/mediocregopher/mediocre-go-lib/merr" |
||||
"github.com/mediocregopher/mediocre-go-lib/mlog" |
||||
"github.com/mediocregopher/mediocre-go-lib/mrun" |
||||
) |
||||
|
||||
// Datastore is a wrapper around a datastore client providing more
|
||||
// functionality.
|
||||
type Datastore struct { |
||||
*datastore.Client |
||||
|
||||
gce *mdb.GCE |
||||
ctx context.Context |
||||
} |
||||
|
||||
// WithDatastore returns a Datastore instance which will be initialized and
|
||||
// configured when the start event is triggered on the returned Context (see
|
||||
// mrun.Start). The Datastore instance will have Close called on it when the
|
||||
// stop event is triggered on the returned Context (see mrun.Stop).
|
||||
//
|
||||
// gce is optional and can be passed in if there's an existing gce object which
|
||||
// should be used, otherwise a new one will be created with mdb.MGCE.
|
||||
func WithDatastore(parent context.Context, gce *mdb.GCE) (context.Context, *Datastore) { |
||||
ctx := mctx.NewChild(parent, "datastore") |
||||
if gce == nil { |
||||
ctx, gce = mdb.WithGCE(ctx, "") |
||||
} |
||||
|
||||
ds := &Datastore{ |
||||
gce: gce, |
||||
} |
||||
|
||||
ctx = mrun.WithStartHook(ctx, func(innerCtx context.Context) error { |
||||
ds.ctx = mctx.MergeAnnotations(ds.ctx, ds.gce.Context()) |
||||
mlog.Info("connecting to datastore", ds.ctx) |
||||
var err error |
||||
ds.Client, err = datastore.NewClient(innerCtx, ds.gce.Project, ds.gce.ClientOptions()...) |
||||
return merr.Wrap(err, ds.ctx) |
||||
}) |
||||
ctx = mrun.WithStopHook(ctx, func(context.Context) error { |
||||
return ds.Client.Close() |
||||
}) |
||||
ds.ctx = ctx |
||||
return mctx.WithChild(parent, ctx), ds |
||||
} |
@ -1,40 +0,0 @@ |
||||
package mdatastore |
||||
|
||||
import ( |
||||
. "testing" |
||||
|
||||
"cloud.google.com/go/datastore" |
||||
"github.com/mediocregopher/mediocre-go-lib/mrand" |
||||
"github.com/mediocregopher/mediocre-go-lib/mtest" |
||||
"github.com/mediocregopher/mediocre-go-lib/mtest/massert" |
||||
) |
||||
|
||||
// Requires datastore emulator to be running
|
||||
func TestBasic(t *T) { |
||||
ctx := mtest.Context() |
||||
ctx = mtest.WithEnv(ctx, "DATASTORE_GCE_PROJECT", "test") |
||||
ctx, ds := WithDatastore(ctx, nil) |
||||
mtest.Run(ctx, t, func() { |
||||
name := mrand.Hex(8) |
||||
key := datastore.NameKey("testKind", name, nil) |
||||
key.Namespace = "TestBasic_" + mrand.Hex(8) |
||||
type valType struct { |
||||
A, B int |
||||
} |
||||
val := valType{ |
||||
A: mrand.Int(), |
||||
B: mrand.Int(), |
||||
} |
||||
|
||||
if _, err := ds.Put(ctx, key, &val); err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
|
||||
var val2 valType |
||||
if err := ds.Get(ctx, key, &val2); err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
|
||||
massert.Require(t, massert.Equal(val, val2)) |
||||
}) |
||||
} |
@ -1,79 +0,0 @@ |
||||
// Package mdb contains a number of database wrappers for databases I commonly
|
||||
// use
|
||||
package mdb |
||||
|
||||
import ( |
||||
"context" |
||||
|
||||
"github.com/mediocregopher/mediocre-go-lib/mcfg" |
||||
"github.com/mediocregopher/mediocre-go-lib/mcmp" |
||||
"github.com/mediocregopher/mediocre-go-lib/mlog" |
||||
"github.com/mediocregopher/mediocre-go-lib/mrun" |
||||
"google.golang.org/api/option" |
||||
) |
||||
|
||||
// GCE wraps configuration parameters commonly used for interacting with GCE
|
||||
// services.
|
||||
type GCE struct { |
||||
cmp *mcmp.Component |
||||
Project string |
||||
CredFile string |
||||
} |
||||
|
||||
type gceOpts struct { |
||||
defaultProject string |
||||
} |
||||
|
||||
// GCEOption is a value which adjusts the behavior of InstGCE.
|
||||
type GCEOption func(*gceOpts) |
||||
|
||||
// GCEDefaultProject sets the given string to be the default project of the GCE
|
||||
// instance. The default project will still be configurable via mcfg regardless
|
||||
// of what this is set to.
|
||||
func GCEDefaultProject(defaultProject string) GCEOption { |
||||
return func(opts *gceOpts) { |
||||
opts.defaultProject = defaultProject |
||||
} |
||||
} |
||||
|
||||
// InstGCE instantiates a GCE which will be initialized when the Init event is
|
||||
// triggered on the given Component. defaultProject is used as the default value
|
||||
// for the mcfg parameter this function creates.
|
||||
func InstGCE(cmp *mcmp.Component, options ...GCEOption) *GCE { |
||||
var opts gceOpts |
||||
for _, opt := range options { |
||||
opt(&opts) |
||||
} |
||||
|
||||
gce := GCE{cmp: cmp.Child("gce")} |
||||
credFile := mcfg.String(gce.cmp, "cred-file", |
||||
mcfg.ParamUsage("Path to GCE credientials JSON file, if any")) |
||||
project := mcfg.String(gce.cmp, "project", |
||||
mcfg.ParamDefaultOrRequired(opts.defaultProject), |
||||
mcfg.ParamUsage("Name of GCE project to use")) |
||||
|
||||
mrun.InitHook(gce.cmp, func(ctx context.Context) error { |
||||
gce.Project = *project |
||||
gce.CredFile = *credFile |
||||
gce.cmp.Annotate("project", gce.Project) |
||||
mlog.From(gce.cmp).Info("GCE config initialized", ctx) |
||||
return nil |
||||
}) |
||||
|
||||
return &gce |
||||
} |
||||
|
||||
// ClientOptions generates and returns the ClientOption instances which can be
|
||||
// passed into most GCE client drivers.
|
||||
func (gce *GCE) ClientOptions() []option.ClientOption { |
||||
var opts []option.ClientOption |
||||
if gce.CredFile != "" { |
||||
opts = append(opts, option.WithCredentialsFile(gce.CredFile)) |
||||
} |
||||
return opts |
||||
} |
||||
|
||||
// Context returns the annotated Context from this instance's initialization.
|
||||
func (gce *GCE) Context() context.Context { |
||||
return gce.cmp.Context() |
||||
} |
@ -1,373 +0,0 @@ |
||||
// Package mpubsub implements connecting to Google's PubSub service and
|
||||
// simplifying a number of interactions with it.
|
||||
package mpubsub |
||||
|
||||
import ( |
||||
"context" |
||||
"sync" |
||||
"time" |
||||
|
||||
"cloud.google.com/go/pubsub" |
||||
"github.com/mediocregopher/mediocre-go-lib/mcmp" |
||||
"github.com/mediocregopher/mediocre-go-lib/mctx" |
||||
"github.com/mediocregopher/mediocre-go-lib/mdb" |
||||
"github.com/mediocregopher/mediocre-go-lib/merr" |
||||
"github.com/mediocregopher/mediocre-go-lib/mlog" |
||||
"github.com/mediocregopher/mediocre-go-lib/mrun" |
||||
"google.golang.org/grpc/codes" |
||||
"google.golang.org/grpc/status" |
||||
) |
||||
|
||||
// TODO Consume (and probably BatchConsume) don't properly handle the Client
|
||||
// being closed.
|
||||
|
||||
func isErrAlreadyExists(err error) bool { |
||||
if err == nil { |
||||
return false |
||||
} |
||||
s, ok := status.FromError(err) |
||||
return ok && s.Code() == codes.AlreadyExists |
||||
} |
||||
|
||||
// Message aliases the type in the official driver
|
||||
type Message = pubsub.Message |
||||
|
||||
// PubSub is a wrapper around a pubsub client providing more functionality.
|
||||
type PubSub struct { |
||||
*pubsub.Client |
||||
|
||||
gce *mdb.GCE |
||||
cmp *mcmp.Component |
||||
} |
||||
|
||||
type pubsubOpts struct { |
||||
gce *mdb.GCE |
||||
} |
||||
|
||||
// PubSubOpt is a value which adjusts the behavior of InstPubSub.
|
||||
type PubSubOpt func(*pubsubOpts) |
||||
|
||||
// PubSubGCE indicates that InstPubSub should use the given GCE instance rather
|
||||
// than instantiate its own.
|
||||
func PubSubGCE(gce *mdb.GCE) PubSubOpt { |
||||
return func(opts *pubsubOpts) { |
||||
opts.gce = gce |
||||
} |
||||
} |
||||
|
||||
// InstPubSub instantiates a PubSub which will be initialized when the Init
|
||||
// event is triggered on the given Component. The PubSub instance will have
|
||||
// Close called on it when the Shutdown event is triggered on the given
|
||||
// Component.
|
||||
func InstPubSub(cmp *mcmp.Component, options ...PubSubOpt) *PubSub { |
||||
var opts pubsubOpts |
||||
for _, opt := range options { |
||||
opt(&opts) |
||||
} |
||||
|
||||
ps := PubSub{ |
||||
gce: opts.gce, |
||||
cmp: cmp.Child("pubsub"), |
||||
} |
||||
if ps.gce == nil { |
||||
ps.gce = mdb.InstGCE(ps.cmp) |
||||
} |
||||
|
||||
mrun.InitHook(ps.cmp, func(ctx context.Context) error { |
||||
mlog.From(ps.cmp).Info("connecting to pubsub", ctx) |
||||
var err error |
||||
ps.Client, err = pubsub.NewClient(ctx, ps.gce.Project, ps.gce.ClientOptions()...) |
||||
return merr.Wrap(err, ps.cmp.Context(), ctx) |
||||
}) |
||||
|
||||
mrun.ShutdownHook(ps.cmp, func(ctx context.Context) error { |
||||
mlog.From(ps.cmp).Info("closing pubsub", ctx) |
||||
return ps.Client.Close() |
||||
}) |
||||
return &ps |
||||
} |
||||
|
||||
// Topic provides methods around a particular topic in PubSub
|
||||
type Topic struct { |
||||
*PubSub |
||||
Name string |
||||
|
||||
ctx context.Context |
||||
topic *pubsub.Topic |
||||
} |
||||
|
||||
// Topic returns, after potentially creating, a topic of the given name
|
||||
func (ps *PubSub) Topic(ctx context.Context, name string, create bool) (*Topic, error) { |
||||
t := &Topic{ |
||||
PubSub: ps, |
||||
ctx: mctx.Annotate(ps.cmp.Context(), "topicName", name), |
||||
Name: name, |
||||
} |
||||
|
||||
var err error |
||||
if create { |
||||
t.topic, err = ps.Client.CreateTopic(ctx, name) |
||||
if isErrAlreadyExists(err) { |
||||
t.topic = ps.Client.Topic(name) |
||||
} else if err != nil { |
||||
return nil, merr.Wrap(err, t.ctx, ctx) |
||||
} |
||||
} else { |
||||
t.topic = ps.Client.Topic(name) |
||||
if exists, err := t.topic.Exists(t.ctx); err != nil { |
||||
return nil, merr.Wrap(err, t.ctx, ctx) |
||||
} else if !exists { |
||||
return nil, merr.New("topic dne", t.ctx, ctx) |
||||
} |
||||
} |
||||
return t, nil |
||||
} |
||||
|
||||
// Publish publishes a message with the given data as its body to the Topic
|
||||
func (t *Topic) Publish(ctx context.Context, data []byte) error { |
||||
_, err := t.topic.Publish(ctx, &Message{Data: data}).Get(ctx) |
||||
if err != nil { |
||||
return merr.Wrap(err, t.ctx, ctx) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
// Subscription provides methods around a subscription to a topic in PubSub
|
||||
type Subscription struct { |
||||
*Topic |
||||
Name string |
||||
|
||||
ctx context.Context |
||||
sub *pubsub.Subscription |
||||
|
||||
// only used in tests to trigger batch processing
|
||||
batchTestTrigger chan bool |
||||
} |
||||
|
||||
// Subscription returns a Subscription instance, after potentially creating it,
|
||||
// for the Topic
|
||||
func (t *Topic) Subscription(ctx context.Context, name string, create bool) (*Subscription, error) { |
||||
name = t.Name + "_" + name |
||||
s := &Subscription{ |
||||
Topic: t, |
||||
Name: name, |
||||
ctx: mctx.Annotate(t.ctx, "subName", name), |
||||
} |
||||
|
||||
var err error |
||||
if create { |
||||
s.sub, err = s.CreateSubscription(ctx, name, pubsub.SubscriptionConfig{ |
||||
Topic: t.topic, |
||||
}) |
||||
if isErrAlreadyExists(err) { |
||||
s.sub = s.PubSub.Subscription(s.Name) |
||||
} else if err != nil { |
||||
return nil, merr.Wrap(err, s.ctx, ctx) |
||||
} |
||||
} else { |
||||
s.sub = s.PubSub.Subscription(s.Name) |
||||
if exists, err := s.sub.Exists(ctx); err != nil { |
||||
return nil, merr.Wrap(err, s.ctx, ctx) |
||||
} else if !exists { |
||||
return nil, merr.New("sub dne", s.ctx, ctx) |
||||
} |
||||
} |
||||
return s, nil |
||||
} |
||||
|
||||
// ConsumerFunc is a function which messages being consumed will be passed. The
|
||||
// returned boolean and returned error are independent. If the bool is false the
|
||||
// message will be returned to the queue for retrying later. If an error is
|
||||
// returned it will be logged.
|
||||
//
|
||||
// The Context will be canceled once the deadline has been reached (as set when
|
||||
// Consume is called).
|
||||
type ConsumerFunc func(context.Context, *Message) (bool, error) |
||||
|
||||
// ConsumerOpts are options which effect the behavior of a Consume method call
|
||||
type ConsumerOpts struct { |
||||
// Default 30s. The timeout each message has to complete before its context
|
||||
// is cancelled and the server re-publishes it
|
||||
Timeout time.Duration |
||||
|
||||
// Default 1. Number of concurrent messages to consume at a time
|
||||
Concurrent int |
||||
|
||||
// TODO DisableBatchAutoTrigger
|
||||
// Currently there is no auto-trigger behavior, batches only get processed
|
||||
// on a dumb ticker. This is necessary for the way I plan to have the
|
||||
// datastore writing, but it's not the expected behavior of a batch getting
|
||||
// triggered everytime <Concurrent> messages come in.
|
||||
} |
||||
|
||||
func (co ConsumerOpts) withDefaults() ConsumerOpts { |
||||
if co.Timeout == 0 { |
||||
co.Timeout = 30 * time.Second |
||||
} |
||||
if co.Concurrent == 0 { |
||||
co.Concurrent = 1 |
||||
} |
||||
return co |
||||
} |
||||
|
||||
// Consume uses the given ConsumerFunc and ConsumerOpts to process messages off
|
||||
// the Subscription
|
||||
func (s *Subscription) Consume(ctx context.Context, fn ConsumerFunc, opts ConsumerOpts) { |
||||
opts = opts.withDefaults() |
||||
s.sub.ReceiveSettings.MaxExtension = opts.Timeout |
||||
s.sub.ReceiveSettings.MaxOutstandingMessages = opts.Concurrent |
||||
|
||||
for { |
||||
err := s.sub.Receive(ctx, func(ctx context.Context, msg *Message) { |
||||
innerCtx, cancel := context.WithTimeout(ctx, opts.Timeout) |
||||
defer cancel() |
||||
|
||||
ok, err := fn(innerCtx, msg) |
||||
if err != nil { |
||||
mlog.From(s.cmp).Warn("error consuming pubsub message", |
||||
s.ctx, ctx, innerCtx, merr.Context(err)) |
||||
} |
||||
|
||||
if ok { |
||||
msg.Ack() |
||||
} else { |
||||
msg.Nack() |
||||
} |
||||
}) |
||||
if ctx.Err() == context.Canceled || err == nil { |
||||
return |
||||
} else if err != nil { |
||||
mlog.From(s.cmp).Warn("error consuming from pubsub", |
||||
s.ctx, ctx, merr.Context(err)) |
||||
time.Sleep(1 * time.Second) |
||||
} |
||||
} |
||||
} |
||||
|
||||
// BatchConsumerFunc is similar to ConsumerFunc, except it takes in a batch of
|
||||
// multiple messages at once. If the boolean returned will apply to every
|
||||
// message in the batch.
|
||||
type BatchConsumerFunc func(context.Context, []*Message) (bool, error) |
||||
|
||||
// BatchGroupFunc is an optional param to BatchConsume which allows for grouping
|
||||
// messages into separate groups. Each message received is attempted to be
|
||||
// placed in a group. Grouping is done by calling this function with the
|
||||
// received message and a random message from a group, and if this function
|
||||
// returns true then the received message is placed into that group. If this
|
||||
// returns false for all groups then a new group is created.
|
||||
//
|
||||
// This function should be a pure function.
|
||||
type BatchGroupFunc func(a, b *Message) bool |
||||
|
||||
// BatchConsume is like Consume, except it groups incoming messages together,
|
||||
// allowing them to be processed in batches instead of individually.
|
||||
//
|
||||
// BatchConsume first collects messages internally for half the
|
||||
// ConsumerOpts.Timeout value. Once that time has passed it will group all
|
||||
// messages based on the BatchGroupFunc (if nil then all collected messages form
|
||||
// one big group). The BatchConsumerFunc is called for each group, with the
|
||||
// context passed in having a timeout of ConsumerOpts.Timeout/2.
|
||||
//
|
||||
// The ConsumerOpts.Concurrent value determines the maximum number of messages
|
||||
// collected during the first section of the process (before BatchConsumerFn is
|
||||
// called).
|
||||
func (s *Subscription) BatchConsume( |
||||
ctx context.Context, |
||||
fn BatchConsumerFunc, gfn BatchGroupFunc, |
||||
opts ConsumerOpts, |
||||
) { |
||||
opts = opts.withDefaults() |
||||
|
||||
type promise struct { |
||||
msg *Message |
||||
retCh chan bool // must be buffered by one
|
||||
} |
||||
|
||||
var groups [][]promise |
||||
var groupsL sync.Mutex |
||||
|
||||
groupProm := func(prom promise) { |
||||
groupsL.Lock() |
||||
defer groupsL.Unlock() |
||||
for i := range groups { |
||||
if gfn == nil || gfn(groups[i][0].msg, prom.msg) { |
||||
groups[i] = append(groups[i], prom) |
||||
return |
||||
} |
||||
} |
||||
groups = append(groups, []promise{prom}) |
||||
} |
||||
|
||||
wg := new(sync.WaitGroup) |
||||
defer wg.Wait() |
||||
|
||||
processGroups := func() { |
||||
groupsL.Lock() |
||||
thisGroups := groups |
||||
groups = nil |
||||
groupsL.Unlock() |
||||
|
||||
// we do a waitgroup chain so as to properly handle the cancel
|
||||
// function. We hold wg (by adding one) until all routines spawned
|
||||
// here have finished, and once they have release wg and cancel
|
||||
thisCtx, cancel := context.WithTimeout(ctx, opts.Timeout/2) |
||||
thisWG := new(sync.WaitGroup) |
||||
thisWG.Add(1) |
||||
wg.Add(1) |
||||
go func() { |
||||
thisWG.Wait() |
||||
cancel() |
||||
wg.Done() |
||||
}() |
||||
|
||||
for i := range thisGroups { |
||||
thisGroup := thisGroups[i] |
||||
thisWG.Add(1) |
||||
go func() { |
||||
defer thisWG.Done() |
||||
msgs := make([]*Message, len(thisGroup)) |
||||
for i := range thisGroup { |
||||
msgs[i] = thisGroup[i].msg |
||||
} |
||||
ret, err := fn(thisCtx, msgs) |
||||
if err != nil { |
||||
mlog.From(s.cmp).Warn("error consuming pubsub batch messages", |
||||
s.ctx, thisCtx, merr.Context(err)) |
||||
} |
||||
for i := range thisGroup { |
||||
thisGroup[i].retCh <- ret // retCh is buffered
|
||||
} |
||||
}() |
||||
} |
||||
thisWG.Done() |
||||
} |
||||
|
||||
wg.Add(1) |
||||
go func() { |
||||
defer wg.Done() |
||||
tick := time.NewTicker(opts.Timeout / 2) |
||||
defer tick.Stop() |
||||
for { |
||||
select { |
||||
case <-tick.C: |
||||
processGroups() |
||||
case <-s.batchTestTrigger: |
||||
processGroups() |
||||
case <-ctx.Done(): |
||||
return |
||||
} |
||||
} |
||||
}() |
||||
|
||||
s.Consume(ctx, func(ctx context.Context, msg *Message) (bool, error) { |
||||
retCh := make(chan bool, 1) |
||||
groupProm(promise{msg: msg, retCh: retCh}) |
||||
select { |
||||
case ret := <-retCh: |
||||
return ret, nil |
||||
case <-ctx.Done(): |
||||
return false, merr.New("reading from batch grouping process timed out", s.ctx, ctx) |
||||
} |
||||
}, opts) |
||||
|
||||
} |
@ -1,174 +0,0 @@ |
||||
package mpubsub |
||||
|
||||
import ( |
||||
"context" |
||||
. "testing" |
||||
"time" |
||||
|
||||
"github.com/mediocregopher/mediocre-go-lib/mrand" |
||||
"github.com/mediocregopher/mediocre-go-lib/mtest" |
||||
"github.com/stretchr/testify/assert" |
||||
"github.com/stretchr/testify/require" |
||||
) |
||||
|
||||
// this requires the pubsub emulator to be running
|
||||
func TestPubSub(t *T) { |
||||
cmp := mtest.Component() |
||||
mtest.Env(cmp, "PUBSUB_GCE_PROJECT", "test") |
||||
ps := InstPubSub(cmp) |
||||
mtest.Run(cmp, t, func() { |
||||
topicName := "testTopic_" + mrand.Hex(8) |
||||
ctx := context.Background() |
||||
|
||||
// Topic shouldn't exist yet
|
||||
_, err := ps.Topic(ctx, topicName, false) |
||||
require.Error(t, err) |
||||
|
||||
// ...so create it
|
||||
topic, err := ps.Topic(ctx, topicName, true) |
||||
require.NoError(t, err) |
||||
|
||||
// Create a subscription and consumer
|
||||
sub, err := topic.Subscription(ctx, "testSub", true) |
||||
require.NoError(t, err) |
||||
|
||||
msgCh := make(chan *Message) |
||||
go sub.Consume(ctx, func(ctx context.Context, m *Message) (bool, error) { |
||||
msgCh <- m |
||||
return true, nil |
||||
}, ConsumerOpts{}) |
||||
time.Sleep(1 * time.Second) // give consumer time to actually start
|
||||
|
||||
// publish a message and make sure it gets consumed
|
||||
assert.NoError(t, topic.Publish(ctx, []byte("foo"))) |
||||
msg := <-msgCh |
||||
assert.Equal(t, []byte("foo"), msg.Data) |
||||
}) |
||||
} |
||||
|
||||
func TestBatchPubSub(t *T) { |
||||
cmp := mtest.Component() |
||||
mtest.Env(cmp, "PUBSUB_GCE_PROJECT", "test") |
||||
ps := InstPubSub(cmp) |
||||
mtest.Run(cmp, t, func() { |
||||
topicName := "testBatchTopic_" + mrand.Hex(8) |
||||
ctx := context.Background() |
||||
|
||||
topic, err := ps.Topic(ctx, topicName, true) |
||||
require.NoError(t, err) |
||||
|
||||
readBatch := func(ch chan []*Message) map[byte]int { |
||||
select { |
||||
case <-time.After(1 * time.Second): |
||||
assert.Fail(t, "waited too long to read batch") |
||||
return nil |
||||
case mm := <-ch: |
||||
ret := map[byte]int{} |
||||
for _, m := range mm { |
||||
ret[m.Data[0]]++ |
||||
} |
||||
return ret |
||||
} |
||||
} |
||||
|
||||
// we use the same sub across the next two sections to ensure that cleanup
|
||||
// also works
|
||||
sub, err := topic.Subscription(ctx, "testSub", true) |
||||
require.NoError(t, err) |
||||
sub.batchTestTrigger = make(chan bool) |
||||
|
||||
{ // no grouping
|
||||
// Create a subscription and consumer
|
||||
ctx, cancel := context.WithCancel(ctx) |
||||
ch := make(chan []*Message) |
||||
go func() { |
||||
sub.BatchConsume(ctx, |
||||
func(ctx context.Context, mm []*Message) (bool, error) { |
||||
ch <- mm |
||||
return true, nil |
||||
}, |
||||
nil, |
||||
ConsumerOpts{Concurrent: 5}, |
||||
) |
||||
close(ch) |
||||
}() |
||||
time.Sleep(1 * time.Second) // give consumer time to actually start
|
||||
|
||||
exp := map[byte]int{} |
||||
for i := byte(0); i <= 9; i++ { |
||||
require.NoError(t, topic.Publish(ctx, []byte{i})) |
||||
exp[i] = 1 |
||||
} |
||||
|
||||
time.Sleep(1 * time.Second) |
||||
sub.batchTestTrigger <- true |
||||
gotA := readBatch(ch) |
||||
assert.Len(t, gotA, 5) |
||||
|
||||
time.Sleep(1 * time.Second) |
||||
sub.batchTestTrigger <- true |
||||
gotB := readBatch(ch) |
||||
assert.Len(t, gotB, 5) |
||||
|
||||
for i, c := range gotB { |
||||
gotA[i] += c |
||||
} |
||||
assert.Equal(t, exp, gotA) |
||||
|
||||
time.Sleep(1 * time.Second) // give time to ack before cancelling
|
||||
cancel() |
||||
<-ch |
||||
} |
||||
|
||||
{ // with grouping
|
||||
ctx, cancel := context.WithCancel(ctx) |
||||
ch := make(chan []*Message) |
||||
go func() { |
||||
sub.BatchConsume(ctx, |
||||
func(ctx context.Context, mm []*Message) (bool, error) { |
||||
ch <- mm |
||||
return true, nil |
||||
}, |
||||
func(a, b *Message) bool { return a.Data[0]%2 == b.Data[0]%2 }, |
||||
ConsumerOpts{Concurrent: 10}, |
||||
) |
||||
close(ch) |
||||
}() |
||||
time.Sleep(1 * time.Second) // give consumer time to actually start
|
||||
|
||||
exp := map[byte]int{} |
||||
for i := byte(0); i <= 9; i++ { |
||||
require.NoError(t, topic.Publish(ctx, []byte{i})) |
||||
exp[i] = 1 |
||||
} |
||||
|
||||
time.Sleep(1 * time.Second) |
||||
sub.batchTestTrigger <- true |
||||
gotA := readBatch(ch) |
||||
assert.Len(t, gotA, 5) |
||||
gotB := readBatch(ch) |
||||
assert.Len(t, gotB, 5) |
||||
|
||||
assertGotGrouped := func(got map[byte]int) { |
||||
prev := byte(255) |
||||
for i := range got { |
||||
if prev != 255 { |
||||
assert.Equal(t, prev%2, i%2) |
||||
} |
||||
prev = i |
||||
} |
||||
} |
||||
|
||||
assertGotGrouped(gotA) |
||||
assertGotGrouped(gotB) |
||||
for i, c := range gotB { |
||||
gotA[i] += c |
||||
} |
||||
assert.Equal(t, exp, gotA) |
||||
|
||||
time.Sleep(1 * time.Second) // give time to ack before cancelling
|
||||
cancel() |
||||
<-ch |
||||
} |
||||
}) |
||||
} |
@ -1,75 +0,0 @@ |
||||
// Package mredis implements connecting to a redis instance.
|
||||
package mredis |
||||
|
||||
import ( |
||||
"context" |
||||
|
||||
"github.com/mediocregopher/mediocre-go-lib/mcfg" |
||||
"github.com/mediocregopher/mediocre-go-lib/mcmp" |
||||
"github.com/mediocregopher/mediocre-go-lib/mlog" |
||||
"github.com/mediocregopher/mediocre-go-lib/mrun" |
||||
"github.com/mediocregopher/radix/v3" |
||||
) |
||||
|
||||
// Redis is a wrapper around a redis client which provides more functionality.
|
||||
type Redis struct { |
||||
radix.Client |
||||
cmp *mcmp.Component |
||||
} |
||||
|
||||
type redisOpts struct { |
||||
dialOpts []radix.DialOpt |
||||
} |
||||
|
||||
// RedisOption is a value which adjusts the behavior of InstRedis.
|
||||
type RedisOption func(*redisOpts) |
||||
|
||||
// RedisDialOpts specifies that the given set of DialOpts should be used when
|
||||
// creating any new connections.
|
||||
func RedisDialOpts(dialOpts ...radix.DialOpt) RedisOption { |
||||
return func(opts *redisOpts) { |
||||
opts.dialOpts = dialOpts |
||||
} |
||||
} |
||||
|
||||
// InstRedis instantiates a Redis instance which will be initialized when the
|
||||
// Init event is triggered on the given Component. The redis client will have
|
||||
// Close called on it when the Shutdown event is triggered on the given
|
||||
// Component.
|
||||
func InstRedis(parent *mcmp.Component, options ...RedisOption) *Redis { |
||||
var opts redisOpts |
||||
for _, opt := range options { |
||||
opt(&opts) |
||||
} |
||||
|
||||
cmp := parent.Child("redis") |
||||
client := new(struct{ radix.Client }) |
||||
|
||||
addr := mcfg.String(cmp, "addr", |
||||
mcfg.ParamDefault("127.0.0.1:6379"), |
||||
mcfg.ParamUsage("Address redis is listening on")) |
||||
poolSize := mcfg.Int(cmp, "pool-size", |
||||
mcfg.ParamDefault(4), |
||||
mcfg.ParamUsage("Number of connections in pool")) |
||||
mrun.InitHook(cmp, func(ctx context.Context) error { |
||||
cmp.Annotate("addr", *addr, "poolSize", *poolSize) |
||||
mlog.From(cmp).Info("connecting to redis", ctx) |
||||
var err error |
||||
client.Client, err = radix.NewPool( |
||||
"tcp", *addr, *poolSize, |
||||
radix.PoolConnFunc(func(network, addr string) (radix.Conn, error) { |
||||
return radix.Dial(network, addr, opts.dialOpts...) |
||||
}), |
||||
) |
||||
return err |
||||
}) |
||||
mrun.ShutdownHook(cmp, func(ctx context.Context) error { |
||||
mlog.From(cmp).Info("shutting down redis", ctx) |
||||
return client.Close() |
||||
}) |
||||
|
||||
return &Redis{ |
||||
Client: client, |
||||
cmp: cmp, |
||||
} |
||||
} |
@ -1,22 +0,0 @@ |
||||
package mredis |
||||
|
||||
import ( |
||||
. "testing" |
||||
|
||||
"github.com/mediocregopher/mediocre-go-lib/mtest" |
||||
|
||||
"github.com/mediocregopher/radix/v3" |
||||
) |
||||
|
||||
func TestRedis(t *T) { |
||||
cmp := mtest.Component() |
||||
redis := InstRedis(cmp) |
||||
mtest.Run(cmp, t, func() { |
||||
var info string |
||||
if err := redis.Do(radix.Cmd(&info, "INFO")); err != nil { |
||||
t.Fatal(err) |
||||
} else if len(info) < 0 { |
||||
t.Fatal("empty info return") |
||||
} |
||||
}) |
||||
} |
@ -1,247 +0,0 @@ |
||||
package mredis |
||||
|
||||
import ( |
||||
"bufio" |
||||
"errors" |
||||
"strconv" |
||||
"strings" |
||||
"sync/atomic" |
||||
"time" |
||||
|
||||
"github.com/mediocregopher/mediocre-go-lib/mctx" |
||||
"github.com/mediocregopher/mediocre-go-lib/merr" |
||||
|
||||
"github.com/mediocregopher/radix/v3" |
||||
"github.com/mediocregopher/radix/v3/resp/resp2" |
||||
) |
||||
|
||||
// borrowed from radix
|
||||
type streamReaderEntry struct { |
||||
stream []byte |
||||
entries []radix.StreamEntry |
||||
} |
||||
|
||||
func (s *streamReaderEntry) UnmarshalRESP(br *bufio.Reader) error { |
||||
var ah resp2.ArrayHeader |
||||
if err := ah.UnmarshalRESP(br); err != nil { |
||||
return err |
||||
} |
||||
if ah.N != 2 { |
||||
return errors.New("invalid xread[group] response") |
||||
} |
||||
|
||||
var stream resp2.BulkStringBytes |
||||
stream.B = s.stream[:0] |
||||
if err := stream.UnmarshalRESP(br); err != nil { |
||||
return err |
||||
} |
||||
s.stream = stream.B |
||||
|
||||
return (resp2.Any{I: &s.entries}).UnmarshalRESP(br) |
||||
} |
||||
|
||||
// StreamEntry wraps radix's StreamEntry type in order to provde some extra
|
||||
// functionality.
|
||||
type StreamEntry struct { |
||||
radix.StreamEntry |
||||
|
||||
// Ack is used in order to acknowledge that a stream message has been
|
||||
// successfully consumed and should not be consumed again.
|
||||
Ack func() error |
||||
|
||||
// Nack is used to declare that a stream message was not successfully
|
||||
// consumed and it needs to be consumed again.
|
||||
Nack func() |
||||
} |
||||
|
||||
// StreamOpts are options used to initialize a Stream instance. Fields are
|
||||
// required unless otherwise noted.
|
||||
type StreamOpts struct { |
||||
// Key is the redis key at which the redis stream resides.
|
||||
Key string |
||||
|
||||
// Group is the name of the consumer group which will consume from Key.
|
||||
Group string |
||||
|
||||
// Consumer is the name of this particular consumer. This value should
|
||||
// remain the same across restarts of the process.
|
||||
Consumer string |
||||
|
||||
// (Optional) InitialCursor is only used when the consumer group is first
|
||||
// being created, and indicates where in the stream the consumer group
|
||||
// should start consuming from.
|
||||
//
|
||||
// "0" indicates the group should consume from the start of the stream. "$"
|
||||
// indicates the group should not consume any old messages, only those added
|
||||
// after the group is initialized. A specific message id can be given to
|
||||
// consume only those messages with greater ids.
|
||||
//
|
||||
// Defaults to "$".
|
||||
InitialCursor string |
||||
|
||||
// (Optional) ReadCount indicates the max number of messages which should be
|
||||
// read on every XREADGROUP call. 0 indicates no limit.
|
||||
ReadCount int |
||||
|
||||
// (Optional) Block indicates what BLOCK value is sent to XREADGROUP calls.
|
||||
// This value _must_ be less than the ReadtTimeout the redis client is
|
||||
// using.
|
||||
//
|
||||
// Defaults to 5 * time.Second
|
||||
Block time.Duration |
||||
} |
||||
|
||||
func (opts *StreamOpts) fillDefaults() { |
||||
if opts.InitialCursor == "" { |
||||
opts.InitialCursor = "$" |
||||
} |
||||
if opts.Block == 0 { |
||||
opts.Block = 5 * time.Second |
||||
} |
||||
} |
||||
|
||||
// Stream wraps a Redis instance in order to provide an abstraction over
|
||||
// consuming messages from a single redis stream. Stream is intended to be used
|
||||
// in a single-threaded manner, and doesn't spawn any go-routines.
|
||||
//
|
||||
// See https://redis.io/topics/streams-intro
|
||||
type Stream struct { |
||||
client *Redis |
||||
opts StreamOpts |
||||
|
||||
// entries are stored to buf in id decreasing order, and then read from it
|
||||
// from back-to-front. This allows us to not have to re-allocate the buffer
|
||||
// during runtime.
|
||||
buf []StreamEntry |
||||
|
||||
hasInit bool |
||||
numPending int64 |
||||
} |
||||
|
||||
// NewStream initializes and returns a Stream instance using the given options.
|
||||
func NewStream(r *Redis, opts StreamOpts) *Stream { |
||||
opts.fillDefaults() |
||||
return &Stream{ |
||||
client: r, |
||||
opts: opts, |
||||
buf: make([]StreamEntry, 0, opts.ReadCount), |
||||
} |
||||
} |
||||
|
||||
func (s *Stream) getNumPending() (int64, error) { |
||||
var res []interface{} |
||||
err := s.client.Do(radix.Cmd(&res, "XPENDING", s.opts.Key, s.opts.Group)) |
||||
if err != nil { |
||||
return 0, merr.Wrap(err, s.client.cmp.Context()) |
||||
} |
||||
return res[0].(int64), nil |
||||
} |
||||
|
||||
func (s *Stream) init() error { |
||||
// MKSTREAM is not documented, but will make the stream if it doesn't
|
||||
// already exist. Only the most elite redis gurus know of it's
|
||||
// existence, don't tell anyone.
|
||||
err := s.client.Do(radix.Cmd(nil, "XGROUP", "CREATE", s.opts.Key, s.opts.Group, s.opts.InitialCursor, "MKSTREAM")) |
||||
if err == nil { |
||||
// cool
|
||||
} else if errStr := err.Error(); !strings.HasPrefix(errStr, `BUSYGROUP Consumer Group name already exists`) { |
||||
return merr.Wrap(err, s.client.cmp.Context()) |
||||
} |
||||
|
||||
numPending, err := s.getNumPending() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
atomic.StoreInt64(&s.numPending, numPending) |
||||
|
||||
// if we're here it means init succeeded, mark as such and gtfo
|
||||
s.hasInit = true |
||||
return nil |
||||
} |
||||
|
||||
func (s *Stream) wrapEntry(entry radix.StreamEntry) StreamEntry { |
||||
return StreamEntry{ |
||||
StreamEntry: entry, |
||||
Ack: func() error { |
||||
return s.client.Do(radix.Cmd(nil, "XACK", s.opts.Key, s.opts.Group, entry.ID.String())) |
||||
}, |
||||
Nack: func() { atomic.AddInt64(&s.numPending, 1) }, |
||||
} |
||||
} |
||||
|
||||
func (s *Stream) fillBufFrom(id string) error { |
||||
args := []string{"GROUP", s.opts.Group, s.opts.Consumer} |
||||
if s.opts.ReadCount > 0 { |
||||
args = append(args, "COUNT", strconv.Itoa(s.opts.ReadCount)) |
||||
} |
||||
blockMS := int(s.opts.Block / time.Millisecond) |
||||
args = append(args, "BLOCK", strconv.Itoa(blockMS)) |
||||
args = append(args, "STREAMS", s.opts.Key, id) |
||||
|
||||
var srEntries []streamReaderEntry |
||||
err := s.client.Do(radix.Cmd(&srEntries, "XREADGROUP", args...)) |
||||
if err != nil { |
||||
return merr.Wrap(err, s.client.cmp.Context()) |
||||
} else if len(srEntries) == 0 { |
||||
return nil // no messages
|
||||
} else if len(srEntries) != 1 || string(srEntries[0].stream) != s.opts.Key { |
||||
return merr.New("malformed return from XREADGROUP", |
||||
mctx.Annotate(s.client.cmp.Context(), "srEntries", srEntries)) |
||||
} |
||||
entries := srEntries[0].entries |
||||
|
||||
for i := len(entries) - 1; i >= 0; i-- { |
||||
s.buf = append(s.buf, s.wrapEntry(entries[i])) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func (s *Stream) fillBuf() error { |
||||
if len(s.buf) > 0 { |
||||
return nil |
||||
} else if !s.hasInit { |
||||
if err := s.init(); err != nil { |
||||
return err |
||||
} else if !s.hasInit { |
||||
return nil |
||||
} |
||||
} |
||||
|
||||
numPending := atomic.LoadInt64(&s.numPending) |
||||
if numPending > 0 { |
||||
if err := s.fillBufFrom("0"); err != nil { |
||||
return err |
||||
} else if len(s.buf) > 0 { |
||||
return nil |
||||
} |
||||
|
||||
// no pending entries, we can mark Stream as such and continue. This
|
||||
// _might_ fail if another routine called Nack in between originally
|
||||
// loading numPending and now, in which case we should leave the buffer
|
||||
// alone and let it get filled again later.
|
||||
if !atomic.CompareAndSwapInt64(&s.numPending, numPending, 0) { |
||||
return nil |
||||
} |
||||
} |
||||
|
||||
return s.fillBufFrom(">") |
||||
} |
||||
|
||||
// Next returns the next StreamEntry which needs processing, or false. This
|
||||
// method is expected to block for up to the value of the Block field in
|
||||
// StreamOpts.
|
||||
//
|
||||
// If an error is returned it's up to the caller whether or not they want to
|
||||
// keep retrying.
|
||||
func (s *Stream) Next() (StreamEntry, bool, error) { |
||||
if err := s.fillBuf(); err != nil { |
||||
return StreamEntry{}, false, err |
||||
} else if len(s.buf) == 0 { |
||||
return StreamEntry{}, false, nil |
||||
} |
||||
|
||||
l := len(s.buf) |
||||
entry := s.buf[l-1] |
||||
s.buf = s.buf[:l-1] |
||||
return entry, true, nil |
||||
} |
@ -1,158 +0,0 @@ |
||||
package mredis |
||||
|
||||
import ( |
||||
"reflect" |
||||
"sync" |
||||
. "testing" |
||||
"time" |
||||
|
||||
"github.com/mediocregopher/mediocre-go-lib/mrand" |
||||
"github.com/mediocregopher/mediocre-go-lib/mtest" |
||||
|
||||
"github.com/mediocregopher/radix/v3" |
||||
) |
||||
|
||||
func TestStream(t *T) { |
||||
cmp := mtest.Component() |
||||
redis := InstRedis(cmp) |
||||
|
||||
streamKey := "stream-" + mrand.Hex(8) |
||||
group := "group-" + mrand.Hex(8) |
||||
stream := NewStream(redis, StreamOpts{ |
||||
Key: streamKey, |
||||
Group: group, |
||||
Consumer: "consumer-" + mrand.Hex(8), |
||||
InitialCursor: "0", |
||||
}) |
||||
|
||||
mtest.Run(cmp, t, func() { |
||||
// once the test is ready to be finished up this will be closed
|
||||
finishUpCh := make(chan struct{}) |
||||
|
||||
// continually publish messages, adding them to the expEntries
|
||||
t.Log("creating publisher") |
||||
pubDone := make(chan struct{}) |
||||
expEntries := map[radix.StreamEntryID]radix.StreamEntry{} |
||||
go func() { |
||||
defer close(pubDone) |
||||
tick := time.NewTicker(50 * time.Millisecond) |
||||
defer tick.Stop() |
||||
|
||||
for { |
||||
var id radix.StreamEntryID |
||||
key, val := mrand.Hex(8), mrand.Hex(8) |
||||
if err := redis.Do(radix.Cmd(&id, "XADD", streamKey, "*", key, val)); err != nil { |
||||
t.Fatalf("error XADDing: %v", err) |
||||
} |
||||
expEntries[id] = radix.StreamEntry{ |
||||
ID: id, |
||||
Fields: map[string]string{key: val}, |
||||
} |
||||
|
||||
select { |
||||
case <-tick.C: |
||||
continue |
||||
case <-finishUpCh: |
||||
return |
||||
} |
||||
} |
||||
}() |
||||
|
||||
gotEntriesL := new(sync.Mutex) |
||||
gotEntries := map[radix.StreamEntryID]radix.StreamEntry{} |
||||
|
||||
// spawn some workers which will process the StreamEntry's. We do this
|
||||
// to try and suss out any race conditions with Nack'ing. Each worker
|
||||
// will have a random chance of Nack'ing, until finishUpCh is closed and
|
||||
// then they will Ack everything.
|
||||
t.Log("creating workers") |
||||
const numWorkers = 5 |
||||
wg := new(sync.WaitGroup) |
||||
entryCh := make(chan StreamEntry, numWorkers*10) |
||||
for i := 0; i < numWorkers; i++ { |
||||
wg.Add(1) |
||||
go func() { |
||||
defer wg.Done() |
||||
for entry := range entryCh { |
||||
select { |
||||
case <-finishUpCh: |
||||
default: |
||||
if mrand.Intn(10) == 0 { |
||||
entry.Nack() |
||||
continue |
||||
} |
||||
} |
||||
|
||||
if err := entry.Ack(); err != nil { |
||||
t.Fatalf("error calling Ack: %v", err) |
||||
} |
||||
gotEntriesL.Lock() |
||||
gotEntries[entry.ID] = entry.StreamEntry |
||||
gotEntriesL.Unlock() |
||||
} |
||||
}() |
||||
} |
||||
|
||||
t.Log("consuming...") |
||||
waitTimer := time.After(5 * time.Second) |
||||
loop: |
||||
for { |
||||
select { |
||||
case <-waitTimer: |
||||
break loop |
||||
default: |
||||
} |
||||
|
||||
entry, ok, err := stream.Next() |
||||
if err != nil { |
||||
t.Fatalf("error calling Next (1): %v", err) |
||||
} else if ok { |
||||
entryCh <- entry |
||||
} |
||||
} |
||||
|
||||
// after 5 seconds we declare that it's time to finish up
|
||||
t.Log("finishing up...") |
||||
close(finishUpCh) |
||||
<-pubDone |
||||
|
||||
// Keep consuming until all messages have come in, then tell the workers
|
||||
// to clean themselves up.
|
||||
t.Log("consuming last of the entries") |
||||
for { |
||||
entry, ok, err := stream.Next() |
||||
if err != nil { |
||||
t.Fatalf("error calling Next (2): %v", err) |
||||
} else if ok { |
||||
entryCh <- entry |
||||
} else { |
||||
break // must be empty
|
||||
} |
||||
} |
||||
close(entryCh) |
||||
wg.Wait() |
||||
t.Log("all workers cleaned up") |
||||
|
||||
// call XPENDING to see if anything comes back, nothing should.
|
||||
t.Log("checking for leftover pending entries") |
||||
numPending, err := stream.getNumPending() |
||||
if err != nil { |
||||
t.Fatalf("error calling XPENDING: %v", err) |
||||
} else if numPending > 0 { |
||||
t.Fatalf("XPENDING says there's %v pending msgs, there should be 0", numPending) |
||||
} |
||||
|
||||
if len(expEntries) != len(gotEntries) { |
||||
t.Errorf("len(expEntries):%d != len(gotEntries):%d", len(expEntries), len(gotEntries)) |
||||
} |
||||
|
||||
for id, expEntry := range expEntries { |
||||
gotEntry, ok := gotEntries[id] |
||||
if !ok { |
||||
t.Errorf("did not consume entry %s", id) |
||||
} else if !reflect.DeepEqual(gotEntry, expEntry) { |
||||
t.Errorf("expEntry:%#v != gotEntry:%#v", expEntry, gotEntry) |
||||
} |
||||
} |
||||
}) |
||||
} |
@ -1,69 +0,0 @@ |
||||
// Package msql implements connecting to a MySQL/MariaDB instance (and possibly
|
||||
// others) and simplifies a number of interactions with it.
|
||||
package msql |
||||
|
||||
import ( |
||||
"context" |
||||
"fmt" |
||||
|
||||
// If something is importing msql it must need mysql, because that's all
|
||||
// that is implemented at the moment
|
||||
_ "github.com/go-sql-driver/mysql" |
||||
"github.com/jmoiron/sqlx" |
||||
"github.com/mediocregopher/mediocre-go-lib/mcfg" |
||||
"github.com/mediocregopher/mediocre-go-lib/mcmp" |
||||
"github.com/mediocregopher/mediocre-go-lib/mctx" |
||||
"github.com/mediocregopher/mediocre-go-lib/merr" |
||||
"github.com/mediocregopher/mediocre-go-lib/mlog" |
||||
"github.com/mediocregopher/mediocre-go-lib/mrun" |
||||
) |
||||
|
||||
// SQL is a wrapper around a sqlx client which provides more functionality.
|
||||
type SQL struct { |
||||
*sqlx.DB |
||||
cmp *mcmp.Component |
||||
} |
||||
|
||||
// InstMySQL returns a SQL instance which will be initialized when the Init
|
||||
// event is triggered on the given Component. The SQL instance will have Close
|
||||
// called on it when the Shutdown event is triggered on the given Component.
|
||||
//
|
||||
// defaultDB indicates the name of the database in MySQL to use by default,
|
||||
// though it will be overwritable in the config.
|
||||
func InstMySQL(cmp *mcmp.Component, defaultDB string) *SQL { |
||||
sql := SQL{cmp: cmp.Child("mysql")} |
||||
|
||||
addr := mcfg.String(sql.cmp, "addr", |
||||
mcfg.ParamDefault("[::1]:3306"), |
||||
mcfg.ParamUsage("Address where MySQL server can be found")) |
||||
user := mcfg.String(sql.cmp, "user", |
||||
mcfg.ParamDefault("root"), |
||||
mcfg.ParamUsage("User to authenticate to MySQL server as")) |
||||
pass := mcfg.String(sql.cmp, "password", |
||||
mcfg.ParamUsage("Password to authenticate to MySQL server with")) |
||||
db := mcfg.String(sql.cmp, "database", |
||||
mcfg.ParamDefault(defaultDB), |
||||
mcfg.ParamUsage("MySQL database to use")) |
||||
|
||||
mrun.InitHook(sql.cmp, func(ctx context.Context) error { |
||||
sql.cmp.Annotate("addr", *addr, "user", *user) |
||||
dsn := fmt.Sprintf("%s:%s@tcp(%s)/%s", *user, *pass, *addr, *db) |
||||
mlog.From(sql.cmp).Debug("constructed dsn", mctx.Annotate(ctx, "dsn", dsn)) |
||||
mlog.From(sql.cmp).Info("connecting to MySQL server", ctx) |
||||
var err error |
||||
sql.DB, err = sqlx.ConnectContext(ctx, "mysql", dsn) |
||||
return merr.Wrap(err, sql.cmp.Context(), ctx) |
||||
}) |
||||
|
||||
mrun.ShutdownHook(sql.cmp, func(ctx context.Context) error { |
||||
mlog.From(sql.cmp).Info("closing connection to MySQL server", ctx) |
||||
return merr.Wrap(sql.Close(), sql.cmp.Context(), ctx) |
||||
}) |
||||
|
||||
return &sql |
||||
} |
||||
|
||||
// Context returns the annotated Context from this instance's initialization.
|
||||
func (sql *SQL) Context() context.Context { |
||||
return sql.cmp.Context() |
||||
} |
@ -1,18 +0,0 @@ |
||||
package msql |
||||
|
||||
import ( |
||||
. "testing" |
||||
|
||||
"github.com/mediocregopher/mediocre-go-lib/mtest" |
||||
) |
||||
|
||||
func TestMySQL(t *T) { |
||||
cmp := mtest.Component() |
||||
sql := InstMySQL(cmp, "test") |
||||
mtest.Run(cmp, t, func() { |
||||
_, err := sql.Exec("CREATE TABLE IF NOT EXISTS msql_test (id INT);") |
||||
if err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
}) |
||||
} |
@ -1,48 +0,0 @@ |
||||
package merr |
||||
|
||||
import ( |
||||
"strings" |
||||
. "testing" |
||||
|
||||
"github.com/mediocregopher/mediocre-go-lib/mtest/massert" |
||||
) |
||||
|
||||
func TestStack(t *T) { |
||||
foo := New("test") |
||||
fooStack, ok := Stack(foo) |
||||
massert.Require(t, massert.Equal(true, ok)) |
||||
|
||||
// test Frame
|
||||
frame := fooStack.Frame() |
||||
massert.Require(t, |
||||
massert.Equal(true, strings.Contains(frame.File, "stack_test.go")), |
||||
massert.Equal(true, strings.Contains(frame.Function, "TestStack")), |
||||
) |
||||
|
||||
frames := fooStack.Frames() |
||||
massert.Require(t, massert.Comment( |
||||
massert.All( |
||||
massert.Equal(true, len(frames) >= 2), |
||||
massert.Equal(true, strings.Contains(frames[0].File, "stack_test.go")), |
||||
massert.Equal(true, strings.Contains(frames[0].Function, "TestStack")), |
||||
), |
||||
"fooStack.FullString():\n%s", fooStack.FullString(), |
||||
)) |
||||
|
||||
// test that WithStack works and can be used to skip frames
|
||||
inner := func() { |
||||
bar := WithStack(foo, 1) |
||||
barStack, _ := Stack(bar) |
||||
frames := barStack.Frames() |
||||
massert.Require(t, massert.Comment( |
||||
massert.All( |
||||
massert.Equal(true, len(frames) >= 2), |
||||
massert.Equal(true, strings.Contains(frames[0].File, "stack_test.go")), |
||||
massert.Equal(true, strings.Contains(frames[0].Function, "TestStack")), |
||||
), |
||||
"barStack.FullString():\n%s", barStack.FullString(), |
||||
)) |
||||
} |
||||
inner() |
||||
|
||||
} |
@ -1,107 +0,0 @@ |
||||
// Package mhttp extends the standard package with extra functionality which is
|
||||
// commonly useful
|
||||
package mhttp |
||||
|
||||
import ( |
||||
"context" |
||||
"net" |
||||
"net/http" |
||||
"net/http/httputil" |
||||
"net/url" |
||||
"strings" |
||||
|
||||
"github.com/mediocregopher/mediocre-go-lib/mcmp" |
||||
"github.com/mediocregopher/mediocre-go-lib/merr" |
||||
"github.com/mediocregopher/mediocre-go-lib/mlog" |
||||
"github.com/mediocregopher/mediocre-go-lib/mnet" |
||||
"github.com/mediocregopher/mediocre-go-lib/mrun" |
||||
) |
||||
|
||||
// Server is returned by WithListeningServer and simply wraps an *http.Server.
|
||||
type Server struct { |
||||
*http.Server |
||||
cmp *mcmp.Component |
||||
} |
||||
|
||||
// InstListeningServer returns a *Server which will be initialized and have
|
||||
// ListenAndServe called on it (asynchronously) when the Init event is triggered
|
||||
// on the Component. The Server will have Shutdown called on it when the
|
||||
// Shutdown event is triggered on the Component.
|
||||
//
|
||||
// This function automatically handles setting up configuration parameters via
|
||||
// mcfg. The default listen address is ":0".
|
||||
func InstListeningServer(cmp *mcmp.Component, h http.Handler) *Server { |
||||
srv := &Server{ |
||||
Server: &http.Server{Handler: h}, |
||||
cmp: cmp.Child("http"), |
||||
} |
||||
|
||||
listener := mnet.InstListener(srv.cmp, |
||||
// http.Server.Shutdown will handle this
|
||||
mnet.ListenerCloseOnShutdown(false), |
||||
) |
||||
|
||||
threadCtx := context.Background() |
||||
mrun.InitHook(srv.cmp, func(ctx context.Context) error { |
||||
srv.Addr = listener.Addr().String() |
||||
threadCtx = mrun.WithThreads(threadCtx, 1, func() error { |
||||
mlog.From(srv.cmp).Info("serving requests", ctx) |
||||
if err := srv.Serve(listener); !merr.Equal(err, http.ErrServerClosed) { |
||||
mlog.From(srv.cmp).Error("error serving listener", ctx, merr.Context(err)) |
||||
return merr.Wrap(err, srv.cmp.Context(), ctx) |
||||
} |
||||
return nil |
||||
}) |
||||
return nil |
||||
}) |
||||
|
||||
mrun.ShutdownHook(srv.cmp, func(ctx context.Context) error { |
||||
mlog.From(srv.cmp).Info("shutting down server", ctx) |
||||
if err := srv.Shutdown(ctx); err != nil { |
||||
return merr.Wrap(err, srv.cmp.Context(), ctx) |
||||
} |
||||
err := mrun.Wait(threadCtx, ctx.Done()) |
||||
return merr.Wrap(err, srv.cmp.Context(), ctx) |
||||
}) |
||||
|
||||
return srv |
||||
} |
||||
|
||||
// AddXForwardedFor populates the X-Forwarded-For header on the Request to
|
||||
// convey that the request is being proxied for IP.
|
||||
//
|
||||
// If the IP is invalid, loopback, or otherwise part of a reserved range, this
|
||||
// does nothing.
|
||||
func AddXForwardedFor(r *http.Request, ipStr string) { |
||||
const xff = "X-Forwarded-For" |
||||
ip := net.ParseIP(ipStr) |
||||
if ip == nil || mnet.IsReservedIP(ip) { // IsReservedIP includes loopback
|
||||
return |
||||
} |
||||
prev, _ := r.Header[xff] |
||||
r.Header.Set(xff, strings.Join(append(prev, ip.String()), ", ")) |
||||
} |
||||
|
||||
// ReverseProxy returns an httputil.ReverseProxy which will send requests to the
|
||||
// given URL and copy their responses back without modification.
|
||||
//
|
||||
// Only the Scheme and Host of the given URL are used.
|
||||
//
|
||||
// Any http.ResponseWriters passed into the ServeHTTP call of the returned
|
||||
// instance should not be modified afterwards.
|
||||
func ReverseProxy(u *url.URL) *httputil.ReverseProxy { |
||||
rp := new(httputil.ReverseProxy) |
||||
rp.Director = func(req *http.Request) { |
||||
if ipStr, _, err := net.SplitHostPort(req.RemoteAddr); err != nil { |
||||
AddXForwardedFor(req, ipStr) |
||||
} |
||||
|
||||
req.URL.Scheme = u.Scheme |
||||
req.URL.Host = u.Host |
||||
} |
||||
|
||||
// TODO when this package has a function for creating a Client use that for
|
||||
// the default here
|
||||
|
||||
return rp |
||||
} |
@ -1,70 +0,0 @@ |
||||
package mhttp |
||||
|
||||
import ( |
||||
"bytes" |
||||
"io" |
||||
"io/ioutil" |
||||
"net/http" |
||||
"net/http/httptest" |
||||
. "testing" |
||||
|
||||
"github.com/mediocregopher/mediocre-go-lib/mtest" |
||||
"github.com/mediocregopher/mediocre-go-lib/mtest/massert" |
||||
) |
||||
|
||||
func TestMListenAndServe(t *T) { |
||||
cmp := mtest.Component() |
||||
|
||||
srv := InstListeningServer(cmp, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { |
||||
io.Copy(rw, r.Body) |
||||
})) |
||||
|
||||
mtest.Run(cmp, t, func() { |
||||
body := bytes.NewBufferString("HELLO") |
||||
resp, err := http.Post("http://"+srv.Addr, "text/plain", body) |
||||
if err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
defer resp.Body.Close() |
||||
|
||||
respBody, err := ioutil.ReadAll(resp.Body) |
||||
if err != nil { |
||||
t.Fatal(err) |
||||
} else if string(respBody) != "HELLO" { |
||||
t.Fatalf("unexpected respBody: %q", respBody) |
||||
} |
||||
}) |
||||
} |
||||
|
||||
func TestAddXForwardedFor(t *T) { |
||||
assertXFF := func(prev []string, ipStr, expected string) massert.Assertion { |
||||
r := httptest.NewRequest("GET", "/", nil) |
||||
for i := range prev { |
||||
r.Header.Add("X-Forwarded-For", prev[i]) |
||||
} |
||||
AddXForwardedFor(r, ipStr) |
||||
var a massert.Assertion |
||||
if expected == "" { |
||||
a = massert.Length(r.Header["X-Forwarded-For"], 0) |
||||
} else { |
||||
a = massert.All( |
||||
massert.Length(r.Header["X-Forwarded-For"], 1), |
||||
massert.Equal(expected, r.Header["X-Forwarded-For"][0]), |
||||
) |
||||
} |
||||
return massert.Comment(a, "prev:%#v ipStr:%q", prev, ipStr) |
||||
} |
||||
|
||||
massert.Require(t, |
||||
assertXFF(nil, "invalid", ""), |
||||
assertXFF(nil, "::1", ""), |
||||
assertXFF([]string{"8.0.0.0"}, "invalid", "8.0.0.0"), |
||||
assertXFF([]string{"8.0.0.0"}, "::1", "8.0.0.0"), |
||||
|
||||
assertXFF(nil, "8.0.0.0", "8.0.0.0"), |
||||
assertXFF([]string{"8.0.0.0"}, "8.0.0.1", "8.0.0.0, 8.0.0.1"), |
||||
assertXFF([]string{"8.0.0.0, 8.0.0.1"}, "8.0.0.2", "8.0.0.0, 8.0.0.1, 8.0.0.2"), |
||||
assertXFF([]string{"8.0.0.0, 8.0.0.1", "8.0.0.2"}, "8.0.0.3", |
||||
"8.0.0.0, 8.0.0.1, 8.0.0.2, 8.0.0.3"), |
||||
) |
||||
} |
@ -1,211 +0,0 @@ |
||||
// Package mnet extends the standard package with extra functionality which is
|
||||
// commonly useful
|
||||
package mnet |
||||
|
||||
import ( |
||||
"context" |
||||
"net" |
||||
"strings" |
||||
|
||||
"github.com/mediocregopher/mediocre-go-lib/mcfg" |
||||
"github.com/mediocregopher/mediocre-go-lib/mcmp" |
||||
"github.com/mediocregopher/mediocre-go-lib/mctx" |
||||
"github.com/mediocregopher/mediocre-go-lib/merr" |
||||
"github.com/mediocregopher/mediocre-go-lib/mlog" |
||||
"github.com/mediocregopher/mediocre-go-lib/mrun" |
||||
) |
||||
|
||||
// Listener is returned by InstListener and simply wraps a net.Listener.
|
||||
type Listener struct { |
||||
// One of these will be populated during the start hook, depending on the
|
||||
// protocol configured.
|
||||
net.Listener |
||||
net.PacketConn |
||||
|
||||
cmp *mcmp.Component |
||||
} |
||||
|
||||
type listenerOpts struct { |
||||
proto string |
||||
defaultAddr string |
||||
closeOnShutdown bool |
||||
} |
||||
|
||||
func (lOpts listenerOpts) isPacketConn() bool { |
||||
proto := strings.ToLower(lOpts.proto) |
||||
return strings.HasPrefix(proto, "udp") || |
||||
proto == "unixgram" || |
||||
strings.HasPrefix(proto, "ip") |
||||
} |
||||
|
||||
// ListenerOpt is a value which adjusts the behavior of InstListener.
|
||||
type ListenerOpt func(*listenerOpts) |
||||
|
||||
// ListenerProtocol adjusts the protocol which the Listener uses. The default is
|
||||
// "tcp".
|
||||
func ListenerProtocol(proto string) ListenerOpt { |
||||
return func(opts *listenerOpts) { |
||||
opts.proto = proto |
||||
} |
||||
} |
||||
|
||||
// ListenerCloseOnShutdown sets the Listener's behavior when mrun's Shutdown
|
||||
// event is triggered on its Component. If true the Listener will call Close on
|
||||
// itself, if false it will do nothing.
|
||||
//
|
||||
// Defaults to true.
|
||||
func ListenerCloseOnShutdown(closeOnShutdown bool) ListenerOpt { |
||||
return func(opts *listenerOpts) { |
||||
opts.closeOnShutdown = closeOnShutdown |
||||
} |
||||
} |
||||
|
||||
// ListenerDefaultAddr adjusts the defaultAddr which the Listener will use. The
|
||||
// addr will still be configurable via mcfg regardless of what this is set to.
|
||||
// The default is ":0".
|
||||
func ListenerDefaultAddr(defaultAddr string) ListenerOpt { |
||||
return func(opts *listenerOpts) { |
||||
opts.defaultAddr = defaultAddr |
||||
} |
||||
} |
||||
|
||||
// InstListener instantiates a Listener which will be initialized when the Init
|
||||
// event is triggered on the given Component, and closed when the Shutdown event
|
||||
// is triggered on the returned Component.
|
||||
func InstListener(cmp *mcmp.Component, opts ...ListenerOpt) *Listener { |
||||
lOpts := listenerOpts{ |
||||
proto: "tcp", |
||||
defaultAddr: ":0", |
||||
closeOnShutdown: true, |
||||
} |
||||
for _, opt := range opts { |
||||
opt(&lOpts) |
||||
} |
||||
|
||||
cmp = cmp.Child("net") |
||||
l := &Listener{cmp: cmp} |
||||
|
||||
addr := mcfg.String(cmp, "listen-addr", |
||||
mcfg.ParamDefault(lOpts.defaultAddr), |
||||
mcfg.ParamUsage( |
||||
strings.ToUpper(lOpts.proto)+" address to listen on in format "+ |
||||
"[host]:port. If port is 0 then a random one will be chosen", |
||||
), |
||||
) |
||||
|
||||
mrun.InitHook(cmp, func(ctx context.Context) error { |
||||
cmp.Annotate("proto", lOpts.proto, "addr", *addr) |
||||
|
||||
var err error |
||||
if lOpts.isPacketConn() { |
||||
if l.PacketConn, err = net.ListenPacket(lOpts.proto, *addr); err != nil { |
||||
return merr.Wrap(err, cmp.Context(), ctx) |
||||
} |
||||
cmp.Annotate("addr", l.PacketConn.LocalAddr().String()) |
||||
} else { |
||||
if l.Listener, err = net.Listen(lOpts.proto, *addr); err != nil { |
||||
return merr.Wrap(err, cmp.Context(), ctx) |
||||
} |
||||
cmp.Annotate("addr", l.Listener.Addr().String()) |
||||
} |
||||
|
||||
mlog.From(cmp).Info("listening") |
||||
return nil |
||||
}) |
||||
|
||||
// TODO track connections and wait for them to complete before shutting
|
||||
// down?
|
||||
mrun.ShutdownHook(cmp, func(context.Context) error { |
||||
if !lOpts.closeOnShutdown { |
||||
return nil |
||||
} |
||||
mlog.From(cmp).Info("shutting down listener") |
||||
return l.Close() |
||||
}) |
||||
|
||||
return l |
||||
} |
||||
|
||||
// Accept wraps a call to Accept on the underlying net.Listener, providing debug
|
||||
// logging.
|
||||
func (l *Listener) Accept() (net.Conn, error) { |
||||
conn, err := l.Listener.Accept() |
||||
if err != nil { |
||||
return conn, err |
||||
} |
||||
mlog.From(l.cmp).Debug("connection accepted", |
||||
mctx.Annotated("remoteAddr", conn.RemoteAddr().String())) |
||||
return conn, nil |
||||
} |
||||
|
||||
// Close wraps a call to Close on the underlying net.Listener, providing debug
|
||||
// logging.
|
||||
func (l *Listener) Close() error { |
||||
mlog.From(l.cmp).Info("listener closing") |
||||
if l.Listener != nil { |
||||
return l.Listener.Close() |
||||
} |
||||
return l.PacketConn.Close() |
||||
} |
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
func mustGetCIDRNetwork(cidr string) *net.IPNet { |
||||
_, n, err := net.ParseCIDR(cidr) |
||||
if err != nil { |
||||
panic(err) |
||||
} |
||||
return n |
||||
} |
||||
|
||||
// https://en.wikipedia.org/wiki/Reserved_IP_addresses
|
||||
|
||||
var reservedCIDRs4 = []*net.IPNet{ |
||||
mustGetCIDRNetwork("0.0.0.0/8"), // current network
|
||||
mustGetCIDRNetwork("10.0.0.0/8"), // private network
|
||||
mustGetCIDRNetwork("100.64.0.0/10"), // private network
|
||||
mustGetCIDRNetwork("127.0.0.0/8"), // localhost
|
||||
mustGetCIDRNetwork("169.254.0.0/16"), // link-local
|
||||
mustGetCIDRNetwork("172.16.0.0/12"), // private network
|
||||
mustGetCIDRNetwork("192.0.0.0/24"), // IETF protocol assignments
|
||||
mustGetCIDRNetwork("192.0.2.0/24"), // documentation and examples
|
||||
mustGetCIDRNetwork("192.88.99.0/24"), // 6to4 Relay
|
||||
mustGetCIDRNetwork("192.168.0.0/16"), // private network
|
||||
mustGetCIDRNetwork("198.18.0.0/15"), // private network
|
||||
mustGetCIDRNetwork("198.51.100.0/24"), // documentation and examples
|
||||
mustGetCIDRNetwork("203.0.113.0/24"), // documentation and examples
|
||||
mustGetCIDRNetwork("224.0.0.0/4"), // IP multicast
|
||||
mustGetCIDRNetwork("240.0.0.0/4"), // reserved
|
||||
mustGetCIDRNetwork("255.255.255.255/32"), // limited broadcast address
|
||||
} |
||||
|
||||
var reservedCIDRs6 = []*net.IPNet{ |
||||
mustGetCIDRNetwork("::/128"), // unspecified address
|
||||
mustGetCIDRNetwork("::1/128"), // loopback address
|
||||
mustGetCIDRNetwork("100::/64"), // discard prefix
|
||||
mustGetCIDRNetwork("2001::/32"), // Teredo tunneling
|
||||
mustGetCIDRNetwork("2001:20::/28"), // ORCHID v2
|
||||
mustGetCIDRNetwork("2001:db8::/32"), // documentation and examples
|
||||
mustGetCIDRNetwork("2002::/16"), // 6to4 addressing
|
||||
mustGetCIDRNetwork("fc00::/7"), // unique local
|
||||
mustGetCIDRNetwork("fe80::/10"), // link local
|
||||
mustGetCIDRNetwork("ff00::/8"), // multicast
|
||||
} |
||||
|
||||
// IsReservedIP returns true if the given valid IP is part of a reserved IP
|
||||
// range.
|
||||
func IsReservedIP(ip net.IP) bool { |
||||
containedBy := func(cidrs []*net.IPNet) bool { |
||||
for _, cidr := range cidrs { |
||||
if cidr.Contains(ip) { |
||||
return true |
||||
} |
||||
} |
||||
return false |
||||
} |
||||
|
||||
if ip.To4() != nil { |
||||
return containedBy(reservedCIDRs4) |
||||
} |
||||
return containedBy(reservedCIDRs6) |
||||
} |
@ -1,61 +0,0 @@ |
||||
package mnet |
||||
|
||||
import ( |
||||
"fmt" |
||||
"io/ioutil" |
||||
"net" |
||||
. "testing" |
||||
|
||||
"github.com/mediocregopher/mediocre-go-lib/mtest" |
||||
"github.com/mediocregopher/mediocre-go-lib/mtest/massert" |
||||
) |
||||
|
||||
func TestIsReservedIP(t *T) { |
||||
assertReserved := func(ipStr string) massert.Assertion { |
||||
ip := net.ParseIP(ipStr) |
||||
if ip == nil { |
||||
panic("ip:" + ipStr + " not valid") |
||||
} |
||||
return massert.Comment(massert.Equal(true, IsReservedIP(ip)), |
||||
"ip:%q", ipStr) |
||||
} |
||||
|
||||
massert.Require(t, |
||||
assertReserved("127.0.0.1"), |
||||
assertReserved("::ffff:127.0.0.1"), |
||||
assertReserved("192.168.40.50"), |
||||
assertReserved("::1"), |
||||
assertReserved("100::1"), |
||||
) |
||||
|
||||
massert.Require(t, massert.None( |
||||
assertReserved("8.8.8.8"), |
||||
assertReserved("::ffff:8.8.8.8"), |
||||
assertReserved("2600:1700:7580:6e80:21c:25ff:fe97:44df"), |
||||
)) |
||||
} |
||||
|
||||
func TestWithListener(t *T) { |
||||
cmp := mtest.Component() |
||||
l := InstListener(cmp) |
||||
mtest.Run(cmp, t, func() { |
||||
go func() { |
||||
conn, err := net.Dial("tcp", l.Addr().String()) |
||||
if err != nil { |
||||
t.Fatal(err) |
||||
} else if _, err = fmt.Fprint(conn, "hello world"); err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
conn.Close() |
||||
}() |
||||
|
||||
conn, err := l.Accept() |
||||
if err != nil { |
||||
t.Fatal(err) |
||||
} else if b, err := ioutil.ReadAll(conn); err != nil { |
||||
t.Fatal(err) |
||||
} else if string(b) != "hello world" { |
||||
t.Fatalf("read %q from conn", b) |
||||
} |
||||
}) |
||||
} |
@ -1,67 +0,0 @@ |
||||
package mrand |
||||
|
||||
import ( |
||||
"math/rand" |
||||
"sync" |
||||
) |
||||
|
||||
// Everything in this file is taken from the math/rand package, which really
|
||||
// ought to expose lockedSource publicly.
|
||||
|
||||
func read(p []byte, int63 func() int64, readVal *int64, readPos *int8) (n int, err error) { |
||||
pos := *readPos |
||||
val := *readVal |
||||
for n = 0; n < len(p); n++ { |
||||
if pos == 0 { |
||||
val = int63() |
||||
pos = 7 |
||||
} |
||||
p[n] = byte(val) |
||||
val >>= 8 |
||||
pos-- |
||||
} |
||||
*readPos = pos |
||||
*readVal = val |
||||
return |
||||
} |
||||
|
||||
type lockedSource struct { |
||||
lk sync.Mutex |
||||
src rand.Source64 |
||||
} |
||||
|
||||
func (r *lockedSource) Int63() (n int64) { |
||||
r.lk.Lock() |
||||
n = r.src.Int63() |
||||
r.lk.Unlock() |
||||
return |
||||
} |
||||
|
||||
func (r *lockedSource) Uint64() (n uint64) { |
||||
r.lk.Lock() |
||||
n = r.src.Uint64() |
||||
r.lk.Unlock() |
||||
return |
||||
} |
||||
|
||||
func (r *lockedSource) Seed(seed int64) { |
||||
r.lk.Lock() |
||||
r.src.Seed(seed) |
||||
r.lk.Unlock() |
||||
} |
||||
|
||||
// seedPos implements Seed for a lockedSource without a race condition.
|
||||
func (r *lockedSource) seedPos(seed int64, readPos *int8) { |
||||
r.lk.Lock() |
||||
r.src.Seed(seed) |
||||
*readPos = 0 |
||||
r.lk.Unlock() |
||||
} |
||||
|
||||
// read implements Read for a lockedSource without a race condition.
|
||||
func (r *lockedSource) read(p []byte, readVal *int64, readPos *int8) (n int, err error) { |
||||
r.lk.Lock() |
||||
n, err = read(p, r.src.Int63, readVal, readPos) |
||||
r.lk.Unlock() |
||||
return |
||||
} |
@ -1,105 +0,0 @@ |
||||
// Package mrand implements extensions and conveniences for using the default
|
||||
// math/rand package.
|
||||
package mrand |
||||
|
||||
import ( |
||||
"encoding/hex" |
||||
"math/rand" |
||||
"reflect" |
||||
"time" |
||||
) |
||||
|
||||
// Rand extends the default rand.Rand type with extra functionality.
|
||||
type Rand struct { |
||||
*rand.Rand |
||||
} |
||||
|
||||
// NewSyncRand initializes and returns a new Rand instance using the given
|
||||
// Source. The returned Rand will be safe for concurrent use.
|
||||
//
|
||||
// This will panic if the given Source doesn't implement rand.Source64.
|
||||
func NewSyncRand(src rand.Source) Rand { |
||||
return Rand{ |
||||
Rand: rand.New(&lockedSource{src: src.(rand.Source64)}), |
||||
} |
||||
} |
||||
|
||||
// Bytes returns n random bytes.
|
||||
func (r Rand) Bytes(n int) []byte { |
||||
b := make([]byte, n) |
||||
if _, err := r.Read(b); err != nil { |
||||
panic(err) |
||||
} |
||||
return b |
||||
} |
||||
|
||||
// Hex returns a random hex string which is n characters long.
|
||||
func (r Rand) Hex(n int) string { |
||||
origN := n |
||||
if n%2 == 1 { |
||||
n++ |
||||
} |
||||
b := r.Bytes(hex.DecodedLen(n)) |
||||
return hex.EncodeToString(b)[:origN] |
||||
} |
||||
|
||||
// Element returns a random element from the given slice.
|
||||
//
|
||||
// If a weighting function is given then that function is used to weight each
|
||||
// element of the slice relative to the others, based on whatever metric and
|
||||
// scale is desired. The weight function must be able to be called more than
|
||||
// once on each element.
|
||||
func (r Rand) Element(slice interface{}, weight func(i int) uint64) interface{} { |
||||
v := reflect.ValueOf(slice) |
||||
l := v.Len() |
||||
|
||||
if weight == nil { |
||||
return v.Index(r.Intn(l)).Interface() |
||||
} |
||||
|
||||
var totalWeight uint64 |
||||
for i := 0; i < l; i++ { |
||||
totalWeight += weight(i) |
||||
} |
||||
|
||||
target := r.Int63n(int64(totalWeight)) |
||||
for i := 0; i < l; i++ { |
||||
w := int64(weight(i)) |
||||
target -= w |
||||
if target < 0 { |
||||
return v.Index(i).Interface() |
||||
} |
||||
} |
||||
panic("should never get here, perhaps the weighting function is inconsistent?") |
||||
} |
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// DefaultRand is an instance off Rand whose methods are directly exported by
|
||||
// this package for convenience.
|
||||
var DefaultRand = NewSyncRand(rand.NewSource(time.Now().UnixNano())) |
||||
|
||||
// Methods off DefaultRand exported to the top level of this package.
|
||||
var ( |
||||
ExpFloat64 = DefaultRand.ExpFloat64 |
||||
Float32 = DefaultRand.Float32 |
||||
Float64 = DefaultRand.Float64 |
||||
Int = DefaultRand.Int |
||||
Int31 = DefaultRand.Int31 |
||||
Int31n = DefaultRand.Int31n |
||||
Int63 = DefaultRand.Int63 |
||||
Int63n = DefaultRand.Int63n |
||||
Intn = DefaultRand.Intn |
||||
NormFloat64 = DefaultRand.NormFloat64 |
||||
Perm = DefaultRand.Perm |
||||
Read = DefaultRand.Read |
||||
Seed = DefaultRand.Seed |
||||
Shuffle = DefaultRand.Shuffle |
||||
Uint32 = DefaultRand.Uint32 |
||||
Uint64 = DefaultRand.Uint64 |
||||
|
||||
// extended methods
|
||||
Bytes = DefaultRand.Bytes |
||||
Hex = DefaultRand.Hex |
||||
Element = DefaultRand.Element |
||||
) |
@ -1,59 +0,0 @@ |
||||
package mrand |
||||
|
||||
import ( |
||||
. "testing" |
||||
|
||||
"github.com/stretchr/testify/assert" |
||||
) |
||||
|
||||
func TestRandBytes(t *T) { |
||||
var prev []byte |
||||
for i := 0; i < 10000; i++ { |
||||
curr := Bytes(16) |
||||
assert.Len(t, curr, 16) |
||||
assert.NotEqual(t, prev, curr) |
||||
prev = curr |
||||
} |
||||
} |
||||
|
||||
func TestRandHex(t *T) { |
||||
// RandHex is basically a wrapper of RandBytes, so we don't have to test it
|
||||
// much
|
||||
assert.Len(t, Hex(16), 16) |
||||
} |
||||
|
||||
func TestRandElement(t *T) { |
||||
slice := []uint64{1, 2, 3} // values are also each value's weight
|
||||
total := func() uint64 { |
||||
var t uint64 |
||||
for i := range slice { |
||||
t += slice[i] |
||||
} |
||||
return t |
||||
}() |
||||
m := map[uint64]uint64{} |
||||
|
||||
iterations := 100000 |
||||
for i := 0; i < iterations; i++ { |
||||
el := Element(slice, func(i int) uint64 { return slice[i] }).(uint64) |
||||
m[el]++ |
||||
} |
||||
|
||||
for i := range slice { |
||||
t.Logf("%d -> %d (%f)", slice[i], m[slice[i]], float64(m[slice[i]])/float64(iterations)) |
||||
} |
||||
|
||||
assertEl := func(i int) { |
||||
el, elF := slice[i], float64(slice[i]) |
||||
gotRatio := float64(m[el]) / float64(iterations) |
||||
expRatio := elF / float64(total) |
||||
diff := (gotRatio - expRatio) / expRatio |
||||
if diff > 0.1 || diff < -0.1 { |
||||
t.Fatalf("ratio of element %d is off: got %f, expected %f (diff:%f)", el, gotRatio, expRatio, diff) |
||||
} |
||||
} |
||||
|
||||
for i := range slice { |
||||
assertEl(i) |
||||
} |
||||
} |
@ -1,138 +0,0 @@ |
||||
// Package mrpc contains types and functionality to facilitate creating RPC
|
||||
// interfaces and for making calls against those same interfaces
|
||||
//
|
||||
// This package contains a few fundamental types: Handler, Call, and
|
||||
// Client. Together these form the components needed to implement nearly any RPC
|
||||
// system.
|
||||
//
|
||||
// TODO an example of an implementation of these interfaces can be found in the
|
||||
// m package
|
||||
package mrpc |
||||
|
||||
import ( |
||||
"context" |
||||
"fmt" |
||||
"reflect" |
||||
) |
||||
|
||||
// Handler is a type which serves RPC calls. For each incoming Call the ServeRPC
|
||||
// method is called, and the return from the method is used as the response. If
|
||||
// an error is returned the response return is ignored.
|
||||
type Handler interface { |
||||
ServeRPC(Call) (interface{}, error) |
||||
} |
||||
|
||||
// HandlerFunc can be used to wrap an individual function which fits the
|
||||
// ServeRPC signature, and use that function as a Handler
|
||||
type HandlerFunc func(Call) (interface{}, error) |
||||
|
||||
// ServeRPC implements the method for the Handler interface by calling the
|
||||
// underlying function
|
||||
func (hf HandlerFunc) ServeRPC(c Call) (interface{}, error) { |
||||
return hf(c) |
||||
} |
||||
|
||||
// Call is passed into the ServeRPC method and contains all information about
|
||||
// the incoming RPC call which is being made
|
||||
type Call interface { |
||||
Context() context.Context |
||||
|
||||
// Method returns the name of the RPC method being called
|
||||
Method() string |
||||
|
||||
// UnmarshalArgs takes in a pointer and unmarshals the RPC call's arguments
|
||||
// into it. The properties of the unmarshaling are dependent on the
|
||||
// underlying implementation of the codec types.
|
||||
UnmarshalArgs(interface{}) error |
||||
} |
||||
|
||||
type call struct { |
||||
ctx context.Context |
||||
method string |
||||
unmarshalArgs func(interface{}) error |
||||
} |
||||
|
||||
func (c call) Context() context.Context { |
||||
return c.ctx |
||||
} |
||||
|
||||
func (c call) Method() string { |
||||
return c.method |
||||
} |
||||
|
||||
func (c call) UnmarshalArgs(i interface{}) error { |
||||
return c.unmarshalArgs(i) |
||||
} |
||||
|
||||
// WithContext returns the same Call it's given, but the new Call will return
|
||||
// the given context when Context() is called
|
||||
func WithContext(c Call, ctx context.Context) Call { |
||||
return call{ctx: ctx, method: c.Method(), unmarshalArgs: c.UnmarshalArgs} |
||||
} |
||||
|
||||
// WithMethod returns the same Call it's given, but the new Call will return the
|
||||
// given method name when Method() is called
|
||||
func WithMethod(c Call, method string) Call { |
||||
return call{ctx: c.Context(), method: method, unmarshalArgs: c.UnmarshalArgs} |
||||
} |
||||
|
||||
// Client is an entity which can perform RPC calls against a remote endpoint.
|
||||
//
|
||||
// res should be a pointer into which the result of the RPC call will be
|
||||
// unmarshaled according to Client's implementation. args will be marshaled and
|
||||
// sent to the remote endpoint according to Client's implementation.
|
||||
type Client interface { |
||||
CallRPC(ctx context.Context, res interface{}, method string, args interface{}) error |
||||
} |
||||
|
||||
// ClientFunc can be used to wrap an individual function which fits the CallRPC
|
||||
// signature, and use that function as a Client
|
||||
type ClientFunc func(context.Context, interface{}, string, interface{}) error |
||||
|
||||
// CallRPC implements the method for the Client interface by calling the
|
||||
// underlying function
|
||||
func (cf ClientFunc) CallRPC( |
||||
ctx context.Context, |
||||
res interface{}, |
||||
method string, |
||||
args interface{}, |
||||
) error { |
||||
return cf(ctx, res, method, args) |
||||
} |
||||
|
||||
// ReflectClient returns a Client whose CallRPC method will use reflection to
|
||||
// call the given Handler's ServeRPC method directly, using reflect.Value's Set
|
||||
// method to copy CallRPC's args parameter into UnmarshalArgs' receiver
|
||||
// parameter, and similarly to copy the result from ServeRPC into CallRPC's
|
||||
// receiver parameter.
|
||||
func ReflectClient(h Handler) Client { |
||||
into := func(dst, src interface{}) error { |
||||
dstV, srcV := reflect.ValueOf(dst), reflect.ValueOf(src) |
||||
dstVi, srcVi := reflect.Indirect(dstV), reflect.Indirect(srcV) |
||||
if !dstVi.CanSet() || dstVi.Type() != srcVi.Type() { |
||||
return fmt.Errorf("can't set value of type %v into type %v", srcV.Type(), dstV.Type()) |
||||
} |
||||
dstVi.Set(srcVi) |
||||
return nil |
||||
} |
||||
|
||||
return ClientFunc(func( |
||||
ctx context.Context, |
||||
resInto interface{}, |
||||
method string, |
||||
args interface{}, |
||||
) error { |
||||
c := call{ |
||||
ctx: ctx, |
||||
method: method, |
||||
unmarshalArgs: func(i interface{}) error { return into(i, args) }, |
||||
} |
||||
|
||||
res, err := h.ServeRPC(c) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
return into(resInto, res) |
||||
}) |
||||
} |
@ -1,67 +0,0 @@ |
||||
package mrpc |
||||
|
||||
import ( |
||||
"context" |
||||
. "testing" |
||||
|
||||
"github.com/mediocregopher/mediocre-go-lib/mrand" |
||||
"github.com/stretchr/testify/assert" |
||||
) |
||||
|
||||
func TestReflectClient(t *T) { |
||||
type argT struct { |
||||
In string |
||||
} |
||||
|
||||
type resT struct { |
||||
Out string |
||||
} |
||||
|
||||
ctx := context.Background() |
||||
|
||||
{ // test with handler returning non-pointer
|
||||
client := ReflectClient(HandlerFunc(func(c Call) (interface{}, error) { |
||||
var args argT |
||||
assert.NoError(t, c.UnmarshalArgs(&args)) |
||||
assert.Equal(t, "foo", c.Method()) |
||||
return resT{Out: args.In}, nil |
||||
})) |
||||
|
||||
{ // test with arg being non-pointer
|
||||
in := mrand.Hex(8) |
||||
var res resT |
||||
assert.NoError(t, client.CallRPC(ctx, &res, "foo", argT{In: in})) |
||||
assert.Equal(t, in, res.Out) |
||||
} |
||||
|
||||
{ // test with arg being pointer
|
||||
in := mrand.Hex(8) |
||||
var res resT |
||||
assert.NoError(t, client.CallRPC(ctx, &res, "foo", &argT{In: in})) |
||||
assert.Equal(t, in, res.Out) |
||||
} |
||||
} |
||||
|
||||
{ // test with handler returning pointer
|
||||
client := ReflectClient(HandlerFunc(func(c Call) (interface{}, error) { |
||||
var args argT |
||||
assert.NoError(t, c.UnmarshalArgs(&args)) |
||||
assert.Equal(t, "foo", c.Method()) |
||||
return &resT{Out: args.In}, nil |
||||
})) |
||||
|
||||
{ // test with arg being non-pointer
|
||||
in := mrand.Hex(8) |
||||
var res resT |
||||
assert.NoError(t, client.CallRPC(ctx, &res, "foo", argT{In: in})) |
||||
assert.Equal(t, in, res.Out) |
||||
} |
||||
|
||||
{ // test with arg being pointer
|
||||
in := mrand.Hex(8) |
||||
var res resT |
||||
assert.NoError(t, client.CallRPC(ctx, &res, "foo", &argT{In: in})) |
||||
assert.Equal(t, in, res.Out) |
||||
} |
||||
} |
||||
} |
@ -1,213 +0,0 @@ |
||||
// Package mchk implements a framework for writing property checker tests, where
|
||||
// test cases are generated randomly and performed, and failing test cases are
|
||||
// output in a way so as to easily be able to rerun them. In addition failing
|
||||
// test cases are minimized so the smallest possible case is returned.
|
||||
//
|
||||
// The central type of the package is Checker. For every Run call on Checker a
|
||||
// new initial State is generated, and then an Action is generated off of that.
|
||||
// The Action is applied to the State to obtain a new State, and a new Action is
|
||||
// generated from there, and so on. If any Action fails it is output along with
|
||||
// all of the Actions leading up to it.
|
||||
package mchk |
||||
|
||||
import ( |
||||
"bytes" |
||||
"fmt" |
||||
"time" |
||||
|
||||
"github.com/mediocregopher/mediocre-go-lib/mrand" |
||||
) |
||||
|
||||
// RunErr represents an test case error which was returned by a Checker Run.
|
||||
//
|
||||
// The string form of RunErr includes the sequence of Params which can be
|
||||
// copy-pasted directly into Checker's RunCase method's arguments.
|
||||
type RunErr struct { |
||||
// The sequence of Action Params which generated the error
|
||||
Params []Params |
||||
|
||||
// The error returned by the final Action
|
||||
Err error |
||||
} |
||||
|
||||
func (ce RunErr) Error() string { |
||||
buf := new(bytes.Buffer) |
||||
fmt.Fprintf(buf, "Test case: []mtest.Params{\n") |
||||
for _, p := range ce.Params { |
||||
fmt.Fprintf(buf, "\t%#v,\n", p) |
||||
} |
||||
fmt.Fprintf(buf, "}\n") |
||||
fmt.Fprintf(buf, "Generated error: %s\n", ce.Err) |
||||
return buf.String() |
||||
} |
||||
|
||||
// State represents the current state of a Checker run. It can be any value
|
||||
// convenient and useful to the test.
|
||||
type State interface{} |
||||
|
||||
// Params represent the parameters to an Action used during a Checker run. It
|
||||
// should be a static value, meaning no pointers or channels.
|
||||
type Params interface{} |
||||
|
||||
// Action describes a change which can take place on a state.
|
||||
type Action struct { |
||||
// Params are defined by the test and affect the behavior of the Action.
|
||||
Params Params |
||||
|
||||
// Incomplete can be set to true to indicate that this Action should never
|
||||
// be the last Action applied, even if that means the length of the Run goes
|
||||
// over MaxLength.
|
||||
Incomplete bool |
||||
|
||||
// Terminate can be set to true to indicate that this Action should always
|
||||
// be the last Action applied, even if the Run's length hasn't reached
|
||||
// MaxLength yet.
|
||||
Terminate bool |
||||
} |
||||
|
||||
// Checker implements a very basic property checker. It generates random test
|
||||
// cases, attempting to find and print out failing ones.
|
||||
type Checker struct { |
||||
// Init returns the initial state of the test. It should always return the
|
||||
// exact same value.
|
||||
Init func() State |
||||
|
||||
// Next returns a new Action which can be Apply'd to the given State. This
|
||||
// function should not modify the State in any way.
|
||||
Next func(State) Action |
||||
|
||||
// Apply performs the Action's changes to a State, returning the new State.
|
||||
// After modifying the State this function should also assert that the new
|
||||
// State is what it's expected to be, returning an error if it's not.
|
||||
Apply func(State, Action) (State, error) |
||||
|
||||
// Cleanup is an optional function which can perform any necessary cleanup
|
||||
// operations on the State. This is called even on error.
|
||||
Cleanup func(State) |
||||
|
||||
// MaxLength indicates the maximum number of Actions which can be strung
|
||||
// together in a single Run. Defaults to 10 if not set.
|
||||
MaxLength int |
||||
|
||||
// If true the Run and RunFor methods will return the first erroring Action
|
||||
// sequence, without trying to remove extraneous Actions from it first.
|
||||
DontMinimize bool |
||||
} |
||||
|
||||
func (c Checker) withDefaults() Checker { |
||||
if c.MaxLength == 0 { |
||||
c.MaxLength = 10 |
||||
} |
||||
return c |
||||
} |
||||
|
||||
// RunFor performs Runs in a loop until maxDuration has elapsed.
|
||||
func (c Checker) RunFor(maxDuration time.Duration) error { |
||||
doneTimer := time.After(maxDuration) |
||||
for { |
||||
select { |
||||
case <-doneTimer: |
||||
return nil |
||||
default: |
||||
} |
||||
|
||||
if err := c.Run(); err != nil { |
||||
return err |
||||
} |
||||
} |
||||
} |
||||
|
||||
// Run generates a single sequence of Actions and applies them in order,
|
||||
// returning nil once the number of Actions performed has reached MaxLength or a
|
||||
// CheckErr if an error is returned. If an error is to be returned this will
|
||||
// attempt to minimize the Actions sequence in order to find the smallest
|
||||
// reproducible test case.
|
||||
func (c Checker) Run() error { |
||||
c = c.withDefaults() |
||||
s := c.Init() |
||||
params := make([]Params, 0, c.MaxLength) |
||||
for { |
||||
action := c.Next(s) |
||||
var err error |
||||
s, err = c.Apply(s, action) |
||||
params = append(params, action.Params) |
||||
|
||||
if err != nil && c.DontMinimize { |
||||
return RunErr{ |
||||
Params: params, |
||||
Err: err, |
||||
} |
||||
} else if err != nil { |
||||
minParams := c.MinimizeCase(params...) |
||||
if minErr := c.RunCase(minParams...); minErr != nil { |
||||
// RunCase already wraps errs in RunErrs, so that's not
|
||||
// necessary here
|
||||
return minErr |
||||
} |
||||
// if the minParams didn't return an error here it means the test
|
||||
// case isn't consistent, as a fallback return the original which
|
||||
// definitely errored
|
||||
return RunErr{ |
||||
Params: params, |
||||
Err: err, |
||||
} |
||||
} else if action.Incomplete { |
||||
continue |
||||
} else if action.Terminate || len(params) >= c.MaxLength { |
||||
return nil |
||||
} |
||||
} |
||||
} |
||||
|
||||
// MinimizeCase repeatedly randomly picks a Param from the set and performs
|
||||
// RunCase without that Param. It does this until it can't remove a single Param
|
||||
// without the error ceasing, and returns that minimized set.
|
||||
func (c Checker) MinimizeCase(params ...Params) []Params { |
||||
outer: |
||||
for { |
||||
if len(params) == 1 { |
||||
return params |
||||
} |
||||
|
||||
tried := map[int]bool{} |
||||
for { |
||||
if len(tried) == len(params) { |
||||
return params |
||||
} |
||||
i := mrand.Intn(len(params)) |
||||
if tried[i] { |
||||
continue |
||||
} |
||||
newParams := make([]Params, 0, len(params)-1) |
||||
newParams = append(newParams, params[:i]...) |
||||
newParams = append(newParams, params[i+1:]...) |
||||
if err := c.RunCase(newParams...); err == nil { |
||||
tried[i] = true |
||||
continue |
||||
} |
||||
params = newParams |
||||
continue outer |
||||
} |
||||
} |
||||
} |
||||
|
||||
// RunCase performs a single sequence of Actions with the given Params.
|
||||
func (c Checker) RunCase(params ...Params) error { |
||||
s := c.Init() |
||||
if c.Cleanup != nil { |
||||
// wrap in a function so we don't capture the value of s right here
|
||||
defer func() { |
||||
c.Cleanup(s) |
||||
}() |
||||
} |
||||
for i := range params { |
||||
var err error |
||||
if s, err = c.Apply(s, Action{Params: params[i]}); err != nil { |
||||
return RunErr{ |
||||
Params: params[:i+1], |
||||
Err: err, |
||||
} |
||||
} |
||||
} |
||||
return nil |
||||
} |
@ -1,49 +0,0 @@ |
||||
package mchk |
||||
|
||||
import ( |
||||
"errors" |
||||
. "testing" |
||||
"time" |
||||
|
||||
"github.com/mediocregopher/mediocre-go-lib/mrand" |
||||
) |
||||
|
||||
func TestCheckerRun(t *T) { |
||||
c := Checker{ |
||||
Init: func() State { return 0 }, |
||||
Next: func(State) Action { |
||||
if mrand.Intn(3) == 0 { |
||||
return Action{Params: -1} |
||||
} |
||||
return Action{Params: 1} |
||||
}, |
||||
Apply: func(s State, a Action) (State, error) { |
||||
si := s.(int) + a.Params.(int) |
||||
if si > 5 { |
||||
return nil, errors.New("went over 5") |
||||
} |
||||
return si, nil |
||||
}, |
||||
MaxLength: 4, |
||||
} |
||||
|
||||
// 4 Actions should never be able to go over 5
|
||||
if err := c.RunFor(time.Second); err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
|
||||
// 20 should always go over 5 eventually
|
||||
c.MaxLength = 20 |
||||
err := c.RunFor(time.Second) |
||||
if err == nil { |
||||
t.Fatal("expected error when maxDepth is 20") |
||||
} else if len(err.(RunErr).Params) < 6 { |
||||
t.Fatalf("strange RunErr when maxDepth is 20: %s", err) |
||||
} |
||||
|
||||
t.Logf("got expected error with large maxDepth:\n%s", err) |
||||
caseErr := c.RunCase(err.(RunErr).Params...) |
||||
if caseErr == nil || err.Error() != caseErr.Error() { |
||||
t.Fatalf("unexpected caseErr: %v", caseErr) |
||||
} |
||||
} |
@ -1,69 +0,0 @@ |
||||
// Package mtest implements functionality useful for testing.
|
||||
package mtest |
||||
|
||||
import ( |
||||
"context" |
||||
"testing" |
||||
|
||||
"github.com/mediocregopher/mediocre-go-lib/mcfg" |
||||
"github.com/mediocregopher/mediocre-go-lib/mcmp" |
||||
"github.com/mediocregopher/mediocre-go-lib/mlog" |
||||
"github.com/mediocregopher/mediocre-go-lib/mrun" |
||||
) |
||||
|
||||
type envCmpKey int |
||||
|
||||
// Component creates and returns a root Component suitable for testing.
|
||||
func Component() *mcmp.Component { |
||||
cmp := new(mcmp.Component) |
||||
logger := mlog.NewLogger() |
||||
logger.SetMaxLevel(mlog.DebugLevel) |
||||
mlog.SetLogger(cmp, logger) |
||||
|
||||
mrun.InitHook(cmp, func(context.Context) error { |
||||
envVals := mcmp.SeriesValues(cmp, envCmpKey(0)) |
||||
env := make([]string, 0, len(envVals)) |
||||
for _, val := range envVals { |
||||
tup := val.([2]string) |
||||
env = append(env, tup[0]+"="+tup[1]) |
||||
} |
||||
return mcfg.Populate(cmp, &mcfg.SourceEnv{Env: env}) |
||||
}) |
||||
|
||||
return cmp |
||||
} |
||||
|
||||
// Env sets the given environment variable on the given Component, such that it
|
||||
// will be used as if it was a real environment variable when the Run function
|
||||
// from this package is called.
|
||||
//
|
||||
// This function will panic if not called on the root Component.
|
||||
func Env(cmp *mcmp.Component, key, val string) { |
||||
if len(cmp.Path()) != 0 { |
||||
panic("Env should only be called on the root Component") |
||||
} |
||||
mcmp.AddSeriesValue(cmp, envCmpKey(0), [2]string{key, val}) |
||||
} |
||||
|
||||
// Run performs the following using the given Component:
|
||||
//
|
||||
// - Calls mrun.Init, which calls mcfg.Populate using any variables set by Env.
|
||||
//
|
||||
// - Calls the passed in body callback.
|
||||
//
|
||||
// - Calls mrun.Shutdown
|
||||
//
|
||||
// The intention is that Run is used within a test on a Component created via
|
||||
// this package's Component function, after any instantiation functions have
|
||||
// been called (e.g. mnet.InstListener).
|
||||
func Run(cmp *mcmp.Component, t *testing.T, body func()) { |
||||
if err := mrun.Init(context.Background(), cmp); err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
|
||||
body() |
||||
|
||||
if err := mrun.Shutdown(context.Background(), cmp); err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
} |
@ -1,19 +0,0 @@ |
||||
package mtest |
||||
|
||||
import ( |
||||
. "testing" |
||||
|
||||
"github.com/mediocregopher/mediocre-go-lib/mcfg" |
||||
) |
||||
|
||||
func TestRun(t *T) { |
||||
cmp := Component() |
||||
Env(cmp, "ARG", "foo") |
||||
|
||||
arg := mcfg.String(cmp, "arg", mcfg.ParamRequired()) |
||||
Run(cmp, t, func() { |
||||
if *arg != "foo" { |
||||
t.Fatalf(`arg not set to "foo", is set to %q`, *arg) |
||||
} |
||||
}) |
||||
} |
@ -1,42 +0,0 @@ |
||||
package mtime |
||||
|
||||
import ( |
||||
"encoding/json" |
||||
"time" |
||||
) |
||||
|
||||
// Duration wraps time.Duration to implement marshaling and unmarshaling methods
|
||||
type Duration struct { |
||||
time.Duration |
||||
} |
||||
|
||||
// MarshalText implements the text.Marshaler interface
|
||||
func (d Duration) MarshalText() ([]byte, error) { |
||||
return []byte(d.Duration.String()), nil |
||||
} |
||||
|
||||
// UnmarshalText implements the text.Unmarshaler interface
|
||||
func (d *Duration) UnmarshalText(b []byte) error { |
||||
var err error |
||||
d.Duration, err = time.ParseDuration(string(b)) |
||||
return err |
||||
} |
||||
|
||||
// MarshalJSON implements the json.Marshaler interface, marshaling the Duration
|
||||
// as a json string via Duration's String method
|
||||
func (d Duration) MarshalJSON() ([]byte, error) { |
||||
return json.Marshal(d.String()) |
||||
} |
||||
|
||||
// UnmarshalJSON implements the json.Unmarshaler interface, unmarshaling the
|
||||
// Duration as a JSON string and using the time.ParseDuration function on that
|
||||
func (d *Duration) UnmarshalJSON(b []byte) error { |
||||
var s string |
||||
err := json.Unmarshal(b, &s) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
d.Duration, err = time.ParseDuration(s) |
||||
return err |
||||
} |
@ -1,30 +0,0 @@ |
||||
package mtime |
||||
|
||||
import ( |
||||
. "testing" |
||||
"time" |
||||
|
||||
"github.com/stretchr/testify/assert" |
||||
) |
||||
|
||||
func TestDuration(t *T) { |
||||
{ |
||||
b, err := Duration{5 * time.Second}.MarshalText() |
||||
assert.NoError(t, err) |
||||
assert.Equal(t, []byte("5s"), b) |
||||
|
||||
var d Duration |
||||
assert.NoError(t, d.UnmarshalText(b)) |
||||
assert.Equal(t, 5*time.Second, d.Duration) |
||||
} |
||||
|
||||
{ |
||||
b, err := Duration{5 * time.Second}.MarshalJSON() |
||||
assert.NoError(t, err) |
||||
assert.Equal(t, []byte(`"5s"`), b) |
||||
|
||||
var d Duration |
||||
assert.NoError(t, d.UnmarshalJSON(b)) |
||||
assert.Equal(t, 5*time.Second, d.Duration) |
||||
} |
||||
} |
@ -1,2 +0,0 @@ |
||||
// Package mtime extends the standard time package with extra functionality
|
||||
package mtime |
@ -1,118 +0,0 @@ |
||||
package mtime |
||||
|
||||
// Code based off the timeutil package in github.com/levenlabs/golib
|
||||
// Changes performed:
|
||||
// - Renamed Timestamp to TS for brevity
|
||||
// - Added NewTS function
|
||||
// - Moved Float64 method
|
||||
// - Moved initialization methods to top
|
||||
// - Made MarshalJSON use String method
|
||||
// - TSNow -> NowTS, make it use NewTS
|
||||
|
||||
import ( |
||||
"bytes" |
||||
"strconv" |
||||
"time" |
||||
) |
||||
|
||||
var unixZero = time.Unix(0, 0) |
||||
|
||||
func timeToFloat(t time.Time) float64 { |
||||
// If time.Time is the empty value, UnixNano will return the farthest back
|
||||
// timestamp a float can represent, which is some large negative value. We
|
||||
// compromise and call it zero
|
||||
if t.IsZero() { |
||||
return 0 |
||||
} |
||||
return float64(t.UnixNano()) / 1e9 |
||||
} |
||||
|
||||
// TS is a wrapper around time.Time which adds methods to marshal and
|
||||
// unmarshal the value as a unix timestamp instead of a formatted string
|
||||
type TS struct { |
||||
time.Time |
||||
} |
||||
|
||||
// NewTS returns a new TS instance wrapping the given time.Time, which will
|
||||
// possibly be truncated a certain amount to account for floating point
|
||||
// precision.
|
||||
func NewTS(t time.Time) TS { |
||||
return TSFromFloat64(timeToFloat(t)) |
||||
} |
||||
|
||||
// NowTS is a wrapper around time.Now which returns a TS.
|
||||
func NowTS() TS { |
||||
return NewTS(time.Now()) |
||||
} |
||||
|
||||
// TSFromInt64 returns a TS equal to the given int64, assuming it too is a unix
|
||||
// timestamp
|
||||
func TSFromInt64(ts int64) TS { |
||||
return TS{time.Unix(ts, 0)} |
||||
} |
||||
|
||||
// TSFromFloat64 returns a TS equal to the given float64, assuming it too is a
|
||||
// unix timestamp. The float64 is interpreted as number of seconds, with
|
||||
// everything after the decimal indicating milliseconds, microseconds, and
|
||||
// nanoseconds
|
||||
func TSFromFloat64(ts float64) TS { |
||||
secs := int64(ts) |
||||
nsecs := int64((ts - float64(secs)) * 1e9) |
||||
return TS{time.Unix(secs, nsecs)} |
||||
} |
||||
|
||||
// TSFromString attempts to parse the string as a float64, and then passes that
|
||||
// into TSFromFloat64, returning the result
|
||||
func TSFromString(ts string) (TS, error) { |
||||
f, err := strconv.ParseFloat(ts, 64) |
||||
if err != nil { |
||||
return TS{}, err |
||||
} |
||||
return TSFromFloat64(f), nil |
||||
} |
||||
|
||||
// String returns the string representation of the TS, in the form of a floating
|
||||
// point form of the time as a unix timestamp
|
||||
func (t TS) String() string { |
||||
ts := timeToFloat(t.Time) |
||||
return strconv.FormatFloat(ts, 'f', -1, 64) |
||||
} |
||||
|
||||
// Float64 returns the float representation of the timestamp in seconds.
|
||||
func (t TS) Float64() float64 { |
||||
return timeToFloat(t.Time) |
||||
} |
||||
|
||||
var jsonNull = []byte("null") |
||||
|
||||
// MarshalJSON returns the JSON representation of the TS as an integer. It
|
||||
// never returns an error
|
||||
func (t TS) MarshalJSON() ([]byte, error) { |
||||
if t.IsZero() { |
||||
return jsonNull, nil |
||||
} |
||||
|
||||
return []byte(t.String()), nil |
||||
} |
||||
|
||||
// UnmarshalJSON takes a JSON integer and converts it into a TS, or
|
||||
// returns an error if this can't be done
|
||||
func (t *TS) UnmarshalJSON(b []byte) error { |
||||
// since 0 is a valid timestamp we can't use that to mean "unset", so we
|
||||
// take null to mean unset instead
|
||||
if bytes.Equal(b, jsonNull) { |
||||
t.Time = time.Time{} |
||||
return nil |
||||
} |
||||
|
||||
var err error |
||||
*t, err = TSFromString(string(b)) |
||||
return err |
||||
} |
||||
|
||||
// IsUnixZero returns true if the timestamp is equal to the unix zero timestamp,
|
||||
// representing 1/1/1970. This is different than checking if the timestamp is
|
||||
// the empty value (which should be done with IsZero)
|
||||
func (t TS) IsUnixZero() bool { |
||||
return t.Equal(unixZero) |
||||
} |
@ -1,118 +0,0 @@ |
||||
package mtime |
||||
|
||||
import ( |
||||
"encoding/json" |
||||
"strconv" |
||||
. "testing" |
||||
"time" |
||||
|
||||
"gopkg.in/mgo.v2/bson" |
||||
|
||||
"github.com/stretchr/testify/assert" |
||||
"github.com/stretchr/testify/require" |
||||
) |
||||
|
||||
func TestTS(t *T) { |
||||
ts := NowTS() |
||||
|
||||
tsJ, err := json.Marshal(&ts) |
||||
require.Nil(t, err) |
||||
|
||||
// tsJ should basically be an integer
|
||||
tsF, err := strconv.ParseFloat(string(tsJ), 64) |
||||
require.Nil(t, err) |
||||
assert.True(t, tsF > 0) |
||||
|
||||
ts2 := TSFromFloat64(tsF) |
||||
assert.Equal(t, ts, ts2) |
||||
|
||||
var ts3 TS |
||||
err = json.Unmarshal(tsJ, &ts3) |
||||
require.Nil(t, err) |
||||
assert.Equal(t, ts, ts3) |
||||
} |
||||
|
||||
// Make sure that we can take in a non-float from json
|
||||
func TestTSMarshalInt(t *T) { |
||||
now := time.Now() |
||||
tsJ := []byte(strconv.FormatInt(now.Unix(), 10)) |
||||
var ts TS |
||||
err := json.Unmarshal(tsJ, &ts) |
||||
require.Nil(t, err) |
||||
assert.Equal(t, ts.Float64(), float64(now.Unix())) |
||||
} |
||||
|
||||
type Foo struct { |
||||
T TS `json:"timestamp" bson:"t"` |
||||
} |
||||
|
||||
func TestTSJSON(t *T) { |
||||
now := NowTS() |
||||
in := Foo{now} |
||||
b, err := json.Marshal(in) |
||||
require.Nil(t, err) |
||||
assert.NotEmpty(t, b) |
||||
|
||||
var out Foo |
||||
err = json.Unmarshal(b, &out) |
||||
require.Nil(t, err) |
||||
assert.Equal(t, in, out) |
||||
} |
||||
|
||||
func TestTSJSONNull(t *T) { |
||||
{ |
||||
var foo Foo |
||||
timestampNull := []byte(`{"timestamp":null}`) |
||||
fooJSON, err := json.Marshal(foo) |
||||
require.Nil(t, err) |
||||
assert.Equal(t, timestampNull, fooJSON) |
||||
|
||||
require.Nil(t, json.Unmarshal(timestampNull, &foo)) |
||||
assert.True(t, foo.T.IsZero()) |
||||
assert.False(t, foo.T.IsUnixZero()) |
||||
} |
||||
|
||||
{ |
||||
var foo Foo |
||||
foo.T = TS{Time: unixZero} |
||||
timestampZero := []byte(`{"timestamp":0}`) |
||||
fooJSON, err := json.Marshal(foo) |
||||
require.Nil(t, err) |
||||
assert.Equal(t, timestampZero, fooJSON) |
||||
|
||||
require.Nil(t, json.Unmarshal(timestampZero, &foo)) |
||||
assert.False(t, foo.T.IsZero()) |
||||
assert.True(t, foo.T.IsUnixZero()) |
||||
} |
||||
} |
||||
|
||||
func TestTSZero(t *T) { |
||||
var ts TS |
||||
assert.True(t, ts.IsZero()) |
||||
assert.False(t, ts.IsUnixZero()) |
||||
tsf := timeToFloat(ts.Time) |
||||
assert.Zero(t, tsf) |
||||
|
||||
ts = TSFromFloat64(0) |
||||
assert.False(t, ts.IsZero()) |
||||
assert.True(t, ts.IsUnixZero()) |
||||
tsf = timeToFloat(ts.Time) |
||||
assert.Zero(t, tsf) |
||||
} |
||||
|
||||
func TestTSBSON(t *T) { |
||||
// BSON only supports up to millisecond precision, but even if we keep that
|
||||
// many it kinda gets messed up due to rounding errors. So we just give it
|
||||
// one with second precision
|
||||
now := TSFromInt64(time.Now().Unix()) |
||||
|
||||
in := Foo{now} |
||||
b, err := bson.Marshal(in) |
||||
require.Nil(t, err) |
||||
assert.NotEmpty(t, b) |
||||
|
||||
var out Foo |
||||
err = bson.Unmarshal(b, &out) |
||||
require.Nil(t, err) |
||||
assert.Equal(t, in, out) |
||||
} |
Loading…
Reference in new issue