forked from yuxh/gearman-go
		
	Merge pull request #94 from cameronpm/client_race
Fix two race conditions in client.Do
This commit is contained in:
		
						commit
						d36dcb7fc2
					
				@ -19,7 +19,6 @@ type Client struct {
 | 
			
		||||
	sync.Mutex
 | 
			
		||||
 | 
			
		||||
	net, addr    string
 | 
			
		||||
	respHandler  *responseHandlerMap
 | 
			
		||||
	innerHandler *responseHandlerMap
 | 
			
		||||
	in           chan *Response
 | 
			
		||||
	conn         net.Conn
 | 
			
		||||
@ -32,11 +31,16 @@ type Client struct {
 | 
			
		||||
 | 
			
		||||
type responseHandlerMap struct {
 | 
			
		||||
	sync.Mutex
 | 
			
		||||
	holder map[string]ResponseHandler
 | 
			
		||||
	holder map[string]handledResponse
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type handledResponse struct {
 | 
			
		||||
	internal ResponseHandler // internal handler, always non-nil
 | 
			
		||||
	external ResponseHandler // handler passed in from (*Client).Do, sometimes nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func newResponseHandlerMap() *responseHandlerMap {
 | 
			
		||||
	return &responseHandlerMap{holder: make(map[string]ResponseHandler, queueSize)}
 | 
			
		||||
	return &responseHandlerMap{holder: make(map[string]handledResponse, queueSize)}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *responseHandlerMap) remove(key string) {
 | 
			
		||||
@ -45,21 +49,22 @@ func (r *responseHandlerMap) remove(key string) {
 | 
			
		||||
	r.Unlock()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *responseHandlerMap) get(key string) (ResponseHandler, bool) {
 | 
			
		||||
func (r *responseHandlerMap) getAndRemove(key string) (handledResponse, bool) {
 | 
			
		||||
	r.Lock()
 | 
			
		||||
	rh, b := r.holder[key]
 | 
			
		||||
	delete(r.holder, key)
 | 
			
		||||
	r.Unlock()
 | 
			
		||||
	return rh, b
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *responseHandlerMap) put(key string, rh ResponseHandler) {
 | 
			
		||||
func (r *responseHandlerMap) putWithExternalHandler(key string, internal, external ResponseHandler) {
 | 
			
		||||
	r.Lock()
 | 
			
		||||
	r.holder[key] = rh
 | 
			
		||||
	r.holder[key] = handledResponse{internal: internal, external: external}
 | 
			
		||||
	r.Unlock()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *responseHandlerMap) putNoLock(key string, rh ResponseHandler) {
 | 
			
		||||
	r.holder[key] = rh
 | 
			
		||||
func (r *responseHandlerMap) put(key string, rh ResponseHandler) {
 | 
			
		||||
	r.putWithExternalHandler(key, rh, nil)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// New returns a client.
 | 
			
		||||
@ -67,7 +72,6 @@ func New(network, addr string) (client *Client, err error) {
 | 
			
		||||
	client = &Client{
 | 
			
		||||
		net:             network,
 | 
			
		||||
		addr:            addr,
 | 
			
		||||
		respHandler:     newResponseHandlerMap(),
 | 
			
		||||
		innerHandler:    newResponseHandlerMap(),
 | 
			
		||||
		in:              make(chan *Response, queueSize),
 | 
			
		||||
		ResponseTimeout: DefaultTimeout,
 | 
			
		||||
@ -168,21 +172,26 @@ ReadLoop:
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (client *Client) processLoop() {
 | 
			
		||||
	rhandlers := map[string]ResponseHandler{}
 | 
			
		||||
	for resp := range client.in {
 | 
			
		||||
		switch resp.DataType {
 | 
			
		||||
		case dtError:
 | 
			
		||||
			client.err(getError(resp.Data))
 | 
			
		||||
		case dtStatusRes:
 | 
			
		||||
			resp = client.handleInner("s"+resp.Handle, resp)
 | 
			
		||||
			client.handleInner("s"+resp.Handle, resp, nil)
 | 
			
		||||
		case dtJobCreated:
 | 
			
		||||
			resp = client.handleInner("c", resp)
 | 
			
		||||
			client.handleInner("c", resp, rhandlers)
 | 
			
		||||
		case dtEchoRes:
 | 
			
		||||
			resp = client.handleInner("e", resp)
 | 
			
		||||
			client.handleInner("e", resp, nil)
 | 
			
		||||
		case dtWorkData, dtWorkWarning, dtWorkStatus:
 | 
			
		||||
			resp = client.handleResponse(resp.Handle, resp)
 | 
			
		||||
			if cb := rhandlers[resp.Handle]; cb != nil {
 | 
			
		||||
				cb(resp)
 | 
			
		||||
			}
 | 
			
		||||
		case dtWorkComplete, dtWorkFail, dtWorkException:
 | 
			
		||||
			client.handleResponse(resp.Handle, resp)
 | 
			
		||||
			client.respHandler.remove(resp.Handle)
 | 
			
		||||
			if cb := rhandlers[resp.Handle]; cb != nil {
 | 
			
		||||
				cb(resp)
 | 
			
		||||
				delete(rhandlers, resp.Handle)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@ -193,21 +202,13 @@ func (client *Client) err(e error) {
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (client *Client) handleResponse(key string, resp *Response) *Response {
 | 
			
		||||
	if h, ok := client.respHandler.get(key); ok {
 | 
			
		||||
		h(resp)
 | 
			
		||||
		return nil
 | 
			
		||||
func (client *Client) handleInner(key string, resp *Response, rhandlers map[string]ResponseHandler) {
 | 
			
		||||
	if h, ok := client.innerHandler.getAndRemove(key); ok {
 | 
			
		||||
		if h.external != nil && resp.Handle != "" {
 | 
			
		||||
			rhandlers[resp.Handle] = h.external
 | 
			
		||||
		}
 | 
			
		||||
		h.internal(resp)
 | 
			
		||||
	}
 | 
			
		||||
	return resp
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (client *Client) handleInner(key string, resp *Response) *Response {
 | 
			
		||||
	if h, ok := client.innerHandler.get(key); ok {
 | 
			
		||||
		h(resp)
 | 
			
		||||
		client.innerHandler.remove(key)
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	return resp
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type handleOrError struct {
 | 
			
		||||
@ -216,14 +217,14 @@ type handleOrError struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (client *Client) do(funcname string, data []byte,
 | 
			
		||||
	flag uint32) (handle string, err error) {
 | 
			
		||||
	flag uint32, h ResponseHandler) (handle string, err error) {
 | 
			
		||||
	if client.conn == nil {
 | 
			
		||||
		return "", ErrLostConn
 | 
			
		||||
	}
 | 
			
		||||
	var result = make(chan handleOrError, 1)
 | 
			
		||||
	client.Lock()
 | 
			
		||||
	defer client.Unlock()
 | 
			
		||||
	client.innerHandler.put("c", func(resp *Response) {
 | 
			
		||||
	client.innerHandler.putWithExternalHandler("c", func(resp *Response) {
 | 
			
		||||
		if resp.DataType == dtError {
 | 
			
		||||
			err = getError(resp.Data)
 | 
			
		||||
			result <- handleOrError{"", err}
 | 
			
		||||
@ -231,7 +232,7 @@ func (client *Client) do(funcname string, data []byte,
 | 
			
		||||
		}
 | 
			
		||||
		handle = resp.Handle
 | 
			
		||||
		result <- handleOrError{handle, nil}
 | 
			
		||||
	})
 | 
			
		||||
	}, h)
 | 
			
		||||
	id := IdGen.Id()
 | 
			
		||||
	req := getJob(id, []byte(funcname), data)
 | 
			
		||||
	req.DataType = flag
 | 
			
		||||
@ -264,12 +265,7 @@ func (client *Client) Do(funcname string, data []byte,
 | 
			
		||||
		datatype = dtSubmitJob
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	client.respHandler.Lock()
 | 
			
		||||
	defer client.respHandler.Unlock()
 | 
			
		||||
	handle, err = client.do(funcname, data, datatype)
 | 
			
		||||
	if err == nil && h != nil {
 | 
			
		||||
		client.respHandler.putNoLock(handle, h)
 | 
			
		||||
	}
 | 
			
		||||
	handle, err = client.do(funcname, data, datatype, h)
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -289,7 +285,7 @@ func (client *Client) DoBg(funcname string, data []byte,
 | 
			
		||||
	default:
 | 
			
		||||
		datatype = dtSubmitJobBg
 | 
			
		||||
	}
 | 
			
		||||
	handle, err = client.do(funcname, data, datatype)
 | 
			
		||||
	handle, err = client.do(funcname, data, datatype, nil)
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1,9 +1,11 @@
 | 
			
		||||
package client
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"flag"
 | 
			
		||||
	"os"
 | 
			
		||||
	"testing"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
@ -17,6 +19,7 @@ var (
 | 
			
		||||
 | 
			
		||||
func TestMain(m *testing.M) {
 | 
			
		||||
	integrationsTestFlag := flag.Bool("integration", false, "Run the integration tests (in addition to the unit tests)")
 | 
			
		||||
	flag.Parse()
 | 
			
		||||
	if integrationsTestFlag != nil {
 | 
			
		||||
		runIntegrationTests = *integrationsTestFlag
 | 
			
		||||
	}
 | 
			
		||||
@ -95,6 +98,66 @@ func TestClientDo(t *testing.T) {
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestClientMultiDo(t *testing.T) {
 | 
			
		||||
	if !runIntegrationTests {
 | 
			
		||||
		t.Skip("To run this test, use: go test -integration")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// This integration test requires that examples/pl/worker_multi.pl be running.
 | 
			
		||||
	//
 | 
			
		||||
	// Test invocation is:
 | 
			
		||||
	//    go test -integration -timeout 10s -run '^TestClient(AddServer|MultiDo)$'
 | 
			
		||||
	//
 | 
			
		||||
	// Send 1000 requests to go through all race conditions
 | 
			
		||||
	const nreqs = 1000
 | 
			
		||||
	errCh := make(chan error)
 | 
			
		||||
	gotCh := make(chan string, nreqs)
 | 
			
		||||
 | 
			
		||||
	olderrh := client.ErrorHandler
 | 
			
		||||
	client.ErrorHandler = func(e error) { errCh <- e }
 | 
			
		||||
	client.ResponseTimeout = 5 * time.Second
 | 
			
		||||
	defer func() { client.ErrorHandler = olderrh }()
 | 
			
		||||
 | 
			
		||||
	nextJobCh := make(chan struct{})
 | 
			
		||||
	defer close(nextJobCh)
 | 
			
		||||
	go func() {
 | 
			
		||||
		for range nextJobCh {
 | 
			
		||||
			start := time.Now()
 | 
			
		||||
			handle, err := client.Do("PerlToUpper", []byte("abcdef"), JobNormal, func(r *Response) { gotCh <- string(r.Data) })
 | 
			
		||||
			if err == ErrLostConn && time.Since(start) > client.ResponseTimeout {
 | 
			
		||||
				errCh <- errors.New("Impossible 'lost conn', deadlock bug detected")
 | 
			
		||||
			} else if err != nil {
 | 
			
		||||
				errCh <- err
 | 
			
		||||
			}
 | 
			
		||||
			if handle == "" {
 | 
			
		||||
				errCh <- errors.New("Handle is empty.")
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	for i := 0; i < nreqs; i++ {
 | 
			
		||||
		select {
 | 
			
		||||
		case err := <-errCh:
 | 
			
		||||
			t.Fatal(err)
 | 
			
		||||
		case nextJobCh <- struct{}{}:
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	remaining := nreqs
 | 
			
		||||
	for remaining > 0 {
 | 
			
		||||
		select {
 | 
			
		||||
		case err := <-errCh:
 | 
			
		||||
			t.Fatal(err)
 | 
			
		||||
		case got := <-gotCh:
 | 
			
		||||
			if got != "ABCDEF" {
 | 
			
		||||
				t.Error("Unexpected response from PerlDoUpper: ", got)
 | 
			
		||||
			}
 | 
			
		||||
			remaining--
 | 
			
		||||
			t.Logf("%d response remaining", remaining)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestClientStatus(t *testing.T) {
 | 
			
		||||
	if !runIntegrationTests {
 | 
			
		||||
		t.Skip("To run this test, use: go test -integration")
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										33
									
								
								example/pl/worker_multi.pl
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										33
									
								
								example/pl/worker_multi.pl
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,33 @@
 | 
			
		||||
#!/usr/bin/perl
 | 
			
		||||
 | 
			
		||||
# Runs 20 children that expose "PerlToUpper" before returning the result.
 | 
			
		||||
 | 
			
		||||
use strict; use warnings;
 | 
			
		||||
use constant CHILDREN => 20;
 | 
			
		||||
use Time::HiRes qw(usleep);
 | 
			
		||||
use Gearman::Worker;
 | 
			
		||||
 | 
			
		||||
$|++;
 | 
			
		||||
my @child_pids;
 | 
			
		||||
for (1 .. CHILDREN) {
 | 
			
		||||
  if (my $pid = fork) {
 | 
			
		||||
    push @child_pids, $pid;
 | 
			
		||||
    next;
 | 
			
		||||
  }
 | 
			
		||||
  eval {
 | 
			
		||||
    my $w = Gearman::Worker->new(job_servers => '127.0.0.1:4730');
 | 
			
		||||
    $w->register_function(PerlToUpper => sub { print "."; uc $_[0]->arg });
 | 
			
		||||
    $w->work while 1;
 | 
			
		||||
  };
 | 
			
		||||
  warn $@ if $@;
 | 
			
		||||
  exit 0;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
$SIG{INT} = $SIG{HUP} = sub {
 | 
			
		||||
  kill 9, @child_pids;
 | 
			
		||||
  print "\nChildren shut down, gracefully exiting\n";
 | 
			
		||||
  exit 0;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
printf "Forked %d children, serving 'PerlToUpper' function to gearman\n", CHILDREN;
 | 
			
		||||
sleep;
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user