diff --git a/client/client.go b/client/client.go index bf95cf4..2f12c11 100644 --- a/client/client.go +++ b/client/client.go @@ -19,7 +19,6 @@ type Client struct { sync.Mutex net, addr string - respHandler *responseHandlerMap innerHandler *responseHandlerMap in chan *Response conn net.Conn @@ -32,11 +31,16 @@ type Client struct { type responseHandlerMap struct { sync.Mutex - holder map[string]ResponseHandler + holder map[string]handledResponse +} + +type handledResponse struct { + internal ResponseHandler // internal handler, always non-nil + external ResponseHandler // handler passed in from (*Client).Do, sometimes nil } func newResponseHandlerMap() *responseHandlerMap { - return &responseHandlerMap{holder: make(map[string]ResponseHandler, queueSize)} + return &responseHandlerMap{holder: make(map[string]handledResponse, queueSize)} } func (r *responseHandlerMap) remove(key string) { @@ -45,21 +49,22 @@ func (r *responseHandlerMap) remove(key string) { r.Unlock() } -func (r *responseHandlerMap) get(key string) (ResponseHandler, bool) { +func (r *responseHandlerMap) getAndRemove(key string) (handledResponse, bool) { r.Lock() rh, b := r.holder[key] + delete(r.holder, key) r.Unlock() return rh, b } -func (r *responseHandlerMap) put(key string, rh ResponseHandler) { +func (r *responseHandlerMap) putWithExternalHandler(key string, internal, external ResponseHandler) { r.Lock() - r.holder[key] = rh + r.holder[key] = handledResponse{internal: internal, external: external} r.Unlock() } -func (r *responseHandlerMap) putNoLock(key string, rh ResponseHandler) { - r.holder[key] = rh +func (r *responseHandlerMap) put(key string, rh ResponseHandler) { + r.putWithExternalHandler(key, rh, nil) } // New returns a client. @@ -67,7 +72,6 @@ func New(network, addr string) (client *Client, err error) { client = &Client{ net: network, addr: addr, - respHandler: newResponseHandlerMap(), innerHandler: newResponseHandlerMap(), in: make(chan *Response, queueSize), ResponseTimeout: DefaultTimeout, @@ -168,21 +172,26 @@ ReadLoop: } func (client *Client) processLoop() { + rhandlers := map[string]ResponseHandler{} for resp := range client.in { switch resp.DataType { case dtError: client.err(getError(resp.Data)) case dtStatusRes: - resp = client.handleInner("s"+resp.Handle, resp) + client.handleInner("s"+resp.Handle, resp, nil) case dtJobCreated: - resp = client.handleInner("c", resp) + client.handleInner("c", resp, rhandlers) case dtEchoRes: - resp = client.handleInner("e", resp) + client.handleInner("e", resp, nil) case dtWorkData, dtWorkWarning, dtWorkStatus: - resp = client.handleResponse(resp.Handle, resp) + if cb := rhandlers[resp.Handle]; cb != nil { + cb(resp) + } case dtWorkComplete, dtWorkFail, dtWorkException: - client.handleResponse(resp.Handle, resp) - client.respHandler.remove(resp.Handle) + if cb := rhandlers[resp.Handle]; cb != nil { + cb(resp) + delete(rhandlers, resp.Handle) + } } } } @@ -193,21 +202,13 @@ func (client *Client) err(e error) { } } -func (client *Client) handleResponse(key string, resp *Response) *Response { - if h, ok := client.respHandler.get(key); ok { - h(resp) - return nil +func (client *Client) handleInner(key string, resp *Response, rhandlers map[string]ResponseHandler) { + if h, ok := client.innerHandler.getAndRemove(key); ok { + if h.external != nil && resp.Handle != "" { + rhandlers[resp.Handle] = h.external + } + h.internal(resp) } - return resp -} - -func (client *Client) handleInner(key string, resp *Response) *Response { - if h, ok := client.innerHandler.get(key); ok { - h(resp) - client.innerHandler.remove(key) - return nil - } - return resp } type handleOrError struct { @@ -216,14 +217,14 @@ type handleOrError struct { } func (client *Client) do(funcname string, data []byte, - flag uint32) (handle string, err error) { + flag uint32, h ResponseHandler) (handle string, err error) { if client.conn == nil { return "", ErrLostConn } var result = make(chan handleOrError, 1) client.Lock() defer client.Unlock() - client.innerHandler.put("c", func(resp *Response) { + client.innerHandler.putWithExternalHandler("c", func(resp *Response) { if resp.DataType == dtError { err = getError(resp.Data) result <- handleOrError{"", err} @@ -231,7 +232,7 @@ func (client *Client) do(funcname string, data []byte, } handle = resp.Handle result <- handleOrError{handle, nil} - }) + }, h) id := IdGen.Id() req := getJob(id, []byte(funcname), data) req.DataType = flag @@ -264,12 +265,7 @@ func (client *Client) Do(funcname string, data []byte, datatype = dtSubmitJob } - client.respHandler.Lock() - defer client.respHandler.Unlock() - handle, err = client.do(funcname, data, datatype) - if err == nil && h != nil { - client.respHandler.putNoLock(handle, h) - } + handle, err = client.do(funcname, data, datatype, h) return } @@ -289,7 +285,7 @@ func (client *Client) DoBg(funcname string, data []byte, default: datatype = dtSubmitJobBg } - handle, err = client.do(funcname, data, datatype) + handle, err = client.do(funcname, data, datatype, nil) return } diff --git a/client/client_test.go b/client/client_test.go index 025f30b..e19ea98 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -1,9 +1,11 @@ package client import ( + "errors" "flag" "os" "testing" + "time" ) const ( @@ -17,6 +19,7 @@ var ( func TestMain(m *testing.M) { integrationsTestFlag := flag.Bool("integration", false, "Run the integration tests (in addition to the unit tests)") + flag.Parse() if integrationsTestFlag != nil { runIntegrationTests = *integrationsTestFlag } @@ -95,6 +98,66 @@ func TestClientDo(t *testing.T) { } } +func TestClientMultiDo(t *testing.T) { + if !runIntegrationTests { + t.Skip("To run this test, use: go test -integration") + } + + // This integration test requires that examples/pl/worker_multi.pl be running. + // + // Test invocation is: + // go test -integration -timeout 10s -run '^TestClient(AddServer|MultiDo)$' + // + // Send 1000 requests to go through all race conditions + const nreqs = 1000 + errCh := make(chan error) + gotCh := make(chan string, nreqs) + + olderrh := client.ErrorHandler + client.ErrorHandler = func(e error) { errCh <- e } + client.ResponseTimeout = 5 * time.Second + defer func() { client.ErrorHandler = olderrh }() + + nextJobCh := make(chan struct{}) + defer close(nextJobCh) + go func() { + for range nextJobCh { + start := time.Now() + handle, err := client.Do("PerlToUpper", []byte("abcdef"), JobNormal, func(r *Response) { gotCh <- string(r.Data) }) + if err == ErrLostConn && time.Since(start) > client.ResponseTimeout { + errCh <- errors.New("Impossible 'lost conn', deadlock bug detected") + } else if err != nil { + errCh <- err + } + if handle == "" { + errCh <- errors.New("Handle is empty.") + } + } + }() + + for i := 0; i < nreqs; i++ { + select { + case err := <-errCh: + t.Fatal(err) + case nextJobCh <- struct{}{}: + } + } + + remaining := nreqs + for remaining > 0 { + select { + case err := <-errCh: + t.Fatal(err) + case got := <-gotCh: + if got != "ABCDEF" { + t.Error("Unexpected response from PerlDoUpper: ", got) + } + remaining-- + t.Logf("%d response remaining", remaining) + } + } +} + func TestClientStatus(t *testing.T) { if !runIntegrationTests { t.Skip("To run this test, use: go test -integration") diff --git a/example/pl/worker_multi.pl b/example/pl/worker_multi.pl new file mode 100644 index 0000000..9da6cc7 --- /dev/null +++ b/example/pl/worker_multi.pl @@ -0,0 +1,33 @@ +#!/usr/bin/perl + +# Runs 20 children that expose "PerlToUpper" before returning the result. + +use strict; use warnings; +use constant CHILDREN => 20; +use Time::HiRes qw(usleep); +use Gearman::Worker; + +$|++; +my @child_pids; +for (1 .. CHILDREN) { + if (my $pid = fork) { + push @child_pids, $pid; + next; + } + eval { + my $w = Gearman::Worker->new(job_servers => '127.0.0.1:4730'); + $w->register_function(PerlToUpper => sub { print "."; uc $_[0]->arg }); + $w->work while 1; + }; + warn $@ if $@; + exit 0; +} + +$SIG{INT} = $SIG{HUP} = sub { + kill 9, @child_pids; + print "\nChildren shut down, gracefully exiting\n"; + exit 0; +}; + +printf "Forked %d children, serving 'PerlToUpper' function to gearman\n", CHILDREN; +sleep;