diff --git a/gracehttp/http.go b/gracehttp/http.go index 73e7595..93e26a6 100644 --- a/gracehttp/http.go +++ b/gracehttp/http.go @@ -5,152 +5,182 @@ package gracehttp import ( "bytes" "crypto/tls" - "errors" "flag" "fmt" "log" "net" "net/http" "os" + "os/signal" + "sync" + "syscall" - "github.com/facebookgo/grace" + "github.com/facebookgo/grace/gracenet" + "github.com/facebookgo/httpdown" ) var ( - verbose = flag.Bool("gracehttp.log", true, "Enable logging.") - errListenersCount = errors.New("unexpected listeners count") - errNoStartedListeners = errors.New("no started listeners") + verbose = flag.Bool("gracehttp.log", true, "Enable logging.") + didInherit = os.Getenv("LISTEN_FDS") != "" + ppid = os.Getppid() ) -// An App contains one or more servers and associated configuration. -type App struct { - Servers []*http.Server - listeners []grace.Listener +// An app contains one or more servers and associated configuration. +type app struct { + servers []*http.Server + http *httpdown.HTTP + net *gracenet.Net + listeners []net.Listener + sds []httpdown.Server errors chan error } -// Listen will inherit or create new listeners. Returns a bool indicating if we -// inherited listeners. This return value is useful in order to decide if we -// should instruct the parent process to terminate. -func (a *App) Listen() (bool, error) { - var err error - a.errors = make(chan error, len(a.Servers)) - a.listeners, err = grace.Inherit() - if err == nil { - if len(a.Servers) != len(a.listeners) { - return true, errListenersCount - } - return true, nil - } else if err == grace.ErrNotInheriting { - if a.listeners, err = a.newListeners(); err != nil { - return false, err - } - return false, nil +func newApp(servers []*http.Server) *app { + return &app{ + servers: servers, + http: &httpdown.HTTP{}, + net: &gracenet.Net{}, + listeners: make([]net.Listener, 0, len(servers)), + sds: make([]httpdown.Server, 0, len(servers)), + + // 2x num servers for possible Close or Stop errors + 1 for possible + // StartProcess error. + errors: make(chan error, 1+(len(servers)*2)), } - return false, fmt.Errorf("failed graceful handoff: %s", err) } -// Creates new listeners (as in not inheriting) for all the configured Servers. -func (a *App) newListeners() ([]grace.Listener, error) { - listeners := make([]grace.Listener, len(a.Servers)) - for index, server := range a.Servers { - addr, err := net.ResolveTCPAddr("tcp", server.Addr) +func (a *app) listen() error { + for _, s := range a.servers { + // TODO: default addresses + l, err := a.net.Listen("tcp", s.Addr) if err != nil { - return nil, fmt.Errorf("net.ResolveTCPAddr %s: %s", server.Addr, err) + return err } - l, err := net.ListenTCP("tcp", addr) - if err != nil { - return nil, fmt.Errorf("net.ListenTCP %s: %s", server.Addr, err) + if s.TLSConfig != nil { + l = tls.NewListener(l, s.TLSConfig) } - listeners[index] = grace.NewListener(l) + a.listeners = append(a.listeners, l) } - return listeners, nil + return nil } -// Serve the configured servers, but do not block. You must call Wait at some -// point to ensure correctly waiting for graceful termination. -func (a *App) Serve() { - for i, l := range a.listeners { - go func(i int, l net.Listener) { - server := a.Servers[i] - - // Wrap the listener for TLS support if necessary. - if server.TLSConfig != nil { - l = tls.NewListener(l, server.TLSConfig) - } - - err := server.Serve(l) - // The underlying Accept() will return grace.ErrAlreadyClosed - // when a signal to do the same is returned, which we are okay with. - if err != nil && err != grace.ErrAlreadyClosed { - a.errors <- fmt.Errorf("http.Serve: %s", err) - } - }(i, l) +func (a *app) serve() { + for i, s := range a.servers { + a.sds = append(a.sds, a.http.Serve(s, a.listeners[i])) } } -// Wait for the serving goroutines to finish. -func (a *App) Wait() error { - waiterr := make(chan error) - go func() { waiterr <- grace.Wait(a.listeners) }() - select { - case err := <-waiterr: - return err - case err := <-a.errors: - return err +func (a *app) wait() { + var wg sync.WaitGroup + wg.Add(len(a.sds) * 2) // Wait & Stop + go a.signalHandler(&wg) + for _, s := range a.sds { + go func(s httpdown.Server) { + defer wg.Done() + if err := s.Wait(); err != nil { + a.errors <- err + } + }(s) + } + wg.Wait() +} + +func (a *app) term(wg *sync.WaitGroup) { + for _, s := range a.sds { + go func(s httpdown.Server) { + defer wg.Done() + if err := s.Stop(); err != nil { + a.errors <- err + } + }(s) + } +} + +func (a *app) signalHandler(wg *sync.WaitGroup) { + ch := make(chan os.Signal, 10) + signal.Notify(ch, syscall.SIGTERM, syscall.SIGUSR2) + for { + sig := <-ch + switch sig { + case syscall.SIGTERM: + // this ensures a subsequent TERM will trigger standard go behaviour of + // terminating. + signal.Stop(ch) + a.term(wg) + return + case syscall.SIGUSR2: + // we only return here if there's an error, otherwise the new process + // will send us a TERM when it's ready to trigger the actual shutdown. + if _, err := a.net.StartProcess(); err != nil { + a.errors <- err + } + } } } // Serve will serve the given http.Servers and will monitor for signals // allowing for graceful termination (SIGTERM) or restart (SIGUSR2). func Serve(servers ...*http.Server) error { - app := &App{Servers: servers} - inherited, err := app.Listen() - if err != nil { + a := newApp(servers) + + // Acquire Listeners + if err := a.listen(); err != nil { return err } + // Some useful logging. if *verbose { - if inherited { - ppid := os.Getppid() + if didInherit { if ppid == 1 { - log.Printf("Listening on init activated %s", pprintAddr(app.listeners)) + log.Printf("Listening on init activated %s", pprintAddr(a.listeners)) } else { const msg = "Graceful handoff of %s with new pid %d and old pid %d" - log.Printf(msg, pprintAddr(app.listeners), os.Getpid(), ppid) + log.Printf(msg, pprintAddr(a.listeners), os.Getpid(), ppid) } } else { const msg = "Serving %s with pid %d" - log.Printf(msg, pprintAddr(app.listeners), os.Getpid()) + log.Printf(msg, pprintAddr(a.listeners), os.Getpid()) } } - app.Serve() + // Start serving. + a.serve() // Close the parent if we inherited and it wasn't init that started us. - if inherited && os.Getppid() != 1 { - if err := grace.CloseParent(); err != nil { + if didInherit && ppid != 1 { + if err := syscall.Kill(ppid, syscall.SIGTERM); err != nil { return fmt.Errorf("failed to close parent: %s", err) } } - err = app.Wait() + waitdone := make(chan struct{}) + go func() { + defer close(waitdone) + a.wait() + }() - if *verbose { - log.Printf("Exiting pid %d.", os.Getpid()) + select { + case err := <-a.errors: + if err == nil { + panic("unexpected nil error") + } + return err + case <-waitdone: + if *verbose { + log.Printf("Exiting pid %d.", os.Getpid()) + } + return nil } - - return err } // Used for pretty printing addresses. -func pprintAddr(listeners []grace.Listener) []byte { - out := bytes.NewBuffer(nil) +func pprintAddr(listeners []net.Listener) []byte { + var out bytes.Buffer for i, l := range listeners { if i != 0 { - fmt.Fprint(out, ", ") + fmt.Fprint(&out, ", ") } - fmt.Fprint(out, l.Addr()) + fmt.Fprint(&out, l.Addr()) } return out.Bytes() } diff --git a/gracehttp/http_test.go b/gracehttp/http_test.go index 7ce36f9..6d9c1fe 100644 --- a/gracehttp/http_test.go +++ b/gracehttp/http_test.go @@ -171,7 +171,6 @@ func (h *harness) SendOne(dialgroup *sync.WaitGroup, url string, pid int) { debug("Send %02d pid=%d url=%s", count, pid, url) client := &http.Client{ Transport: &http.Transport{ - DisableKeepAlives: true, Dial: func(network, addr string) (net.Conn, error) { defer func() { time.Sleep(50 * time.Millisecond) diff --git a/gracenet/net.go b/gracenet/net.go new file mode 100644 index 0000000..268b524 --- /dev/null +++ b/gracenet/net.go @@ -0,0 +1,252 @@ +// Package gracenet provides a family of Listen functions that either open a +// fresh connection or provide an inherited connection from when the process +// was started. The behave like their counterparts in the net pacakge, but +// transparently provide support for graceful restarts without dropping +// connections. This is provided in a systemd socket activation compatible form +// to allow using socket activation. +// +// BUG: Doesn't handle closing of listeners. +package gracenet + +import ( + "fmt" + "net" + "os" + "os/exec" + "strconv" + "strings" + "sync" +) + +const ( + // Used to indicate a graceful restart in the new process. + envCountKey = "LISTEN_FDS" + envCountKeyPrefix = envCountKey + "=" +) + +// In order to keep the working directory the same as when we started we record +// it at startup. +var originalWD, _ = os.Getwd() + +// Net provides the family of Listen functions and maintains the associated +// state. Typically you will have only once instance of Net per application. +type Net struct { + inherited []net.Listener + active []net.Listener + mutex sync.Mutex + inheritOnce sync.Once + + // used in tests to override the default behavior of starting from fd 3. + fdStart int +} + +func (n *Net) inherit() error { + var retErr error + n.inheritOnce.Do(func() { + n.mutex.Lock() + defer n.mutex.Unlock() + countStr := os.Getenv(envCountKey) + if countStr == "" { + return + } + count, err := strconv.Atoi(countStr) + if err != nil { + retErr = fmt.Errorf("found invalid count value: %s=%s", envCountKey, countStr) + return + } + + // In tests this may be overridden. + fdStart := n.fdStart + if fdStart == 0 { + // In normal operations if we are inheriting, the listeners will begin at + // fd 3. + fdStart = 3 + } + + for i := fdStart; i < fdStart+count; i++ { + file := os.NewFile(uintptr(i), "listener") + l, err := net.FileListener(file) + if err != nil { + file.Close() + retErr = fmt.Errorf("error inheriting socket fd %d: %s", i, err) + return + } + if err := file.Close(); err != nil { + retErr = fmt.Errorf("error closing inherited socket fd %d: %s", i, err) + return + } + n.inherited = append(n.inherited, l) + } + }) + return retErr +} + +// Listen announces on the local network address laddr. The network net must be +// a stream-oriented network: "tcp", "tcp4", "tcp6", "unix" or "unixpacket". It +// returns an inherited net.Listener for the matching network and address, or +// creates a new one using net.Listen. +func (n *Net) Listen(nett, laddr string) (net.Listener, error) { + switch nett { + default: + return nil, net.UnknownNetworkError(nett) + case "tcp", "tcp4", "tcp6": + addr, err := net.ResolveTCPAddr(nett, laddr) + if err != nil { + return nil, err + } + return n.ListenTCP(nett, addr) + case "unix", "unixpacket", "invalid_unix_net_for_test": + addr, err := net.ResolveUnixAddr(nett, laddr) + if err != nil { + return nil, err + } + return n.ListenUnix(nett, addr) + } +} + +// ListenTCP announces on the local network address laddr. The network net must +// be: "tcp", "tcp4" or "tcp6". It returns an inherited net.Listener for the +// matching network and address, or creates a new one using net.ListenTCP. +func (n *Net) ListenTCP(nett string, laddr *net.TCPAddr) (*net.TCPListener, error) { + if err := n.inherit(); err != nil { + return nil, err + } + + n.mutex.Lock() + defer n.mutex.Unlock() + + // look for an inherited listener + for i, l := range n.inherited { + if l == nil { // we nil used inherited listeners + continue + } + if isSameAddr(l.Addr(), laddr) { + n.inherited[i] = nil + n.active = append(n.active, l) + return l.(*net.TCPListener), nil + } + } + + // make a fresh listener + l, err := net.ListenTCP(nett, laddr) + if err != nil { + return nil, err + } + n.active = append(n.active, l) + return l, nil +} + +// ListenUnix announces on the local network address laddr. The network net +// must be a: "unix" or "unixpacket". It returns an inherited net.Listener for +// the matching network and address, or creates a new one using net.ListenUnix. +func (n *Net) ListenUnix(nett string, laddr *net.UnixAddr) (*net.UnixListener, error) { + if err := n.inherit(); err != nil { + return nil, err + } + + n.mutex.Lock() + defer n.mutex.Unlock() + + // look for an inherited listener + for i, l := range n.inherited { + if l == nil { // we nil used inherited listeners + continue + } + if isSameAddr(l.Addr(), laddr) { + n.inherited[i] = nil + n.active = append(n.active, l) + return l.(*net.UnixListener), nil + } + } + + // make a fresh listener + l, err := net.ListenUnix(nett, laddr) + if err != nil { + return nil, err + } + n.active = append(n.active, l) + return l, nil +} + +// activeListeners returns a snapshot copy of the active listeners. +func (n *Net) activeListeners() ([]net.Listener, error) { + n.mutex.Lock() + defer n.mutex.Unlock() + ls := make([]net.Listener, len(n.active)) + copy(ls, n.active) + return ls, nil +} + +func isSameAddr(a1, a2 net.Addr) bool { + if a1.Network() != a2.Network() { + return false + } + a1s := a1.String() + a2s := a2.String() + if a1s == a2s { + return true + } + + // This allows for ipv6 vs ipv4 local addresses to compare as equal. This + // scenario is common when listening on localhost. + const ipv6prefix = "[::]" + a1s = strings.TrimPrefix(a1s, ipv6prefix) + a2s = strings.TrimPrefix(a2s, ipv6prefix) + const ipv4prefix = "0.0.0.0" + a1s = strings.TrimPrefix(a1s, ipv4prefix) + a2s = strings.TrimPrefix(a2s, ipv4prefix) + return a1s == a2s +} + +// StartProcess starts a new process passing it the active listeners. It +// doesn't fork, but starts a new process using the same environment and +// arguments as when it was originally started. This allows for a newly +// deployed binary to be started. It returns the pid of the newly started +// process when successful. +func (n *Net) StartProcess() (int, error) { + listeners, err := n.activeListeners() + if err != nil { + return 0, err + } + + // Extract the fds from the listeners. + files := make([]*os.File, len(listeners)) + for i, l := range listeners { + files[i], err = l.(filer).File() + if err != nil { + return 0, err + } + defer files[i].Close() + } + + // Use the original binary location. This works with symlinks such that if + // the file it points to has been changed we will use the updated symlink. + argv0, err := exec.LookPath(os.Args[0]) + if err != nil { + return 0, err + } + + // Pass on the environment and replace the old count key with the new one. + var env []string + for _, v := range os.Environ() { + if !strings.HasPrefix(v, envCountKeyPrefix) { + env = append(env, v) + } + } + env = append(env, fmt.Sprintf("%s%d", envCountKeyPrefix, len(listeners))) + + allFiles := append([]*os.File{os.Stdin, os.Stdout, os.Stderr}, files...) + process, err := os.StartProcess(argv0, os.Args, &os.ProcAttr{ + Dir: originalWD, + Env: env, + Files: allFiles, + }) + if err != nil { + return 0, err + } + return process.Pid, nil +} + +type filer interface { + File() (*os.File, error) +} diff --git a/gracenet/net_test.go b/gracenet/net_test.go new file mode 100644 index 0000000..8b79476 --- /dev/null +++ b/gracenet/net_test.go @@ -0,0 +1,359 @@ +package gracenet + +import ( + "fmt" + "io/ioutil" + "net" + "os" + "path/filepath" + "regexp" + "syscall" + "testing" + + "github.com/facebookgo/ensure" + "github.com/facebookgo/freeport" +) + +func TestEmptyCountEnvVariable(t *testing.T) { + var n Net + os.Setenv(envCountKey, "") + ensure.Nil(t, n.inherit()) +} + +func TestZeroCountEnvVariable(t *testing.T) { + var n Net + os.Setenv(envCountKey, "0") + ensure.Nil(t, n.inherit()) +} + +func TestInvalidCountEnvVariable(t *testing.T) { + var n Net + os.Setenv(envCountKey, "a") + expected := regexp.MustCompile("^found invalid count value: LISTEN_FDS=a$") + ensure.Err(t, n.inherit(), expected) +} + +func TestInvalidFileInherit(t *testing.T) { + var n Net + tmpfile, err := ioutil.TempFile("", "TestInvalidFileInherit-") + ensure.Nil(t, err) + defer os.Remove(tmpfile.Name()) + n.fdStart = dup(t, int(tmpfile.Fd())) + os.Setenv(envCountKey, "1") + ensure.Err(t, n.inherit(), regexp.MustCompile("^error inheriting socket fd")) + ensure.DeepEqual(t, len(n.inherited), 0) + ensure.Nil(t, tmpfile.Close()) +} + +func TestInheritErrorOnListenTCPWithInvalidCount(t *testing.T) { + var n Net + os.Setenv(envCountKey, "a") + _, err := n.Listen("tcp", ":0") + ensure.NotNil(t, err) +} + +func TestInheritErrorOnListenUnixWithInvalidCount(t *testing.T) { + var n Net + os.Setenv(envCountKey, "a") + tmpdir, err := ioutil.TempDir("", "TestInheritErrorOnListenUnixWithInvalidCount-") + ensure.Nil(t, err) + ensure.Nil(t, os.RemoveAll(tmpdir)) + _, err = n.Listen("unix", filepath.Join(tmpdir, "socket")) + ensure.NotNil(t, err) +} + +func TestOneTcpInherit(t *testing.T) { + var n Net + l, err := net.Listen("tcp", ":0") + ensure.Nil(t, err) + file, err := l.(*net.TCPListener).File() + ensure.Nil(t, err) + ensure.Nil(t, l.Close()) + n.fdStart = dup(t, int(file.Fd())) + ensure.Nil(t, file.Close()) + os.Setenv(envCountKey, "1") + ensure.Nil(t, n.inherit()) + ensure.DeepEqual(t, len(n.inherited), 1) + l, err = n.Listen("tcp", l.Addr().String()) + ensure.Nil(t, err) + ensure.DeepEqual(t, len(n.active), 1) + ensure.DeepEqual(t, n.inherited[0], nil) + active, err := n.activeListeners() + ensure.Nil(t, err) + ensure.DeepEqual(t, len(active), 1) + ensure.Nil(t, l.Close()) +} + +func TestSecondTcpListen(t *testing.T) { + var n Net + os.Setenv(envCountKey, "") + l, err := n.Listen("tcp", ":0") + ensure.Nil(t, err) + _, err = n.Listen("tcp", l.Addr().String()) + ensure.Err(t, err, regexp.MustCompile("bind: address already in use$")) + ensure.Nil(t, l.Close()) +} + +func TestSecondTcpListenInherited(t *testing.T) { + var n Net + l, err := net.Listen("tcp", ":0") + ensure.Nil(t, err) + file, err := l.(*net.TCPListener).File() + ensure.Nil(t, err) + ensure.Nil(t, l.Close()) + n.fdStart = dup(t, int(file.Fd())) + ensure.Nil(t, file.Close()) + os.Setenv(envCountKey, "1") + ensure.Nil(t, n.inherit()) + ensure.DeepEqual(t, len(n.inherited), 1) + l, err = n.Listen("tcp", l.Addr().String()) + ensure.Nil(t, err) + ensure.DeepEqual(t, len(n.active), 1) + ensure.DeepEqual(t, n.inherited[0], nil) + _, err = n.Listen("tcp", l.Addr().String()) + ensure.Err(t, err, regexp.MustCompile("bind: address already in use$")) + ensure.Nil(t, l.Close()) +} + +func TestInvalidNetwork(t *testing.T) { + var n Net + os.Setenv(envCountKey, "") + _, err := n.Listen("foo", "") + ensure.Err(t, err, regexp.MustCompile("^unknown network foo$")) +} + +func TestInvalidNetworkUnix(t *testing.T) { + var n Net + os.Setenv(envCountKey, "") + _, err := n.Listen("invalid_unix_net_for_test", "") + ensure.Err(t, err, regexp.MustCompile("^unknown network invalid_unix_net_for_test$")) +} + +func TestWithTcp0000(t *testing.T) { + var n Net + port, err := freeport.Get() + ensure.Nil(t, err) + addr := fmt.Sprintf("0.0.0.0:%d", port) + l, err := net.Listen("tcp", addr) + ensure.Nil(t, err) + file, err := l.(*net.TCPListener).File() + ensure.Nil(t, err) + ensure.Nil(t, l.Close()) + n.fdStart = dup(t, int(file.Fd())) + ensure.Nil(t, file.Close()) + os.Setenv(envCountKey, "1") + ensure.Nil(t, n.inherit()) + ensure.DeepEqual(t, len(n.inherited), 1) + l, err = n.Listen("tcp", addr) + ensure.Nil(t, err) + ensure.DeepEqual(t, len(n.active), 1) + ensure.DeepEqual(t, n.inherited[0], nil) + ensure.Nil(t, l.Close()) +} + +func TestWithTcpIPv6Loal(t *testing.T) { + var n Net + l, err := net.Listen("tcp", "[::]:0") + ensure.Nil(t, err) + file, err := l.(*net.TCPListener).File() + ensure.Nil(t, err) + ensure.Nil(t, l.Close()) + n.fdStart = dup(t, int(file.Fd())) + ensure.Nil(t, file.Close()) + os.Setenv(envCountKey, "1") + ensure.Nil(t, n.inherit()) + ensure.DeepEqual(t, len(n.inherited), 1) + l, err = n.Listen("tcp", l.Addr().String()) + ensure.Nil(t, err) + ensure.DeepEqual(t, len(n.active), 1) + ensure.DeepEqual(t, n.inherited[0], nil) + ensure.Nil(t, l.Close()) +} + +func TestOneUnixInherit(t *testing.T) { + var n Net + tmpfile, err := ioutil.TempFile("", "TestOneUnixInherit-") + ensure.Nil(t, err) + ensure.Nil(t, tmpfile.Close()) + ensure.Nil(t, os.Remove(tmpfile.Name())) + defer os.Remove(tmpfile.Name()) + l, err := net.Listen("unix", tmpfile.Name()) + ensure.Nil(t, err) + file, err := l.(*net.UnixListener).File() + ensure.Nil(t, err) + ensure.Nil(t, l.Close()) + n.fdStart = dup(t, int(file.Fd())) + ensure.Nil(t, file.Close()) + os.Setenv(envCountKey, "1") + ensure.Nil(t, n.inherit()) + ensure.DeepEqual(t, len(n.inherited), 1) + l, err = n.Listen("unix", tmpfile.Name()) + ensure.Nil(t, err) + ensure.DeepEqual(t, len(n.active), 1) + ensure.DeepEqual(t, n.inherited[0], nil) + ensure.Nil(t, l.Close()) +} + +func TestInvalidTcpAddr(t *testing.T) { + var n Net + os.Setenv(envCountKey, "") + _, err := n.Listen("tcp", "abc") + ensure.Err(t, err, regexp.MustCompile("^missing port in address abc$")) +} + +func TestTwoTCP(t *testing.T) { + var n Net + + port1, err := freeport.Get() + ensure.Nil(t, err) + addr1 := fmt.Sprintf(":%d", port1) + l1, err := net.Listen("tcp", addr1) + ensure.Nil(t, err) + + port2, err := freeport.Get() + ensure.Nil(t, err) + addr2 := fmt.Sprintf(":%d", port2) + l2, err := net.Listen("tcp", addr2) + ensure.Nil(t, err) + + file1, err := l1.(*net.TCPListener).File() + ensure.Nil(t, err) + file2, err := l2.(*net.TCPListener).File() + ensure.Nil(t, err) + + // assign both to prevent GC from kicking in the finalizer + fds := []int{dup(t, int(file1.Fd())), dup(t, int(file2.Fd()))} + n.fdStart = fds[0] + os.Setenv(envCountKey, "2") + + // Close these after to ensure we get coalaced file descriptors. + ensure.Nil(t, l1.Close()) + ensure.Nil(t, l2.Close()) + + ensure.Nil(t, n.inherit()) + ensure.DeepEqual(t, len(n.inherited), 2) + + l1, err = n.Listen("tcp", addr1) + ensure.Nil(t, err) + ensure.DeepEqual(t, len(n.active), 1) + ensure.DeepEqual(t, n.inherited[0], nil) + ensure.Nil(t, l1.Close()) + ensure.Nil(t, file1.Close()) + + l2, err = n.Listen("tcp", addr2) + ensure.Nil(t, err) + ensure.DeepEqual(t, len(n.active), 2) + ensure.DeepEqual(t, n.inherited[1], nil) + ensure.Nil(t, l2.Close()) + ensure.Nil(t, file2.Close()) +} + +func TestOneUnixAndOneTcpInherit(t *testing.T) { + var n Net + + tmpfile, err := ioutil.TempFile("", "TestOneUnixAndOneTcpInherit-") + ensure.Nil(t, err) + ensure.Nil(t, tmpfile.Close()) + ensure.Nil(t, os.Remove(tmpfile.Name())) + defer os.Remove(tmpfile.Name()) + unixL, err := net.Listen("unix", tmpfile.Name()) + ensure.Nil(t, err) + + port, err := freeport.Get() + ensure.Nil(t, err) + addr := fmt.Sprintf(":%d", port) + tcpL, err := net.Listen("tcp", addr) + ensure.Nil(t, err) + + tcpF, err := tcpL.(*net.TCPListener).File() + ensure.Nil(t, err) + unixF, err := unixL.(*net.UnixListener).File() + ensure.Nil(t, err) + + // assign both to prevent GC from kicking in the finalizer + fds := []int{dup(t, int(tcpF.Fd())), dup(t, int(unixF.Fd()))} + n.fdStart = fds[0] + os.Setenv(envCountKey, "2") + + // Close these after to ensure we get coalaced file descriptors. + ensure.Nil(t, tcpL.Close()) + ensure.Nil(t, unixL.Close()) + + ensure.Nil(t, n.inherit()) + ensure.DeepEqual(t, len(n.inherited), 2) + + unixL, err = n.Listen("unix", tmpfile.Name()) + ensure.Nil(t, err) + ensure.DeepEqual(t, len(n.active), 1) + ensure.DeepEqual(t, n.inherited[1], nil) + ensure.Nil(t, unixL.Close()) + ensure.Nil(t, unixF.Close()) + + tcpL, err = n.Listen("tcp", addr) + ensure.Nil(t, err) + ensure.DeepEqual(t, len(n.active), 2) + ensure.DeepEqual(t, n.inherited[0], nil) + ensure.Nil(t, tcpL.Close()) + ensure.Nil(t, tcpF.Close()) +} + +func TestSecondUnixListen(t *testing.T) { + var n Net + os.Setenv(envCountKey, "") + tmpfile, err := ioutil.TempFile("", "TestSecondUnixListen-") + ensure.Nil(t, err) + ensure.Nil(t, tmpfile.Close()) + ensure.Nil(t, os.Remove(tmpfile.Name())) + defer os.Remove(tmpfile.Name()) + l, err := n.Listen("unix", tmpfile.Name()) + ensure.Nil(t, err) + _, err = n.Listen("unix", tmpfile.Name()) + ensure.Err(t, err, regexp.MustCompile("bind: address already in use$")) + ensure.Nil(t, l.Close()) +} + +func TestSecondUnixListenInherited(t *testing.T) { + var n Net + tmpfile, err := ioutil.TempFile("", "TestSecondUnixListenInherited-") + ensure.Nil(t, err) + ensure.Nil(t, tmpfile.Close()) + ensure.Nil(t, os.Remove(tmpfile.Name())) + defer os.Remove(tmpfile.Name()) + l1, err := net.Listen("unix", tmpfile.Name()) + ensure.Nil(t, err) + file, err := l1.(*net.UnixListener).File() + ensure.Nil(t, err) + n.fdStart = dup(t, int(file.Fd())) + ensure.Nil(t, file.Close()) + os.Setenv(envCountKey, "1") + ensure.Nil(t, n.inherit()) + ensure.DeepEqual(t, len(n.inherited), 1) + l2, err := n.Listen("unix", tmpfile.Name()) + ensure.Nil(t, err) + ensure.DeepEqual(t, len(n.active), 1) + ensure.DeepEqual(t, n.inherited[0], nil) + _, err = n.Listen("unix", tmpfile.Name()) + ensure.Err(t, err, regexp.MustCompile("bind: address already in use$")) + ensure.Nil(t, l1.Close()) + ensure.Nil(t, l2.Close()) +} + +func TestPortZeroTwice(t *testing.T) { + var n Net + os.Setenv(envCountKey, "") + l1, err := n.Listen("tcp", ":0") + ensure.Nil(t, err) + l2, err := n.Listen("tcp", ":0") + ensure.Nil(t, err) + ensure.Nil(t, l1.Close()) + ensure.Nil(t, l2.Close()) +} + +// We dup file descriptors because the os.Files are closed by a finalizer when +// they are GCed, which interacts badly with the fact that the OS reuses fds, +// and that we emulating inheriting the fd by it's integer value in our tests. +func dup(t *testing.T, fd int) int { + nfd, err := syscall.Dup(fd) + ensure.Nil(t, err) + return nfd +}