[ADDED] Support for processing of asynchronous INFO protocols

This would be used in conjunction with server's PR #314.
The client may receive INFO protocols with a field containing
a possible array of URLs that correspond to server addresses in
the full mesh cluster that clients can connect to.
The server pool is updated with these URLs.
3 new Options and 2 new Options setters have been introduced to
manage username/password and/or tokens when dealing with these
bare URLs.

cc /derekcollison
This commit is contained in:
Ivan Kozlovic
2016-08-01 12:27:39 -06:00
parent c0ad3f0797
commit 6db5372cae
4 changed files with 394 additions and 37 deletions

155
nats.go
View File

@@ -141,6 +141,10 @@ type Options struct {
// NOTE: This does not affect AsyncSubscriptions which are // NOTE: This does not affect AsyncSubscriptions which are
// dictated by PendingLimits() // dictated by PendingLimits()
SubChanLen int SubChanLen int
User string
Password string
Token string
} }
const ( const (
@@ -176,6 +180,7 @@ type Conn struct {
url *url.URL url *url.URL
conn net.Conn conn net.Conn
srvPool []*srv srvPool []*srv
urls map[string]struct{} // Keep track of all known URLs (used by processInfo)
bw *bufio.Writer bw *bufio.Writer
pending *bytes.Buffer pending *bytes.Buffer
fch chan bool fch chan bool
@@ -261,15 +266,25 @@ type srv struct {
} }
type serverInfo struct { type serverInfo struct {
Id string `json:"server_id"` Id string `json:"server_id"`
Host string `json:"host"` Host string `json:"host"`
Port uint `json:"port"` Port uint `json:"port"`
Version string `json:"version"` Version string `json:"version"`
AuthRequired bool `json:"auth_required"` AuthRequired bool `json:"auth_required"`
TLSRequired bool `json:"tls_required"` TLSRequired bool `json:"tls_required"`
MaxPayload int64 `json:"max_payload"` MaxPayload int64 `json:"max_payload"`
ConnectURLs []string `json:"connect_urls,omitempty"`
} }
const (
// clientProtoZero is the original client protocol from 2009.
// http://nats.io/documentation/internals/nats-protocol/
clientProtoZero = iota
// clientProtoInfo signals a client can receive more then the original INFO block.
// This can be used to update clients on other cluster members, etc.
clientProtoInfo
)
type connectInfo struct { type connectInfo struct {
Verbose bool `json:"verbose"` Verbose bool `json:"verbose"`
Pedantic bool `json:"pedantic"` Pedantic bool `json:"pedantic"`
@@ -280,6 +295,7 @@ type connectInfo struct {
Name string `json:"name"` Name string `json:"name"`
Lang string `json:"lang"` Lang string `json:"lang"`
Version string `json:"version"` Version string `json:"version"`
Protocol int `json:"protocol"`
} }
// MsgHandler is a callback function that processes messages delivered to // MsgHandler is a callback function that processes messages delivered to
@@ -445,6 +461,25 @@ func ErrorHandler(cb ErrHandler) Option {
} }
} }
// UserInfo is an Option to set the username and password to
// use when not included directly in the URLs.
func UserInfo(user, password string) Option {
return func(o *Options) error {
o.User = user
o.Password = password
return nil
}
}
// Token is an Option to set the token to use when not included
// directly in the URLs.
func Token(token string) Option {
return func(o *Options) error {
o.Token = token
return nil
}
}
// Handler processing // Handler processing
// SetDisconnectHandler will set the disconnect event handler. // SetDisconnectHandler will set the disconnect event handler.
@@ -621,44 +656,31 @@ const tlsScheme = "tls"
// the NoRandomize flag is set. // the NoRandomize flag is set.
func (nc *Conn) setupServerPool() error { func (nc *Conn) setupServerPool() error {
nc.srvPool = make([]*srv, 0, srvPoolSize) nc.srvPool = make([]*srv, 0, srvPoolSize)
nc.urls = make(map[string]struct{}, srvPoolSize)
if nc.Opts.Url != _EMPTY_ { if nc.Opts.Url != _EMPTY_ {
u, err := url.Parse(nc.Opts.Url) if err := nc.addURLToPool(nc.Opts.Url); err != nil {
if err != nil {
return err return err
} }
s := &srv{url: u}
nc.srvPool = append(nc.srvPool, s)
} }
var srvrs []string // Create srv objects from each url string in nc.Opts.Servers
source := rand.NewSource(time.Now().UnixNano()) // and add them to the pool
r := rand.New(source) for _, urlString := range nc.Opts.Servers {
if err := nc.addURLToPool(urlString); err != nil {
if nc.Opts.NoRandomize {
srvrs = nc.Opts.Servers
} else {
in := r.Perm(len(nc.Opts.Servers))
for _, i := range in {
srvrs = append(srvrs, nc.Opts.Servers[i])
}
}
for _, urlString := range srvrs {
u, err := url.Parse(urlString)
if err != nil {
return err return err
} }
s := &srv{url: u} }
nc.srvPool = append(nc.srvPool, s)
// Randomize if allowed to
if !nc.Opts.NoRandomize {
nc.shufflePool()
} }
// Place default URL if pool is empty. // Place default URL if pool is empty.
if len(nc.srvPool) <= 0 { if len(nc.srvPool) <= 0 {
u, err := url.Parse(DefaultURL) if err := nc.addURLToPool(DefaultURL); err != nil {
if err != nil {
return err return err
} }
s := &srv{url: u}
nc.srvPool = append(nc.srvPool, s)
} }
// Check for Scheme hint to move to TLS mode. // Check for Scheme hint to move to TLS mode.
@@ -675,6 +697,31 @@ func (nc *Conn) setupServerPool() error {
return nc.pickServer() return nc.pickServer()
} }
// addURLToPool adds an entry to the server pool
func (nc *Conn) addURLToPool(sURL string) error {
u, err := url.Parse(sURL)
if err != nil {
return err
}
s := &srv{url: u}
nc.srvPool = append(nc.srvPool, s)
nc.urls[u.Host] = struct{}{}
return nil
}
// shufflePool swaps randomly elements in the server pool
func (nc *Conn) shufflePool() {
if len(nc.srvPool) <= 1 {
return
}
source := rand.NewSource(time.Now().UnixNano())
r := rand.New(source)
for i := range nc.srvPool {
j := r.Intn(i + 1)
nc.srvPool[i], nc.srvPool[j] = nc.srvPool[j], nc.srvPool[i]
}
}
// createConn will connect to the server and wrap the appropriate // createConn will connect to the server and wrap the appropriate
// bufio structures. It will do the right thing when an existing // bufio structures. It will do the right thing when an existing
// connection is in place. // connection is in place.
@@ -844,7 +891,10 @@ func (nc *Conn) connect() error {
// For first connect we walk all servers in the pool and try // For first connect we walk all servers in the pool and try
// to connect immediately. // to connect immediately.
nc.mu.Lock() nc.mu.Lock()
for i := range nc.srvPool { // Get the size of the pool. The pool may change inside a loop
// iteration due to INFO protocol.
poolSize := len(nc.srvPool)
for i := 0; i < poolSize; i++ {
nc.url = nc.srvPool[i].url nc.url = nc.srvPool[i].url
if err := nc.createConn(); err == nil { if err := nc.createConn(); err == nil {
@@ -866,6 +916,9 @@ func (nc *Conn) connect() error {
nc.mu.Lock() nc.mu.Lock()
nc.url = nil nc.url = nil
} }
// Refresh our view of pool length since it may have been
// modified when processing the INFO protocol.
poolSize = len(nc.srvPool)
} else { } else {
// Cancel out default connection refused, will trigger the // Cancel out default connection refused, will trigger the
// No servers error conditional // No servers error conditional
@@ -956,10 +1009,15 @@ func (nc *Conn) connectProto() (string, error) {
user = u.Username() user = u.Username()
pass, _ = u.Password() pass, _ = u.Password()
} }
} else {
// Take from options (pssibly all empty strings)
user = nc.Opts.User
pass = nc.Opts.Password
token = nc.Opts.Token
} }
cinfo := connectInfo{o.Verbose, o.Pedantic, cinfo := connectInfo{o.Verbose, o.Pedantic,
user, pass, token, user, pass, token,
o.Secure, o.Name, LangString, Version} o.Secure, o.Name, LangString, Version, clientProtoInfo}
b, err := json.Marshal(cinfo) b, err := json.Marshal(cinfo)
if err != nil { if err != nil {
return _EMPTY_, ErrJsonParse return _EMPTY_, ErrJsonParse
@@ -1552,11 +1610,38 @@ func (nc *Conn) processOK() {
// processInfo is used to parse the info messages sent // processInfo is used to parse the info messages sent
// from the server. // from the server.
// This function May update the server pool.
func (nc *Conn) processInfo(info string) error { func (nc *Conn) processInfo(info string) error {
if info == _EMPTY_ { if info == _EMPTY_ {
return nil return nil
} }
return json.Unmarshal([]byte(info), &nc.info) if err := json.Unmarshal([]byte(info), &nc.info); err != nil {
return err
}
updated := false
urls := nc.info.ConnectURLs
for _, curl := range urls {
if _, present := nc.urls[curl]; !present {
if err := nc.addURLToPool(fmt.Sprintf("nats://%s", curl)); err != nil {
continue
}
updated = true
}
}
if updated && !nc.Opts.NoRandomize {
nc.shufflePool()
}
return nil
}
// processAsyncInfo does the same than processInfo, but is called
// from the parser. Calls processInfo under connection's lock
// protection.
func (nc *Conn) processAsyncInfo(info []byte) {
nc.mu.Lock()
// Ignore errors, we will simply not update the server pool...
nc.processInfo(string(info))
nc.mu.Unlock()
} }
// LastError reports the last error encountered via the connection. // LastError reports the last error encountered via the connection.

View File

@@ -785,3 +785,186 @@ func TestNormalizeError(t *testing.T) {
t.Fatalf("Expected '%s', got '%s'", expected, s) t.Fatalf("Expected '%s', got '%s'", expected, s)
} }
} }
func TestAsyncINFO(t *testing.T) {
opts := DefaultOptions
c := &Conn{Opts: opts}
c.ps = &parseState{}
if c.ps.state != OP_START {
t.Fatalf("Expected OP_START vs %d\n", c.ps.state)
}
info := []byte("INFO {}\r\n")
if c.ps.state != OP_START {
t.Fatalf("Expected OP_START vs %d\n", c.ps.state)
}
err := c.parse(info[:1])
if err != nil || c.ps.state != OP_I {
t.Fatalf("Unexpected: %d : %v\n", c.ps.state, err)
}
err = c.parse(info[1:2])
if err != nil || c.ps.state != OP_IN {
t.Fatalf("Unexpected: %d : %v\n", c.ps.state, err)
}
err = c.parse(info[2:3])
if err != nil || c.ps.state != OP_INF {
t.Fatalf("Unexpected: %d : %v\n", c.ps.state, err)
}
err = c.parse(info[3:4])
if err != nil || c.ps.state != OP_INFO {
t.Fatalf("Unexpected: %d : %v\n", c.ps.state, err)
}
err = c.parse(info[4:5])
if err != nil || c.ps.state != OP_INFO_SPC {
t.Fatalf("Unexpected: %d : %v\n", c.ps.state, err)
}
err = c.parse(info[5:])
if err != nil || c.ps.state != OP_START {
t.Fatalf("Unexpected: %d : %v\n", c.ps.state, err)
}
// All at once
err = c.parse(info)
if err != nil || c.ps.state != OP_START {
t.Fatalf("Unexpected: %d : %v\n", c.ps.state, err)
}
// Server pool needs to be setup
c.setupServerPool()
// Good INFOs
good := []string{"INFO {}\r\n", "INFO {} \r\n", "INFO { \"server_id\": \"test\" } \r\n", "INFO {\"connect_urls\":[]}\r\n"}
for _, gi := range good {
c.ps = &parseState{}
err = c.parse([]byte(gi))
if err != nil || c.ps.state != OP_START {
t.Fatalf("Protocol %q should be fine. Err=%v state=%v", gi, err, c.ps.state)
}
}
// Wrong INFOs
wrong := []string{"INFOx {}\r\n", "INFO{}\r\n", "INFO {}"}
for _, wi := range wrong {
c.ps = &parseState{}
err = c.parse([]byte(wi))
if err == nil && c.ps.state == OP_START {
t.Fatalf("Protocol %q should have failed", wi)
}
}
checkPool := func(urls ...string) {
// Check both pool and urls map
if len(c.srvPool) != len(urls) {
t.Fatalf("Pool should have %d elements, has %d", len(urls), len(c.srvPool))
}
if len(c.urls) != len(urls) {
t.Fatalf("Map should have %d elements, has %d", len(urls), len(c.urls))
}
for i, url := range urls {
if c.Opts.NoRandomize {
if c.srvPool[i].url.Host != url {
t.Fatalf("Pool should have %q at index %q, has %q", url, i, c.srvPool[i].url.Host)
}
} else {
if _, present := c.urls[url]; !present {
t.Fatalf("Pool should have %q", url)
}
}
}
}
// Now test the decoding of "connect_urls"
// No randomize for now
c.Opts.NoRandomize = true
// Reset the pool
c.setupServerPool()
info = []byte("INFO {\"connect_urls\":[\"localhost:5222\"]}\r\n")
err = c.parse(info)
if err != nil || c.ps.state != OP_START {
t.Fatalf("Unexpected: %d : %v\n", c.ps.state, err)
}
// Pool now should contain localhost:4222 (the default URL) and localhost:5222
checkPool("localhost:4222", "localhost:5222")
// Make sure that if client receives the same, it is not added again.
err = c.parse(info)
if err != nil || c.ps.state != OP_START {
t.Fatalf("Unexpected: %d : %v\n", c.ps.state, err)
}
// Pool should still contain localhost:4222 (the default URL) and localhost:5222
checkPool("localhost:4222", "localhost:5222")
// Receive a new URL
info = []byte("INFO {\"connect_urls\":[\"localhost:6222\"]}\r\n")
err = c.parse(info)
if err != nil || c.ps.state != OP_START {
t.Fatalf("Unexpected: %d : %v\n", c.ps.state, err)
}
// Pool now should contain localhost:4222 (the default URL) localhost:5222 and localhost:6222
checkPool("localhost:4222", "localhost:5222", "localhost:6222")
// Receive more than 1 URL at once
info = []byte("INFO {\"connect_urls\":[\"localhost:7222\", \"localhost:8222\"]}\r\n")
err = c.parse(info)
if err != nil || c.ps.state != OP_START {
t.Fatalf("Unexpected: %d : %v\n", c.ps.state, err)
}
// Pool now should contain localhost:4222 (the default URL) localhost:5222, localhost:6222
// localhost:7222 and localhost:8222
checkPool("localhost:4222", "localhost:5222", "localhost:6222", "localhost:7222", "localhost:8222")
// Test with pool randomization now
c.Opts.NoRandomize = false
c.setupServerPool()
info = []byte("INFO {\"connect_urls\":[\"localhost:5222\"]}\r\n")
err = c.parse(info)
if err != nil || c.ps.state != OP_START {
t.Fatalf("Unexpected: %d : %v\n", c.ps.state, err)
}
// Pool now should contain localhost:4222 (the default URL) and localhost:5222
checkPool("localhost:4222", "localhost:5222")
// Make sure that if client receives the same, it is not added again.
err = c.parse(info)
if err != nil || c.ps.state != OP_START {
t.Fatalf("Unexpected: %d : %v\n", c.ps.state, err)
}
// Pool should still contain localhost:4222 (the default URL) and localhost:5222
checkPool("localhost:4222", "localhost:5222")
// Receive a new URL
info = []byte("INFO {\"connect_urls\":[\"localhost:6222\"]}\r\n")
err = c.parse(info)
if err != nil || c.ps.state != OP_START {
t.Fatalf("Unexpected: %d : %v\n", c.ps.state, err)
}
// Pool now should contain localhost:4222 (the default URL) localhost:5222 and localhost:6222
checkPool("localhost:4222", "localhost:5222", "localhost:6222")
// Receive more than 1 URL at once
info = []byte("INFO {\"connect_urls\":[\"localhost:7222\", \"localhost:8222\"]}\r\n")
err = c.parse(info)
if err != nil || c.ps.state != OP_START {
t.Fatalf("Unexpected: %d : %v\n", c.ps.state, err)
}
// Pool now should contain localhost:4222 (the default URL) localhost:5222, localhost:6222
// localhost:7222 and localhost:8222
checkPool("localhost:4222", "localhost:5222", "localhost:6222", "localhost:7222", "localhost:8222")
// Finally, check that the pool should be randomized.
allUrls := []string{"localhost:4222", "localhost:5222", "localhost:6222", "localhost:7222", "localhost:8222"}
same := 0
for i, url := range c.srvPool {
if url.url.Host == allUrls[i] {
same++
}
}
if same == len(allUrls) {
t.Fatal("Pool does not seem to be randomized")
}
}

View File

@@ -50,6 +50,12 @@ const (
OP_PO OP_PO
OP_PON OP_PON
OP_PONG OP_PONG
OP_I
OP_IN
OP_INF
OP_INFO
OP_INFO_SPC
OP_INFO_ARG
) )
// parse is the fast protocol parser engine. // parse is the fast protocol parser engine.
@@ -72,6 +78,8 @@ func (nc *Conn) parse(buf []byte) error {
nc.ps.state = OP_PLUS nc.ps.state = OP_PLUS
case '-': case '-':
nc.ps.state = OP_MINUS nc.ps.state = OP_MINUS
case 'I', 'i':
nc.ps.state = OP_I
default: default:
goto parseErr goto parseErr
} }
@@ -289,6 +297,61 @@ func (nc *Conn) parse(buf []byte) error {
nc.processPing() nc.processPing()
nc.ps.drop, nc.ps.state = 0, OP_START nc.ps.drop, nc.ps.state = 0, OP_START
} }
case OP_I:
switch b {
case 'N', 'n':
nc.ps.state = OP_IN
default:
goto parseErr
}
case OP_IN:
switch b {
case 'F', 'f':
nc.ps.state = OP_INF
default:
goto parseErr
}
case OP_INF:
switch b {
case 'O', 'o':
nc.ps.state = OP_INFO
default:
goto parseErr
}
case OP_INFO:
switch b {
case ' ', '\t':
nc.ps.state = OP_INFO_SPC
default:
goto parseErr
}
case OP_INFO_SPC:
switch b {
case ' ', '\t':
continue
default:
nc.ps.state = OP_INFO_ARG
nc.ps.as = i
}
case OP_INFO_ARG:
switch b {
case '\r':
nc.ps.drop = 1
case '\n':
var arg []byte
if nc.ps.argBuf != nil {
arg = nc.ps.argBuf
nc.ps.argBuf = nil
} else {
arg = buf[nc.ps.as : i-nc.ps.drop]
}
nc.processAsyncInfo(arg)
nc.ps.drop, nc.ps.as, nc.ps.state = 0, i+1, OP_START
default:
if nc.ps.argBuf != nil {
nc.ps.argBuf = append(nc.ps.argBuf, b)
}
}
default: default:
goto parseErr goto parseErr
} }

