mnet: implement ListenerOnStart

This commit is contained in:
Brian Picciano 2019-01-11 17:47:30 -05:00
parent 96db88b7d0
commit 8a8cebd127
2 changed files with 71 additions and 0 deletions

View File

@ -4,8 +4,42 @@ package mnet
import ( import (
"net" "net"
"github.com/mediocregopher/mediocre-go-lib/mcfg"
"github.com/mediocregopher/mediocre-go-lib/mctx"
"github.com/mediocregopher/mediocre-go-lib/mlog"
"github.com/mediocregopher/mediocre-go-lib/mrun"
) )
// ListenerOnStart returns a Listener which will be initialized when the start
// event is triggered on ctx (see mrun.Start).
//
// network defaults to "tcp" if empty. defaultAddr defaults to ":0" if empty,
// and will be configurable via mcfg.
func ListenerOnStart(ctx mctx.Context, network, defaultAddr string) net.Listener {
if network == "" {
network = "tcp"
}
if defaultAddr == "" {
defaultAddr = ":0"
}
addr := mcfg.String(ctx, "addr", defaultAddr, network+" address to listen on in format [host]:port. If port is 0 then a random one will be chosen")
var l struct{ net.Listener }
mrun.OnStart(ctx, func(mctx.Context) error {
var err error
if l.Listener, err = net.Listen(network, *addr); err != nil {
return err
}
mlog.From(ctx).Info("listening", mlog.KV{"addr": l.Addr()})
return nil
})
return &l
}
////////////////////////////////////////////////////////////////////////////////
func mustGetCIDRNetwork(cidr string) *net.IPNet { func mustGetCIDRNetwork(cidr string) *net.IPNet {
_, n, err := net.ParseCIDR(cidr) _, n, err := net.ParseCIDR(cidr)
if err != nil { if err != nil {

View File

@ -1,9 +1,14 @@
package mnet package mnet
import ( import (
"fmt"
"io/ioutil"
"net" "net"
. "testing" . "testing"
"github.com/mediocregopher/mediocre-go-lib/mcfg"
"github.com/mediocregopher/mediocre-go-lib/mctx"
"github.com/mediocregopher/mediocre-go-lib/mrun"
"github.com/mediocregopher/mediocre-go-lib/mtest/massert" "github.com/mediocregopher/mediocre-go-lib/mtest/massert"
) )
@ -31,3 +36,35 @@ func TestIsReservedIP(t *T) {
assertReserved("2600:1700:7580:6e80:21c:25ff:fe97:44df"), assertReserved("2600:1700:7580:6e80:21c:25ff:fe97:44df"),
)) ))
} }
func TestListen(t *T) {
ctx := mctx.New()
l := ListenerOnStart(ctx, "", "")
if err := mcfg.Populate(ctx, nil); err != nil {
t.Fatal(err)
} else if err := mrun.Start(ctx); err != nil {
t.Fatal(err)
}
go func() {
conn, err := net.Dial("tcp", l.Addr().String())
if err != nil {
t.Fatal(err)
} else if _, err = fmt.Fprint(conn, "hello world"); err != nil {
t.Fatal(err)
}
conn.Close()
}()
conn, err := l.Accept()
if err != nil {
t.Fatal(err)
} else if b, err := ioutil.ReadAll(conn); err != nil {
t.Fatal(err)
} else if string(b) != "hello world" {
t.Fatalf("read %q from conn", b)
}
conn.Close()
l.Close()
}