diff --git a/worker/agent.go b/worker/agent.go index 7385211..afd6f76 100644 --- a/worker/agent.go +++ b/worker/agent.go @@ -2,6 +2,8 @@ package worker import ( "bufio" + "bytes" + "encoding/binary" "net" "sync" "io" @@ -53,7 +55,7 @@ func (a *agent) work() { var err error var data, leftdata []byte for { - if data, err = a.read(bufferSize); err != nil { + if data, err = a.read(); err != nil { if opErr, ok := err.(*net.OpError); ok { if opErr.Temporary() { continue @@ -159,20 +161,31 @@ func (a *agent) reconnect() (error){ } // read length bytes from the socket -func (a *agent) read(length int) (data []byte, err error) { +func (a *agent) read() (data []byte, err error) { n := 0 - buf := getBuffer(bufferSize) - // read until data can be unpacked - for i := length; i > 0 || len(data) < minPacketLength; i -= n { - if n, err = a.rw.Read(buf); err != nil { - return - } - data = append(data, buf[0:n]...) - if n < bufferSize { - break - } + + tmp := getBuffer(bufferSize) + var buf bytes.Buffer + + // read the header so we can get the length of the data + if n, err = a.rw.Read(tmp); err != nil { + return } - return + dl := int(binary.BigEndian.Uint32(tmp[8:12])) + + // write what we read so far + buf.Write(tmp[:n]) + + // read until we receive all the data + for buf.Len() < dl+minPacketLength { + if n, err = a.rw.Read(tmp); err != nil { + return buf.Bytes(), err + } + + buf.Write(tmp[:n]) + } + + return buf.Bytes(), err } // Internal write the encoded job. diff --git a/worker/worker_test.go b/worker/worker_test.go index 7c47b14..b496a47 100644 --- a/worker/worker_test.go +++ b/worker/worker_test.go @@ -1,6 +1,7 @@ package worker import ( + "bytes" "sync" "testing" "time" @@ -78,6 +79,61 @@ func TestWork(t *testing.T) { wg.Wait() } +func TestLargeDataWork(t *testing.T) { + worker := New(Unlimited) + defer worker.Close() + + if err := worker.AddServer(Network, "127.0.0.1:4730"); err != nil { + t.Error(err) + } + worker.Ready() + + l := 5714 + var wg sync.WaitGroup + + bigdataHandler := func(job Job) error { + defer wg.Done() + if len(job.Data()) != l { + t.Errorf("expected length %d. got %d.", l, len(job.Data())) + } + return nil + } + if err := worker.AddFunc("bigdata", foobar, 0); err != nil { + defer wg.Done() + t.Error(err) + } + + worker.JobHandler = bigdataHandler + + worker.ErrorHandler = func(err error) { + t.Fatal("shouldn't have received an error") + } + + if err := worker.Ready(); err != nil { + t.Error(err) + return + } + go worker.Work() + wg.Add(1) + + // var cli *client.Client + // var err error + // if cli, err = client.New(client.Network, "127.0.0.1:4730"); err != nil { + // t.Fatal(err) + // } + // cli.ErrorHandler = func(e error) { + // t.Error(e) + // } + + // _, err = cli.Do("bigdata", bytes.Repeat([]byte("a"), l), client.JobLow, func(res *client.Response) { + // }) + // if err != nil { + // t.Error(err) + // } + + worker.Echo(bytes.Repeat([]byte("a"), l)) + wg.Wait() +} func TestWorkerClose(t *testing.T) { worker.Close()