Merge all upstream changes

This commit is contained in:
glebarez
2023-01-28 20:44:52 +07:00
parent 4143dc8465
commit b8c64c30ab
15 changed files with 791 additions and 67 deletions

View File

@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:generate go run generator.go
//go:generate go run generator.go -full-path-comments
package sqlite // import "modernc.org/sqlite"
@@ -695,14 +695,14 @@ type tx struct {
c *conn
}
func newTx(c *conn) (*tx, error) {
func newTx(c *conn, opts driver.TxOptions) (*tx, error) {
r := &tx{c: c}
var sql string
if c.beginMode != "" {
sql := "begin"
if !opts.ReadOnly && c.beginMode != "" {
sql = "begin " + c.beginMode
} else {
sql = "begin"
}
if err := r.exec(context.Background(), sql); err != nil {
return nil, err
}
@@ -790,7 +790,7 @@ type conn struct {
}
func newConn(dsn string) (*conn, error) {
var query string
var query, vfsName string
// Parse the query parameters from the dsn and them from the dsn if not prefixed by file:
// https://github.com/mattn/go-sqlite3/blob/3392062c729d77820afc1f5cae3427f0de39e954/sqlite3.go#L1046
@@ -798,6 +798,12 @@ func newConn(dsn string) (*conn, error) {
pos := strings.IndexRune(dsn, '?')
if pos >= 1 {
query = dsn[pos+1:]
var err error
vfsName, err = getVFSName(query)
if err != nil {
return nil, err
}
if !strings.HasPrefix(dsn, "file:") {
dsn = dsn[:pos]
}
@@ -806,6 +812,7 @@ func newConn(dsn string) (*conn, error) {
c := &conn{tls: libc.NewTLS()}
db, err := c.openV2(
dsn,
vfsName,
sqlite3.SQLITE_OPEN_READWRITE|sqlite3.SQLITE_OPEN_CREATE|
sqlite3.SQLITE_OPEN_FULLMUTEX|
sqlite3.SQLITE_OPEN_URI,
@@ -846,6 +853,23 @@ func stmtLog(tls *libc.TLS, type1 uint32, cd uintptr, pd uintptr, xd uintptr) in
return sqlite3.SQLITE_OK
}
func getVFSName(query string) (r string, err error) {
q, err := url.ParseQuery(query)
if err != nil {
return "", err
}
for _, v := range q["vfs"] {
if r != "" && r != v {
return "", fmt.Errorf("conflicting vfs query parameters: %v", q["vfs"])
}
r = v
}
return r, nil
}
func applyQueryParams(c *conn, query string) error {
q, err := url.ParseQuery(query)
if err != nil {
@@ -1285,8 +1309,8 @@ func (c *conn) extendedResultCodes(on bool) error {
// const char *zVfs /* Name of VFS module to use */
//
// );
func (c *conn) openV2(name string, flags int32) (uintptr, error) {
var p, s uintptr
func (c *conn) openV2(name, vfsName string, flags int32) (uintptr, error) {
var p, s, vfs uintptr
defer func() {
if p != 0 {
@@ -1295,6 +1319,9 @@ func (c *conn) openV2(name string, flags int32) (uintptr, error) {
if s != 0 {
c.free(s)
}
if vfs != 0 {
c.free(vfs)
}
}()
p, err := c.malloc(int(ptrSize))
@@ -1306,7 +1333,13 @@ func (c *conn) openV2(name string, flags int32) (uintptr, error) {
return 0, err
}
if rc := sqlite3.Xsqlite3_open_v2(c.tls, s, p, flags, 0); rc != sqlite3.SQLITE_OK {
if vfsName != "" {
if vfs, err = libc.CString(vfsName); err != nil {
return 0, err
}
}
if rc := sqlite3.Xsqlite3_open_v2(c.tls, s, p, flags, vfs); rc != sqlite3.SQLITE_OK {
return 0, c.errstr(rc)
}
@@ -1352,7 +1385,7 @@ func (c *conn) Begin() (driver.Tx, error) {
}
func (c *conn) begin(ctx context.Context, opts driver.TxOptions) (t driver.Tx, err error) {
return newTx(c)
return newTx(c, opts)
}
// Close invalidates and potentially stops any current prepared statements and