diff --git a/grace.go b/grace.go index 6a77d50..88b1fe4 100644 --- a/grace.go +++ b/grace.go @@ -91,7 +91,19 @@ func (l *listener) Close() error { return err } -func (l *listener) Accept() (net.Conn, error) { +func (l *listener) Accept() (c net.Conn, err error) { + // Presume we'll accept and decrement in defer if we don't. If we did this + // after a successful accept we would have a race condition where we may end + // up incorrectly shutting down between the time we do a successful accept + // and the increment. + l.wg.Add(1) + defer func() { + // If we didn't accept, we decrement our presumptuous count above. + if c == nil { + l.wg.Done() + } + }() + l.closedMutex.RLock() if l.closed { l.closedMutex.RUnlock() @@ -99,7 +111,7 @@ func (l *listener) Accept() (net.Conn, error) { } l.closedMutex.RUnlock() - c, err := l.Listener.Accept() + c, err = l.Listener.Accept() if err != nil { if strings.HasSuffix(err.Error(), errClosed) { return nil, ErrAlreadyClosed @@ -118,7 +130,6 @@ func (l *listener) Accept() (net.Conn, error) { } return nil, err } - l.wg.Add(1) return conn{Conn: c, wg: &l.wg}, nil }