Merge pull request #94 from cameronpm/client_race
Fix two race conditions in client.Do
This commit is contained in:
		
						commit
						d36dcb7fc2
					
				| @ -19,7 +19,6 @@ type Client struct { | |||||||
| 	sync.Mutex | 	sync.Mutex | ||||||
| 
 | 
 | ||||||
| 	net, addr    string | 	net, addr    string | ||||||
| 	respHandler  *responseHandlerMap |  | ||||||
| 	innerHandler *responseHandlerMap | 	innerHandler *responseHandlerMap | ||||||
| 	in           chan *Response | 	in           chan *Response | ||||||
| 	conn         net.Conn | 	conn         net.Conn | ||||||
| @ -32,11 +31,16 @@ type Client struct { | |||||||
| 
 | 
 | ||||||
| type responseHandlerMap struct { | type responseHandlerMap struct { | ||||||
| 	sync.Mutex | 	sync.Mutex | ||||||
| 	holder map[string]ResponseHandler | 	holder map[string]handledResponse | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type handledResponse struct { | ||||||
|  | 	internal ResponseHandler // internal handler, always non-nil
 | ||||||
|  | 	external ResponseHandler // handler passed in from (*Client).Do, sometimes nil
 | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func newResponseHandlerMap() *responseHandlerMap { | func newResponseHandlerMap() *responseHandlerMap { | ||||||
| 	return &responseHandlerMap{holder: make(map[string]ResponseHandler, queueSize)} | 	return &responseHandlerMap{holder: make(map[string]handledResponse, queueSize)} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (r *responseHandlerMap) remove(key string) { | func (r *responseHandlerMap) remove(key string) { | ||||||
| @ -45,21 +49,22 @@ func (r *responseHandlerMap) remove(key string) { | |||||||
| 	r.Unlock() | 	r.Unlock() | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (r *responseHandlerMap) get(key string) (ResponseHandler, bool) { | func (r *responseHandlerMap) getAndRemove(key string) (handledResponse, bool) { | ||||||
| 	r.Lock() | 	r.Lock() | ||||||
| 	rh, b := r.holder[key] | 	rh, b := r.holder[key] | ||||||
|  | 	delete(r.holder, key) | ||||||
| 	r.Unlock() | 	r.Unlock() | ||||||
| 	return rh, b | 	return rh, b | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (r *responseHandlerMap) put(key string, rh ResponseHandler) { | func (r *responseHandlerMap) putWithExternalHandler(key string, internal, external ResponseHandler) { | ||||||
| 	r.Lock() | 	r.Lock() | ||||||
| 	r.holder[key] = rh | 	r.holder[key] = handledResponse{internal: internal, external: external} | ||||||
| 	r.Unlock() | 	r.Unlock() | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (r *responseHandlerMap) putNoLock(key string, rh ResponseHandler) { | func (r *responseHandlerMap) put(key string, rh ResponseHandler) { | ||||||
| 	r.holder[key] = rh | 	r.putWithExternalHandler(key, rh, nil) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // New returns a client.
 | // New returns a client.
 | ||||||
| @ -67,7 +72,6 @@ func New(network, addr string) (client *Client, err error) { | |||||||
| 	client = &Client{ | 	client = &Client{ | ||||||
| 		net:             network, | 		net:             network, | ||||||
| 		addr:            addr, | 		addr:            addr, | ||||||
| 		respHandler:     newResponseHandlerMap(), |  | ||||||
| 		innerHandler:    newResponseHandlerMap(), | 		innerHandler:    newResponseHandlerMap(), | ||||||
| 		in:              make(chan *Response, queueSize), | 		in:              make(chan *Response, queueSize), | ||||||
| 		ResponseTimeout: DefaultTimeout, | 		ResponseTimeout: DefaultTimeout, | ||||||
| @ -168,21 +172,26 @@ ReadLoop: | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (client *Client) processLoop() { | func (client *Client) processLoop() { | ||||||
|  | 	rhandlers := map[string]ResponseHandler{} | ||||||
| 	for resp := range client.in { | 	for resp := range client.in { | ||||||
| 		switch resp.DataType { | 		switch resp.DataType { | ||||||
| 		case dtError: | 		case dtError: | ||||||
| 			client.err(getError(resp.Data)) | 			client.err(getError(resp.Data)) | ||||||
| 		case dtStatusRes: | 		case dtStatusRes: | ||||||
| 			resp = client.handleInner("s"+resp.Handle, resp) | 			client.handleInner("s"+resp.Handle, resp, nil) | ||||||
| 		case dtJobCreated: | 		case dtJobCreated: | ||||||
| 			resp = client.handleInner("c", resp) | 			client.handleInner("c", resp, rhandlers) | ||||||
| 		case dtEchoRes: | 		case dtEchoRes: | ||||||
| 			resp = client.handleInner("e", resp) | 			client.handleInner("e", resp, nil) | ||||||
| 		case dtWorkData, dtWorkWarning, dtWorkStatus: | 		case dtWorkData, dtWorkWarning, dtWorkStatus: | ||||||
| 			resp = client.handleResponse(resp.Handle, resp) | 			if cb := rhandlers[resp.Handle]; cb != nil { | ||||||
|  | 				cb(resp) | ||||||
|  | 			} | ||||||
| 		case dtWorkComplete, dtWorkFail, dtWorkException: | 		case dtWorkComplete, dtWorkFail, dtWorkException: | ||||||
| 			client.handleResponse(resp.Handle, resp) | 			if cb := rhandlers[resp.Handle]; cb != nil { | ||||||
| 			client.respHandler.remove(resp.Handle) | 				cb(resp) | ||||||
|  | 				delete(rhandlers, resp.Handle) | ||||||
|  | 			} | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| @ -193,21 +202,13 @@ func (client *Client) err(e error) { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (client *Client) handleResponse(key string, resp *Response) *Response { | func (client *Client) handleInner(key string, resp *Response, rhandlers map[string]ResponseHandler) { | ||||||
| 	if h, ok := client.respHandler.get(key); ok { | 	if h, ok := client.innerHandler.getAndRemove(key); ok { | ||||||
| 		h(resp) | 		if h.external != nil && resp.Handle != "" { | ||||||
| 		return nil | 			rhandlers[resp.Handle] = h.external | ||||||
|  | 		} | ||||||
|  | 		h.internal(resp) | ||||||
| 	} | 	} | ||||||
| 	return resp |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (client *Client) handleInner(key string, resp *Response) *Response { |  | ||||||
| 	if h, ok := client.innerHandler.get(key); ok { |  | ||||||
| 		h(resp) |  | ||||||
| 		client.innerHandler.remove(key) |  | ||||||
| 		return nil |  | ||||||
| 	} |  | ||||||
| 	return resp |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| type handleOrError struct { | type handleOrError struct { | ||||||
| @ -216,14 +217,14 @@ type handleOrError struct { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (client *Client) do(funcname string, data []byte, | func (client *Client) do(funcname string, data []byte, | ||||||
| 	flag uint32) (handle string, err error) { | 	flag uint32, h ResponseHandler) (handle string, err error) { | ||||||
| 	if client.conn == nil { | 	if client.conn == nil { | ||||||
| 		return "", ErrLostConn | 		return "", ErrLostConn | ||||||
| 	} | 	} | ||||||
| 	var result = make(chan handleOrError, 1) | 	var result = make(chan handleOrError, 1) | ||||||
| 	client.Lock() | 	client.Lock() | ||||||
| 	defer client.Unlock() | 	defer client.Unlock() | ||||||
| 	client.innerHandler.put("c", func(resp *Response) { | 	client.innerHandler.putWithExternalHandler("c", func(resp *Response) { | ||||||
| 		if resp.DataType == dtError { | 		if resp.DataType == dtError { | ||||||
| 			err = getError(resp.Data) | 			err = getError(resp.Data) | ||||||
| 			result <- handleOrError{"", err} | 			result <- handleOrError{"", err} | ||||||
| @ -231,7 +232,7 @@ func (client *Client) do(funcname string, data []byte, | |||||||
| 		} | 		} | ||||||
| 		handle = resp.Handle | 		handle = resp.Handle | ||||||
| 		result <- handleOrError{handle, nil} | 		result <- handleOrError{handle, nil} | ||||||
| 	}) | 	}, h) | ||||||
| 	id := IdGen.Id() | 	id := IdGen.Id() | ||||||
| 	req := getJob(id, []byte(funcname), data) | 	req := getJob(id, []byte(funcname), data) | ||||||
| 	req.DataType = flag | 	req.DataType = flag | ||||||
| @ -264,12 +265,7 @@ func (client *Client) Do(funcname string, data []byte, | |||||||
| 		datatype = dtSubmitJob | 		datatype = dtSubmitJob | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	client.respHandler.Lock() | 	handle, err = client.do(funcname, data, datatype, h) | ||||||
| 	defer client.respHandler.Unlock() |  | ||||||
| 	handle, err = client.do(funcname, data, datatype) |  | ||||||
| 	if err == nil && h != nil { |  | ||||||
| 		client.respHandler.putNoLock(handle, h) |  | ||||||
| 	} |  | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| @ -289,7 +285,7 @@ func (client *Client) DoBg(funcname string, data []byte, | |||||||
| 	default: | 	default: | ||||||
| 		datatype = dtSubmitJobBg | 		datatype = dtSubmitJobBg | ||||||
| 	} | 	} | ||||||
| 	handle, err = client.do(funcname, data, datatype) | 	handle, err = client.do(funcname, data, datatype, nil) | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -1,9 +1,11 @@ | |||||||
| package client | package client | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
|  | 	"errors" | ||||||
| 	"flag" | 	"flag" | ||||||
| 	"os" | 	"os" | ||||||
| 	"testing" | 	"testing" | ||||||
|  | 	"time" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| const ( | const ( | ||||||
| @ -17,6 +19,7 @@ var ( | |||||||
| 
 | 
 | ||||||
| func TestMain(m *testing.M) { | func TestMain(m *testing.M) { | ||||||
| 	integrationsTestFlag := flag.Bool("integration", false, "Run the integration tests (in addition to the unit tests)") | 	integrationsTestFlag := flag.Bool("integration", false, "Run the integration tests (in addition to the unit tests)") | ||||||
|  | 	flag.Parse() | ||||||
| 	if integrationsTestFlag != nil { | 	if integrationsTestFlag != nil { | ||||||
| 		runIntegrationTests = *integrationsTestFlag | 		runIntegrationTests = *integrationsTestFlag | ||||||
| 	} | 	} | ||||||
| @ -95,6 +98,66 @@ func TestClientDo(t *testing.T) { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func TestClientMultiDo(t *testing.T) { | ||||||
|  | 	if !runIntegrationTests { | ||||||
|  | 		t.Skip("To run this test, use: go test -integration") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// This integration test requires that examples/pl/worker_multi.pl be running.
 | ||||||
|  | 	//
 | ||||||
|  | 	// Test invocation is:
 | ||||||
|  | 	//    go test -integration -timeout 10s -run '^TestClient(AddServer|MultiDo)$'
 | ||||||
|  | 	//
 | ||||||
|  | 	// Send 1000 requests to go through all race conditions
 | ||||||
|  | 	const nreqs = 1000 | ||||||
|  | 	errCh := make(chan error) | ||||||
|  | 	gotCh := make(chan string, nreqs) | ||||||
|  | 
 | ||||||
|  | 	olderrh := client.ErrorHandler | ||||||
|  | 	client.ErrorHandler = func(e error) { errCh <- e } | ||||||
|  | 	client.ResponseTimeout = 5 * time.Second | ||||||
|  | 	defer func() { client.ErrorHandler = olderrh }() | ||||||
|  | 
 | ||||||
|  | 	nextJobCh := make(chan struct{}) | ||||||
|  | 	defer close(nextJobCh) | ||||||
|  | 	go func() { | ||||||
|  | 		for range nextJobCh { | ||||||
|  | 			start := time.Now() | ||||||
|  | 			handle, err := client.Do("PerlToUpper", []byte("abcdef"), JobNormal, func(r *Response) { gotCh <- string(r.Data) }) | ||||||
|  | 			if err == ErrLostConn && time.Since(start) > client.ResponseTimeout { | ||||||
|  | 				errCh <- errors.New("Impossible 'lost conn', deadlock bug detected") | ||||||
|  | 			} else if err != nil { | ||||||
|  | 				errCh <- err | ||||||
|  | 			} | ||||||
|  | 			if handle == "" { | ||||||
|  | 				errCh <- errors.New("Handle is empty.") | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	}() | ||||||
|  | 
 | ||||||
|  | 	for i := 0; i < nreqs; i++ { | ||||||
|  | 		select { | ||||||
|  | 		case err := <-errCh: | ||||||
|  | 			t.Fatal(err) | ||||||
|  | 		case nextJobCh <- struct{}{}: | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	remaining := nreqs | ||||||
|  | 	for remaining > 0 { | ||||||
|  | 		select { | ||||||
|  | 		case err := <-errCh: | ||||||
|  | 			t.Fatal(err) | ||||||
|  | 		case got := <-gotCh: | ||||||
|  | 			if got != "ABCDEF" { | ||||||
|  | 				t.Error("Unexpected response from PerlDoUpper: ", got) | ||||||
|  | 			} | ||||||
|  | 			remaining-- | ||||||
|  | 			t.Logf("%d response remaining", remaining) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func TestClientStatus(t *testing.T) { | func TestClientStatus(t *testing.T) { | ||||||
| 	if !runIntegrationTests { | 	if !runIntegrationTests { | ||||||
| 		t.Skip("To run this test, use: go test -integration") | 		t.Skip("To run this test, use: go test -integration") | ||||||
|  | |||||||
							
								
								
									
										33
									
								
								example/pl/worker_multi.pl
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										33
									
								
								example/pl/worker_multi.pl
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,33 @@ | |||||||
|  | #!/usr/bin/perl | ||||||
|  | 
 | ||||||
|  | # Runs 20 children that expose "PerlToUpper" before returning the result. | ||||||
|  | 
 | ||||||
|  | use strict; use warnings; | ||||||
|  | use constant CHILDREN => 20; | ||||||
|  | use Time::HiRes qw(usleep); | ||||||
|  | use Gearman::Worker; | ||||||
|  | 
 | ||||||
|  | $|++; | ||||||
|  | my @child_pids; | ||||||
|  | for (1 .. CHILDREN) { | ||||||
|  |   if (my $pid = fork) { | ||||||
|  |     push @child_pids, $pid; | ||||||
|  |     next; | ||||||
|  |   } | ||||||
|  |   eval { | ||||||
|  |     my $w = Gearman::Worker->new(job_servers => '127.0.0.1:4730'); | ||||||
|  |     $w->register_function(PerlToUpper => sub { print "."; uc $_[0]->arg }); | ||||||
|  |     $w->work while 1; | ||||||
|  |   }; | ||||||
|  |   warn $@ if $@; | ||||||
|  |   exit 0; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | $SIG{INT} = $SIG{HUP} = sub { | ||||||
|  |   kill 9, @child_pids; | ||||||
|  |   print "\nChildren shut down, gracefully exiting\n"; | ||||||
|  |   exit 0; | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | printf "Forked %d children, serving 'PerlToUpper' function to gearman\n", CHILDREN; | ||||||
|  | sleep; | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user