fix: identifiers casing

This commit is contained in:
Asdine El Hrychy
2025-09-27 21:47:24 +05:30
parent 3a1e556155
commit d6fb3ddbef
3 changed files with 52 additions and 20 deletions

View File

@@ -30,3 +30,11 @@ tidy:
go mod tidy
cd sqltests && go mod tidy && cd ..
cd cmd/chai && go mod tidy && cd ../..
pg:
docker run -d --name chai-pg \
-e POSTGRES_USER=postgres \
-e POSTGRES_PASSWORD=postgres \
-e POSTGRES_DB=postgres \
-p 5432:5432 \
postgres:latest

View File

@@ -2,6 +2,7 @@ package database
import (
"fmt"
"maps"
"math"
"sort"
"strings"
@@ -609,35 +610,34 @@ func newCatalogCache() *catalogCache {
func (c *catalogCache) Load(tables []TableInfo, indexes []IndexInfo, sequences []Sequence) {
for i := range tables {
c.tables[tables[i].TableName] = &TableInfoRelation{Info: &tables[i]}
lc := strings.ToLower(tables[i].TableName)
c.tables[lc] = &TableInfoRelation{Info: &tables[i]}
}
for i := range indexes {
c.indexes[indexes[i].IndexName] = &IndexInfoRelation{Info: &indexes[i]}
lc := strings.ToLower(indexes[i].IndexName)
c.indexes[lc] = &IndexInfoRelation{Info: &indexes[i]}
}
for i := range sequences {
c.sequences[sequences[i].Info.Name] = &sequences[i]
lc := strings.ToLower(sequences[i].Info.Name)
c.sequences[lc] = &sequences[i]
}
}
func (c *catalogCache) Clone() *catalogCache {
clone := newCatalogCache()
for k, v := range c.tables {
clone.tables[k] = v
}
for k, v := range c.indexes {
clone.indexes[k] = v
}
for k, v := range c.sequences {
clone.sequences[k] = v
}
maps.Copy(clone.tables, c.tables)
maps.Copy(clone.indexes, c.indexes)
maps.Copy(clone.sequences, c.sequences)
return clone
}
func (c *catalogCache) objectExists(name string) bool {
name = strings.ToLower(name)
// checking if table exists with the same name
if _, ok := c.tables[name]; ok {
return true
@@ -699,10 +699,11 @@ func (c *catalogCache) Add(tx *Transaction, o Relation) error {
}
m := c.getMapByType(o.Type())
m[name] = o
lc := strings.ToLower(name)
m[lc] = o
tx.OnRollbackHooks = append(tx.OnRollbackHooks, func() {
delete(m, name)
delete(m, lc)
})
return nil
@@ -711,15 +712,16 @@ func (c *catalogCache) Add(tx *Transaction, o Relation) error {
func (c *catalogCache) Replace(tx *Transaction, o Relation) error {
m := c.getMapByType(o.Type())
old, ok := m[o.Name()]
name := strings.ToLower(o.Name())
old, ok := m[name]
if !ok {
return errs.NewNotFoundError(o.Name())
}
m[o.Name()] = o
m[name] = o
tx.OnRollbackHooks = append(tx.OnRollbackHooks, func() {
m[o.Name()] = old
m[name] = old
})
return nil
@@ -728,6 +730,7 @@ func (c *catalogCache) Replace(tx *Transaction, o Relation) error {
func (c *catalogCache) Delete(tx *Transaction, tp, name string) (Relation, error) {
m := c.getMapByType(tp)
name = strings.ToLower(name)
o, ok := m[name]
if !ok {
return nil, errs.NewNotFoundError(name)
@@ -745,7 +748,8 @@ func (c *catalogCache) Delete(tx *Transaction, tp, name string) (Relation, error
func (c *catalogCache) Get(tp, name string) (Relation, error) {
m := c.getMapByType(tp)
o, ok := m[name]
lc := strings.ToLower(name)
o, ok := m[lc]
if !ok {
return nil, errs.NewNotFoundError(name)
}
@@ -757,8 +761,8 @@ func (c *catalogCache) ListObjects(tp string) []string {
m := c.getMapByType(tp)
list := make([]string, 0, len(m))
for name := range m {
list = append(list, name)
for _, r := range m {
list = append(list, r.Name())
}
sort.Strings(list)

20
sqltests/MISC/casing.sql Normal file
View File

@@ -0,0 +1,20 @@
-- test: Identifiers are not case sensitive
CREATE TABLE Test(pk INT PRIMARY KEY);
INSERT INTO test (pk) VALUES (1) RETURNING *;
/* result:
{"pk": 1}
*/
-- test: Identifiers are not case sensitive
CREATE TABLE Test(pk INT PRIMARY KEY);
INSERT INTO tEst (pk) VALUES (1) RETURNING *;
/* result:
{"pk": 1}
*/
-- test: Identifiers are not case sensitive
CREATE TABLE TEST(pk INT PRIMARY KEY);
INSERT INTO test (pk) VALUES (1) RETURNING *;
/* result:
{"pk": 1}
*/