https support in gracehttp

This commit is contained in:
Naitik Shah 2013-08-20 11:29:40 -07:00
parent 8711fea1ad
commit 8060336110
3 changed files with 130 additions and 58 deletions

View File

@ -4,14 +4,16 @@ package gracehttp
import ( import (
"bytes" "bytes"
"crypto/tls"
"errors" "errors"
"flag" "flag"
"fmt" "fmt"
"github.com/daaku/go.grace"
"log" "log"
"net" "net"
"net/http" "net/http"
"os" "os"
"github.com/daaku/go.grace"
) )
type serverSlice []*http.Server type serverSlice []*http.Server
@ -47,7 +49,12 @@ func (servers serverSlice) serveWait(listeners []grace.Listener) error {
errch := make(chan error, len(listeners)+1) // listeners + grace.Wait errch := make(chan error, len(listeners)+1) // listeners + grace.Wait
for i, l := range listeners { for i, l := range listeners {
go func(i int, l net.Listener) { go func(i int, l net.Listener) {
err := servers[i].Serve(l) server := servers[i]
if server.TLSConfig != nil {
l = tls.NewListener(l, server.TLSConfig)
}
err := server.Serve(l)
// The underlying Accept() will return grace.ErrAlreadyClosed // The underlying Accept() will return grace.ErrAlreadyClosed
// when a signal to do the same is returned, which we are okay with. // when a signal to do the same is returned, which we are okay with.
if err != nil && err != grace.ErrAlreadyClosed { if err != nil && err != grace.ErrAlreadyClosed {

View File

@ -2,12 +2,12 @@ package gracehttp_test
import ( import (
"bufio" "bufio"
"crypto/tls"
"encoding/json" "encoding/json"
"flag" "flag"
"fmt" "fmt"
"github.com/daaku/go.freeport"
"github.com/daaku/go.tool"
"io/ioutil" "io/ioutil"
"log"
"net" "net"
"net/http" "net/http"
"os" "os"
@ -17,6 +17,9 @@ import (
"syscall" "syscall"
"testing" "testing"
"time" "time"
"github.com/daaku/go.freeport"
"github.com/daaku/go.tool"
) )
// Debug logging. // Debug logging.
@ -32,6 +35,7 @@ func debug(format string, a ...interface{}) {
type response struct { type response struct {
Sleep time.Duration Sleep time.Duration
Pid int Pid int
Error string
} }
// State for the test run. // State for the test run.
@ -39,37 +43,29 @@ type harness struct {
T *testing.T // The test instance. T *testing.T // The test instance.
ImportPath string // The import path for the server command. ImportPath string // The import path for the server command.
ExeName string // The temp binary from the build. ExeName string // The temp binary from the build.
Addr []string // The addresses for the http servers. httpAddr string // The address for the http server.
httpsAddr string // The address for the https server.
Process []*os.Process // The server commands, oldest to newest. Process []*os.Process // The server commands, oldest to newest.
ProcessMutex sync.Mutex // The mutex to guard Process manipulation. ProcessMutex sync.Mutex // The mutex to guard Process manipulation.
RequestWaitGroup sync.WaitGroup // The wait group for the HTTP requests. RequestWaitGroup sync.WaitGroup // The wait group for the HTTP requests.
newProcess chan bool // A bool is sent on restart. newProcess chan bool // A bool is sent on start/restart.
requestCount int requestCount int
requestCountMutex sync.Mutex requestCountMutex sync.Mutex
} }
// Find 3 free ports and setup addresses. // Find 3 free ports and setup addresses.
func (h *harness) SetupAddr() { func (h *harness) SetupAddr() {
for i := 3; i > 0; i-- {
port, err := freeport.Get() port, err := freeport.Get()
if err != nil { if err != nil {
h.T.Fatalf("Failed to find a free port: %s", err) h.T.Fatalf("Failed to find a free port: %s", err)
} }
h.Addr = append(h.Addr, fmt.Sprintf("127.0.0.1:%d", port)) h.httpAddr = fmt.Sprintf("127.0.0.1:%d", port)
}
}
// Builds the command line arguments. port, err = freeport.Get()
func (h *harness) Args() []string { if err != nil {
if h.Addr == nil { h.T.Fatalf("Failed to find a free port: %s", err)
h.SetupAddr()
}
return []string{
"-gracehttp.log=false",
"-a0", h.Addr[0],
"-a1", h.Addr[1],
"-a2", h.Addr[2],
} }
h.httpsAddr = fmt.Sprintf("127.0.0.1:%d", port)
} }
// Builds the command. // Builds the command.
@ -93,26 +89,30 @@ func (h *harness) Build() {
// Start a fresh server and wait for pid updates on restart. // Start a fresh server and wait for pid updates on restart.
func (h *harness) Start() { func (h *harness) Start() {
cmd := exec.Command(h.ExeName, h.Args()...) h.SetupAddr()
cmd := exec.Command(h.ExeName, "-http", h.httpAddr, "-https", h.httpsAddr)
stderr, err := cmd.StderrPipe() stderr, err := cmd.StderrPipe()
go func() { go func() {
reader := bufio.NewReader(stderr) reader := bufio.NewReader(stderr)
for { for {
line, isPrefix, err := reader.ReadLine() line, isPrefix, err := reader.ReadLine()
if err != nil { if err != nil {
h.T.Fatalf("Failed to read line from server process: %s", err) log.Fatalf("Failed to read line from server process: %s", err)
} }
if isPrefix { if isPrefix {
h.T.Fatalf("Deal with isPrefix for line: %s", line) log.Fatalf("Deal with isPrefix for line: %s", line)
} }
res := &response{} res := &response{}
err = json.Unmarshal([]byte(line), res) err = json.Unmarshal([]byte(line), res)
if err != nil { if err != nil {
h.T.Fatalf("Could not parse json from stderr %s: %s", line, err) log.Fatalf("Could not parse json from stderr %s: %s", line, err)
}
if res.Error != "" {
println(fmt.Sprintf("Got error from process: %v", res))
} }
process, err := os.FindProcess(res.Pid) process, err := os.FindProcess(res.Pid)
if err != nil { if err != nil {
h.T.Fatalf("Could not find process with pid: %d", res.Pid) log.Fatalf("Could not find process with pid: %d", res.Pid)
} }
h.ProcessMutex.Lock() h.ProcessMutex.Lock()
h.Process = append(h.Process, process) h.Process = append(h.Process, process)
@ -163,7 +163,7 @@ func (h *harness) RemoveExe() {
} }
} }
// Get the global request count. // Get the global request count and increment it.
func (h *harness) RequestCount() int { func (h *harness) RequestCount() int {
h.requestCountMutex.Lock() h.requestCountMutex.Lock()
defer h.requestCountMutex.Unlock() defer h.requestCountMutex.Unlock()
@ -173,9 +173,10 @@ func (h *harness) RequestCount() int {
} }
// Helper for sending a single request. // Helper for sending a single request.
func (h *harness) SendOne(dialgroup *sync.WaitGroup, duration time.Duration, addr string, pid int) { func (h *harness) SendOne(dialgroup *sync.WaitGroup, url string, pid int) {
defer h.RequestWaitGroup.Done()
count := h.RequestCount() count := h.RequestCount()
debug("Send %02d pid=%d duration=%s", count, pid, duration) debug("Send %02d pid=%d url=%s", count, pid, url)
client := &http.Client{ client := &http.Client{
Transport: &http.Transport{ Transport: &http.Transport{
DisableKeepAlives: true, DisableKeepAlives: true,
@ -183,39 +184,46 @@ func (h *harness) SendOne(dialgroup *sync.WaitGroup, duration time.Duration, add
defer dialgroup.Done() defer dialgroup.Done()
return net.Dial(network, addr) return net.Dial(network, addr)
}, },
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
}, },
} }
url := fmt.Sprintf("http://%s/sleep/?duration=%s", addr, duration.String())
r, err := client.Get(url) r, err := client.Get(url)
if err != nil { if err != nil {
h.T.Fatalf("Failed request to %s: %s", url, err) h.T.Fatalf("Failed request %02d to %s pid=%d: %s", count, url, pid, err)
} }
debug("Body %02d pid=%d duration=%s", count, pid, duration) debug("Body %02d pid=%d url=%s", count, pid, url)
defer r.Body.Close() defer r.Body.Close()
res := &response{} res := &response{}
err = json.NewDecoder(r.Body).Decode(res) err = json.NewDecoder(r.Body).Decode(res)
if err != nil { if err != nil {
h.T.Fatalf("Failed to ready decode json response body: %s", err) h.T.Fatalf("Failed to ready decode json response body pid=%d: %s", pid, err)
} }
if pid != res.Pid { if pid != res.Pid {
h.T.Fatalf("Didn't get expected pid %d instead got %d", pid, res.Pid) h.T.Fatalf("Didn't get expected pid %d instead got %d", pid, res.Pid)
} }
debug("Done %02d pid=%d duration=%s", count, pid, duration) debug("Done %02d pid=%d url=%s", count, pid, url)
h.RequestWaitGroup.Done()
} }
// Send test HTTP request. // Send test HTTP request.
func (h *harness) SendRequest() { func (h *harness) SendRequest() {
pid := h.MostRecentProcess().Pid pid := h.MostRecentProcess().Pid
httpFastUrl := fmt.Sprintf("http://%s/sleep/?duration=0", h.httpAddr)
httpSlowUrl := fmt.Sprintf("http://%s/sleep/?duration=2s", h.httpAddr)
httpsFastUrl := fmt.Sprintf("https://%s/sleep/?duration=0", h.httpsAddr)
httpsSlowUrl := fmt.Sprintf("https://%s/sleep/?duration=2s", h.httpsAddr)
var dialgroup sync.WaitGroup var dialgroup sync.WaitGroup
for _, addr := range h.Addr { h.RequestWaitGroup.Add(4)
debug("Added 2 Requests") dialgroup.Add(4)
h.RequestWaitGroup.Add(2) go h.SendOne(&dialgroup, httpFastUrl, pid)
dialgroup.Add(2) go h.SendOne(&dialgroup, httpSlowUrl, pid)
go h.SendOne(&dialgroup, time.Second*0, addr, pid) go h.SendOne(&dialgroup, httpsFastUrl, pid)
go h.SendOne(&dialgroup, time.Second*2, addr, pid) go h.SendOne(&dialgroup, httpsSlowUrl, pid)
} debug("Added Requests pid=%d", pid)
dialgroup.Wait() dialgroup.Wait()
debug("Dialed Requests pid=%d", pid)
} }
// Wait for everything. // Wait for everything.

View File

@ -2,6 +2,7 @@
package main package main
import ( import (
"crypto/tls"
"encoding/json" "encoding/json"
"flag" "flag"
"fmt" "fmt"
@ -17,25 +18,56 @@ import (
type response struct { type response struct {
Sleep time.Duration Sleep time.Duration
Pid int Pid int
Error string `json:,omitempty`
} }
func wait(wg *sync.WaitGroup, addr string) { func wait(wg *sync.WaitGroup, url string) {
defer wg.Done() defer wg.Done()
url := fmt.Sprintf("http://%s/sleep/?duration=0", addr)
for { for {
if _, err := http.Get(url); err == nil { _, err := http.Get(url)
if err == nil {
return return
} else {
e2 := json.NewEncoder(os.Stderr).Encode(&response{
Error: err.Error(),
Pid: os.Getpid(),
})
if e2 != nil {
log.Fatalf("Error writing error json: %s", e2)
} }
} }
} }
}
func httpsServer(addr string) *http.Server {
cert, err := tls.X509KeyPair(localhostCert, localhostKey)
if err != nil {
log.Fatal("error loading cert: %v", err)
}
return &http.Server{
Addr: addr,
Handler: newHandler(),
TLSConfig: &tls.Config{
NextProtos: []string{"http/1.1"},
Certificates: []tls.Certificate{cert},
},
}
}
func main() { func main() {
var addrs [3]string var httpAddr, httpsAddr string
flag.StringVar(&addrs[0], "a0", ":48560", "Zero address to bind to.") flag.StringVar(&httpAddr, "http", ":48560", "http address to bind to")
flag.StringVar(&addrs[1], "a1", ":48561", "First address to bind to.") flag.StringVar(&httpsAddr, "https", ":48561", "https address to bind to")
flag.StringVar(&addrs[2], "a2", ":48562", "Second address to bind to.")
flag.Parse() flag.Parse()
// we have self signed certs
http.DefaultTransport = &http.Transport{
DisableKeepAlives: true,
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
}
err := flag.Set("gracehttp.log", "false") err := flag.Set("gracehttp.log", "false")
if err != nil { if err != nil {
log.Fatalf("Error setting gracehttp.log: %s", err) log.Fatalf("Error setting gracehttp.log: %s", err)
@ -45,10 +77,9 @@ func main() {
// addresses. the ensures we only print the line once we're ready to serve. // addresses. the ensures we only print the line once we're ready to serve.
go func() { go func() {
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(len(addrs)) wg.Add(2)
for _, addr := range addrs { go wait(&wg, fmt.Sprintf("http://%s/sleep/?duration=0", httpAddr))
go wait(&wg, addr) go wait(&wg, fmt.Sprintf("https://%s/sleep/?duration=0", httpsAddr))
}
wg.Wait() wg.Wait()
err = json.NewEncoder(os.Stderr).Encode(&response{Pid: os.Getpid()}) err = json.NewEncoder(os.Stderr).Encode(&response{Pid: os.Getpid()})
@ -58,9 +89,8 @@ func main() {
}() }()
err = gracehttp.Serve( err = gracehttp.Serve(
&http.Server{Addr: addrs[0], Handler: newHandler()}, &http.Server{Addr: httpAddr, Handler: newHandler()},
&http.Server{Addr: addrs[1], Handler: newHandler()}, httpsServer(httpsAddr),
&http.Server{Addr: addrs[2], Handler: newHandler()},
) )
if err != nil { if err != nil {
log.Fatalf("Error in gracehttp.Serve: %s", err) log.Fatalf("Error in gracehttp.Serve: %s", err)
@ -85,3 +115,30 @@ func newHandler() http.Handler {
}) })
return mux return mux
} }
// localhostCert is a PEM-encoded TLS cert with SAN IPs
// "127.0.0.1" and "[::1]", expiring at the last second of 2049 (the end
// of ASN.1 time).
// generated from src/pkg/crypto/tls:
// go run generate_cert.go --rsa-bits 512 --host 127.0.0.1,::1,example.com --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h
var localhostCert = []byte(`-----BEGIN CERTIFICATE-----
MIIBdzCCASOgAwIBAgIBADALBgkqhkiG9w0BAQUwEjEQMA4GA1UEChMHQWNtZSBD
bzAeFw03MDAxMDEwMDAwMDBaFw00OTEyMzEyMzU5NTlaMBIxEDAOBgNVBAoTB0Fj
bWUgQ28wWjALBgkqhkiG9w0BAQEDSwAwSAJBALyCfqwwip8BvTKgVKGdmjZTU8DD
ndR+WALmFPIRqn89bOU3s30olKiqYEju/SFoEvMyFRT/TWEhXHDaufThqaMCAwEA
AaNoMGYwDgYDVR0PAQH/BAQDAgCkMBMGA1UdJQQMMAoGCCsGAQUFBwMBMA8GA1Ud
EwEB/wQFMAMBAf8wLgYDVR0RBCcwJYILZXhhbXBsZS5jb22HBH8AAAGHEAAAAAAA
AAAAAAAAAAAAAAEwCwYJKoZIhvcNAQEFA0EAr/09uy108p51rheIOSnz4zgduyTl
M+4AmRo8/U1twEZLgfAGG/GZjREv2y4mCEUIM3HebCAqlA5jpRg76Rf8jw==
-----END CERTIFICATE-----`)
// localhostKey is the private key for localhostCert.
var localhostKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
MIIBOQIBAAJBALyCfqwwip8BvTKgVKGdmjZTU8DDndR+WALmFPIRqn89bOU3s30o
lKiqYEju/SFoEvMyFRT/TWEhXHDaufThqaMCAwEAAQJAPXuWUxTV8XyAt8VhNQER
LgzJcUKb9JVsoS1nwXgPksXnPDKnL9ax8VERrdNr+nZbj2Q9cDSXBUovfdtehcdP
qQIhAO48ZsPylbTrmtjDEKiHT2Ik04rLotZYS2U873J6I7WlAiEAypDjYxXyafv/
Yo1pm9onwcetQKMW8CS3AjuV9Axzj6cCIEx2Il19fEMG4zny0WPlmbrcKvD/DpJQ
4FHrzsYlIVTpAiAas7S1uAvneqd0l02HlN9OxQKKlbUNXNme+rnOnOGS2wIgS0jW
zl1jvrOSJeP1PpAHohWz6LOhEr8uvltWkN6x3vE=
-----END RSA PRIVATE KEY-----`)