Fix two race conditions in client.Do

* Common race condition is fixed by identifying that Client.respHandler
  can be completely removed since all respHandler operations (get, put
  and invocation) can be moved into the Client.processLoop goroutine,
  meaning that zero locking is required. This race condition resulted
  in a deadlock that was resolved by the response timeout at the
  end of client.Do, returning ErrLostConn

* Rare race condition is fixed by changing responseHandlerMap.get
  to .getAndRemove. This race condition resulted in the innerHandler
  for a new dtJobCreated assigned in client.Do overriding a stale older
  dtJobCreated request, and the newer innerHandler being removed by an
  older dtJobCreated in client.processLoop > client.handleInner. When
  the newer dtJobCreated response was received, the handler for it had
  already been deleted. This was resolved by the response timeout at the
  end of client.Do, returning ErrLostConn
This commit is contained in:
Paul Cameron 2021-05-02 11:02:15 +10:00
parent 81d00aa9ce
commit a8f0a04c3d
3 changed files with 131 additions and 39 deletions

View File

@ -19,7 +19,6 @@ type Client struct {
sync.Mutex sync.Mutex
net, addr string net, addr string
respHandler *responseHandlerMap
innerHandler *responseHandlerMap innerHandler *responseHandlerMap
in chan *Response in chan *Response
conn net.Conn conn net.Conn
@ -32,11 +31,16 @@ type Client struct {
type responseHandlerMap struct { type responseHandlerMap struct {
sync.Mutex sync.Mutex
holder map[string]ResponseHandler holder map[string]handledResponse
}
type handledResponse struct {
internal ResponseHandler // internal handler, always non-nil
external ResponseHandler // handler passed in from (*Client).Do, sometimes nil
} }
func newResponseHandlerMap() *responseHandlerMap { func newResponseHandlerMap() *responseHandlerMap {
return &responseHandlerMap{holder: make(map[string]ResponseHandler, queueSize)} return &responseHandlerMap{holder: make(map[string]handledResponse, queueSize)}
} }
func (r *responseHandlerMap) remove(key string) { func (r *responseHandlerMap) remove(key string) {
@ -45,21 +49,22 @@ func (r *responseHandlerMap) remove(key string) {
r.Unlock() r.Unlock()
} }
func (r *responseHandlerMap) get(key string) (ResponseHandler, bool) { func (r *responseHandlerMap) getAndRemove(key string) (handledResponse, bool) {
r.Lock() r.Lock()
rh, b := r.holder[key] rh, b := r.holder[key]
delete(r.holder, key)
r.Unlock() r.Unlock()
return rh, b return rh, b
} }
func (r *responseHandlerMap) put(key string, rh ResponseHandler) { func (r *responseHandlerMap) putWithExternalHandler(key string, internal, external ResponseHandler) {
r.Lock() r.Lock()
r.holder[key] = rh r.holder[key] = handledResponse{internal: internal, external: external}
r.Unlock() r.Unlock()
} }
func (r *responseHandlerMap) putNoLock(key string, rh ResponseHandler) { func (r *responseHandlerMap) put(key string, rh ResponseHandler) {
r.holder[key] = rh r.putWithExternalHandler(key, rh, nil)
} }
// New returns a client. // New returns a client.
@ -67,7 +72,6 @@ func New(network, addr string) (client *Client, err error) {
client = &Client{ client = &Client{
net: network, net: network,
addr: addr, addr: addr,
respHandler: newResponseHandlerMap(),
innerHandler: newResponseHandlerMap(), innerHandler: newResponseHandlerMap(),
in: make(chan *Response, queueSize), in: make(chan *Response, queueSize),
ResponseTimeout: DefaultTimeout, ResponseTimeout: DefaultTimeout,
@ -168,21 +172,26 @@ ReadLoop:
} }
func (client *Client) processLoop() { func (client *Client) processLoop() {
rhandlers := map[string]ResponseHandler{}
for resp := range client.in { for resp := range client.in {
switch resp.DataType { switch resp.DataType {
case dtError: case dtError:
client.err(getError(resp.Data)) client.err(getError(resp.Data))
case dtStatusRes: case dtStatusRes:
resp = client.handleInner("s"+resp.Handle, resp) client.handleInner("s"+resp.Handle, resp, nil)
case dtJobCreated: case dtJobCreated:
resp = client.handleInner("c", resp) client.handleInner("c", resp, rhandlers)
case dtEchoRes: case dtEchoRes:
resp = client.handleInner("e", resp) client.handleInner("e", resp, nil)
case dtWorkData, dtWorkWarning, dtWorkStatus: case dtWorkData, dtWorkWarning, dtWorkStatus:
resp = client.handleResponse(resp.Handle, resp) if cb := rhandlers[resp.Handle]; cb != nil {
cb(resp)
}
case dtWorkComplete, dtWorkFail, dtWorkException: case dtWorkComplete, dtWorkFail, dtWorkException:
client.handleResponse(resp.Handle, resp) if cb := rhandlers[resp.Handle]; cb != nil {
client.respHandler.remove(resp.Handle) cb(resp)
delete(rhandlers, resp.Handle)
}
} }
} }
} }
@ -193,21 +202,13 @@ func (client *Client) err(e error) {
} }
} }
func (client *Client) handleResponse(key string, resp *Response) *Response { func (client *Client) handleInner(key string, resp *Response, rhandlers map[string]ResponseHandler) {
if h, ok := client.respHandler.get(key); ok { if h, ok := client.innerHandler.getAndRemove(key); ok {
h(resp) if h.external != nil && resp.Handle != "" {
return nil rhandlers[resp.Handle] = h.external
} }
return resp h.internal(resp)
} }
func (client *Client) handleInner(key string, resp *Response) *Response {
if h, ok := client.innerHandler.get(key); ok {
h(resp)
client.innerHandler.remove(key)
return nil
}
return resp
} }
type handleOrError struct { type handleOrError struct {
@ -216,14 +217,14 @@ type handleOrError struct {
} }
func (client *Client) do(funcname string, data []byte, func (client *Client) do(funcname string, data []byte,
flag uint32) (handle string, err error) { flag uint32, h ResponseHandler) (handle string, err error) {
if client.conn == nil { if client.conn == nil {
return "", ErrLostConn return "", ErrLostConn
} }
var result = make(chan handleOrError, 1) var result = make(chan handleOrError, 1)
client.Lock() client.Lock()
defer client.Unlock() defer client.Unlock()
client.innerHandler.put("c", func(resp *Response) { client.innerHandler.putWithExternalHandler("c", func(resp *Response) {
if resp.DataType == dtError { if resp.DataType == dtError {
err = getError(resp.Data) err = getError(resp.Data)
result <- handleOrError{"", err} result <- handleOrError{"", err}
@ -231,7 +232,7 @@ func (client *Client) do(funcname string, data []byte,
} }
handle = resp.Handle handle = resp.Handle
result <- handleOrError{handle, nil} result <- handleOrError{handle, nil}
}) }, h)
id := IdGen.Id() id := IdGen.Id()
req := getJob(id, []byte(funcname), data) req := getJob(id, []byte(funcname), data)
req.DataType = flag req.DataType = flag
@ -264,12 +265,7 @@ func (client *Client) Do(funcname string, data []byte,
datatype = dtSubmitJob datatype = dtSubmitJob
} }
client.respHandler.Lock() handle, err = client.do(funcname, data, datatype, h)
defer client.respHandler.Unlock()
handle, err = client.do(funcname, data, datatype)
if err == nil && h != nil {
client.respHandler.putNoLock(handle, h)
}
return return
} }
@ -289,7 +285,7 @@ func (client *Client) DoBg(funcname string, data []byte,
default: default:
datatype = dtSubmitJobBg datatype = dtSubmitJobBg
} }
handle, err = client.do(funcname, data, datatype) handle, err = client.do(funcname, data, datatype, nil)
return return
} }

