diff --git a/cassandra_test.go b/cassandra_test.go index ec6969190..50c8de47f 100644 --- a/cassandra_test.go +++ b/cassandra_test.go @@ -32,7 +32,6 @@ import ( "context" "errors" "fmt" - "github.com/stretchr/testify/require" "io" "math" "math/big" @@ -45,6 +44,8 @@ import ( "time" "unicode" + "github.com/stretchr/testify/require" + "gopkg.in/inf.v0" ) @@ -808,7 +809,7 @@ func TestReconnection(t *testing.T) { session := createSessionFromCluster(cluster, t) defer session.Close() - h := session.ring.allHosts()[0] + h := session.hostSource.getHostsList()[0] session.handleNodeDown(h.ConnectAddress(), h.Port()) if h.State() != NodeDown { @@ -1613,7 +1614,7 @@ func TestPrepare_PreparedCacheEviction(t *testing.T) { } // Walk through all the configured hosts and test cache retention and eviction - for _, host := range session.ring.hosts { + for _, host := range session.hostSource.hosts { _, ok := session.stmtsLRU.lru.Get(session.stmtsLRU.keyFor(host.HostID(), session.cfg.Keyspace, "SELECT id,mod FROM prepcachetest WHERE id = 0")) if ok { t.Errorf("expected first select to be purged but was in cache for host=%q", host) @@ -2769,7 +2770,7 @@ func TestTokenAwareConnPool(t *testing.T) { session := createSessionFromCluster(cluster, t) defer session.Close() - expectedPoolSize := cluster.NumConns * len(session.ring.allHosts()) + expectedPoolSize := cluster.NumConns * len(session.hostSource.getHostsList()) // wait for pool to fill for i := 0; i < 10; i++ { diff --git a/conn.go b/conn.go index ae02bd71c..d5b76c82e 100644 --- a/conn.go +++ b/conn.go @@ -1697,7 +1697,7 @@ func (c *Conn) awaitSchemaAgreement(ctx context.Context) (err error) { } for _, row := range rows { - host, err := c.session.hostInfoFromMap(row, &HostInfo{connectAddress: c.host.ConnectAddress(), port: c.session.cfg.Port}) + host, err := hostInfoFromMap(row, &HostInfo{connectAddress: c.host.ConnectAddress(), port: c.session.cfg.Port}, c.session.cfg.translateAddressPort) if err != nil { goto cont } diff --git a/control.go b/control.go index b30b44ea3..daa662dd0 100644 --- a/control.go +++ b/control.go @@ -294,12 +294,12 @@ type connHost struct { func (c *controlConn) setupConn(conn *Conn) error { // we need up-to-date host info for the filterHost call below iter := conn.querySystemLocal(context.TODO()) - host, err := c.session.hostInfoFromIter(iter, conn.host.connectAddress, conn.conn.RemoteAddr().(*net.TCPAddr).Port) + host, err := hostInfoFromIter(iter, conn.host.connectAddress, conn.conn.RemoteAddr().(*net.TCPAddr).Port, c.session.cfg.translateAddressPort) if err != nil { return err } - host = c.session.ring.addOrUpdate(host) + host = c.session.hostSource.addOrUpdate(host) if c.session.cfg.filterHost(host) { return fmt.Errorf("host was filtered: %v", host.ConnectAddress()) @@ -385,7 +385,7 @@ func (c *controlConn) reconnect() { } func (c *controlConn) attemptReconnect() (*Conn, error) { - hosts := c.session.ring.allHosts() + hosts := c.session.hostSource.getHostsList() hosts = shuffleHosts(hosts) // keep the old behavior of connecting to the old host first by moving it to diff --git a/control_ccm_test.go b/control_ccm_test.go index 426a59aef..ecccb1ff0 100644 --- a/control_ccm_test.go +++ b/control_ccm_test.go @@ -131,7 +131,7 @@ func TestControlConn_ReconnectRefreshesRing(t *testing.T) { }() assertNodeDown := func() error { - hosts := session.ring.currentHosts() + hosts := session.hostSource.getHostsMap() if len(hosts) != 1 { return fmt.Errorf("expected 1 host in ring but there were %v", len(hosts)) } @@ -169,7 +169,7 @@ func TestControlConn_ReconnectRefreshesRing(t *testing.T) { } assertNodeUp := func() error { - hosts := session.ring.currentHosts() + hosts := session.hostSource.getHostsMap() if len(hosts) != len(allCcmHosts) { return fmt.Errorf("expected %v hosts in ring but there were %v", len(allCcmHosts), len(hosts)) } diff --git a/events.go b/events.go index 93b001acc..bfddf16bb 100644 --- a/events.go +++ b/events.go @@ -217,7 +217,7 @@ func (s *Session) handleNodeUp(eventIp net.IP, eventPort int) { s.logger.Printf("gocql: Session.handleNodeUp: %s:%d\n", eventIp.String(), eventPort) } - host, ok := s.ring.getHostByIP(eventIp.String()) + host, ok := s.hostSource.getHostByIP(eventIp.String()) if !ok { s.debounceRingRefresh() return @@ -256,7 +256,7 @@ func (s *Session) handleNodeDown(ip net.IP, port int) { s.logger.Printf("gocql: Session.handleNodeDown: %s:%d\n", ip.String(), port) } - host, ok := s.ring.getHostByIP(ip.String()) + host, ok := s.hostSource.getHostByIP(ip.String()) if ok { host.setState(NodeDown) if s.cfg.filterHost(host) { diff --git a/events_ccm_test.go b/events_ccm_test.go index a105985bc..1e3c96018 100644 --- a/events_ccm_test.go +++ b/events_ccm_test.go @@ -104,7 +104,7 @@ func TestEventNodeDownControl(t *testing.T) { } session.pool.mu.RUnlock() - host := session.ring.getHost(node.Addr) + host := session.hostSource.getHost(node.Addr) if host == nil { t.Fatal("node not in metadata ring") } else if host.IsUp() { @@ -146,7 +146,7 @@ func TestEventNodeDown(t *testing.T) { t.Fatal("node not removed after remove event") } - host := session.ring.getHost(node.Addr) + host := session.hostSource.getHost(node.Addr) if host == nil { t.Fatal("node not in metadata ring") } else if host.IsUp() { @@ -203,7 +203,7 @@ func TestEventNodeUp(t *testing.T) { t.Fatal("node not added after node added event") } - host := session.ring.getHost(node.Addr) + host := session.hostSource.getHost(node.Addr) if host == nil { t.Fatal("node not in metadata ring") } else if !host.IsUp() { diff --git a/host_source.go b/host_source.go index a0bab9ad0..ad150a4b6 100644 --- a/host_source.go +++ b/host_source.go @@ -25,7 +25,6 @@ package gocql import ( - "context" "errors" "fmt" "net" @@ -445,14 +444,6 @@ func (h *HostInfo) String() string { h.port, h.dataCenter, h.rack, h.hostId, h.version, h.state, len(h.tokens)) } -// Polls system.peers at a specific interval to find new hosts -type ringDescriber struct { - session *Session - mu sync.Mutex - prevHosts []*HostInfo - prevPartitioner string -} - // Returns true if we are using system_schema.keyspaces instead of system.schema_keyspaces func checkSystemSchema(control *controlConn) (bool, error) { iter := control.query("SELECT * FROM system_schema.keyspaces") @@ -471,7 +462,7 @@ func checkSystemSchema(control *controlConn) (bool, error) { // Given a map that represents a row from either system.local or system.peers // return as much information as we can in *HostInfo -func (s *Session) hostInfoFromMap(row map[string]interface{}, host *HostInfo) (*HostInfo, error) { +func hostInfoFromMap(row map[string]interface{}, host *HostInfo, translateAddressPort func(addr net.IP, port int) (net.IP, int)) (*HostInfo, error) { const assertErrorMsg = "Assertion failed for %s" var ok bool @@ -583,14 +574,14 @@ func (s *Session) hostInfoFromMap(row map[string]interface{}, host *HostInfo) (* // Not sure what the port field will be called until the JIRA issue is complete } - ip, port := s.cfg.translateAddressPort(host.ConnectAddress(), host.port) + ip, port := translateAddressPort(host.ConnectAddress(), host.port) host.connectAddress = ip host.port = port return host, nil } -func (s *Session) hostInfoFromIter(iter *Iter, connectAddress net.IP, defaultPort int) (*HostInfo, error) { +func hostInfoFromIter(iter *Iter, connectAddress net.IP, defaultPort int, translateAddressPort func(addr net.IP, port int) (net.IP, int)) (*HostInfo, error) { rows, err := iter.SliceMap() if err != nil { // TODO(zariel): make typed error @@ -601,106 +592,13 @@ func (s *Session) hostInfoFromIter(iter *Iter, connectAddress net.IP, defaultPor return nil, errors.New("query returned 0 rows") } - host, err := s.hostInfoFromMap(rows[0], &HostInfo{connectAddress: connectAddress, port: defaultPort}) + host, err := hostInfoFromMap(rows[0], &HostInfo{connectAddress: connectAddress, port: defaultPort}, translateAddressPort) if err != nil { return nil, err } return host, nil } -// Ask the control node for the local host info -func (r *ringDescriber) getLocalHostInfo() (*HostInfo, error) { - if r.session.control == nil { - return nil, errNoControl - } - - iter := r.session.control.withConnHost(func(ch *connHost) *Iter { - return ch.conn.querySystemLocal(context.TODO()) - }) - - if iter == nil { - return nil, errNoControl - } - - host, err := r.session.hostInfoFromIter(iter, nil, r.session.cfg.Port) - if err != nil { - return nil, fmt.Errorf("could not retrieve local host info: %w", err) - } - return host, nil -} - -// Ask the control node for host info on all it's known peers -func (r *ringDescriber) getClusterPeerInfo(localHost *HostInfo) ([]*HostInfo, error) { - if r.session.control == nil { - return nil, errNoControl - } - - var peers []*HostInfo - iter := r.session.control.withConnHost(func(ch *connHost) *Iter { - return ch.conn.querySystemPeers(context.TODO(), localHost.version) - }) - - if iter == nil { - return nil, errNoControl - } - - rows, err := iter.SliceMap() - if err != nil { - // TODO(zariel): make typed error - return nil, fmt.Errorf("unable to fetch peer host info: %s", err) - } - - for _, row := range rows { - // extract all available info about the peer - host, err := r.session.hostInfoFromMap(row, &HostInfo{port: r.session.cfg.Port}) - if err != nil { - return nil, err - } else if !isValidPeer(host) { - // If it's not a valid peer - r.session.logger.Printf("Found invalid peer '%s' "+ - "Likely due to a gossip or snitch issue, this host will be ignored", host) - continue - } - - peers = append(peers, host) - } - - return peers, nil -} - -// Return true if the host is a valid peer -func isValidPeer(host *HostInfo) bool { - return !(len(host.RPCAddress()) == 0 || - host.hostId == "" || - host.dataCenter == "" || - host.rack == "" || - len(host.tokens) == 0) -} - -// GetHosts returns a list of hosts found via queries to system.local and system.peers -func (r *ringDescriber) GetHosts() ([]*HostInfo, string, error) { - r.mu.Lock() - defer r.mu.Unlock() - - localHost, err := r.getLocalHostInfo() - if err != nil { - return r.prevHosts, r.prevPartitioner, err - } - - peerHosts, err := r.getClusterPeerInfo(localHost) - if err != nil { - return r.prevHosts, r.prevPartitioner, err - } - - hosts := append([]*HostInfo{localHost}, peerHosts...) - var partitioner string - if len(hosts) > 0 { - partitioner = hosts[0].Partitioner() - } - - return hosts, partitioner, nil -} - // debounceRingRefresh submits a ring refresh request to the ring refresh debouncer. func (s *Session) debounceRingRefresh() { s.ringRefresher.debounce() @@ -716,21 +614,21 @@ func (s *Session) refreshRing() error { return err } -func refreshRing(r *ringDescriber) error { - hosts, partitioner, err := r.GetHosts() +func refreshRing(s *Session) error { + hosts, partitioner, err := s.hostSource.GetHostsFromSystem() if err != nil { return err } - prevHosts := r.session.ring.currentHosts() + prevHosts := s.hostSource.getHostsMap() for _, h := range hosts { - if r.session.cfg.filterHost(h) { + if s.cfg.filterHost(h) { continue } - if host, ok := r.session.ring.addHostIfMissing(h); !ok { - r.session.startPoolFill(h) + if host, ok := s.hostSource.addHostIfMissing(h); !ok { + s.startPoolFill(h) } else { // host (by hostID) already exists; determine if IP has changed newHostID := h.HostID() @@ -744,23 +642,21 @@ func refreshRing(r *ringDescriber) error { } else { // host IP has changed // remove old HostInfo (w/old IP) - r.session.removeHost(existing) - if _, alreadyExists := r.session.ring.addHostIfMissing(h); alreadyExists { + s.removeHost(existing) + if _, alreadyExists := s.hostSource.addHostIfMissing(h); alreadyExists { return fmt.Errorf("add new host=%s after removal: %w", h, ErrHostAlreadyExists) } // add new HostInfo (same hostID, new IP) - r.session.startPoolFill(h) + s.startPoolFill(h) } } delete(prevHosts, h.HostID()) } for _, host := range prevHosts { - r.session.removeHost(host) + s.removeHost(host) } - - r.session.metadata.setPartitioner(partitioner) - r.session.policy.SetPartitioner(partitioner) + s.policy.SetPartitioner(partitioner) return nil } diff --git a/integration_test.go b/integration_test.go index 61ffbf504..0cb936f13 100644 --- a/integration_test.go +++ b/integration_test.go @@ -62,12 +62,12 @@ func TestAuthentication(t *testing.T) { session.Close() } -func TestGetHosts(t *testing.T) { +func TestGetHostsFromSystem(t *testing.T) { clusterHosts := getClusterHosts() cluster := createCluster() session := createSessionFromCluster(cluster, t) - hosts, partitioner, err := session.hostSource.GetHosts() + hosts, partitioner, err := session.hostSource.GetHostsFromSystem() assertTrue(t, "err == nil", err == nil) assertEqual(t, "len(hosts)", len(clusterHosts), len(hosts)) diff --git a/ring.go b/ring.go deleted file mode 100644 index 6821c0df2..000000000 --- a/ring.go +++ /dev/null @@ -1,167 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -/* - * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 - * Copyright (c) 2016, The Gocql authors, - * provided under the BSD-3-Clause License. - * See the NOTICE file distributed with this work for additional information. - */ - -package gocql - -import ( - "fmt" - "sync" - "sync/atomic" -) - -type ring struct { - // endpoints are the set of endpoints which the driver will attempt to connect - // to in the case it can not reach any of its hosts. They are also used to boot - // strap the initial connection. - endpoints []*HostInfo - - mu sync.RWMutex - // hosts are the set of all hosts in the cassandra ring that we know of. - // key of map is host_id. - hosts map[string]*HostInfo - // hostIPToUUID maps host native address to host_id. - hostIPToUUID map[string]string - - hostList []*HostInfo - pos uint32 - - // TODO: we should store the ring metadata here also. -} - -func (r *ring) rrHost() *HostInfo { - r.mu.RLock() - defer r.mu.RUnlock() - if len(r.hostList) == 0 { - return nil - } - - pos := int(atomic.AddUint32(&r.pos, 1) - 1) - return r.hostList[pos%len(r.hostList)] -} - -func (r *ring) getHostByIP(ip string) (*HostInfo, bool) { - r.mu.RLock() - defer r.mu.RUnlock() - hi, ok := r.hostIPToUUID[ip] - return r.hosts[hi], ok -} - -func (r *ring) getHost(hostID string) *HostInfo { - r.mu.RLock() - host := r.hosts[hostID] - r.mu.RUnlock() - return host -} - -func (r *ring) allHosts() []*HostInfo { - r.mu.RLock() - hosts := make([]*HostInfo, 0, len(r.hosts)) - for _, host := range r.hosts { - hosts = append(hosts, host) - } - r.mu.RUnlock() - return hosts -} - -func (r *ring) currentHosts() map[string]*HostInfo { - r.mu.RLock() - hosts := make(map[string]*HostInfo, len(r.hosts)) - for k, v := range r.hosts { - hosts[k] = v - } - r.mu.RUnlock() - return hosts -} - -func (r *ring) addOrUpdate(host *HostInfo) *HostInfo { - if existingHost, ok := r.addHostIfMissing(host); ok { - existingHost.update(host) - host = existingHost - } - return host -} - -func (r *ring) addHostIfMissing(host *HostInfo) (*HostInfo, bool) { - if host.invalidConnectAddr() { - panic(fmt.Sprintf("invalid host: %v", host)) - } - hostID := host.HostID() - - r.mu.Lock() - if r.hosts == nil { - r.hosts = make(map[string]*HostInfo) - } - if r.hostIPToUUID == nil { - r.hostIPToUUID = make(map[string]string) - } - - existing, ok := r.hosts[hostID] - if !ok { - r.hosts[hostID] = host - r.hostIPToUUID[host.nodeToNodeAddress().String()] = hostID - existing = host - r.hostList = append(r.hostList, host) - } - r.mu.Unlock() - return existing, ok -} - -func (r *ring) removeHost(hostID string) bool { - r.mu.Lock() - if r.hosts == nil { - r.hosts = make(map[string]*HostInfo) - } - if r.hostIPToUUID == nil { - r.hostIPToUUID = make(map[string]string) - } - - h, ok := r.hosts[hostID] - if ok { - for i, host := range r.hostList { - if host.HostID() == hostID { - r.hostList = append(r.hostList[:i], r.hostList[i+1:]...) - break - } - } - delete(r.hostIPToUUID, h.nodeToNodeAddress().String()) - } - delete(r.hosts, hostID) - r.mu.Unlock() - return ok -} - -type clusterMetadata struct { - mu sync.RWMutex - partitioner string -} - -func (c *clusterMetadata) setPartitioner(partitioner string) { - c.mu.Lock() - defer c.mu.Unlock() - - if c.partitioner != partitioner { - // TODO: update other things now - c.partitioner = partitioner - } -} diff --git a/ring_describer.go b/ring_describer.go new file mode 100644 index 000000000..6a1bec50a --- /dev/null +++ b/ring_describer.go @@ -0,0 +1,295 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/* + * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 + * Copyright (c) 2016, The Gocql authors, + * provided under the BSD-3-Clause License. + * See the NOTICE file distributed with this work for additional information. + */ + +package gocql + +import ( + "context" + "fmt" + "sync" +) + +// Polls system.peers at a specific interval to find new hosts +type ringDescriber struct { + control *controlConn + cfg *ClusterConfig + logger StdLogger + prevHosts []*HostInfo + prevPartitioner string + + mu sync.RWMutex + // hosts are the set of all hosts in the cassandra ring that we know of. + // key of map is host_id. + hosts map[string]*HostInfo + // hostIPToUUID maps host native address to host_id. + hostIPToUUID map[string]string +} + +func (r *ringDescriber) setControlConn(c *controlConn) { + r.mu.Lock() + defer r.mu.Unlock() + + r.control = c +} + +// Ask the control node for the local host info +func (r *ringDescriber) getLocalHostInfo() (*HostInfo, error) { + if r.control == nil { + return nil, errNoControl + } + + iter := r.control.withConnHost(func(ch *connHost) *Iter { + return ch.conn.querySystemLocal(context.TODO()) + }) + + if iter == nil { + return nil, errNoControl + } + + host, err := hostInfoFromIter(iter, nil, r.cfg.Port, r.cfg.translateAddressPort) + if err != nil { + return nil, fmt.Errorf("could not retrieve local host info: %w", err) + } + return host, nil +} + +// Ask the control node for host info on all it's known peers +func (r *ringDescriber) getClusterPeerInfo(localHost *HostInfo) ([]*HostInfo, error) { + if r.control == nil { + return nil, errNoControl + } + + var peers []*HostInfo + iter := r.control.withConnHost(func(ch *connHost) *Iter { + return ch.conn.querySystemPeers(context.TODO(), localHost.version) + }) + + if iter == nil { + return nil, errNoControl + } + + rows, err := iter.SliceMap() + if err != nil { + // TODO(zariel): make typed error + return nil, fmt.Errorf("unable to fetch peer host info: %s", err) + } + + for _, row := range rows { + // extract all available info about the peer + host, err := hostInfoFromMap(row, &HostInfo{port: r.cfg.Port}, r.cfg.translateAddressPort) + if err != nil { + return nil, err + } else if !isValidPeer(host) { + // If it's not a valid peer + r.logger.Printf("Found invalid peer '%s' "+ + "Likely due to a gossip or snitch issue, this host will be ignored", host) + continue + } + + peers = append(peers, host) + } + + return peers, nil +} + +// Return true if the host is a valid peer +func isValidPeer(host *HostInfo) bool { + return !(len(host.RPCAddress()) == 0 || + host.hostId == "" || + host.dataCenter == "" || + host.rack == "" || + len(host.tokens) == 0) +} + +// GetHostsFromSystem returns a list of hosts found via queries to system.local and system.peers +func (r *ringDescriber) GetHostsFromSystem() ([]*HostInfo, string, error) { + r.mu.Lock() + defer r.mu.Unlock() + + localHost, err := r.getLocalHostInfo() + if err != nil { + return r.prevHosts, r.prevPartitioner, err + } + + peerHosts, err := r.getClusterPeerInfo(localHost) + if err != nil { + return r.prevHosts, r.prevPartitioner, err + } + + hosts := append([]*HostInfo{localHost}, peerHosts...) + var partitioner string + if len(hosts) > 0 { + partitioner = hosts[0].Partitioner() + } + + r.prevHosts = hosts + r.prevPartitioner = partitioner + + return hosts, partitioner, nil +} + +// Given an ip/port return HostInfo for the specified ip/port +func (r *ringDescriber) getHostInfo(hostID UUID) (*HostInfo, error) { + var host *HostInfo + for _, table := range []string{"system.peers", "system.local"} { + ch := r.control.getConn() + var iter *Iter + if ch.host.HostID() == hostID.String() { + host = ch.host + iter = nil + } + + if table == "system.peers" { + if ch.conn.getIsSchemaV2() { + iter = ch.conn.querySystem(context.TODO(), qrySystemPeersV2) + } else { + iter = ch.conn.querySystem(context.TODO(), qrySystemPeers) + } + } else { + iter = ch.conn.query(context.TODO(), fmt.Sprintf("SELECT * FROM %s", table)) + } + + if iter != nil { + rows, err := iter.SliceMap() + if err != nil { + return nil, err + } + + for _, row := range rows { + h, err := hostInfoFromMap(row, &HostInfo{port: r.cfg.Port}, r.cfg.translateAddressPort) + if err != nil { + return nil, err + } + + if h.HostID() == hostID.String() { + host = h + break + } + } + } + } + + if host == nil { + return nil, errors.New("unable to fetch host info: invalid control connection") + } else if host.invalidConnectAddr() { + return nil, fmt.Errorf("host ConnectAddress invalid ip=%v: %v", host.connectAddress, host) + } + + return host, nil +} + +func (r *ringDescriber) rrHost() *HostInfo { + r.mu.RLock() + defer r.mu.RUnlock() + if len(r.hostList) == 0 { + return nil + } + + pos := int(atomic.AddUint32(&r.pos, 1) - 1) + return r.hostList[pos%len(r.hostList)] +} + +func (r *ringDescriber) getHostByIP(ip string) (*HostInfo, bool) { + r.mu.RLock() + defer r.mu.RUnlock() + hi, ok := r.hostIPToUUID[ip] + return r.hosts[hi], ok +} + +func (r *ringDescriber) getHost(hostID string) *HostInfo { + r.mu.RLock() + host := r.hosts[hostID] + r.mu.RUnlock() + return host +} + +func (r *ringDescriber) getHostsList() []*HostInfo { + r.mu.RLock() + hosts := make([]*HostInfo, 0, len(r.hosts)) + for _, host := range r.hosts { + hosts = append(hosts, host) + } + r.mu.RUnlock() + return hosts +} + +func (r *ringDescriber) getHostsMap() map[string]*HostInfo { + r.mu.RLock() + hosts := make(map[string]*HostInfo, len(r.hosts)) + for k, v := range r.hosts { + hosts[k] = v + } + r.mu.RUnlock() + return hosts +} + +func (r *ringDescriber) addOrUpdate(host *HostInfo) *HostInfo { + if existingHost, ok := r.addHostIfMissing(host); ok { + existingHost.update(host) + host = existingHost + } + return host +} + +func (r *ringDescriber) addHostIfMissing(host *HostInfo) (*HostInfo, bool) { + if host.invalidConnectAddr() { + panic(fmt.Sprintf("invalid host: %v", host)) + } + hostID := host.HostID() + + r.mu.Lock() + if r.hosts == nil { + r.hosts = make(map[string]*HostInfo) + } + if r.hostIPToUUID == nil { + r.hostIPToUUID = make(map[string]string) + } + + existing, ok := r.hosts[hostID] + if !ok { + r.hosts[hostID] = host + r.hostIPToUUID[host.nodeToNodeAddress().String()] = hostID + existing = host + } + r.mu.Unlock() + return existing, ok +} + +func (r *ringDescriber) removeHost(hostID string) bool { + r.mu.Lock() + if r.hosts == nil { + r.hosts = make(map[string]*HostInfo) + } + if r.hostIPToUUID == nil { + r.hostIPToUUID = make(map[string]string) + } + + h, ok := r.hosts[hostID] + if ok { + delete(r.hostIPToUUID, h.nodeToNodeAddress().String()) + } + delete(r.hosts, hostID) + r.mu.Unlock() + return ok +} diff --git a/ring_test.go b/ring_test.go index 3e9533ecd..c7a797947 100644 --- a/ring_test.go +++ b/ring_test.go @@ -30,7 +30,7 @@ import ( ) func TestRing_AddHostIfMissing_Missing(t *testing.T) { - ring := &ring{} + ring := &ringDescriber{} host := &HostInfo{hostId: MustRandomUUID().String(), connectAddress: net.IPv4(1, 1, 1, 1)} h1, ok := ring.addHostIfMissing(host) @@ -44,7 +44,7 @@ func TestRing_AddHostIfMissing_Missing(t *testing.T) { } func TestRing_AddHostIfMissing_Existing(t *testing.T) { - ring := &ring{} + ring := &ringDescriber{} host := &HostInfo{hostId: MustRandomUUID().String(), connectAddress: net.IPv4(1, 1, 1, 1)} ring.addHostIfMissing(host) diff --git a/session.go b/session.go index d04a13672..c83fa353f 100644 --- a/session.go +++ b/session.go @@ -72,9 +72,6 @@ type Session struct { pool *policyConnPool policy HostSelectionPolicy - ring ring - metadata clusterMetadata - mu sync.RWMutex control *controlConn @@ -166,8 +163,8 @@ func NewSession(cfg ClusterConfig) (*Session, error) { s.routingKeyInfoCache.lru = lru.New(cfg.MaxRoutingKeyInfo) - s.hostSource = &ringDescriber{session: s} - s.ringRefresher = newRefreshDebouncer(ringRefreshDebounceTime, func() error { return refreshRing(s.hostSource) }) + s.hostSource = &ringDescriber{cfg: &s.cfg, logger: s.logger} + s.ringRefresher = newRefreshDebouncer(ringRefreshDebounceTime, func() error { return refreshRing(s) }) if cfg.PoolConfig.HostSelectionPolicy == nil { cfg.PoolConfig.HostSelectionPolicy = RoundRobinHostPolicy() @@ -216,7 +213,6 @@ func (s *Session) init() error { if err != nil { return err } - s.ring.endpoints = hosts if !s.cfg.disableControlConn { s.control = createControlConn(s) @@ -239,7 +235,7 @@ func (s *Session) init() error { if !s.cfg.DisableInitialHostLookup { var partitioner string - newHosts, partitioner, err := s.hostSource.GetHosts() + newHosts, partitioner, err := s.hostSource.GetHostsFromSystem() if err != nil { return err } @@ -255,6 +251,8 @@ func (s *Session) init() error { } } + s.hostSource.setControlConn(s.control) + for _, host := range hosts { // In case when host lookup is disabled and when we are in unit tests, // host are not discovered, and we are missing host ID information used @@ -282,7 +280,7 @@ func (s *Session) init() error { // again atomic.AddInt64(&left, 1) for _, host := range hostMap { - host := s.ring.addOrUpdate(host) + host := s.hostSource.addOrUpdate(host) if s.cfg.filterHost(host) { continue } @@ -344,7 +342,7 @@ func (s *Session) init() error { newer, _ := checkSystemSchema(s.control) s.useSystemSchema = newer } else { - version := s.ring.rrHost().Version() + version := s.hostSource.getHostsList()[0].Version() s.useSystemSchema = version.AtLeast(3, 0, 0) s.hasAggregatesAndFunctions = version.AtLeast(2, 2, 0) } @@ -388,11 +386,11 @@ func (s *Session) reconnectDownedHosts(intv time.Duration) { for { select { case <-reconnectTicker.C: - hosts := s.ring.allHosts() + hosts := s.hostSource.getHostsList() - // Print session.ring for debug. + // Print session.hostSource for debug. if gocqlDebug { - buf := bytes.NewBufferString("Session.ring:") + buf := bytes.NewBufferString("Session.hostSource:") for _, h := range hosts { buf.WriteString("[" + h.ConnectAddress().String() + ":" + h.State().String() + "]") } @@ -558,7 +556,7 @@ func (s *Session) removeHost(h *HostInfo) { s.policy.RemoveHost(h) hostID := h.HostID() s.pool.removeHost(hostID) - s.ring.removeHost(hostID) + s.hostSource.removeHost(hostID) } // KeyspaceMetadata returns the schema metadata for the keyspace specified. Returns an error if the keyspace does not exist. @@ -574,7 +572,7 @@ func (s *Session) KeyspaceMetadata(keyspace string) (*KeyspaceMetadata, error) { } func (s *Session) getConn() *Conn { - hosts := s.ring.allHosts() + hosts := s.hostSource.getHostsList() for _, host := range hosts { if !host.IsUp() { continue