diff --git a/mnet/mnet.go b/mnet/mnet.go index c8a4523..33dbe10 100644 --- a/mnet/mnet.go +++ b/mnet/mnet.go @@ -4,8 +4,42 @@ package mnet import ( "net" + + "github.com/mediocregopher/mediocre-go-lib/mcfg" + "github.com/mediocregopher/mediocre-go-lib/mctx" + "github.com/mediocregopher/mediocre-go-lib/mlog" + "github.com/mediocregopher/mediocre-go-lib/mrun" ) +// ListenerOnStart returns a Listener which will be initialized when the start +// event is triggered on ctx (see mrun.Start). +// +// network defaults to "tcp" if empty. defaultAddr defaults to ":0" if empty, +// and will be configurable via mcfg. +func ListenerOnStart(ctx mctx.Context, network, defaultAddr string) net.Listener { + if network == "" { + network = "tcp" + } + if defaultAddr == "" { + defaultAddr = ":0" + } + addr := mcfg.String(ctx, "addr", defaultAddr, network+" address to listen on in format [host]:port. If port is 0 then a random one will be chosen") + + var l struct{ net.Listener } + mrun.OnStart(ctx, func(mctx.Context) error { + var err error + if l.Listener, err = net.Listen(network, *addr); err != nil { + return err + } + mlog.From(ctx).Info("listening", mlog.KV{"addr": l.Addr()}) + return nil + }) + + return &l +} + +//////////////////////////////////////////////////////////////////////////////// + func mustGetCIDRNetwork(cidr string) *net.IPNet { _, n, err := net.ParseCIDR(cidr) if err != nil { diff --git a/mnet/mnet_test.go b/mnet/mnet_test.go index 64fc7fe..bf878e2 100644 --- a/mnet/mnet_test.go +++ b/mnet/mnet_test.go @@ -1,9 +1,14 @@ package mnet import ( + "fmt" + "io/ioutil" "net" . "testing" + "github.com/mediocregopher/mediocre-go-lib/mcfg" + "github.com/mediocregopher/mediocre-go-lib/mctx" + "github.com/mediocregopher/mediocre-go-lib/mrun" "github.com/mediocregopher/mediocre-go-lib/mtest/massert" ) @@ -31,3 +36,35 @@ func TestIsReservedIP(t *T) { assertReserved("2600:1700:7580:6e80:21c:25ff:fe97:44df"), )) } + +func TestListen(t *T) { + ctx := mctx.New() + l := ListenerOnStart(ctx, "", "") + if err := mcfg.Populate(ctx, nil); err != nil { + t.Fatal(err) + } else if err := mrun.Start(ctx); err != nil { + t.Fatal(err) + } + + go func() { + conn, err := net.Dial("tcp", l.Addr().String()) + if err != nil { + t.Fatal(err) + } else if _, err = fmt.Fprint(conn, "hello world"); err != nil { + t.Fatal(err) + } + conn.Close() + }() + + conn, err := l.Accept() + if err != nil { + t.Fatal(err) + } else if b, err := ioutil.ReadAll(conn); err != nil { + t.Fatal(err) + } else if string(b) != "hello world" { + t.Fatalf("read %q from conn", b) + } + + conn.Close() + l.Close() +}