diff --git a/client/client.go b/client/client.go index 46b544b..082efb4 100644 --- a/client/client.go +++ b/client/client.go @@ -6,6 +6,7 @@ import ( "io" "net" "sync" + "bufio" ) // One client connect to one server. @@ -19,6 +20,7 @@ type Client struct { in chan *Response isConn bool conn net.Conn + rw *bufio.ReadWriter ErrorHandler ErrorHandler } @@ -36,6 +38,8 @@ func New(network, addr string) (client *Client, err error) { if err != nil { return } + client.rw = bufio.NewReadWriter(bufio.NewReader(client.conn), + bufio.NewWriter(client.conn)) client.isConn = true go client.readLoop() go client.processLoop() @@ -46,12 +50,12 @@ func (client *Client) write(req *request) (err error) { var n int buf := req.Encode() for i := 0; i < len(buf); i += n { - n, err = client.conn.Write(buf[i:]) + n, err = client.rw.Write(buf[i:]) if err != nil { return } } - return + return client.rw.Flush() } func (client *Client) read(length int) (data []byte, err error) { @@ -59,7 +63,7 @@ func (client *Client) read(length int) (data []byte, err error) { buf := getBuffer(bufferSize) // read until data can be unpacked for i := length; i > 0 || len(data) < minPacketLength; i -= n { - if n, err = client.conn.Read(buf); err != nil { + if n, err = client.rw.Read(buf); err != nil { if err == io.EOF { err = ErrLostConn } @@ -78,6 +82,7 @@ func (client *Client) readLoop() { var data, leftdata []byte var err error var resp *Response +ReadLoop: for { if data, err = client.read(bufferSize); err != nil { client.err(err) @@ -93,24 +98,30 @@ func (client *Client) readLoop() { client.err(err) break } + client.rw = bufio.NewReadWriter(bufio.NewReader(client.conn), + bufio.NewWriter(client.conn)) continue } if len(leftdata) > 0 { // some data left for processing data = append(leftdata, data...) } - l := len(data) - if l < minPacketLength { // not enough data - leftdata = data - continue - } - if resp, l, err = decodeResponse(data); err != nil { - client.err(err) - continue - } - client.in <- resp - leftdata = nil - if len(data) > l { - leftdata = data[l:] + for { + l := len(data) + if l < minPacketLength { // not enough data + leftdata = data + continue ReadLoop + } + if resp, l, err = decodeResponse(data); err != nil { + leftdata = data[l:] + continue ReadLoop + } else { + client.in <- resp + } + data = data[l:] + if len(data) > 0 { + continue + } + break } } } @@ -131,9 +142,13 @@ func (client *Client) processLoop() { resp = client.handleInner("c", resp) case dtEchoRes: resp = client.handleInner("e", resp) - case dtWorkData, dtWorkWarning, dtWorkStatus, dtWorkComplete, - dtWorkFail, dtWorkException: + case dtWorkData, dtWorkWarning, dtWorkStatus: + resp = client.handleResponse(resp.Handle, resp) + case dtWorkComplete, dtWorkFail, dtWorkException: resp = client.handleResponse(resp.Handle, resp) + if resp != nil { + delete(client.respHandler, resp.Handle) + } } } } @@ -147,7 +162,6 @@ func (client *Client) err(e error) { func (client *Client) handleResponse(key string, resp *Response) *Response { if h, ok := client.respHandler[key]; ok { h(resp) - delete(client.respHandler, key) return nil } return resp @@ -227,7 +241,7 @@ func (client *Client) Status(handle string) (status *Status, err error) { client.lastcall = "s" + handle client.innerHandler["s"+handle] = func(resp *Response) { var err error - status, err = resp.Status() + status, err = resp._status() if err != nil { client.err(err) } diff --git a/client/common.go b/client/common.go index b011662..89ec4c4 100644 --- a/client/common.go +++ b/client/common.go @@ -49,19 +49,22 @@ const ( dtSubmitJobHighBg = 32 dtSubmitJobLow = 33 dtSubmitJobLowBg = 34 + + WorkComplate = dtWorkComplete + WorkDate = dtWorkData + WorkStatus = dtWorkStatus + WorkWarning = dtWorkWarning + WorkFail = dtWorkFail + WorkException = dtWorkException ) const ( // Job type - // JOB_NORMAL | JOB_BG means a normal level job run in background - // normal level - JobNormal = 0 - // background job - JobBg = 1 + JobNormal = iota // low level - JobLow = 2 + JobLow // high level - JobHigh = 4 + JobHigh ) func getBuffer(l int) (buf []byte) { diff --git a/client/response.go b/client/response.go index 3903ab7..c215cb1 100644 --- a/client/response.go +++ b/client/response.go @@ -32,13 +32,7 @@ func (resp *Response) Result() (data []byte, err error) { err = ErrWorkException fallthrough case dtWorkComplete: - s := bytes.SplitN(resp.Data, []byte{'\x00'}, 2) - if len(s) != 2 { - err = fmt.Errorf("Invalid data: %V", resp.Data) - return - } - resp.Handle = string(s[0]) - data = s[1] + data = resp.Data default: err = ErrDataType } @@ -52,26 +46,25 @@ func (resp *Response) Update() (data []byte, err error) { err = ErrDataType return } - s := bytes.SplitN(resp.Data, []byte{'\x00'}, 2) - if len(s) != 2 { - err = ErrInvalidData - return - } + data = resp.Data if resp.DataType == dtWorkWarning { err = ErrWorkWarning } - resp.Handle = string(s[0]) - data = s[1] return } // Decode a job from byte slice func decodeResponse(data []byte) (resp *Response, l int, err error) { - if len(data) < minPacketLength { // valid package should not less 12 bytes + a := len(data) + if a < minPacketLength { // valid package should not less 12 bytes err = fmt.Errorf("Invalid data: %V", data) return } dl := int(binary.BigEndian.Uint32(data[8:12])) + if a < minPacketLength + dl { + err = fmt.Errorf("Invalid data: %V", data) + return + } dt := data[minPacketLength : dl+minPacketLength] if len(dt) != int(dl) { // length not equal err = fmt.Errorf("Invalid data: %V", data) @@ -101,8 +94,31 @@ func decodeResponse(data []byte) (resp *Response, l int, err error) { return } -// status handler func (resp *Response) Status() (status *Status, err error) { + data := bytes.SplitN(resp.Data, []byte{'\x00'}, 2) + if len(data) != 2 { + err = fmt.Errorf("Invalid data: %V", resp.Data) + return + } + status = &Status{} + status.Handle = resp.Handle + status.Known = true + status.Running = true + status.Numerator, err = strconv.ParseUint(string(data[0]), 10, 0) + if err != nil { + err = fmt.Errorf("Invalid Integer: %s", data[0]) + return + } + status.Denominator, err = strconv.ParseUint(string(data[1]), 10, 0) + if err != nil { + err = fmt.Errorf("Invalid Integer: %s", data[1]) + return + } + return +} + +// status handler +func (resp *Response) _status() (status *Status, err error) { data := bytes.SplitN(resp.Data, []byte{'\x00'}, 4) if len(data) != 4 { err = fmt.Errorf("Invalid data: %V", resp.Data) diff --git a/worker/agent.go b/worker/agent.go index 76fbbde..241b88a 100644 --- a/worker/agent.go +++ b/worker/agent.go @@ -5,12 +5,14 @@ import ( "net" "strings" "sync" + "bufio" ) // The agent of job server. type agent struct { sync.Mutex conn net.Conn + rw *bufio.ReadWriter worker *Worker in chan []byte net, addr string @@ -34,6 +36,8 @@ func (a *agent) Connect() (err error) { if err != nil { return } + a.rw = bufio.NewReadWriter(bufio.NewReader(a.conn), + bufio.NewWriter(a.conn)) go a.work() return } @@ -58,6 +62,8 @@ func (a *agent) work() { a.worker.err(err) break } + a.rw = bufio.NewReadWriter(bufio.NewReader(a.conn), + bufio.NewWriter(a.conn)) } if len(leftdata) > 0 { // some data left for processing data = append(leftdata, data...) @@ -120,7 +126,7 @@ func (a *agent) read(length int) (data []byte, err error) { buf := getBuffer(bufferSize) // read until data can be unpacked for i := length; i > 0 || len(data) < minPacketLength; i -= n { - if n, err = a.conn.Read(buf); err != nil { + if n, err = a.rw.Read(buf); err != nil { if isClosed(err) { err = ErrLostConn } @@ -139,10 +145,10 @@ func (a *agent) write(outpack *outPack) (err error) { var n int buf := outpack.Encode() for i := 0; i < len(buf); i += n { - n, err = a.conn.Write(buf[i:]) + n, err = a.rw.Write(buf[i:]) if err != nil { return err } } - return + return a.rw.Flush() }