View File

@@ -38,6 +38,19 @@ func TestAuth(t *testing.T) {
t.Fatal("Should have connected successfully with a token") t.Fatal("Should have connected successfully with a token")
} }
nc.Close() nc.Close()
// Use Options
nc, err = nats.Connect("nats://localhost:8232", nats.UserInfo("derek", "foo"))
if err != nil {
t.Fatalf("Should have connected successfully with a token: %v", err)
}
nc.Close()
// Verify that credentials in URL take precedence.
nc, err = nats.Connect("nats://derek:foo@localhost:8232", nats.UserInfo("foo", "bar"))
if err != nil {
t.Fatalf("Should have connected successfully with a token: %v", err)
}
nc.Close()
} }
func TestAuthFailNoDisconnectCB(t *testing.T) { func TestAuthFailNoDisconnectCB(t *testing.T) {
@@ -145,10 +158,23 @@ func TestTokenAuth(t *testing.T) {
t.Fatal("Should have received an error while trying to connect") t.Fatal("Should have received an error while trying to connect")
} }
tokenUrl := fmt.Sprintf("nats://%s@localhost:8232", secret) tokenURL := fmt.Sprintf("nats://%s@localhost:8232", secret)
nc, err := nats.Connect(tokenUrl) nc, err := nats.Connect(tokenURL)
if err != nil { if err != nil {
t.Fatal("Should have connected successfully") t.Fatal("Should have connected successfully")
} }
nc.Close() nc.Close()
// Use Options
nc, err = nats.Connect("nats://localhost:8232", nats.Token(secret))
if err != nil {
t.Fatalf("Should have connected successfully: %v", err)
}
nc.Close()
// Verify that token in the URL takes precedence.
nc, err = nats.Connect(tokenURL, nats.Token("badtoken"))
if err != nil {
t.Fatalf("Should have connected successfully: %v", err)
}
nc.Close()
} }