View File

@ -1,9 +1,11 @@
package client package client
import ( import (
"errors"
"flag" "flag"
"os" "os"
"testing" "testing"
"time"
) )
const ( const (
@ -17,6 +19,7 @@ var (
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
integrationsTestFlag := flag.Bool("integration", false, "Run the integration tests (in addition to the unit tests)") integrationsTestFlag := flag.Bool("integration", false, "Run the integration tests (in addition to the unit tests)")
flag.Parse()
if integrationsTestFlag != nil { if integrationsTestFlag != nil {
runIntegrationTests = *integrationsTestFlag runIntegrationTests = *integrationsTestFlag
} }
@ -95,6 +98,66 @@ func TestClientDo(t *testing.T) {
} }
} }
func TestClientMultiDo(t *testing.T) {
if !runIntegrationTests {
t.Skip("To run this test, use: go test -integration")
}
// This integration test requires that examples/pl/worker_multi.pl be running.
//
// Test invocation is:
// go test -integration -timeout 10s -run '^TestClient(AddServer|MultiDo)$'
//
// Send 1000 requests to go through all race conditions
const nreqs = 1000
errCh := make(chan error)
gotCh := make(chan string, nreqs)
olderrh := client.ErrorHandler
client.ErrorHandler = func(e error) { errCh <- e }
client.ResponseTimeout = 5 * time.Second
defer func() { client.ErrorHandler = olderrh }()
nextJobCh := make(chan struct{})
defer close(nextJobCh)
go func() {
for range nextJobCh {
start := time.Now()
handle, err := client.Do("PerlToUpper", []byte("abcdef"), JobNormal, func(r *Response) { gotCh <- string(r.Data) })
if err == ErrLostConn && time.Since(start) > client.ResponseTimeout {
errCh <- errors.New("Impossible 'lost conn', deadlock bug detected")
} else if err != nil {
errCh <- err
}
if handle == "" {
errCh <- errors.New("Handle is empty.")
}
}
}()
for i := 0; i < nreqs; i++ {
select {
case err := <-errCh:
t.Fatal(err)
case nextJobCh <- struct{}{}:
}
}
remaining := nreqs
for remaining > 0 {
select {
case err := <-errCh:
t.Fatal(err)
case got := <-gotCh:
if got != "ABCDEF" {
t.Error("Unexpected response from PerlDoUpper: ", got)
}
remaining--
t.Logf("%d response remaining", remaining)
}
}
}
func TestClientStatus(t *testing.T) { func TestClientStatus(t *testing.T) {
if !runIntegrationTests { if !runIntegrationTests {
t.Skip("To run this test, use: go test -integration") t.Skip("To run this test, use: go test -integration")

View File

@ -0,0 +1,33 @@
#!/usr/bin/perl
# Runs 20 children that expose "PerlToUpper" before returning the result.
use strict; use warnings;
use constant CHILDREN => 20;
use Time::HiRes qw(usleep);
use Gearman::Worker;
$|++;
my @child_pids;
for (1 .. CHILDREN) {
if (my $pid = fork) {
push @child_pids, $pid;
next;
}
eval {
my $w = Gearman::Worker->new(job_servers => '127.0.0.1:4730');
$w->register_function(PerlToUpper => sub { print "."; uc $_[0]->arg });
$w->work while 1;
};
warn $@ if $@;
exit 0;
}
$SIG{INT} = $SIG{HUP} = sub {
kill 9, @child_pids;
print "\nChildren shut down, gracefully exiting\n";
exit 0;
};
printf "Forked %d children, serving 'PerlToUpper' function to gearman\n", CHILDREN;
sleep;