forked from yuxh/gearman-go
		
	Merge pull request #94 from cameronpm/client_race
Fix two race conditions in client.Do
This commit is contained in:
		
						commit
						d36dcb7fc2
					
				@ -19,7 +19,6 @@ type Client struct {
 | 
				
			|||||||
	sync.Mutex
 | 
						sync.Mutex
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	net, addr    string
 | 
						net, addr    string
 | 
				
			||||||
	respHandler  *responseHandlerMap
 | 
					 | 
				
			||||||
	innerHandler *responseHandlerMap
 | 
						innerHandler *responseHandlerMap
 | 
				
			||||||
	in           chan *Response
 | 
						in           chan *Response
 | 
				
			||||||
	conn         net.Conn
 | 
						conn         net.Conn
 | 
				
			||||||
@ -32,11 +31,16 @@ type Client struct {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
type responseHandlerMap struct {
 | 
					type responseHandlerMap struct {
 | 
				
			||||||
	sync.Mutex
 | 
						sync.Mutex
 | 
				
			||||||
	holder map[string]ResponseHandler
 | 
						holder map[string]handledResponse
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type handledResponse struct {
 | 
				
			||||||
 | 
						internal ResponseHandler // internal handler, always non-nil
 | 
				
			||||||
 | 
						external ResponseHandler // handler passed in from (*Client).Do, sometimes nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func newResponseHandlerMap() *responseHandlerMap {
 | 
					func newResponseHandlerMap() *responseHandlerMap {
 | 
				
			||||||
	return &responseHandlerMap{holder: make(map[string]ResponseHandler, queueSize)}
 | 
						return &responseHandlerMap{holder: make(map[string]handledResponse, queueSize)}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (r *responseHandlerMap) remove(key string) {
 | 
					func (r *responseHandlerMap) remove(key string) {
 | 
				
			||||||
@ -45,21 +49,22 @@ func (r *responseHandlerMap) remove(key string) {
 | 
				
			|||||||
	r.Unlock()
 | 
						r.Unlock()
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (r *responseHandlerMap) get(key string) (ResponseHandler, bool) {
 | 
					func (r *responseHandlerMap) getAndRemove(key string) (handledResponse, bool) {
 | 
				
			||||||
	r.Lock()
 | 
						r.Lock()
 | 
				
			||||||
	rh, b := r.holder[key]
 | 
						rh, b := r.holder[key]
 | 
				
			||||||
 | 
						delete(r.holder, key)
 | 
				
			||||||
	r.Unlock()
 | 
						r.Unlock()
 | 
				
			||||||
	return rh, b
 | 
						return rh, b
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (r *responseHandlerMap) put(key string, rh ResponseHandler) {
 | 
					func (r *responseHandlerMap) putWithExternalHandler(key string, internal, external ResponseHandler) {
 | 
				
			||||||
	r.Lock()
 | 
						r.Lock()
 | 
				
			||||||
	r.holder[key] = rh
 | 
						r.holder[key] = handledResponse{internal: internal, external: external}
 | 
				
			||||||
	r.Unlock()
 | 
						r.Unlock()
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (r *responseHandlerMap) putNoLock(key string, rh ResponseHandler) {
 | 
					func (r *responseHandlerMap) put(key string, rh ResponseHandler) {
 | 
				
			||||||
	r.holder[key] = rh
 | 
						r.putWithExternalHandler(key, rh, nil)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// New returns a client.
 | 
					// New returns a client.
 | 
				
			||||||
@ -67,7 +72,6 @@ func New(network, addr string) (client *Client, err error) {
 | 
				
			|||||||
	client = &Client{
 | 
						client = &Client{
 | 
				
			||||||
		net:             network,
 | 
							net:             network,
 | 
				
			||||||
		addr:            addr,
 | 
							addr:            addr,
 | 
				
			||||||
		respHandler:     newResponseHandlerMap(),
 | 
					 | 
				
			||||||
		innerHandler:    newResponseHandlerMap(),
 | 
							innerHandler:    newResponseHandlerMap(),
 | 
				
			||||||
		in:              make(chan *Response, queueSize),
 | 
							in:              make(chan *Response, queueSize),
 | 
				
			||||||
		ResponseTimeout: DefaultTimeout,
 | 
							ResponseTimeout: DefaultTimeout,
 | 
				
			||||||
@ -168,21 +172,26 @@ ReadLoop:
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (client *Client) processLoop() {
 | 
					func (client *Client) processLoop() {
 | 
				
			||||||
 | 
						rhandlers := map[string]ResponseHandler{}
 | 
				
			||||||
	for resp := range client.in {
 | 
						for resp := range client.in {
 | 
				
			||||||
		switch resp.DataType {
 | 
							switch resp.DataType {
 | 
				
			||||||
		case dtError:
 | 
							case dtError:
 | 
				
			||||||
			client.err(getError(resp.Data))
 | 
								client.err(getError(resp.Data))
 | 
				
			||||||
		case dtStatusRes:
 | 
							case dtStatusRes:
 | 
				
			||||||
			resp = client.handleInner("s"+resp.Handle, resp)
 | 
								client.handleInner("s"+resp.Handle, resp, nil)
 | 
				
			||||||
		case dtJobCreated:
 | 
							case dtJobCreated:
 | 
				
			||||||
			resp = client.handleInner("c", resp)
 | 
								client.handleInner("c", resp, rhandlers)
 | 
				
			||||||
		case dtEchoRes:
 | 
							case dtEchoRes:
 | 
				
			||||||
			resp = client.handleInner("e", resp)
 | 
								client.handleInner("e", resp, nil)
 | 
				
			||||||
		case dtWorkData, dtWorkWarning, dtWorkStatus:
 | 
							case dtWorkData, dtWorkWarning, dtWorkStatus:
 | 
				
			||||||
			resp = client.handleResponse(resp.Handle, resp)
 | 
								if cb := rhandlers[resp.Handle]; cb != nil {
 | 
				
			||||||
 | 
									cb(resp)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
		case dtWorkComplete, dtWorkFail, dtWorkException:
 | 
							case dtWorkComplete, dtWorkFail, dtWorkException:
 | 
				
			||||||
			client.handleResponse(resp.Handle, resp)
 | 
								if cb := rhandlers[resp.Handle]; cb != nil {
 | 
				
			||||||
			client.respHandler.remove(resp.Handle)
 | 
									cb(resp)
 | 
				
			||||||
 | 
									delete(rhandlers, resp.Handle)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@ -193,21 +202,13 @@ func (client *Client) err(e error) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (client *Client) handleResponse(key string, resp *Response) *Response {
 | 
					func (client *Client) handleInner(key string, resp *Response, rhandlers map[string]ResponseHandler) {
 | 
				
			||||||
	if h, ok := client.respHandler.get(key); ok {
 | 
						if h, ok := client.innerHandler.getAndRemove(key); ok {
 | 
				
			||||||
		h(resp)
 | 
							if h.external != nil && resp.Handle != "" {
 | 
				
			||||||
		return nil
 | 
								rhandlers[resp.Handle] = h.external
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							h.internal(resp)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return resp
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (client *Client) handleInner(key string, resp *Response) *Response {
 | 
					 | 
				
			||||||
	if h, ok := client.innerHandler.get(key); ok {
 | 
					 | 
				
			||||||
		h(resp)
 | 
					 | 
				
			||||||
		client.innerHandler.remove(key)
 | 
					 | 
				
			||||||
		return nil
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return resp
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type handleOrError struct {
 | 
					type handleOrError struct {
 | 
				
			||||||
@ -216,14 +217,14 @@ type handleOrError struct {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (client *Client) do(funcname string, data []byte,
 | 
					func (client *Client) do(funcname string, data []byte,
 | 
				
			||||||
	flag uint32) (handle string, err error) {
 | 
						flag uint32, h ResponseHandler) (handle string, err error) {
 | 
				
			||||||
	if client.conn == nil {
 | 
						if client.conn == nil {
 | 
				
			||||||
		return "", ErrLostConn
 | 
							return "", ErrLostConn
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	var result = make(chan handleOrError, 1)
 | 
						var result = make(chan handleOrError, 1)
 | 
				
			||||||
	client.Lock()
 | 
						client.Lock()
 | 
				
			||||||
	defer client.Unlock()
 | 
						defer client.Unlock()
 | 
				
			||||||
	client.innerHandler.put("c", func(resp *Response) {
 | 
						client.innerHandler.putWithExternalHandler("c", func(resp *Response) {
 | 
				
			||||||
		if resp.DataType == dtError {
 | 
							if resp.DataType == dtError {
 | 
				
			||||||
			err = getError(resp.Data)
 | 
								err = getError(resp.Data)
 | 
				
			||||||
			result <- handleOrError{"", err}
 | 
								result <- handleOrError{"", err}
 | 
				
			||||||
@ -231,7 +232,7 @@ func (client *Client) do(funcname string, data []byte,
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
		handle = resp.Handle
 | 
							handle = resp.Handle
 | 
				
			||||||
		result <- handleOrError{handle, nil}
 | 
							result <- handleOrError{handle, nil}
 | 
				
			||||||
	})
 | 
						}, h)
 | 
				
			||||||
	id := IdGen.Id()
 | 
						id := IdGen.Id()
 | 
				
			||||||
	req := getJob(id, []byte(funcname), data)
 | 
						req := getJob(id, []byte(funcname), data)
 | 
				
			||||||
	req.DataType = flag
 | 
						req.DataType = flag
 | 
				
			||||||
@ -264,12 +265,7 @@ func (client *Client) Do(funcname string, data []byte,
 | 
				
			|||||||
		datatype = dtSubmitJob
 | 
							datatype = dtSubmitJob
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	client.respHandler.Lock()
 | 
						handle, err = client.do(funcname, data, datatype, h)
 | 
				
			||||||
	defer client.respHandler.Unlock()
 | 
					 | 
				
			||||||
	handle, err = client.do(funcname, data, datatype)
 | 
					 | 
				
			||||||
	if err == nil && h != nil {
 | 
					 | 
				
			||||||
		client.respHandler.putNoLock(handle, h)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -289,7 +285,7 @@ func (client *Client) DoBg(funcname string, data []byte,
 | 
				
			|||||||
	default:
 | 
						default:
 | 
				
			||||||
		datatype = dtSubmitJobBg
 | 
							datatype = dtSubmitJobBg
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	handle, err = client.do(funcname, data, datatype)
 | 
						handle, err = client.do(funcname, data, datatype, nil)
 | 
				
			||||||
	return
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -1,9 +1,11 @@
 | 
				
			|||||||
package client
 | 
					package client
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
 | 
						"errors"
 | 
				
			||||||
	"flag"
 | 
						"flag"
 | 
				
			||||||
	"os"
 | 
						"os"
 | 
				
			||||||
	"testing"
 | 
						"testing"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
const (
 | 
					const (
 | 
				
			||||||
@ -17,6 +19,7 @@ var (
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
func TestMain(m *testing.M) {
 | 
					func TestMain(m *testing.M) {
 | 
				
			||||||
	integrationsTestFlag := flag.Bool("integration", false, "Run the integration tests (in addition to the unit tests)")
 | 
						integrationsTestFlag := flag.Bool("integration", false, "Run the integration tests (in addition to the unit tests)")
 | 
				
			||||||
 | 
						flag.Parse()
 | 
				
			||||||
	if integrationsTestFlag != nil {
 | 
						if integrationsTestFlag != nil {
 | 
				
			||||||
		runIntegrationTests = *integrationsTestFlag
 | 
							runIntegrationTests = *integrationsTestFlag
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@ -95,6 +98,66 @@ func TestClientDo(t *testing.T) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestClientMultiDo(t *testing.T) {
 | 
				
			||||||
 | 
						if !runIntegrationTests {
 | 
				
			||||||
 | 
							t.Skip("To run this test, use: go test -integration")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// This integration test requires that examples/pl/worker_multi.pl be running.
 | 
				
			||||||
 | 
						//
 | 
				
			||||||
 | 
						// Test invocation is:
 | 
				
			||||||
 | 
						//    go test -integration -timeout 10s -run '^TestClient(AddServer|MultiDo)$'
 | 
				
			||||||
 | 
						//
 | 
				
			||||||
 | 
						// Send 1000 requests to go through all race conditions
 | 
				
			||||||
 | 
						const nreqs = 1000
 | 
				
			||||||
 | 
						errCh := make(chan error)
 | 
				
			||||||
 | 
						gotCh := make(chan string, nreqs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						olderrh := client.ErrorHandler
 | 
				
			||||||
 | 
						client.ErrorHandler = func(e error) { errCh <- e }
 | 
				
			||||||
 | 
						client.ResponseTimeout = 5 * time.Second
 | 
				
			||||||
 | 
						defer func() { client.ErrorHandler = olderrh }()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						nextJobCh := make(chan struct{})
 | 
				
			||||||
 | 
						defer close(nextJobCh)
 | 
				
			||||||
 | 
						go func() {
 | 
				
			||||||
 | 
							for range nextJobCh {
 | 
				
			||||||
 | 
								start := time.Now()
 | 
				
			||||||
 | 
								handle, err := client.Do("PerlToUpper", []byte("abcdef"), JobNormal, func(r *Response) { gotCh <- string(r.Data) })
 | 
				
			||||||
 | 
								if err == ErrLostConn && time.Since(start) > client.ResponseTimeout {
 | 
				
			||||||
 | 
									errCh <- errors.New("Impossible 'lost conn', deadlock bug detected")
 | 
				
			||||||
 | 
								} else if err != nil {
 | 
				
			||||||
 | 
									errCh <- err
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								if handle == "" {
 | 
				
			||||||
 | 
									errCh <- errors.New("Handle is empty.")
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for i := 0; i < nreqs; i++ {
 | 
				
			||||||
 | 
							select {
 | 
				
			||||||
 | 
							case err := <-errCh:
 | 
				
			||||||
 | 
								t.Fatal(err)
 | 
				
			||||||
 | 
							case nextJobCh <- struct{}{}:
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						remaining := nreqs
 | 
				
			||||||
 | 
						for remaining > 0 {
 | 
				
			||||||
 | 
							select {
 | 
				
			||||||
 | 
							case err := <-errCh:
 | 
				
			||||||
 | 
								t.Fatal(err)
 | 
				
			||||||
 | 
							case got := <-gotCh:
 | 
				
			||||||
 | 
								if got != "ABCDEF" {
 | 
				
			||||||
 | 
									t.Error("Unexpected response from PerlDoUpper: ", got)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								remaining--
 | 
				
			||||||
 | 
								t.Logf("%d response remaining", remaining)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestClientStatus(t *testing.T) {
 | 
					func TestClientStatus(t *testing.T) {
 | 
				
			||||||
	if !runIntegrationTests {
 | 
						if !runIntegrationTests {
 | 
				
			||||||
		t.Skip("To run this test, use: go test -integration")
 | 
							t.Skip("To run this test, use: go test -integration")
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										33
									
								
								example/pl/worker_multi.pl
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										33
									
								
								example/pl/worker_multi.pl
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,33 @@
 | 
				
			|||||||
 | 
					#!/usr/bin/perl
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Runs 20 children that expose "PerlToUpper" before returning the result.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					use strict; use warnings;
 | 
				
			||||||
 | 
					use constant CHILDREN => 20;
 | 
				
			||||||
 | 
					use Time::HiRes qw(usleep);
 | 
				
			||||||
 | 
					use Gearman::Worker;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					$|++;
 | 
				
			||||||
 | 
					my @child_pids;
 | 
				
			||||||
 | 
					for (1 .. CHILDREN) {
 | 
				
			||||||
 | 
					  if (my $pid = fork) {
 | 
				
			||||||
 | 
					    push @child_pids, $pid;
 | 
				
			||||||
 | 
					    next;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  eval {
 | 
				
			||||||
 | 
					    my $w = Gearman::Worker->new(job_servers => '127.0.0.1:4730');
 | 
				
			||||||
 | 
					    $w->register_function(PerlToUpper => sub { print "."; uc $_[0]->arg });
 | 
				
			||||||
 | 
					    $w->work while 1;
 | 
				
			||||||
 | 
					  };
 | 
				
			||||||
 | 
					  warn $@ if $@;
 | 
				
			||||||
 | 
					  exit 0;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					$SIG{INT} = $SIG{HUP} = sub {
 | 
				
			||||||
 | 
					  kill 9, @child_pids;
 | 
				
			||||||
 | 
					  print "\nChildren shut down, gracefully exiting\n";
 | 
				
			||||||
 | 
					  exit 0;
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					printf "Forked %d children, serving 'PerlToUpper' function to gearman\n", CHILDREN;
 | 
				
			||||||
 | 
					sleep;
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user