refactor gracehttp for readability

This commit is contained in:
Naitik Shah 2013-10-15 11:54:28 -07:00
parent aad68df4be
commit d1f693d1d4
2 changed files with 107 additions and 98 deletions

View File

@ -16,40 +16,66 @@ import (
"github.com/daaku/go.grace" "github.com/daaku/go.grace"
) )
type serverSlice []*http.Server
var ( var (
verbose = flag.Bool("gracehttp.log", true, "Enable logging.") verbose = flag.Bool("gracehttp.log", true, "Enable logging.")
errListenersCount = errors.New("unexpected listeners count") errListenersCount = errors.New("unexpected listeners count")
errNoStartedListeners = errors.New("no started listeners")
) )
// Creates new listeners for all the given addresses. // Defines an application containing various servers and associated
func (servers serverSlice) newListeners() ([]grace.Listener, error) { // configuration.
listeners := make([]grace.Listener, len(servers)) type App struct {
for index, pair := range servers { Servers []*http.Server
addr, err := net.ResolveTCPAddr("tcp", pair.Addr) listeners []grace.Listener
errors chan error
}
// Inherit or create new listeners. Returns a bool indicating if we inherited
// listeners. This return value is useful in order to decide if we should
// instruct the parent process to terminate.
func (a *App) Listen() (bool, error) {
var err error
a.errors = make(chan error, len(a.Servers))
a.listeners, err = grace.Inherit()
if err == nil {
if len(a.Servers) != len(a.listeners) {
return true, errListenersCount
}
return true, nil
} else if err == grace.ErrNotInheriting {
if a.listeners, err = a.newListeners(); err != nil {
return false, err
}
return false, nil
}
return false, fmt.Errorf("Failed graceful handoff: %s", err)
}
// Creates new listeners (as in not inheriting) for all the configured Servers.
func (a *App) newListeners() ([]grace.Listener, error) {
listeners := make([]grace.Listener, len(a.Servers))
for index, server := range a.Servers {
addr, err := net.ResolveTCPAddr("tcp", server.Addr)
if err != nil { if err != nil {
return nil, fmt.Errorf( return nil, fmt.Errorf("net.ResolveTCPAddr %s: %s", server.Addr, err)
"Failed net.ResolveTCPAddr for %s: %s", pair.Addr, err)
} }
l, err := net.ListenTCP("tcp", addr) l, err := net.ListenTCP("tcp", addr)
if err != nil { if err != nil {
return nil, fmt.Errorf("Failed net.ListenTCP for %s: %s", pair.Addr, err) return nil, fmt.Errorf("net.ListenTCP %s: %s", server.Addr, err)
} }
listeners[index] = grace.NewListener(l) listeners[index] = grace.NewListener(l)
} }
return listeners, nil return listeners, nil
} }
// Serve on the given listeners and wait for signals. // Serve the configured servers, but do not block. You must call Wait at some
func (servers serverSlice) serveWait(listeners []grace.Listener) error { // point to ensure correctly waiting for graceful termination.
if len(servers) != len(listeners) { func (a *App) Serve() {
return errListenersCount for i, l := range a.listeners {
}
errch := make(chan error, len(listeners)+1) // listeners + grace.Wait
for i, l := range listeners {
go func(i int, l net.Listener) { go func(i int, l net.Listener) {
server := servers[i] server := a.Servers[i]
// Wrap the listener for TLS support if necessary.
if server.TLSConfig != nil { if server.TLSConfig != nil {
l = tls.NewListener(l, server.TLSConfig) l = tls.NewListener(l, server.TLSConfig)
} }
@ -58,61 +84,65 @@ func (servers serverSlice) serveWait(listeners []grace.Listener) error {
// The underlying Accept() will return grace.ErrAlreadyClosed // The underlying Accept() will return grace.ErrAlreadyClosed
// when a signal to do the same is returned, which we are okay with. // when a signal to do the same is returned, which we are okay with.
if err != nil && err != grace.ErrAlreadyClosed { if err != nil && err != grace.ErrAlreadyClosed {
errch <- fmt.Errorf("Failed http.Serve: %s", err) a.errors <- fmt.Errorf("http.Serve: %s", err)
} }
}(i, l) }(i, l)
} }
go func() { }
err := grace.Wait(listeners)
if err != nil { // Wait for the serving goroutines to finish.
errch <- fmt.Errorf("Failed grace.Wait: %s", err) func (a *App) Wait() error {
} else { waiterr := make(chan error)
errch <- nil go func() { waiterr <- grace.Wait(a.listeners) }()
} select {
}() case err := <-waiterr:
return <-errch return err
case err := <-a.errors:
return err
}
} }
// Serve will serve the given pairs of addresses and listeners and // Serve will serve the given pairs of addresses and listeners and
// will monitor for signals allowing for graceful termination (SIGTERM) // will monitor for signals allowing for graceful termination (SIGTERM)
// or restart (SIGUSR2). // or restart (SIGUSR2).
func Serve(servers ...*http.Server) error { func Serve(servers ...*http.Server) error {
sslice := serverSlice(servers) app := &App{Servers: servers}
listeners, err := grace.Inherit() inherited, err := app.Listen()
if err == nil {
err = grace.CloseParent()
if err != nil {
return fmt.Errorf("Failed to close parent: %s", err)
}
if *verbose {
ppid := os.Getppid()
if ppid == 1 {
log.Printf("Listening on init activated %s", pprintAddr(listeners))
} else {
log.Printf(
"Graceful handoff of %s with new pid %d and old pid %d.",
pprintAddr(listeners), os.Getpid(), ppid)
}
}
} else if err == grace.ErrNotInheriting {
listeners, err = sslice.newListeners()
if err != nil {
return err
}
if *verbose {
log.Printf("Serving %s with pid %d.", pprintAddr(listeners), os.Getpid())
}
} else {
return fmt.Errorf("Failed graceful handoff: %s", err)
}
err = sslice.serveWait(listeners)
if err != nil { if err != nil {
return err return err
} }
if *verbose {
if inherited {
ppid := os.Getppid()
if ppid == 1 {
log.Printf("Listening on init activated %s", pprintAddr(app.listeners))
} else {
const msg = "Graceful handoff of %s with new pid %d and old pid %d"
log.Printf(msg, pprintAddr(app.listeners), os.Getpid(), ppid)
}
} else {
const msg = "Serving %s with pid %d"
log.Printf(msg, pprintAddr(app.listeners), os.Getpid())
}
}
app.Serve()
// Close the parent if we inherited and it wasn't init that started us.
if inherited && os.Getppid() != 1 {
if err := grace.CloseParent(); err != nil {
return fmt.Errorf("Failed to close parent: %s", err)
}
}
err = app.Wait()
if *verbose { if *verbose {
log.Printf("Exiting pid %d.", os.Getpid()) log.Printf("Exiting pid %d.", os.Getpid())
} }
return nil
return err
} }
// Used for pretty printing addresses. // Used for pretty printing addresses.

View File

@ -7,13 +7,10 @@ import (
"flag" "flag"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"log"
"net" "net"
"net/http" "net/http"
"os" "os"
"os/exec" "os/exec"
"path/filepath"
"sync" "sync"
"syscall" "syscall"
"testing" "testing"
@ -23,8 +20,13 @@ import (
"github.com/daaku/go.tool" "github.com/daaku/go.tool"
) )
// Debug logging. var (
var debugLog = flag.Bool("debug", false, "enable debug logging") // Debug logging.
debugLog = flag.Bool("debug", false, "enable debug logging")
testserverCommand = &tool.CommandBuild{
ImportPath: "github.com/daaku/go.grace/gracehttp/testserver",
}
)
func debug(format string, a ...interface{}) { func debug(format string, a ...interface{}) {
if *debugLog { if *debugLog {
@ -38,34 +40,6 @@ var (
buildOnce sync.Once buildOnce sync.Once
) )
// Builds the command.
func build(t *testing.T) string {
buildOnce.Do(func() {
const pkg = "github.com/daaku/go.grace/gracehttp/testserver"
basename := filepath.Base(pkg)
tempFile, err := ioutil.TempFile("", basename+"-")
if err != nil {
buildErr = err
return
}
buildOut = tempFile.Name()
_ = os.Remove(buildOut) // the build tool will create this
options := tool.Options{
ImportPaths: []string{pkg},
Output: buildOut,
}
_, err = options.Command("build")
if err != nil {
buildErr = err
return
}
})
if buildErr != nil {
t.Fatal(buildErr)
}
return buildOut
}
// The response from the test server. // The response from the test server.
type response struct { type response struct {
Sleep time.Duration Sleep time.Duration
@ -104,8 +78,13 @@ func (h *harness) setupAddr() {
// Start a fresh server and wait for pid updates on restart. // Start a fresh server and wait for pid updates on restart.
func (h *harness) Start() { func (h *harness) Start() {
bin, err := testserverCommand.Build()
if err != nil {
h.T.Fatalf("build error: %s", err)
}
h.setupAddr() h.setupAddr()
cmd := exec.Command(build(h.T), "-http", h.httpAddr, "-https", h.httpsAddr) cmd := exec.Command(bin, "-http", h.httpAddr, "-https", h.httpsAddr)
stderr, err := cmd.StderrPipe() stderr, err := cmd.StderrPipe()
if err != nil { if err != nil {
h.T.Fatal(err) h.T.Fatal(err)
@ -118,22 +97,22 @@ func (h *harness) Start() {
return return
} }
if err != nil { if err != nil {
log.Fatalf("Failed to read line from server process: %s", err) println(fmt.Sprintf("Failed to read line from server process: %s", err))
} }
if isPrefix { if isPrefix {
log.Fatalf("Deal with isPrefix for line: %s", line) println(fmt.Sprintf("Deal with isPrefix for line: %s", line))
} }
res := &response{} res := &response{}
err = json.Unmarshal([]byte(line), res) err = json.Unmarshal([]byte(line), res)
if err != nil { if err != nil {
log.Fatalf("Could not parse json from stderr %s: %s", line, err) println(fmt.Sprintf("Could not parse json from stderr %s: %s", line, err))
} }
if res.Error != "" { if res.Error != "" {
println(fmt.Sprintf("Got error from process: %v", res)) println(fmt.Sprintf("Got error from process: %v", res))
} }
process, err := os.FindProcess(res.Pid) process, err := os.FindProcess(res.Pid)
if err != nil { if err != nil {
log.Fatalf("Could not find process with pid: %d", res.Pid) println(fmt.Sprintf("Could not find process with pid: %d", res.Pid))
} }
h.ProcessMutex.Lock() h.ProcessMutex.Lock()
h.Process = append(h.Process, process) h.Process = append(h.Process, process)