Merge pull request #94 from cameronpm/client_race

Fix two race conditions in client.Do
This commit is contained in:
Xing 2021-05-03 16:56:33 +12:00 committed by GitHub
commit d36dcb7fc2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 131 additions and 39 deletions

View File

@ -19,7 +19,6 @@ type Client struct {
sync.Mutex sync.Mutex
net, addr string net, addr string
respHandler *responseHandlerMap
innerHandler *responseHandlerMap innerHandler *responseHandlerMap
in chan *Response in chan *Response
conn net.Conn conn net.Conn
@ -32,11 +31,16 @@ type Client struct {
type responseHandlerMap struct { type responseHandlerMap struct {
sync.Mutex sync.Mutex
holder map[string]ResponseHandler holder map[string]handledResponse
}
type handledResponse struct {
internal ResponseHandler // internal handler, always non-nil
external ResponseHandler // handler passed in from (*Client).Do, sometimes nil
} }
func newResponseHandlerMap() *responseHandlerMap { func newResponseHandlerMap() *responseHandlerMap {
return &responseHandlerMap{holder: make(map[string]ResponseHandler, queueSize)} return &responseHandlerMap{holder: make(map[string]handledResponse, queueSize)}
} }
func (r *responseHandlerMap) remove(key string) { func (r *responseHandlerMap) remove(key string) {
@ -45,21 +49,22 @@ func (r *responseHandlerMap) remove(key string) {
r.Unlock() r.Unlock()
} }
func (r *responseHandlerMap) get(key string) (ResponseHandler, bool) { func (r *responseHandlerMap) getAndRemove(key string) (handledResponse, bool) {
r.Lock() r.Lock()
rh, b := r.holder[key] rh, b := r.holder[key]
delete(r.holder, key)
r.Unlock() r.Unlock()
return rh, b return rh, b
} }
func (r *responseHandlerMap) put(key string, rh ResponseHandler) { func (r *responseHandlerMap) putWithExternalHandler(key string, internal, external ResponseHandler) {
r.Lock() r.Lock()
r.holder[key] = rh r.holder[key] = handledResponse{internal: internal, external: external}
r.Unlock() r.Unlock()
} }
func (r *responseHandlerMap) putNoLock(key string, rh ResponseHandler) { func (r *responseHandlerMap) put(key string, rh ResponseHandler) {
r.holder[key] = rh r.putWithExternalHandler(key, rh, nil)
} }
// New returns a client. // New returns a client.
@ -67,7 +72,6 @@ func New(network, addr string) (client *Client, err error) {
client = &Client{ client = &Client{
net: network, net: network,
addr: addr, addr: addr,
respHandler: newResponseHandlerMap(),
innerHandler: newResponseHandlerMap(), innerHandler: newResponseHandlerMap(),
in: make(chan *Response, queueSize), in: make(chan *Response, queueSize),
ResponseTimeout: DefaultTimeout, ResponseTimeout: DefaultTimeout,
@ -168,21 +172,26 @@ ReadLoop:
} }
func (client *Client) processLoop() { func (client *Client) processLoop() {
rhandlers := map[string]ResponseHandler{}
for resp := range client.in { for resp := range client.in {
switch resp.DataType { switch resp.DataType {
case dtError: case dtError:
client.err(getError(resp.Data)) client.err(getError(resp.Data))
case dtStatusRes: case dtStatusRes:
resp = client.handleInner("s"+resp.Handle, resp) client.handleInner("s"+resp.Handle, resp, nil)
case dtJobCreated: case dtJobCreated:
resp = client.handleInner("c", resp) client.handleInner("c", resp, rhandlers)
case dtEchoRes: case dtEchoRes:
resp = client.handleInner("e", resp) client.handleInner("e", resp, nil)
case dtWorkData, dtWorkWarning, dtWorkStatus: case dtWorkData, dtWorkWarning, dtWorkStatus:
resp = client.handleResponse(resp.Handle, resp) if cb := rhandlers[resp.Handle]; cb != nil {
cb(resp)
}
case dtWorkComplete, dtWorkFail, dtWorkException: case dtWorkComplete, dtWorkFail, dtWorkException:
client.handleResponse(resp.Handle, resp) if cb := rhandlers[resp.Handle]; cb != nil {
client.respHandler.remove(resp.Handle) cb(resp)
delete(rhandlers, resp.Handle)
}
} }
} }
} }
@ -193,21 +202,13 @@ func (client *Client) err(e error) {
} }
} }
func (client *Client) handleResponse(key string, resp *Response) *Response { func (client *Client) handleInner(key string, resp *Response, rhandlers map[string]ResponseHandler) {
if h, ok := client.respHandler.get(key); ok { if h, ok := client.innerHandler.getAndRemove(key); ok {
h(resp) if h.external != nil && resp.Handle != "" {
return nil rhandlers[resp.Handle] = h.external
}
h.internal(resp)
} }
return resp
}
func (client *Client) handleInner(key string, resp *Response) *Response {
if h, ok := client.innerHandler.get(key); ok {
h(resp)
client.innerHandler.remove(key)
return nil
}
return resp
} }
type handleOrError struct { type handleOrError struct {
@ -216,14 +217,14 @@ type handleOrError struct {
} }
func (client *Client) do(funcname string, data []byte, func (client *Client) do(funcname string, data []byte,
flag uint32) (handle string, err error) { flag uint32, h ResponseHandler) (handle string, err error) {
if client.conn == nil { if client.conn == nil {
return "", ErrLostConn return "", ErrLostConn
} }
var result = make(chan handleOrError, 1) var result = make(chan handleOrError, 1)
client.Lock() client.Lock()
defer client.Unlock() defer client.Unlock()
client.innerHandler.put("c", func(resp *Response) { client.innerHandler.putWithExternalHandler("c", func(resp *Response) {
if resp.DataType == dtError { if resp.DataType == dtError {
err = getError(resp.Data) err = getError(resp.Data)
result <- handleOrError{"", err} result <- handleOrError{"", err}
@ -231,7 +232,7 @@ func (client *Client) do(funcname string, data []byte,
} }
handle = resp.Handle handle = resp.Handle
result <- handleOrError{handle, nil} result <- handleOrError{handle, nil}
}) }, h)
id := IdGen.Id() id := IdGen.Id()
req := getJob(id, []byte(funcname), data) req := getJob(id, []byte(funcname), data)
req.DataType = flag req.DataType = flag
@ -264,12 +265,7 @@ func (client *Client) Do(funcname string, data []byte,
datatype = dtSubmitJob datatype = dtSubmitJob
} }
client.respHandler.Lock() handle, err = client.do(funcname, data, datatype, h)
defer client.respHandler.Unlock()
handle, err = client.do(funcname, data, datatype)
if err == nil && h != nil {
client.respHandler.putNoLock(handle, h)
}
return return
} }
@ -289,7 +285,7 @@ func (client *Client) DoBg(funcname string, data []byte,
default: default:
datatype = dtSubmitJobBg datatype = dtSubmitJobBg
} }
handle, err = client.do(funcname, data, datatype) handle, err = client.do(funcname, data, datatype, nil)
return return
} }

