gitroast/vendor/github.com/pingcap/go-hbase/conn.go

292 lines
5.7 KiB
Go

package hbase
import (
"bufio"
"bytes"
"io"
"net"
"strings"
"sync"
pb "github.com/golang/protobuf/proto"
"github.com/juju/errors"
"github.com/ngaut/log"
"github.com/pingcap/go-hbase/iohelper"
"github.com/pingcap/go-hbase/proto"
)
type ServiceType byte
const (
MasterMonitorService = iota + 1
MasterService
MasterAdminService
AdminService
ClientService
RegionServerStatusService
)
// convert above const to protobuf string
var ServiceString = map[ServiceType]string{
MasterMonitorService: "MasterMonitorService",
MasterService: "MasterService",
MasterAdminService: "MasterAdminService",
AdminService: "AdminService",
ClientService: "ClientService",
RegionServerStatusService: "RegionServerStatusService",
}
type idGenerator struct {
n int
mu *sync.RWMutex
}
func newIdGenerator() *idGenerator {
return &idGenerator{
n: 0,
mu: &sync.RWMutex{},
}
}
func (a *idGenerator) get() int {
a.mu.RLock()
v := a.n
a.mu.RUnlock()
return v
}
func (a *idGenerator) incrAndGet() int {
a.mu.Lock()
a.n++
v := a.n
a.mu.Unlock()
return v
}
type connection struct {
mu sync.Mutex
addr string
conn net.Conn
bw *bufio.Writer
idGen *idGenerator
serviceType ServiceType
in chan *iohelper.PbBuffer
ongoingCalls map[int]*call
}
func processMessage(msg []byte) ([][]byte, error) {
buf := pb.NewBuffer(msg)
payloads := make([][]byte, 0)
// Question: why can we ignore this error?
for {
hbytes, err := buf.DecodeRawBytes(true)
if err != nil {
// Check whether error is `unexpected EOF`.
if strings.Contains(err.Error(), "unexpected EOF") {
break
}
log.Errorf("Decode raw bytes error - %v", errors.ErrorStack(err))
return nil, errors.Trace(err)
}
payloads = append(payloads, hbytes)
}
return payloads, nil
}
func readPayloads(r io.Reader) ([][]byte, error) {
nBytesExpecting, err := iohelper.ReadInt32(r)
if err != nil {
return nil, errors.Trace(err)
}
if nBytesExpecting > 0 {
buf, err := iohelper.ReadN(r, nBytesExpecting)
// Question: why should we return error only when we get an io.EOF error?
if err != nil && ErrorEqual(err, io.EOF) {
return nil, errors.Trace(err)
}
payloads, err := processMessage(buf)
if err != nil {
return nil, errors.Trace(err)
}
if len(payloads) > 0 {
return payloads, nil
}
}
return nil, errors.New("unexpected payload")
}
func newConnection(addr string, srvType ServiceType) (*connection, error) {
conn, err := net.Dial("tcp", addr)
if err != nil {
return nil, errors.Trace(err)
}
if _, ok := ServiceString[srvType]; !ok {
return nil, errors.Errorf("unexpected service type [serviceType=%d]", srvType)
}
c := &connection{
addr: addr,
bw: bufio.NewWriter(conn),
conn: conn,
in: make(chan *iohelper.PbBuffer, 20),
serviceType: srvType,
idGen: newIdGenerator(),
ongoingCalls: map[int]*call{},
}
err = c.init()
if err != nil {
return nil, errors.Trace(err)
}
return c, nil
}
func (c *connection) init() error {
err := c.writeHead()
if err != nil {
return errors.Trace(err)
}
err = c.writeConnectionHeader()
if err != nil {
return errors.Trace(err)
}
go func() {
err := c.processMessages()
if err != nil {
log.Warnf("process messages failed - %v", errors.ErrorStack(err))
return
}
}()
go c.dispatch()
return nil
}
func (c *connection) processMessages() error {
for {
msgs, err := readPayloads(c.conn)
if err != nil {
return errors.Trace(err)
}
var rh proto.ResponseHeader
err = pb.Unmarshal(msgs[0], &rh)
if err != nil {
return errors.Trace(err)
}
callId := rh.GetCallId()
c.mu.Lock()
call, ok := c.ongoingCalls[int(callId)]
if !ok {
c.mu.Unlock()
return errors.Errorf("Invalid call id: %d", callId)
}
delete(c.ongoingCalls, int(callId))
c.mu.Unlock()
exception := rh.GetException()
if exception != nil {
call.complete(errors.Errorf("Exception returned: %s\n%s", exception.GetExceptionClassName(), exception.GetStackTrace()), nil)
} else if len(msgs) == 2 {
call.complete(nil, msgs[1])
}
}
}
func (c *connection) writeHead() error {
buf := bytes.NewBuffer(nil)
buf.Write(hbaseHeaderBytes)
buf.WriteByte(0)
buf.WriteByte(80)
_, err := c.conn.Write(buf.Bytes())
return errors.Trace(err)
}
func (c *connection) writeConnectionHeader() error {
buf := iohelper.NewPbBuffer()
service := pb.String(ServiceString[c.serviceType])
err := buf.WritePBMessage(&proto.ConnectionHeader{
UserInfo: &proto.UserInformation{
EffectiveUser: pb.String("pingcap"),
},
ServiceName: service,
})
if err != nil {
return errors.Trace(err)
}
err = buf.PrependSize()
if err != nil {
return errors.Trace(err)
}
_, err = c.conn.Write(buf.Bytes())
if err != nil {
return errors.Trace(err)
}
return nil
}
func (c *connection) dispatch() {
for {
select {
case buf := <-c.in:
// TODO: add error check.
c.bw.Write(buf.Bytes())
if len(c.in) == 0 {
c.bw.Flush()
}
}
}
}
func (c *connection) call(request *call) error {
id := c.idGen.incrAndGet()
rh := &proto.RequestHeader{
CallId: pb.Uint32(uint32(id)),
MethodName: pb.String(request.methodName),
RequestParam: pb.Bool(true),
}
request.id = uint32(id)
bfrh := iohelper.NewPbBuffer()
err := bfrh.WritePBMessage(rh)
if err != nil {
return errors.Trace(err)
}
bfr := iohelper.NewPbBuffer()
err = bfr.WritePBMessage(request.request)
if err != nil {
return errors.Trace(err)
}
// Buf =>
// | total size | pb1 size | pb1 | pb2 size | pb2 | ...
buf := iohelper.NewPbBuffer()
buf.WriteDelimitedBuffers(bfrh, bfr)
c.mu.Lock()
c.ongoingCalls[id] = request
c.in <- buf
c.mu.Unlock()
return nil
}
func (c *connection) close() error {
return c.conn.Close()
}