diff --git a/README.md b/README.md index c6244de..07faa65 100644 --- a/README.md +++ b/README.md @@ -1 +1,38 @@ # hashring + +A golang consistent hashring + +Install +=== + + go get github.com/g4zhuj/hashring + +Usage +=== + + +``` +// virtualSpots means virtual spots created by each node + nodeWeight := make(map[string]int) + nodeWeight["node1"] = 1 + nodeWeight["node2"] = 1 + nodeWeight["node3"] = 2 + vitualSpots := 100 + hash := NewHashRing(virtualSpots) + + + //add nodes + hash.AddNodes(nodeWeight) + + //remove node + hash.RemoveNode("node3") + + + //add node + hash.AddNode("node3", 3) + + + //get key's node + node := hash.GetNode("key") + +``` \ No newline at end of file diff --git a/hashring.go b/hashring.go new file mode 100644 index 0000000..e0d7378 --- /dev/null +++ b/hashring.go @@ -0,0 +1,117 @@ +package hashring + +import ( + "crypto/sha1" + // "hash" + "math" + "sort" + "strconv" +) + +const ( + DefaultVirualSpots = 40 +) + +type node struct { + nodeKey string + spotValue uint32 +} + +type nodesArray []node + +func (p nodesArray) Len() int { return len(p) } +func (p nodesArray) Less(i, j int) bool { return p[i].spotValue < p[j].spotValue } +func (p nodesArray) Swap(i, j int) { p[i], p[j] = p[j], p[i] } +func (p nodesArray) Sort() { sort.Sort(p) } + +type HashRing struct { + virualSpots int + nodes nodesArray + weights map[string]int +} + +func NewHashRing(spots int) *HashRing { + if spots == 0 { + spots = DefaultVirualSpots + } + + h := &HashRing{ + virualSpots: spots, + weights: make(map[string]int), + } + return h +} + +func (h *HashRing) AddNodes(nodeWeight map[string]int) { + for nodeKey, w := range nodeWeight { + h.weights[nodeKey] = w + } + h.generate() +} + +func (h *HashRing) AddNode(nodeKey string, weight int) { + h.weights[nodeKey] = weight + h.generate() +} + +func (h *HashRing) RemoveNode(nodeKey string) { + delete(h.weights, nodeKey) + h.generate() +} + +func (h *HashRing) UpdateNode(nodeKey string, weight int) { + h.weights[nodeKey] = weight + h.generate() +} + +func (h *HashRing) generate() { + var totalW int + for _, w := range h.weights { + totalW += w + } + + totalVirtualSpots := h.virualSpots * len(h.weights) + h.nodes = nodesArray{} + + for nodeKey, w := range h.weights { + spots := int(math.Floor(float64(w) / float64(totalW) * float64(totalVirtualSpots))) + for i := 1; i <= spots; i++ { + hash := sha1.New() + hash.Write([]byte(nodeKey + ":" + strconv.Itoa(i))) + hashBytes := hash.Sum(nil) + n := node{ + nodeKey: nodeKey, + spotValue: genValue(hashBytes[6:10]), + } + h.nodes = append(h.nodes, n) + hash.Reset() + } + } + h.nodes.Sort() +} + +func genValue(bs []byte) uint32 { + if len(bs) < 4 { + return 0 + } + v := (uint32(bs[3]) << 24) | (uint32(bs[2]) << 16) | (uint32(bs[1]) << 8) | (uint32(bs[0])) + return v +} + +func (h *HashRing) GetNode(s string) string { + if len(h.nodes) == 0 { + return "" + } + + hash := sha1.New() + hash.Write([]byte(s)) + hashBytes := hash.Sum(nil) + v := genValue(hashBytes[6:10]) + i := sort.Search(len(h.nodes), func(i int) bool { return h.nodes[i].spotValue >= v }) + + if i == len(h.nodes) { + i = 0 + } + + return h.nodes[i].nodeKey +} diff --git a/hashring_test.go b/hashring_test.go new file mode 100644 index 0000000..c92fdfc --- /dev/null +++ b/hashring_test.go @@ -0,0 +1,84 @@ +package hashring + +import ( + // "fmt" + "testing" +) + +const ( + node1 = "192.168.1.1" + node2 = "192.168.1.2" + node3 = "192.168.1.3" +) + +func getNodesCount(nodes nodesArray) (int, int, int) { + node1Count := 0 + node2Count := 0 + node3Count := 0 + + for _, node := range nodes { + if node.nodeKey == node1 { + node1Count += 1 + } + if node.nodeKey == node2 { + node2Count += 1 + + } + if node.nodeKey == node3 { + node3Count += 1 + + } + } + return node1Count, node2Count, node3Count +} + +func TestHash(t *testing.T) { + + nodeWeight := make(map[string]int) + nodeWeight[node1] = 2 + nodeWeight[node2] = 2 + nodeWeight[node3] = 3 + vitualSpots := 100 + + hash := NewHashRing(vitualSpots) + + hash.AddNodes(nodeWeight) + if hash.GetNode("1") != node3 { + t.Fatalf("expetcd %v got %v", node3, hash.GetNode("1")) + } + if hash.GetNode("2") != node3 { + t.Fatalf("expetcd %v got %v", node3, hash.GetNode("2")) + } + if hash.GetNode("3") != node2 { + t.Fatalf("expetcd %v got %v", node2, hash.GetNode("3")) + } + c1, c2, c3 := getNodesCount(hash.nodes) + t.Logf("len of nodes is %v after AddNodes node1:%v, node2:%v, node3:%v", len(hash.nodes), c1, c2, c3) + + hash.RemoveNode(node3) + if hash.GetNode("1") != node1 { + t.Fatalf("expetcd %v got %v", node1, hash.GetNode("1")) + } + if hash.GetNode("2") != node2 { + t.Fatalf("expetcd %v got %v", node1, hash.GetNode("2")) + } + if hash.GetNode("3") != node2 { + t.Fatalf("expetcd %v got %v", node2, hash.GetNode("3")) + } + c1, c2, c3 = getNodesCount(hash.nodes) + t.Logf("len of nodes is %v after RemoveNode node1:%v, node2:%v, node3:%v", len(hash.nodes), c1, c2, c3) + + hash.AddNode(node3, 3) + if hash.GetNode("1") != node3 { + t.Fatalf("expetcd %v got %v", node3, hash.GetNode("1")) + } + if hash.GetNode("2") != node3 { + t.Fatalf("expetcd %v got %v", node3, hash.GetNode("2")) + } + if hash.GetNode("3") != node2 { + t.Fatalf("expetcd %v got %v", node2, hash.GetNode("3")) + } + c1, c2, c3 = getNodesCount(hash.nodes) + t.Logf("len of nodes is %v after AddNode node1:%v, node2:%v, node3:%v", len(hash.nodes), c1, c2, c3) + +}