View File

@ -1,9 +1,11 @@
package client package client
import ( import (
"errors"
"flag" "flag"
"os" "os"
"testing" "testing"
"time"
) )
const ( const (
@ -17,6 +19,7 @@ var (
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
integrationsTestFlag := flag.Bool("integration", false, "Run the integration tests (in addition to the unit tests)") integrationsTestFlag := flag.Bool("integration", false, "Run the integration tests (in addition to the unit tests)")
flag.Parse()
if integrationsTestFlag != nil { if integrationsTestFlag != nil {
runIntegrationTests = *integrationsTestFlag runIntegrationTests = *integrationsTestFlag
} }
@ -95,6 +98,66 @@ func TestClientDo(t *testing.T) {
} }
} }
func TestClientMultiDo(t *testing.T) {
if !runIntegrationTests {
t.Skip("To run this test, use: go test -integration")
}
// This integration test requires that examples/pl/worker_multi.pl be running.
//
// Test invocation is:
// go test -integration -timeout 10s -run '^TestClient(AddServer|MultiDo)$'
//
// Send 1000 requests to go through all race conditions
const nreqs = 1000
errCh := make(chan error)
gotCh := make(chan string, nreqs)
olderrh := client.ErrorHandler
client.ErrorHandler = func(e error) { errCh <- e }
client.ResponseTimeout = 5 * time.Second
defer func() { client.ErrorHandler = olderrh }()
nextJobCh := make(chan struct{})
defer close(nextJobCh)
go func() {
for range nextJobCh {
start := time.Now()
handle, err := client.Do("PerlToUpper", []byte("abcdef"), JobNormal, func(r *Response) { gotCh <- string(r.Data) })
if err == ErrLostConn && time.Since(start) > client.ResponseTimeout {
errCh <- errors.New("Impossible 'lost conn', deadlock bug detected")
} else if err != nil {
errCh <- err
}
if handle == "" {
errCh <- errors.New("Handle is empty.")
}
}
}()
for i := 0; i < nreqs; i++ {
select {
case err := <-errCh:
t.Fatal(err)
case nextJobCh <- struct{}{}:
}
}
remaining := nreqs
for remaining > 0 {
select {
case err := <-errCh:
t.Fatal(err)
case got := <-gotCh:
if got != "ABCDEF" {
t.Error("Unexpected response from PerlDoUpper: ", got)
}
remaining--
t.Logf("%d response remaining", remaining)
}
}
}
func TestClientStatus(t *testing.T) { func TestClientStatus(t *testing.T) {
if !runIntegrationTests { if !runIntegrationTests {
t.Skip("To run this test, use: go test -integration") t.Skip("To run this test, use: go test -integration")

View File

@ -0,0 +1,33 @@
#!/usr/bin/perl
# Runs 20 children that expose "PerlToUpper" before returning the result.
use strict; use warnings;
use constant CHILDREN => 20;
use Time::HiRes qw(usleep);
use Gearman::Worker;
$|++;
my @child_pids;
for (1 .. CHILDREN) {
if (my $pid = fork) {
push @child_pids, $pid;
next;
}
eval {
my $w = Gearman::Worker->new(job_servers => '127.0.0.1:4730');
$w->register_function(PerlToUpper => sub { print "."; uc $_[0]->arg });
$w->work while 1;
};
warn $@ if $@;
exit 0;
}
$SIG{INT} = $SIG{HUP} = sub {
kill 9, @child_pids;
print "\nChildren shut down, gracefully exiting\n";
exit 0;
};
printf "Forked %d children, serving 'PerlToUpper' function to gearman\n", CHILDREN;
sleep;