diff --git a/gracehttp/http.go b/gracehttp/http.go index 14e5e91..a1012a4 100644 --- a/gracehttp/http.go +++ b/gracehttp/http.go @@ -16,40 +16,66 @@ import ( "github.com/daaku/go.grace" ) -type serverSlice []*http.Server - var ( - verbose = flag.Bool("gracehttp.log", true, "Enable logging.") - errListenersCount = errors.New("unexpected listeners count") + verbose = flag.Bool("gracehttp.log", true, "Enable logging.") + errListenersCount = errors.New("unexpected listeners count") + errNoStartedListeners = errors.New("no started listeners") ) -// Creates new listeners for all the given addresses. -func (servers serverSlice) newListeners() ([]grace.Listener, error) { - listeners := make([]grace.Listener, len(servers)) - for index, pair := range servers { - addr, err := net.ResolveTCPAddr("tcp", pair.Addr) +// Defines an application containing various servers and associated +// configuration. +type App struct { + Servers []*http.Server + listeners []grace.Listener + errors chan error +} + +// Inherit or create new listeners. Returns a bool indicating if we inherited +// listeners. This return value is useful in order to decide if we should +// instruct the parent process to terminate. +func (a *App) Listen() (bool, error) { + var err error + a.errors = make(chan error, len(a.Servers)) + a.listeners, err = grace.Inherit() + if err == nil { + if len(a.Servers) != len(a.listeners) { + return true, errListenersCount + } + return true, nil + } else if err == grace.ErrNotInheriting { + if a.listeners, err = a.newListeners(); err != nil { + return false, err + } + return false, nil + } + return false, fmt.Errorf("Failed graceful handoff: %s", err) +} + +// Creates new listeners (as in not inheriting) for all the configured Servers. +func (a *App) newListeners() ([]grace.Listener, error) { + listeners := make([]grace.Listener, len(a.Servers)) + for index, server := range a.Servers { + addr, err := net.ResolveTCPAddr("tcp", server.Addr) if err != nil { - return nil, fmt.Errorf( - "Failed net.ResolveTCPAddr for %s: %s", pair.Addr, err) + return nil, fmt.Errorf("net.ResolveTCPAddr %s: %s", server.Addr, err) } l, err := net.ListenTCP("tcp", addr) if err != nil { - return nil, fmt.Errorf("Failed net.ListenTCP for %s: %s", pair.Addr, err) + return nil, fmt.Errorf("net.ListenTCP %s: %s", server.Addr, err) } listeners[index] = grace.NewListener(l) } return listeners, nil } -// Serve on the given listeners and wait for signals. -func (servers serverSlice) serveWait(listeners []grace.Listener) error { - if len(servers) != len(listeners) { - return errListenersCount - } - errch := make(chan error, len(listeners)+1) // listeners + grace.Wait - for i, l := range listeners { +// Serve the configured servers, but do not block. You must call Wait at some +// point to ensure correctly waiting for graceful termination. +func (a *App) Serve() { + for i, l := range a.listeners { go func(i int, l net.Listener) { - server := servers[i] + server := a.Servers[i] + + // Wrap the listener for TLS support if necessary. if server.TLSConfig != nil { l = tls.NewListener(l, server.TLSConfig) } @@ -58,61 +84,65 @@ func (servers serverSlice) serveWait(listeners []grace.Listener) error { // The underlying Accept() will return grace.ErrAlreadyClosed // when a signal to do the same is returned, which we are okay with. if err != nil && err != grace.ErrAlreadyClosed { - errch <- fmt.Errorf("Failed http.Serve: %s", err) + a.errors <- fmt.Errorf("http.Serve: %s", err) } }(i, l) } - go func() { - err := grace.Wait(listeners) - if err != nil { - errch <- fmt.Errorf("Failed grace.Wait: %s", err) - } else { - errch <- nil - } - }() - return <-errch +} + +// Wait for the serving goroutines to finish. +func (a *App) Wait() error { + waiterr := make(chan error) + go func() { waiterr <- grace.Wait(a.listeners) }() + select { + case err := <-waiterr: + return err + case err := <-a.errors: + return err + } } // Serve will serve the given pairs of addresses and listeners and // will monitor for signals allowing for graceful termination (SIGTERM) // or restart (SIGUSR2). func Serve(servers ...*http.Server) error { - sslice := serverSlice(servers) - listeners, err := grace.Inherit() - if err == nil { - err = grace.CloseParent() - if err != nil { - return fmt.Errorf("Failed to close parent: %s", err) - } - if *verbose { - ppid := os.Getppid() - if ppid == 1 { - log.Printf("Listening on init activated %s", pprintAddr(listeners)) - } else { - log.Printf( - "Graceful handoff of %s with new pid %d and old pid %d.", - pprintAddr(listeners), os.Getpid(), ppid) - } - } - } else if err == grace.ErrNotInheriting { - listeners, err = sslice.newListeners() - if err != nil { - return err - } - if *verbose { - log.Printf("Serving %s with pid %d.", pprintAddr(listeners), os.Getpid()) - } - } else { - return fmt.Errorf("Failed graceful handoff: %s", err) - } - err = sslice.serveWait(listeners) + app := &App{Servers: servers} + inherited, err := app.Listen() if err != nil { return err } + + if *verbose { + if inherited { + ppid := os.Getppid() + if ppid == 1 { + log.Printf("Listening on init activated %s", pprintAddr(app.listeners)) + } else { + const msg = "Graceful handoff of %s with new pid %d and old pid %d" + log.Printf(msg, pprintAddr(app.listeners), os.Getpid(), ppid) + } + } else { + const msg = "Serving %s with pid %d" + log.Printf(msg, pprintAddr(app.listeners), os.Getpid()) + } + } + + app.Serve() + + // Close the parent if we inherited and it wasn't init that started us. + if inherited && os.Getppid() != 1 { + if err := grace.CloseParent(); err != nil { + return fmt.Errorf("Failed to close parent: %s", err) + } + } + + err = app.Wait() + if *verbose { log.Printf("Exiting pid %d.", os.Getpid()) } - return nil + + return err } // Used for pretty printing addresses. diff --git a/gracehttp/http_test.go b/gracehttp/http_test.go index 546ed22..9930c5e 100644 --- a/gracehttp/http_test.go +++ b/gracehttp/http_test.go @@ -7,13 +7,10 @@ import ( "flag" "fmt" "io" - "io/ioutil" - "log" "net" "net/http" "os" "os/exec" - "path/filepath" "sync" "syscall" "testing" @@ -23,8 +20,13 @@ import ( "github.com/daaku/go.tool" ) -// Debug logging. -var debugLog = flag.Bool("debug", false, "enable debug logging") +var ( + // Debug logging. + debugLog = flag.Bool("debug", false, "enable debug logging") + testserverCommand = &tool.CommandBuild{ + ImportPath: "github.com/daaku/go.grace/gracehttp/testserver", + } +) func debug(format string, a ...interface{}) { if *debugLog { @@ -38,34 +40,6 @@ var ( buildOnce sync.Once ) -// Builds the command. -func build(t *testing.T) string { - buildOnce.Do(func() { - const pkg = "github.com/daaku/go.grace/gracehttp/testserver" - basename := filepath.Base(pkg) - tempFile, err := ioutil.TempFile("", basename+"-") - if err != nil { - buildErr = err - return - } - buildOut = tempFile.Name() - _ = os.Remove(buildOut) // the build tool will create this - options := tool.Options{ - ImportPaths: []string{pkg}, - Output: buildOut, - } - _, err = options.Command("build") - if err != nil { - buildErr = err - return - } - }) - if buildErr != nil { - t.Fatal(buildErr) - } - return buildOut -} - // The response from the test server. type response struct { Sleep time.Duration @@ -104,8 +78,13 @@ func (h *harness) setupAddr() { // Start a fresh server and wait for pid updates on restart. func (h *harness) Start() { + bin, err := testserverCommand.Build() + if err != nil { + h.T.Fatalf("build error: %s", err) + } + h.setupAddr() - cmd := exec.Command(build(h.T), "-http", h.httpAddr, "-https", h.httpsAddr) + cmd := exec.Command(bin, "-http", h.httpAddr, "-https", h.httpsAddr) stderr, err := cmd.StderrPipe() if err != nil { h.T.Fatal(err) @@ -118,22 +97,22 @@ func (h *harness) Start() { return } if err != nil { - log.Fatalf("Failed to read line from server process: %s", err) + println(fmt.Sprintf("Failed to read line from server process: %s", err)) } if isPrefix { - log.Fatalf("Deal with isPrefix for line: %s", line) + println(fmt.Sprintf("Deal with isPrefix for line: %s", line)) } res := &response{} err = json.Unmarshal([]byte(line), res) if err != nil { - log.Fatalf("Could not parse json from stderr %s: %s", line, err) + println(fmt.Sprintf("Could not parse json from stderr %s: %s", line, err)) } if res.Error != "" { println(fmt.Sprintf("Got error from process: %v", res)) } process, err := os.FindProcess(res.Pid) if err != nil { - log.Fatalf("Could not find process with pid: %d", res.Pid) + println(fmt.Sprintf("Could not find process with pid: %d", res.Pid)) } h.ProcessMutex.Lock() h.Process = append(h.Process, process)