mirror of
https://github.com/smallnest/rpcx.git
synced 2025-09-27 04:26:26 +08:00
212 lines
4.8 KiB
Go
212 lines
4.8 KiB
Go
package client
|
|
|
|
import (
|
|
"bufio"
|
|
"crypto/tls"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"time"
|
|
|
|
"github.com/smallnest/rpcx/log"
|
|
"github.com/smallnest/rpcx/share"
|
|
"golang.org/x/net/websocket"
|
|
)
|
|
|
|
type ConnFactoryFn func(c *Client, network, address string) (net.Conn, error)
|
|
|
|
var ConnFactories = make(map[string]ConnFactoryFn)
|
|
|
|
func init() {
|
|
ConnFactories["http"] = newDirectHTTPConn
|
|
ConnFactories["kcp"] = newDirectKCPConn
|
|
ConnFactories["quic"] = newDirectQuicConn
|
|
ConnFactories["unix"] = newDirectConn
|
|
ConnFactories["memu"] = newMemuConn
|
|
ConnFactories["iouring"] = newIOUringConn
|
|
}
|
|
|
|
// Connect connects the server via specified network.
|
|
func (client *Client) Connect(network, address string) error {
|
|
var conn net.Conn
|
|
var err error
|
|
|
|
switch network {
|
|
case "http":
|
|
conn, err = newDirectHTTPConn(client, network, address)
|
|
case "ws", "wss":
|
|
conn, err = newDirectWSConn(client, network, address)
|
|
default:
|
|
fn := ConnFactories[network]
|
|
if fn != nil {
|
|
conn, err = fn(client, network, address)
|
|
} else {
|
|
conn, err = newDirectConn(client, network, address)
|
|
}
|
|
}
|
|
|
|
if err == nil && conn != nil {
|
|
if tc, ok := conn.(*net.TCPConn); ok && client.option.TCPKeepAlivePeriod > 0 {
|
|
_ = tc.SetKeepAlive(true)
|
|
_ = tc.SetKeepAlivePeriod(client.option.TCPKeepAlivePeriod)
|
|
}
|
|
|
|
if client.option.IdleTimeout != 0 {
|
|
_ = conn.SetDeadline(time.Now().Add(client.option.IdleTimeout))
|
|
}
|
|
|
|
if client.Plugins != nil {
|
|
conn, err = client.Plugins.DoConnCreated(conn)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
client.Conn = conn
|
|
client.r = bufio.NewReaderSize(conn, ReaderBuffsize)
|
|
// c.w = bufio.NewWriterSize(conn, WriterBuffsize)
|
|
|
|
// start reading and writing since connected
|
|
go client.input()
|
|
|
|
if client.option.Heartbeat && client.option.HeartbeatInterval > 0 {
|
|
go client.heartbeat()
|
|
}
|
|
|
|
}
|
|
|
|
if err != nil && client.Plugins != nil {
|
|
client.Plugins.DoConnCreateFailed(network, address)
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
func newDirectConn(c *Client, network, address string) (net.Conn, error) {
|
|
var conn net.Conn
|
|
var tlsConn *tls.Conn
|
|
var err error
|
|
|
|
if c == nil {
|
|
err = fmt.Errorf("nil client")
|
|
return nil, err
|
|
}
|
|
|
|
if c.option.TLSConfig != nil {
|
|
dialer := &net.Dialer{
|
|
Timeout: c.option.ConnectTimeout,
|
|
}
|
|
tlsConn, err = tls.DialWithDialer(dialer, network, address, c.option.TLSConfig)
|
|
// or conn:= tls.Client(netConn, &config)
|
|
conn = net.Conn(tlsConn)
|
|
} else {
|
|
conn, err = net.DialTimeout(network, address, c.option.ConnectTimeout)
|
|
}
|
|
|
|
if err != nil {
|
|
log.Warnf("failed to dial server: %v", err)
|
|
return nil, err
|
|
}
|
|
|
|
return conn, nil
|
|
}
|
|
|
|
var connected = "200 Connected to rpcx"
|
|
|
|
func newDirectHTTPConn(c *Client, network, address string) (net.Conn, error) {
|
|
if c == nil {
|
|
return nil, errors.New("empty client")
|
|
}
|
|
path := c.option.RPCPath
|
|
if path == "" {
|
|
path = share.DefaultRPCPath
|
|
}
|
|
|
|
var conn net.Conn
|
|
var tlsConn *tls.Conn
|
|
var err error
|
|
|
|
if c.option.TLSConfig != nil {
|
|
dialer := &net.Dialer{
|
|
Timeout: c.option.ConnectTimeout,
|
|
}
|
|
tlsConn, err = tls.DialWithDialer(dialer, "tcp", address, c.option.TLSConfig)
|
|
// or conn:= tls.Client(netConn, &config)
|
|
|
|
conn = net.Conn(tlsConn)
|
|
} else {
|
|
conn, err = net.DialTimeout("tcp", address, c.option.ConnectTimeout)
|
|
}
|
|
if err != nil {
|
|
log.Errorf("failed to dial server: %v", err)
|
|
return nil, err
|
|
}
|
|
|
|
_, err = io.WriteString(conn, "CONNECT "+path+" HTTP/1.0\n\n")
|
|
if err != nil {
|
|
// Dial() success but Write() failed here, close the successfully
|
|
// created conn before return.
|
|
conn.Close()
|
|
|
|
log.Errorf("failed to make CONNECT: %v", err)
|
|
return nil, err
|
|
}
|
|
|
|
// Require successful HTTP response
|
|
// before switching to RPC protocol.
|
|
resp, err := http.ReadResponse(bufio.NewReader(conn), &http.Request{Method: "CONNECT"})
|
|
if err == nil && resp.Status == connected {
|
|
return conn, nil
|
|
}
|
|
if err == nil {
|
|
log.Errorf("unexpected HTTP response: %v", err)
|
|
err = errors.New("unexpected HTTP response: " + resp.Status)
|
|
}
|
|
conn.Close()
|
|
return nil, &net.OpError{
|
|
Op: "dial-http",
|
|
Net: network + " " + address,
|
|
Addr: nil,
|
|
Err: err,
|
|
}
|
|
}
|
|
|
|
func newDirectWSConn(c *Client, network, address string) (net.Conn, error) {
|
|
if c == nil {
|
|
return nil, errors.New("empty client")
|
|
}
|
|
path := c.option.RPCPath
|
|
if path == "" {
|
|
path = share.DefaultRPCPath
|
|
}
|
|
|
|
var conn net.Conn
|
|
var err error
|
|
|
|
// url := "ws://localhost:12345/ws"
|
|
|
|
var url, origin string
|
|
if network == "ws" {
|
|
url = fmt.Sprintf("ws://%s%s", address, path)
|
|
origin = fmt.Sprintf("http://%s", address)
|
|
} else {
|
|
url = fmt.Sprintf("wss://%s%s", address, path)
|
|
origin = fmt.Sprintf("https://%s", address)
|
|
}
|
|
|
|
if c.option.TLSConfig != nil {
|
|
config, erri := websocket.NewConfig(url, origin)
|
|
if erri != nil {
|
|
return nil, erri
|
|
}
|
|
config.TlsConfig = c.option.TLSConfig
|
|
conn, err = websocket.DialConfig(config)
|
|
} else {
|
|
conn, err = websocket.Dial(url, "", origin)
|
|
}
|
|
|
|
return conn, err
|
|
}
|