Files
aorm/null/null.go
tangpanqing 75415ac229 update
2023-01-12 10:42:28 +08:00

251 lines
4.9 KiB
Go

package null
import (
"bytes"
"database/sql"
"encoding/json"
"errors"
"fmt"
"math"
"reflect"
"strconv"
"time"
)
// Int 整数
type Int struct {
sql.NullInt64
}
// IntFrom 创建整数
func IntFrom(i int64) Int {
return Int{
NullInt64: sql.NullInt64{
Int64: i,
Valid: true,
},
}
}
// String 字符串
type String struct {
sql.NullString
}
// StringFrom 创建字符串
func StringFrom(s string) String {
return String{
NullString: sql.NullString{
String: s,
Valid: true,
},
}
}
// Float 浮点数
type Float struct {
sql.NullFloat64
}
// FloatFrom 创建浮点数
func FloatFrom(f float64) Float {
return Float{
NullFloat64: sql.NullFloat64{
Float64: f,
Valid: true,
},
}
}
// Bool 布尔值
type Bool struct {
sql.NullBool
}
// BoolFrom 创建布尔值
func BoolFrom(b bool) Bool {
return Bool{
NullBool: sql.NullBool{
Bool: b,
Valid: true,
},
}
}
// Time 时间
type Time struct {
sql.NullTime
}
// TimeFrom 创建时间
func TimeFrom(t time.Time) Time {
return Time{
NullTime: sql.NullTime{
Time: t,
Valid: true,
},
}
}
var nullBytes = []byte("null")
// UnmarshalJSON 反序列化浮点数
func (f *Float) UnmarshalJSON(data []byte) error {
if bytes.Equal(data, nullBytes) {
f.Valid = false
return nil
}
if err := json.Unmarshal(data, &f.Float64); err != nil {
var typeError *json.UnmarshalTypeError
if errors.As(err, &typeError) {
// special case: accept string input
if typeError.Value != "string" {
return fmt.Errorf("null: JSON input is invalid driver (need float or string): %w", err)
}
var str string
if err := json.Unmarshal(data, &str); err != nil {
return fmt.Errorf("null: couldn't unmarshal number string: %w", err)
}
n, err := strconv.ParseFloat(str, 64)
if err != nil {
return fmt.Errorf("null: couldn't convert string to float: %w", err)
}
f.Float64 = n
f.Valid = true
return nil
}
return fmt.Errorf("null: couldn't unmarshal JSON: %w", err)
}
f.Valid = true
return nil
}
// MarshalJSON 序列化浮点数
func (f Float) MarshalJSON() ([]byte, error) {
if !f.Valid {
return []byte("null"), nil
}
if math.IsInf(f.Float64, 0) || math.IsNaN(f.Float64) {
return nil, &json.UnsupportedValueError{
Value: reflect.ValueOf(f.Float64),
Str: strconv.FormatFloat(f.Float64, 'g', -1, 64),
}
}
return []byte(strconv.FormatFloat(f.Float64, 'f', -1, 64)), nil
}
// UnmarshalJSON 反序列化布尔值
func (b *Bool) UnmarshalJSON(data []byte) error {
if bytes.Equal(data, nullBytes) {
b.Valid = false
return nil
}
if err := json.Unmarshal(data, &b.Bool); err != nil {
return fmt.Errorf("null: couldn't unmarshal JSON: %w", err)
}
b.Valid = true
return nil
}
// MarshalJSON 序列化布尔值
func (b Bool) MarshalJSON() ([]byte, error) {
if !b.Valid {
return []byte("null"), nil
}
if !b.Bool {
return []byte("false"), nil
}
return []byte("true"), nil
}
// UnmarshalJSON 反序列化时间
func (t *Time) UnmarshalJSON(data []byte) error {
if bytes.Equal(data, nullBytes) {
t.Valid = false
return nil
}
if err := json.Unmarshal(data, &t.Time); err != nil {
return fmt.Errorf("null: couldn't unmarshal JSON: %w", err)
}
t.Valid = true
return nil
}
// MarshalJSON 序列化时间
func (t Time) MarshalJSON() ([]byte, error) {
if !t.Valid {
return []byte("null"), nil
}
return t.Time.MarshalJSON()
}
// UnmarshalJSON 反序列化整数
func (i *Int) UnmarshalJSON(data []byte) error {
if bytes.Equal(data, nullBytes) {
i.Valid = false
return nil
}
if err := json.Unmarshal(data, &i.Int64); err != nil {
var typeError *json.UnmarshalTypeError
if errors.As(err, &typeError) {
// special case: accept string input
if typeError.Value != "string" {
return fmt.Errorf("null: JSON input is invalid driver (need int or string): %w", err)
}
var str string
if err := json.Unmarshal(data, &str); err != nil {
return fmt.Errorf("null: couldn't unmarshal number string: %w", err)
}
n, err := strconv.ParseInt(str, 10, 64)
if err != nil {
return fmt.Errorf("null: couldn't convert string to int: %w", err)
}
i.Int64 = n
i.Valid = true
return nil
}
return fmt.Errorf("null: couldn't unmarshal JSON: %w", err)
}
i.Valid = true
return nil
}
// MarshalJSON 序列化整数
func (i Int) MarshalJSON() ([]byte, error) {
if !i.Valid {
return []byte("null"), nil
}
return []byte(strconv.FormatInt(i.Int64, 10)), nil
}
// UnmarshalJSON 反序列化字符串
func (s *String) UnmarshalJSON(data []byte) error {
if bytes.Equal(data, nullBytes) {
s.Valid = false
return nil
}
if err := json.Unmarshal(data, &s.String); err != nil {
return fmt.Errorf("null: couldn't unmarshal JSON: %w", err)
}
s.Valid = true
return nil
}
// MarshalJSON 序列化字符串
func (s String) MarshalJSON() ([]byte, error) {
if !s.Valid {
return []byte("null"), nil
}
return json.Marshal(s.String)
}