Pg to bun (#148)

* start moving to bun

* changing more stuff

* more

* and yet more

* tests passing

* seems stable now

* more big changes

* small fix

* little fixes
This commit is contained in:
tobi 2021-08-25 15:34:33 +02:00 committed by GitHub
commit 2dc9fc1626
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
713 changed files with 98694 additions and 22704 deletions

3
vendor/github.com/uptrace/bun/.gitignore generated vendored Normal file
View file

@ -0,0 +1,3 @@
*.s3db
*.prof
*.test

6
vendor/github.com/uptrace/bun/.prettierrc.yaml generated vendored Normal file
View file

@ -0,0 +1,6 @@
trailingComma: all
tabWidth: 2
semi: false
singleQuote: true
proseWrap: always
printWidth: 100

99
vendor/github.com/uptrace/bun/CHANGELOG.md generated vendored Normal file
View file

@ -0,0 +1,99 @@
# Changelog
## v0.4.1 - Aug 18 2021
- Fixed migrate package to properly rollback migrations.
- Added `allowzero` tag option that undoes `nullzero` option.
## v0.4.0 - Aug 11 2021
- Changed `WhereGroup` function to accept `*SelectQuery`.
- Fixed query hooks for count queries.
## v0.3.4 - Jul 19 2021
- Renamed `migrate.CreateGo` to `CreateGoMigration`.
- Added `migrate.WithPackageName` to customize the Go package name in generated migrations.
- Renamed `migrate.CreateSQL` to `CreateSQLMigrations` and changed `CreateSQLMigrations` to create
both up and down migration files.
## v0.3.1 - Jul 12 2021
- Renamed `alias` field struct tag to `alt` so it is not confused with column alias.
- Reworked migrate package API. See
[migrate](https://github.com/uptrace/bun/tree/master/example/migrate) example for details.
## v0.3.0 - Jul 09 2021
- Changed migrate package to return structured data instead of logging the progress. See
[migrate](https://github.com/uptrace/bun/tree/master/example/migrate) example for details.
## v0.2.14 - Jul 01 2021
- Added [sqliteshim](https://pkg.go.dev/github.com/uptrace/bun/driver/sqliteshim) by
[Ivan Trubach](https://github.com/tie).
- Added support for MySQL 5.7 in addition to MySQL 8.
## v0.2.12 - Jun 29 2021
- Fixed scanners for net.IP and net.IPNet.
## v0.2.10 - Jun 29 2021
- Fixed pgdriver to format passed query args.
## v0.2.9 - Jun 27 2021
- Added support for prepared statements in pgdriver.
## v0.2.7 - Jun 26 2021
- Added `UpdateQuery.Bulk` helper to generate bulk-update queries.
Before:
```go
models := []Model{
{42, "hello"},
{43, "world"},
}
return db.NewUpdate().
With("_data", db.NewValues(&models)).
Model(&models).
Table("_data").
Set("model.str = _data.str").
Where("model.id = _data.id")
```
Now:
```go
db.NewUpdate().
Model(&models).
Bulk()
```
## v0.2.5 - Jun 25 2021
- Changed time.Time to always append zero time as `NULL`.
- Added `db.RunInTx` helper.
## v0.2.4 - Jun 21 2021
- Added SSL support to pgdriver.
## v0.2.3 - Jun 20 2021
- Replaced `ForceDelete(ctx)` with `ForceDelete().Exec(ctx)` for soft deletes.
## v0.2.1 - Jun 17 2021
- Renamed `DBI` to `IConn`. `IConn` is a common interface for `*sql.DB`, `*sql.Conn`, and `*sql.Tx`.
- Added `IDB`. `IDB` is a common interface for `*bun.DB`, `bun.Conn`, and `bun.Tx`.
## v0.2.0 - Jun 16 2021
- Changed [model hooks](https://bun.uptrace.dev/guide/hooks.html#model-hooks). See
[model-hooks](example/model-hooks) example.
- Renamed `has-one` to `belongs-to`. Renamed `belongs-to` to `has-one`. Previously Bun used
incorrect names for these relations.

24
vendor/github.com/uptrace/bun/LICENSE generated vendored Normal file
View file

@ -0,0 +1,24 @@
Copyright (c) 2021 Vladimir Mihailenco. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

21
vendor/github.com/uptrace/bun/Makefile generated vendored Normal file
View file

@ -0,0 +1,21 @@
ALL_GO_MOD_DIRS := $(shell find . -type f -name 'go.mod' -exec dirname {} \; | sort)
test:
set -e; for dir in $(ALL_GO_MOD_DIRS); do \
echo "go test in $${dir}"; \
(cd "$${dir}" && \
go test ./... && \
go vet); \
done
go_mod_tidy:
set -e; for dir in $(ALL_GO_MOD_DIRS); do \
echo "go mod tidy in $${dir}"; \
(cd "$${dir}" && \
go get -d ./... && \
go mod tidy); \
done
fmt:
gofmt -w -s ./
goimports -w -local github.com/uptrace/bun ./

267
vendor/github.com/uptrace/bun/README.md generated vendored Normal file
View file

@ -0,0 +1,267 @@
<p align="center">
<a href="https://uptrace.dev/?utm_source=gh-redis&utm_campaign=gh-redis-banner1">
<img src="https://raw.githubusercontent.com/uptrace/roadmap/master/banner1.png" alt="All-in-one tool to optimize performance and monitor errors & logs">
</a>
</p>
# Simple and performant SQL database client
[![build workflow](https://github.com/uptrace/bun/actions/workflows/build.yml/badge.svg)](https://github.com/uptrace/bun/actions)
[![PkgGoDev](https://pkg.go.dev/badge/github.com/uptrace/bun)](https://pkg.go.dev/github.com/uptrace/bun)
[![Documentation](https://img.shields.io/badge/bun-documentation-informational)](https://bun.uptrace.dev/)
[![Chat](https://discordapp.com/api/guilds/752070105847955518/widget.png)](https://discord.gg/rWtp5Aj)
Main features are:
- Works with [PostgreSQL](https://bun.uptrace.dev/guide/drivers.html#postgresql),
[MySQL](https://bun.uptrace.dev/guide/drivers.html#mysql),
[SQLite](https://bun.uptrace.dev/guide/drivers.html#sqlite).
- [Selecting](/example/basic/) into a map, struct, slice of maps/structs/vars.
- [Bulk inserts](https://bun.uptrace.dev/guide/queries.html#insert).
- [Bulk updates](https://bun.uptrace.dev/guide/queries.html#update) using common table expressions.
- [Bulk deletes](https://bun.uptrace.dev/guide/queries.html#delete).
- [Fixtures](https://bun.uptrace.dev/guide/fixtures.html).
- [Migrations](https://bun.uptrace.dev/guide/migrations.html).
- [Soft deletes](https://bun.uptrace.dev/guide/soft-deletes.html).
Resources:
- [Examples](https://github.com/uptrace/bun/tree/master/example)
- [Documentation](https://bun.uptrace.dev/)
- [Reference](https://pkg.go.dev/github.com/uptrace/bun)
- [Starter kit](https://github.com/go-bun/bun-starter-kit)
- [RealWorld app](https://github.com/go-bun/bun-realworld-app)
<details>
<summary>github.com/frederikhors/orm-benchmark results</summary>
```
4000 times - Insert
raw_stmt: 0.38s 94280 ns/op 718 B/op 14 allocs/op
raw: 0.39s 96719 ns/op 718 B/op 13 allocs/op
beego_orm: 0.48s 118994 ns/op 2411 B/op 56 allocs/op
bun: 0.57s 142285 ns/op 918 B/op 12 allocs/op
pg: 0.58s 145496 ns/op 1235 B/op 12 allocs/op
gorm: 0.70s 175294 ns/op 6665 B/op 88 allocs/op
xorm: 0.76s 189533 ns/op 3032 B/op 94 allocs/op
4000 times - MultiInsert 100 row
raw: 4.59s 1147385 ns/op 135155 B/op 916 allocs/op
raw_stmt: 4.59s 1148137 ns/op 131076 B/op 916 allocs/op
beego_orm: 5.50s 1375637 ns/op 179962 B/op 2747 allocs/op
bun: 6.18s 1544648 ns/op 4265 B/op 214 allocs/op
pg: 7.01s 1753495 ns/op 5039 B/op 114 allocs/op
gorm: 9.52s 2379219 ns/op 293956 B/op 3729 allocs/op
xorm: 11.66s 2915478 ns/op 286140 B/op 7422 allocs/op
4000 times - Update
raw_stmt: 0.26s 65781 ns/op 773 B/op 14 allocs/op
raw: 0.31s 77209 ns/op 757 B/op 13 allocs/op
beego_orm: 0.43s 107064 ns/op 1802 B/op 47 allocs/op
bun: 0.56s 139839 ns/op 589 B/op 4 allocs/op
pg: 0.60s 149608 ns/op 896 B/op 11 allocs/op
gorm: 0.74s 185970 ns/op 6604 B/op 81 allocs/op
xorm: 0.81s 203240 ns/op 2994 B/op 119 allocs/op
4000 times - Read
raw: 0.33s 81671 ns/op 2081 B/op 49 allocs/op
raw_stmt: 0.34s 85847 ns/op 2112 B/op 50 allocs/op
beego_orm: 0.38s 94777 ns/op 2106 B/op 75 allocs/op
pg: 0.42s 106148 ns/op 1526 B/op 22 allocs/op
bun: 0.43s 106904 ns/op 1319 B/op 18 allocs/op
gorm: 0.65s 162221 ns/op 5240 B/op 108 allocs/op
xorm: 1.13s 281738 ns/op 8326 B/op 237 allocs/op
4000 times - MultiRead limit 100
raw: 1.52s 380351 ns/op 38356 B/op 1037 allocs/op
raw_stmt: 1.54s 385541 ns/op 38388 B/op 1038 allocs/op
pg: 1.86s 465468 ns/op 24045 B/op 631 allocs/op
bun: 2.58s 645354 ns/op 30009 B/op 1122 allocs/op
beego_orm: 2.93s 732028 ns/op 55280 B/op 3077 allocs/op
gorm: 4.97s 1241831 ns/op 71628 B/op 3877 allocs/op
xorm: doesn't work
```
</details>
## Installation
```go
go get github.com/uptrace/bun
```
You also need to install a database/sql driver and the corresponding Bun
[dialect](https://bun.uptrace.dev/guide/drivers.html).
## Quickstart
First you need to create a `sql.DB`. Here we are using the
[sqliteshim](https://pkg.go.dev/github.com/uptrace/bun/driver/sqliteshim) driver which choses
between [modernc.org/sqlite](https://modernc.org/sqlite/) and
[mattn/go-sqlite3](https://github.com/mattn/go-sqlite3) depending on your platform.
```go
import "github.com/uptrace/bun/driver/sqliteshim"
sqldb, err := sql.Open(sqliteshim.ShimName, "file::memory:?cache=shared")
if err != nil {
panic(err)
}
```
And then create a `bun.DB` on top of it using the corresponding SQLite dialect:
```go
import (
"github.com/uptrace/bun"
"github.com/uptrace/bun/dialect/sqlitedialect"
)
db := bun.NewDB(sqldb, sqlitedialect.New())
```
Now you are ready to issue some queries:
```go
type User struct {
ID int64
Name string
}
user := new(User)
err := db.NewSelect().
Model(user).
Where("name != ?", "").
OrderExpr("id ASC").
Limit(1).
Scan(ctx)
```
The code above is equivalent to:
```go
query := "SELECT id, name FROM users AS user WHERE name != '' ORDER BY id ASC LIMIT 1"
rows, err := sqldb.QueryContext(ctx, query)
if err != nil {
panic(err)
}
if !rows.Next() {
panic(sql.ErrNoRows)
}
user := new(User)
if err := db.ScanRow(ctx, rows, user); err != nil {
panic(err)
}
if err := rows.Err(); err != nil {
panic(err)
}
```
## Basic example
To provide initial data for our [example](/example/basic/), we will use Bun
[fixtures](https://bun.uptrace.dev/guide/fixtures.html):
```go
import "github.com/uptrace/bun/dbfixture"
// Register models for the fixture.
db.RegisterModel((*User)(nil), (*Story)(nil))
// WithRecreateTables tells Bun to drop existing tables and create new ones.
fixture := dbfixture.New(db, dbfixture.WithRecreateTables())
// Load fixture.yaml which contains data for User and Story models.
if err := fixture.Load(ctx, os.DirFS("."), "fixture.yaml"); err != nil {
panic(err)
}
```
The `fixture.yaml` looks like this:
```yaml
- model: User
rows:
- _id: admin
name: admin
emails: ['admin1@admin', 'admin2@admin']
- _id: root
name: root
emails: ['root1@root', 'root2@root']
- model: Story
rows:
- title: Cool story
author_id: '{{ $.User.admin.ID }}'
```
To select all users:
```go
users := make([]User, 0)
if err := db.NewSelect().Model(&users).OrderExpr("id ASC").Scan(ctx); err != nil {
panic(err)
}
```
To select a single user by id:
```go
user1 := new(User)
if err := db.NewSelect().Model(user1).Where("id = ?", 1).Scan(ctx); err != nil {
panic(err)
}
```
To select a story and the associated author in a single query:
```go
story := new(Story)
if err := db.NewSelect().
Model(story).
Relation("Author").
Limit(1).
Scan(ctx); err != nil {
panic(err)
}
```
To select a user into a map:
```go
m := make(map[string]interface{})
if err := db.NewSelect().
Model((*User)(nil)).
Limit(1).
Scan(ctx, &m); err != nil {
panic(err)
}
```
To select all users scanning each column into a separate slice:
```go
var ids []int64
var names []string
if err := db.NewSelect().
ColumnExpr("id, name").
Model((*User)(nil)).
OrderExpr("id ASC").
Scan(ctx, &ids, &names); err != nil {
panic(err)
}
```
For more details, please consult [docs](https://bun.uptrace.dev/) and check [examples](example).
## Contributors
Thanks to all the people who already contributed!
<a href="https://github.com/uptrace/bun/graphs/contributors">
<img src="https://contributors-img.web.app/image?repo=uptrace/bun" />
</a>

21
vendor/github.com/uptrace/bun/RELEASING.md generated vendored Normal file
View file

@ -0,0 +1,21 @@
# Releasing
1. Run `release.sh` script which updates versions in go.mod files and pushes a new branch to GitHub:
```shell
./scripts/release.sh -t v1.0.0
```
2. Open a pull request and wait for the build to finish.
3. Merge the pull request and run `tag.sh` to create tags for packages:
```shell
./scripts/tag.sh -t v1.0.0
```
4. Push the tags:
```shell
git push origin --tags
```

122
vendor/github.com/uptrace/bun/bun.go generated vendored Normal file
View file

@ -0,0 +1,122 @@
package bun
import (
"context"
"fmt"
"reflect"
"github.com/uptrace/bun/schema"
)
type (
Safe = schema.Safe
Ident = schema.Ident
)
type NullTime = schema.NullTime
type BaseModel = schema.BaseModel
type (
BeforeScanHook = schema.BeforeScanHook
AfterScanHook = schema.AfterScanHook
)
type BeforeSelectHook interface {
BeforeSelect(ctx context.Context, query *SelectQuery) error
}
type AfterSelectHook interface {
AfterSelect(ctx context.Context, query *SelectQuery) error
}
type BeforeInsertHook interface {
BeforeInsert(ctx context.Context, query *InsertQuery) error
}
type AfterInsertHook interface {
AfterInsert(ctx context.Context, query *InsertQuery) error
}
type BeforeUpdateHook interface {
BeforeUpdate(ctx context.Context, query *UpdateQuery) error
}
type AfterUpdateHook interface {
AfterUpdate(ctx context.Context, query *UpdateQuery) error
}
type BeforeDeleteHook interface {
BeforeDelete(ctx context.Context, query *DeleteQuery) error
}
type AfterDeleteHook interface {
AfterDelete(ctx context.Context, query *DeleteQuery) error
}
type BeforeCreateTableHook interface {
BeforeCreateTable(ctx context.Context, query *CreateTableQuery) error
}
type AfterCreateTableHook interface {
AfterCreateTable(ctx context.Context, query *CreateTableQuery) error
}
type BeforeDropTableHook interface {
BeforeDropTable(ctx context.Context, query *DropTableQuery) error
}
type AfterDropTableHook interface {
AfterDropTable(ctx context.Context, query *DropTableQuery) error
}
//------------------------------------------------------------------------------
type InValues struct {
slice reflect.Value
err error
}
var _ schema.QueryAppender = InValues{}
func In(slice interface{}) InValues {
v := reflect.ValueOf(slice)
if v.Kind() != reflect.Slice {
return InValues{
err: fmt.Errorf("bun: In(non-slice %T)", slice),
}
}
return InValues{
slice: v,
}
}
func (in InValues) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) {
if in.err != nil {
return nil, in.err
}
return appendIn(fmter, b, in.slice), nil
}
func appendIn(fmter schema.Formatter, b []byte, slice reflect.Value) []byte {
sliceLen := slice.Len()
for i := 0; i < sliceLen; i++ {
if i > 0 {
b = append(b, ", "...)
}
elem := slice.Index(i)
if elem.Kind() == reflect.Interface {
elem = elem.Elem()
}
if elem.Kind() == reflect.Slice {
b = append(b, '(')
b = appendIn(fmter, b, elem)
b = append(b, ')')
} else {
b = fmter.AppendValue(b, elem)
}
}
return b
}

502
vendor/github.com/uptrace/bun/db.go generated vendored Normal file
View file

@ -0,0 +1,502 @@
package bun
import (
"context"
"database/sql"
"errors"
"fmt"
"reflect"
"strings"
"sync/atomic"
"github.com/uptrace/bun/dialect/feature"
"github.com/uptrace/bun/internal"
"github.com/uptrace/bun/schema"
)
const (
discardUnknownColumns internal.Flag = 1 << iota
)
type DBStats struct {
Queries uint64
Errors uint64
}
type DBOption func(db *DB)
func WithDiscardUnknownColumns() DBOption {
return func(db *DB) {
db.flags = db.flags.Set(discardUnknownColumns)
}
}
type DB struct {
*sql.DB
dialect schema.Dialect
features feature.Feature
queryHooks []QueryHook
fmter schema.Formatter
flags internal.Flag
stats DBStats
}
func NewDB(sqldb *sql.DB, dialect schema.Dialect, opts ...DBOption) *DB {
dialect.Init(sqldb)
db := &DB{
DB: sqldb,
dialect: dialect,
features: dialect.Features(),
fmter: schema.NewFormatter(dialect),
}
for _, opt := range opts {
opt(db)
}
return db
}
func (db *DB) String() string {
var b strings.Builder
b.WriteString("DB<dialect=")
b.WriteString(db.dialect.Name().String())
b.WriteString(">")
return b.String()
}
func (db *DB) DBStats() DBStats {
return DBStats{
Queries: atomic.LoadUint64(&db.stats.Queries),
Errors: atomic.LoadUint64(&db.stats.Errors),
}
}
func (db *DB) NewValues(model interface{}) *ValuesQuery {
return NewValuesQuery(db, model)
}
func (db *DB) NewSelect() *SelectQuery {
return NewSelectQuery(db)
}
func (db *DB) NewInsert() *InsertQuery {
return NewInsertQuery(db)
}
func (db *DB) NewUpdate() *UpdateQuery {
return NewUpdateQuery(db)
}
func (db *DB) NewDelete() *DeleteQuery {
return NewDeleteQuery(db)
}
func (db *DB) NewCreateTable() *CreateTableQuery {
return NewCreateTableQuery(db)
}
func (db *DB) NewDropTable() *DropTableQuery {
return NewDropTableQuery(db)
}
func (db *DB) NewCreateIndex() *CreateIndexQuery {
return NewCreateIndexQuery(db)
}
func (db *DB) NewDropIndex() *DropIndexQuery {
return NewDropIndexQuery(db)
}
func (db *DB) NewTruncateTable() *TruncateTableQuery {
return NewTruncateTableQuery(db)
}
func (db *DB) NewAddColumn() *AddColumnQuery {
return NewAddColumnQuery(db)
}
func (db *DB) NewDropColumn() *DropColumnQuery {
return NewDropColumnQuery(db)
}
func (db *DB) ResetModel(ctx context.Context, models ...interface{}) error {
for _, model := range models {
if _, err := db.NewDropTable().Model(model).IfExists().Exec(ctx); err != nil {
return err
}
if _, err := db.NewCreateTable().Model(model).Exec(ctx); err != nil {
return err
}
}
return nil
}
func (db *DB) Dialect() schema.Dialect {
return db.dialect
}
func (db *DB) ScanRows(ctx context.Context, rows *sql.Rows, dest ...interface{}) error {
model, err := newModel(db, dest)
if err != nil {
return err
}
_, err = model.ScanRows(ctx, rows)
return err
}
func (db *DB) ScanRow(ctx context.Context, rows *sql.Rows, dest ...interface{}) error {
model, err := newModel(db, dest)
if err != nil {
return err
}
rs, ok := model.(rowScanner)
if !ok {
return fmt.Errorf("bun: %T does not support ScanRow", model)
}
return rs.ScanRow(ctx, rows)
}
func (db *DB) AddQueryHook(hook QueryHook) {
db.queryHooks = append(db.queryHooks, hook)
}
func (db *DB) Table(typ reflect.Type) *schema.Table {
return db.dialect.Tables().Get(typ)
}
func (db *DB) RegisterModel(models ...interface{}) {
db.dialect.Tables().Register(models...)
}
func (db *DB) clone() *DB {
clone := *db
l := len(clone.queryHooks)
clone.queryHooks = clone.queryHooks[:l:l]
return &clone
}
func (db *DB) WithNamedArg(name string, value interface{}) *DB {
clone := db.clone()
clone.fmter = clone.fmter.WithNamedArg(name, value)
return clone
}
func (db *DB) Formatter() schema.Formatter {
return db.fmter
}
//------------------------------------------------------------------------------
func (db *DB) Exec(query string, args ...interface{}) (sql.Result, error) {
return db.ExecContext(context.Background(), query, args...)
}
func (db *DB) ExecContext(
ctx context.Context, query string, args ...interface{},
) (sql.Result, error) {
ctx, event := db.beforeQuery(ctx, nil, query, args)
res, err := db.DB.ExecContext(ctx, db.format(query, args))
db.afterQuery(ctx, event, res, err)
return res, err
}
func (db *DB) Query(query string, args ...interface{}) (*sql.Rows, error) {
return db.QueryContext(context.Background(), query, args...)
}
func (db *DB) QueryContext(
ctx context.Context, query string, args ...interface{},
) (*sql.Rows, error) {
ctx, event := db.beforeQuery(ctx, nil, query, args)
rows, err := db.DB.QueryContext(ctx, db.format(query, args))
db.afterQuery(ctx, event, nil, err)
return rows, err
}
func (db *DB) QueryRow(query string, args ...interface{}) *sql.Row {
return db.QueryRowContext(context.Background(), query, args...)
}
func (db *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
ctx, event := db.beforeQuery(ctx, nil, query, args)
row := db.DB.QueryRowContext(ctx, db.format(query, args))
db.afterQuery(ctx, event, nil, row.Err())
return row
}
func (db *DB) format(query string, args []interface{}) string {
return db.fmter.FormatQuery(query, args...)
}
//------------------------------------------------------------------------------
type Conn struct {
db *DB
*sql.Conn
}
func (db *DB) Conn(ctx context.Context) (Conn, error) {
conn, err := db.DB.Conn(ctx)
if err != nil {
return Conn{}, err
}
return Conn{
db: db,
Conn: conn,
}, nil
}
func (c Conn) ExecContext(
ctx context.Context, query string, args ...interface{},
) (sql.Result, error) {
ctx, event := c.db.beforeQuery(ctx, nil, query, args)
res, err := c.Conn.ExecContext(ctx, c.db.format(query, args))
c.db.afterQuery(ctx, event, res, err)
return res, err
}
func (c Conn) QueryContext(
ctx context.Context, query string, args ...interface{},
) (*sql.Rows, error) {
ctx, event := c.db.beforeQuery(ctx, nil, query, args)
rows, err := c.Conn.QueryContext(ctx, c.db.format(query, args))
c.db.afterQuery(ctx, event, nil, err)
return rows, err
}
func (c Conn) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
ctx, event := c.db.beforeQuery(ctx, nil, query, args)
row := c.Conn.QueryRowContext(ctx, c.db.format(query, args))
c.db.afterQuery(ctx, event, nil, row.Err())
return row
}
func (c Conn) NewValues(model interface{}) *ValuesQuery {
return NewValuesQuery(c.db, model).Conn(c)
}
func (c Conn) NewSelect() *SelectQuery {
return NewSelectQuery(c.db).Conn(c)
}
func (c Conn) NewInsert() *InsertQuery {
return NewInsertQuery(c.db).Conn(c)
}
func (c Conn) NewUpdate() *UpdateQuery {
return NewUpdateQuery(c.db).Conn(c)
}
func (c Conn) NewDelete() *DeleteQuery {
return NewDeleteQuery(c.db).Conn(c)
}
func (c Conn) NewCreateTable() *CreateTableQuery {
return NewCreateTableQuery(c.db).Conn(c)
}
func (c Conn) NewDropTable() *DropTableQuery {
return NewDropTableQuery(c.db).Conn(c)
}
func (c Conn) NewCreateIndex() *CreateIndexQuery {
return NewCreateIndexQuery(c.db).Conn(c)
}
func (c Conn) NewDropIndex() *DropIndexQuery {
return NewDropIndexQuery(c.db).Conn(c)
}
func (c Conn) NewTruncateTable() *TruncateTableQuery {
return NewTruncateTableQuery(c.db).Conn(c)
}
func (c Conn) NewAddColumn() *AddColumnQuery {
return NewAddColumnQuery(c.db).Conn(c)
}
func (c Conn) NewDropColumn() *DropColumnQuery {
return NewDropColumnQuery(c.db).Conn(c)
}
//------------------------------------------------------------------------------
type Stmt struct {
*sql.Stmt
}
func (db *DB) Prepare(query string) (Stmt, error) {
return db.PrepareContext(context.Background(), query)
}
func (db *DB) PrepareContext(ctx context.Context, query string) (Stmt, error) {
stmt, err := db.DB.PrepareContext(ctx, query)
if err != nil {
return Stmt{}, err
}
return Stmt{Stmt: stmt}, nil
}
//------------------------------------------------------------------------------
type Tx struct {
db *DB
*sql.Tx
}
// RunInTx runs the function in a transaction. If the function returns an error,
// the transaction is rolled back. Otherwise, the transaction is committed.
func (db *DB) RunInTx(
ctx context.Context, opts *sql.TxOptions, fn func(ctx context.Context, tx Tx) error,
) error {
tx, err := db.BeginTx(ctx, opts)
if err != nil {
return err
}
defer tx.Rollback() //nolint:errcheck
if err := fn(ctx, tx); err != nil {
return err
}
return tx.Commit()
}
func (db *DB) Begin() (Tx, error) {
return db.BeginTx(context.Background(), nil)
}
func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) {
tx, err := db.DB.BeginTx(ctx, opts)
if err != nil {
return Tx{}, err
}
return Tx{
db: db,
Tx: tx,
}, nil
}
func (tx Tx) Exec(query string, args ...interface{}) (sql.Result, error) {
return tx.ExecContext(context.TODO(), query, args...)
}
func (tx Tx) ExecContext(
ctx context.Context, query string, args ...interface{},
) (sql.Result, error) {
ctx, event := tx.db.beforeQuery(ctx, nil, query, args)
res, err := tx.Tx.ExecContext(ctx, tx.db.format(query, args))
tx.db.afterQuery(ctx, event, res, err)
return res, err
}
func (tx Tx) Query(query string, args ...interface{}) (*sql.Rows, error) {
return tx.QueryContext(context.TODO(), query, args...)
}
func (tx Tx) QueryContext(
ctx context.Context, query string, args ...interface{},
) (*sql.Rows, error) {
ctx, event := tx.db.beforeQuery(ctx, nil, query, args)
rows, err := tx.Tx.QueryContext(ctx, tx.db.format(query, args))
tx.db.afterQuery(ctx, event, nil, err)
return rows, err
}
func (tx Tx) QueryRow(query string, args ...interface{}) *sql.Row {
return tx.QueryRowContext(context.TODO(), query, args...)
}
func (tx Tx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
ctx, event := tx.db.beforeQuery(ctx, nil, query, args)
row := tx.Tx.QueryRowContext(ctx, tx.db.format(query, args))
tx.db.afterQuery(ctx, event, nil, row.Err())
return row
}
//------------------------------------------------------------------------------
func (tx Tx) NewValues(model interface{}) *ValuesQuery {
return NewValuesQuery(tx.db, model).Conn(tx)
}
func (tx Tx) NewSelect() *SelectQuery {
return NewSelectQuery(tx.db).Conn(tx)
}
func (tx Tx) NewInsert() *InsertQuery {
return NewInsertQuery(tx.db).Conn(tx)
}
func (tx Tx) NewUpdate() *UpdateQuery {
return NewUpdateQuery(tx.db).Conn(tx)
}
func (tx Tx) NewDelete() *DeleteQuery {
return NewDeleteQuery(tx.db).Conn(tx)
}
func (tx Tx) NewCreateTable() *CreateTableQuery {
return NewCreateTableQuery(tx.db).Conn(tx)
}
func (tx Tx) NewDropTable() *DropTableQuery {
return NewDropTableQuery(tx.db).Conn(tx)
}
func (tx Tx) NewCreateIndex() *CreateIndexQuery {
return NewCreateIndexQuery(tx.db).Conn(tx)
}
func (tx Tx) NewDropIndex() *DropIndexQuery {
return NewDropIndexQuery(tx.db).Conn(tx)
}
func (tx Tx) NewTruncateTable() *TruncateTableQuery {
return NewTruncateTableQuery(tx.db).Conn(tx)
}
func (tx Tx) NewAddColumn() *AddColumnQuery {
return NewAddColumnQuery(tx.db).Conn(tx)
}
func (tx Tx) NewDropColumn() *DropColumnQuery {
return NewDropColumnQuery(tx.db).Conn(tx)
}
//------------------------------------------------------------------------------0
func (db *DB) makeQueryBytes() []byte {
// TODO: make this configurable?
return make([]byte, 0, 4096)
}
//------------------------------------------------------------------------------
type result struct {
r sql.Result
n int
}
func (r result) RowsAffected() (int64, error) {
if r.r != nil {
return r.r.RowsAffected()
}
return int64(r.n), nil
}
func (r result) LastInsertId() (int64, error) {
if r.r != nil {
return r.r.LastInsertId()
}
return 0, errors.New("LastInsertId is not available")
}

178
vendor/github.com/uptrace/bun/dialect/append.go generated vendored Normal file
View file

@ -0,0 +1,178 @@
package dialect
import (
"encoding/hex"
"math"
"strconv"
"time"
"unicode/utf8"
"github.com/uptrace/bun/internal"
"github.com/uptrace/bun/internal/parser"
)
func AppendError(b []byte, err error) []byte {
b = append(b, "?!("...)
b = append(b, err.Error()...)
b = append(b, ')')
return b
}
func AppendNull(b []byte) []byte {
return append(b, "NULL"...)
}
func AppendBool(b []byte, v bool) []byte {
if v {
return append(b, "TRUE"...)
}
return append(b, "FALSE"...)
}
func AppendFloat32(b []byte, v float32) []byte {
return appendFloat(b, float64(v), 32)
}
func AppendFloat64(b []byte, v float64) []byte {
return appendFloat(b, v, 64)
}
func appendFloat(b []byte, v float64, bitSize int) []byte {
switch {
case math.IsNaN(v):
return append(b, "'NaN'"...)
case math.IsInf(v, 1):
return append(b, "'Infinity'"...)
case math.IsInf(v, -1):
return append(b, "'-Infinity'"...)
default:
return strconv.AppendFloat(b, v, 'f', -1, bitSize)
}
}
func AppendString(b []byte, s string) []byte {
b = append(b, '\'')
for _, r := range s {
if r == '\000' {
continue
}
if r == '\'' {
b = append(b, '\'', '\'')
continue
}
if r < utf8.RuneSelf {
b = append(b, byte(r))
continue
}
l := len(b)
if cap(b)-l < utf8.UTFMax {
b = append(b, make([]byte, utf8.UTFMax)...)
}
n := utf8.EncodeRune(b[l:l+utf8.UTFMax], r)
b = b[:l+n]
}
b = append(b, '\'')
return b
}
func AppendBytes(b []byte, bytes []byte) []byte {
if bytes == nil {
return AppendNull(b)
}
b = append(b, `'\x`...)
s := len(b)
b = append(b, make([]byte, hex.EncodedLen(len(bytes)))...)
hex.Encode(b[s:], bytes)
b = append(b, '\'')
return b
}
func AppendTime(b []byte, tm time.Time) []byte {
if tm.IsZero() {
return AppendNull(b)
}
b = append(b, '\'')
b = tm.UTC().AppendFormat(b, "2006-01-02 15:04:05.999999-07:00")
b = append(b, '\'')
return b
}
func AppendJSON(b, jsonb []byte) []byte {
b = append(b, '\'')
p := parser.New(jsonb)
for p.Valid() {
c := p.Read()
switch c {
case '"':
b = append(b, '"')
case '\'':
b = append(b, "''"...)
case '\000':
continue
case '\\':
if p.SkipBytes([]byte("u0000")) {
b = append(b, `\\u0000`...)
} else {
b = append(b, '\\')
if p.Valid() {
b = append(b, p.Read())
}
}
default:
b = append(b, c)
}
}
b = append(b, '\'')
return b
}
//------------------------------------------------------------------------------
func AppendIdent(b []byte, field string, quote byte) []byte {
return appendIdent(b, internal.Bytes(field), quote)
}
func appendIdent(b, src []byte, quote byte) []byte {
var quoted bool
loop:
for _, c := range src {
switch c {
case '*':
if !quoted {
b = append(b, '*')
continue loop
}
case '.':
if quoted {
b = append(b, quote)
quoted = false
}
b = append(b, '.')
continue loop
}
if !quoted {
b = append(b, quote)
quoted = true
}
if c == quote {
b = append(b, quote, quote)
} else {
b = append(b, c)
}
}
if quoted {
b = append(b, quote)
}
return b
}

26
vendor/github.com/uptrace/bun/dialect/dialect.go generated vendored Normal file
View file

@ -0,0 +1,26 @@
package dialect
type Name int
func (n Name) String() string {
switch n {
case PG:
return "pg"
case SQLite:
return "sqlite"
case MySQL5:
return "mysql5"
case MySQL8:
return "mysql8"
default:
return "invalid"
}
}
const (
Invalid Name = iota
PG
SQLite
MySQL5
MySQL8
)

View file

@ -0,0 +1,22 @@
package feature
import "github.com/uptrace/bun/internal"
type Feature = internal.Flag
const DefaultFeatures = Returning | TableCascade
const (
Returning Feature = 1 << iota
DefaultPlaceholder
DoubleColonCast
ValuesRow
UpdateMultiTable
InsertTableAlias
DeleteTableAlias
AutoIncrement
TableCascade
TableIdentity
TableTruncate
OnDuplicateKey
)

View file

@ -0,0 +1,24 @@
Copyright (c) 2021 Vladimir Mihailenco. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View file

@ -0,0 +1,303 @@
package pgdialect
import (
"database/sql/driver"
"fmt"
"reflect"
"strconv"
"time"
"unicode/utf8"
"github.com/uptrace/bun/dialect"
"github.com/uptrace/bun/schema"
)
var (
driverValuerType = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
stringType = reflect.TypeOf((*string)(nil)).Elem()
sliceStringType = reflect.TypeOf([]string(nil))
intType = reflect.TypeOf((*int)(nil)).Elem()
sliceIntType = reflect.TypeOf([]int(nil))
int64Type = reflect.TypeOf((*int64)(nil)).Elem()
sliceInt64Type = reflect.TypeOf([]int64(nil))
float64Type = reflect.TypeOf((*float64)(nil)).Elem()
sliceFloat64Type = reflect.TypeOf([]float64(nil))
)
func customAppender(typ reflect.Type) schema.AppenderFunc {
switch typ.Kind() {
case reflect.Uint32:
return appendUint32ValueAsInt
case reflect.Uint, reflect.Uint64:
return appendUint64ValueAsInt
}
return nil
}
func appendUint32ValueAsInt(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
return strconv.AppendInt(b, int64(int32(v.Uint())), 10)
}
func appendUint64ValueAsInt(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
return strconv.AppendInt(b, int64(v.Uint()), 10)
}
//------------------------------------------------------------------------------
func arrayAppend(fmter schema.Formatter, b []byte, v interface{}) []byte {
switch v := v.(type) {
case int64:
return strconv.AppendInt(b, v, 10)
case float64:
return dialect.AppendFloat64(b, v)
case bool:
return dialect.AppendBool(b, v)
case []byte:
return dialect.AppendBytes(b, v)
case string:
return arrayAppendString(b, v)
case time.Time:
return dialect.AppendTime(b, v)
default:
err := fmt.Errorf("pgdialect: can't append %T", v)
return dialect.AppendError(b, err)
}
}
func arrayElemAppender(typ reflect.Type) schema.AppenderFunc {
if typ.Kind() == reflect.String {
return arrayAppendStringValue
}
if typ.Implements(driverValuerType) {
return arrayAppendDriverValue
}
return schema.Appender(typ, customAppender)
}
func arrayAppendStringValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
return arrayAppendString(b, v.String())
}
func arrayAppendDriverValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
iface, err := v.Interface().(driver.Valuer).Value()
if err != nil {
return dialect.AppendError(b, err)
}
return arrayAppend(fmter, b, iface)
}
//------------------------------------------------------------------------------
func arrayAppender(typ reflect.Type) schema.AppenderFunc {
kind := typ.Kind()
if kind == reflect.Ptr {
typ = typ.Elem()
kind = typ.Kind()
}
switch kind {
case reflect.Slice, reflect.Array:
// ok:
default:
return nil
}
elemType := typ.Elem()
if kind == reflect.Slice {
switch elemType {
case stringType:
return appendStringSliceValue
case intType:
return appendIntSliceValue
case int64Type:
return appendInt64SliceValue
case float64Type:
return appendFloat64SliceValue
}
}
appendElem := arrayElemAppender(elemType)
if appendElem == nil {
panic(fmt.Errorf("pgdialect: %s is not supported", typ))
}
return func(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
kind := v.Kind()
switch kind {
case reflect.Ptr, reflect.Slice:
if v.IsNil() {
return dialect.AppendNull(b)
}
}
if kind == reflect.Ptr {
v = v.Elem()
}
b = append(b, '\'')
b = append(b, '{')
for i := 0; i < v.Len(); i++ {
elem := v.Index(i)
b = appendElem(fmter, b, elem)
b = append(b, ',')
}
if v.Len() > 0 {
b[len(b)-1] = '}' // Replace trailing comma.
} else {
b = append(b, '}')
}
b = append(b, '\'')
return b
}
}
func appendStringSliceValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
ss := v.Convert(sliceStringType).Interface().([]string)
return appendStringSlice(b, ss)
}
func appendStringSlice(b []byte, ss []string) []byte {
if ss == nil {
return dialect.AppendNull(b)
}
b = append(b, '\'')
b = append(b, '{')
for _, s := range ss {
b = arrayAppendString(b, s)
b = append(b, ',')
}
if len(ss) > 0 {
b[len(b)-1] = '}' // Replace trailing comma.
} else {
b = append(b, '}')
}
b = append(b, '\'')
return b
}
func appendIntSliceValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
ints := v.Convert(sliceIntType).Interface().([]int)
return appendIntSlice(b, ints)
}
func appendIntSlice(b []byte, ints []int) []byte {
if ints == nil {
return dialect.AppendNull(b)
}
b = append(b, '\'')
b = append(b, '{')
for _, n := range ints {
b = strconv.AppendInt(b, int64(n), 10)
b = append(b, ',')
}
if len(ints) > 0 {
b[len(b)-1] = '}' // Replace trailing comma.
} else {
b = append(b, '}')
}
b = append(b, '\'')
return b
}
func appendInt64SliceValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
ints := v.Convert(sliceInt64Type).Interface().([]int64)
return appendInt64Slice(b, ints)
}
func appendInt64Slice(b []byte, ints []int64) []byte {
if ints == nil {
return dialect.AppendNull(b)
}
b = append(b, '\'')
b = append(b, '{')
for _, n := range ints {
b = strconv.AppendInt(b, n, 10)
b = append(b, ',')
}
if len(ints) > 0 {
b[len(b)-1] = '}' // Replace trailing comma.
} else {
b = append(b, '}')
}
b = append(b, '\'')
return b
}
func appendFloat64SliceValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
floats := v.Convert(sliceFloat64Type).Interface().([]float64)
return appendFloat64Slice(b, floats)
}
func appendFloat64Slice(b []byte, floats []float64) []byte {
if floats == nil {
return dialect.AppendNull(b)
}
b = append(b, '\'')
b = append(b, '{')
for _, n := range floats {
b = dialect.AppendFloat64(b, n)
b = append(b, ',')
}
if len(floats) > 0 {
b[len(b)-1] = '}' // Replace trailing comma.
} else {
b = append(b, '}')
}
b = append(b, '\'')
return b
}
//------------------------------------------------------------------------------
func arrayAppendString(b []byte, s string) []byte {
b = append(b, '"')
for _, r := range s {
switch r {
case 0:
// ignore
case '\'':
b = append(b, "'''"...)
case '"':
b = append(b, '\\', '"')
case '\\':
b = append(b, '\\', '\\')
default:
if r < utf8.RuneSelf {
b = append(b, byte(r))
break
}
l := len(b)
if cap(b)-l < utf8.UTFMax {
b = append(b, make([]byte, utf8.UTFMax)...)
}
n := utf8.EncodeRune(b[l:l+utf8.UTFMax], r)
b = b[:l+n]
}
}
b = append(b, '"')
return b
}

View file

@ -0,0 +1,65 @@
package pgdialect
import (
"database/sql"
"fmt"
"reflect"
"github.com/uptrace/bun/schema"
)
type ArrayValue struct {
v reflect.Value
append schema.AppenderFunc
scan schema.ScannerFunc
}
// Array accepts a slice and returns a wrapper for working with PostgreSQL
// array data type.
//
// For struct fields you can use array tag:
//
// Emails []string `bun:",array"`
func Array(vi interface{}) *ArrayValue {
v := reflect.ValueOf(vi)
if !v.IsValid() {
panic(fmt.Errorf("bun: Array(nil)"))
}
return &ArrayValue{
v: v,
append: arrayAppender(v.Type()),
scan: arrayScanner(v.Type()),
}
}
var (
_ schema.QueryAppender = (*ArrayValue)(nil)
_ sql.Scanner = (*ArrayValue)(nil)
)
func (a *ArrayValue) AppendQuery(fmter schema.Formatter, b []byte) ([]byte, error) {
if a.append == nil {
panic(fmt.Errorf("bun: Array(unsupported %s)", a.v.Type()))
}
return a.append(fmter, b, a.v), nil
}
func (a *ArrayValue) Scan(src interface{}) error {
if a.scan == nil {
return fmt.Errorf("bun: Array(unsupported %s)", a.v.Type())
}
if a.v.Kind() != reflect.Ptr {
return fmt.Errorf("bun: Array(non-pointer %s)", a.v.Type())
}
return a.scan(a.v, src)
}
func (a *ArrayValue) Value() interface{} {
if a.v.IsValid() {
return a.v.Interface()
}
return nil
}

View file

@ -0,0 +1,146 @@
package pgdialect
import (
"bytes"
"fmt"
"io"
)
type arrayParser struct {
b []byte
i int
buf []byte
err error
}
func newArrayParser(b []byte) *arrayParser {
p := &arrayParser{
b: b,
i: 1,
}
if len(b) < 2 || b[0] != '{' || b[len(b)-1] != '}' {
p.err = fmt.Errorf("bun: can't parse array: %q", b)
}
return p
}
func (p *arrayParser) NextElem() ([]byte, error) {
if p.err != nil {
return nil, p.err
}
c, err := p.readByte()
if err != nil {
return nil, err
}
switch c {
case '}':
return nil, io.EOF
case '"':
b, err := p.readSubstring()
if err != nil {
return nil, err
}
if p.peek() == ',' {
p.skipNext()
}
return b, nil
default:
b := p.readSimple()
if bytes.Equal(b, []byte("NULL")) {
b = nil
}
if p.peek() == ',' {
p.skipNext()
}
return b, nil
}
}
func (p *arrayParser) readSimple() []byte {
p.unreadByte()
if i := bytes.IndexByte(p.b[p.i:], ','); i >= 0 {
b := p.b[p.i : p.i+i]
p.i += i
return b
}
b := p.b[p.i : len(p.b)-1]
p.i = len(p.b) - 1
return b
}
func (p *arrayParser) readSubstring() ([]byte, error) {
c, err := p.readByte()
if err != nil {
return nil, err
}
p.buf = p.buf[:0]
for {
if c == '"' {
break
}
next, err := p.readByte()
if err != nil {
return nil, err
}
if c == '\\' {
switch next {
case '\\', '"':
p.buf = append(p.buf, next)
c, err = p.readByte()
if err != nil {
return nil, err
}
default:
p.buf = append(p.buf, '\\')
c = next
}
continue
}
p.buf = append(p.buf, c)
c = next
}
return p.buf, nil
}
func (p *arrayParser) valid() bool {
return p.i < len(p.b)
}
func (p *arrayParser) readByte() (byte, error) {
if p.valid() {
c := p.b[p.i]
p.i++
return c, nil
}
return 0, io.EOF
}
func (p *arrayParser) unreadByte() {
p.i--
}
func (p *arrayParser) peek() byte {
if p.valid() {
return p.b[p.i]
}
return 0
}
func (p *arrayParser) skipNext() {
p.i++
}

View file

@ -0,0 +1,302 @@
package pgdialect
import (
"fmt"
"io"
"reflect"
"strconv"
"github.com/uptrace/bun/internal"
"github.com/uptrace/bun/schema"
)
func arrayScanner(typ reflect.Type) schema.ScannerFunc {
kind := typ.Kind()
if kind == reflect.Ptr {
typ = typ.Elem()
kind = typ.Kind()
}
switch kind {
case reflect.Slice, reflect.Array:
// ok:
default:
return nil
}
elemType := typ.Elem()
if kind == reflect.Slice {
switch elemType {
case stringType:
return scanStringSliceValue
case intType:
return scanIntSliceValue
case int64Type:
return scanInt64SliceValue
case float64Type:
return scanFloat64SliceValue
}
}
scanElem := schema.Scanner(elemType)
return func(dest reflect.Value, src interface{}) error {
dest = reflect.Indirect(dest)
if !dest.CanSet() {
return fmt.Errorf("bun: Scan(non-settable %s)", dest.Type())
}
kind := dest.Kind()
if src == nil {
if kind != reflect.Slice || !dest.IsNil() {
dest.Set(reflect.Zero(dest.Type()))
}
return nil
}
if kind == reflect.Slice {
if dest.IsNil() {
dest.Set(reflect.MakeSlice(dest.Type(), 0, 0))
} else if dest.Len() > 0 {
dest.Set(dest.Slice(0, 0))
}
}
b, err := toBytes(src)
if err != nil {
return err
}
p := newArrayParser(b)
nextValue := internal.MakeSliceNextElemFunc(dest)
for {
elem, err := p.NextElem()
if err != nil {
if err == io.EOF {
break
}
return err
}
elemValue := nextValue()
if err := scanElem(elemValue, elem); err != nil {
return err
}
}
return nil
}
}
func scanStringSliceValue(dest reflect.Value, src interface{}) error {
dest = reflect.Indirect(dest)
if !dest.CanSet() {
return fmt.Errorf("bun: Scan(non-settable %s)", dest.Type())
}
slice, err := decodeStringSlice(src)
if err != nil {
return err
}
dest.Set(reflect.ValueOf(slice))
return nil
}
func decodeStringSlice(src interface{}) ([]string, error) {
if src == nil {
return nil, nil
}
b, err := toBytes(src)
if err != nil {
return nil, err
}
slice := make([]string, 0)
p := newArrayParser(b)
for {
elem, err := p.NextElem()
if err != nil {
if err == io.EOF {
break
}
return nil, err
}
slice = append(slice, string(elem))
}
return slice, nil
}
func scanIntSliceValue(dest reflect.Value, src interface{}) error {
dest = reflect.Indirect(dest)
if !dest.CanSet() {
return fmt.Errorf("bun: Scan(non-settable %s)", dest.Type())
}
slice, err := decodeIntSlice(src)
if err != nil {
return err
}
dest.Set(reflect.ValueOf(slice))
return nil
}
func decodeIntSlice(src interface{}) ([]int, error) {
if src == nil {
return nil, nil
}
b, err := toBytes(src)
if err != nil {
return nil, err
}
slice := make([]int, 0)
p := newArrayParser(b)
for {
elem, err := p.NextElem()
if err != nil {
if err == io.EOF {
break
}
return nil, err
}
if elem == nil {
slice = append(slice, 0)
continue
}
n, err := strconv.Atoi(bytesToString(elem))
if err != nil {
return nil, err
}
slice = append(slice, n)
}
return slice, nil
}
func scanInt64SliceValue(dest reflect.Value, src interface{}) error {
dest = reflect.Indirect(dest)
if !dest.CanSet() {
return fmt.Errorf("bun: Scan(non-settable %s)", dest.Type())
}
slice, err := decodeInt64Slice(src)
if err != nil {
return err
}
dest.Set(reflect.ValueOf(slice))
return nil
}
func decodeInt64Slice(src interface{}) ([]int64, error) {
if src == nil {
return nil, nil
}
b, err := toBytes(src)
if err != nil {
return nil, err
}
slice := make([]int64, 0)
p := newArrayParser(b)
for {
elem, err := p.NextElem()
if err != nil {
if err == io.EOF {
break
}
return nil, err
}
if elem == nil {
slice = append(slice, 0)
continue
}
n, err := strconv.ParseInt(bytesToString(elem), 10, 64)
if err != nil {
return nil, err
}
slice = append(slice, n)
}
return slice, nil
}
func scanFloat64SliceValue(dest reflect.Value, src interface{}) error {
dest = reflect.Indirect(dest)
if !dest.CanSet() {
return fmt.Errorf("bun: Scan(non-settable %s)", dest.Type())
}
slice, err := scanFloat64Slice(src)
if err != nil {
return err
}
dest.Set(reflect.ValueOf(slice))
return nil
}
func scanFloat64Slice(src interface{}) ([]float64, error) {
if src == -1 {
return nil, nil
}
b, err := toBytes(src)
if err != nil {
return nil, err
}
slice := make([]float64, 0)
p := newArrayParser(b)
for {
elem, err := p.NextElem()
if err != nil {
if err == io.EOF {
break
}
return nil, err
}
if elem == nil {
slice = append(slice, 0)
continue
}
n, err := strconv.ParseFloat(bytesToString(elem), 64)
if err != nil {
return nil, err
}
slice = append(slice, n)
}
return slice, nil
}
func toBytes(src interface{}) ([]byte, error) {
switch src := src.(type) {
case string:
return stringToBytes(src), nil
case []byte:
return src, nil
default:
return nil, fmt.Errorf("bun: got %T, wanted []byte or string", src)
}
}

View file

@ -0,0 +1,150 @@
package pgdialect
import (
"database/sql"
"reflect"
"strconv"
"sync"
"time"
"github.com/uptrace/bun/dialect"
"github.com/uptrace/bun/dialect/feature"
"github.com/uptrace/bun/dialect/sqltype"
"github.com/uptrace/bun/schema"
)
type Dialect struct {
tables *schema.Tables
features feature.Feature
appenderMap sync.Map
scannerMap sync.Map
}
func New() *Dialect {
d := new(Dialect)
d.tables = schema.NewTables(d)
d.features = feature.Returning |
feature.DefaultPlaceholder |
feature.DoubleColonCast |
feature.InsertTableAlias |
feature.DeleteTableAlias |
feature.TableCascade |
feature.TableIdentity |
feature.TableTruncate
return d
}
func (d *Dialect) Init(*sql.DB) {}
func (d *Dialect) Name() dialect.Name {
return dialect.PG
}
func (d *Dialect) Features() feature.Feature {
return d.features
}
func (d *Dialect) Tables() *schema.Tables {
return d.tables
}
func (d *Dialect) OnTable(table *schema.Table) {
for _, field := range table.FieldMap {
d.onField(field)
}
}
func (d *Dialect) onField(field *schema.Field) {
field.DiscoveredSQLType = fieldSQLType(field)
if field.AutoIncrement {
switch field.DiscoveredSQLType {
case sqltype.SmallInt:
field.CreateTableSQLType = pgTypeSmallSerial
case sqltype.Integer:
field.CreateTableSQLType = pgTypeSerial
case sqltype.BigInt:
field.CreateTableSQLType = pgTypeBigSerial
}
}
if field.Tag.HasOption("array") {
field.Append = arrayAppender(field.IndirectType)
field.Scan = arrayScanner(field.IndirectType)
}
}
func (d *Dialect) IdentQuote() byte {
return '"'
}
func (d *Dialect) Append(fmter schema.Formatter, b []byte, v interface{}) []byte {
switch v := v.(type) {
case nil:
return dialect.AppendNull(b)
case bool:
return dialect.AppendBool(b, v)
case int:
return strconv.AppendInt(b, int64(v), 10)
case int32:
return strconv.AppendInt(b, int64(v), 10)
case int64:
return strconv.AppendInt(b, v, 10)
case uint:
return strconv.AppendInt(b, int64(v), 10)
case uint32:
return strconv.AppendInt(b, int64(v), 10)
case uint64:
return strconv.AppendInt(b, int64(v), 10)
case float32:
return dialect.AppendFloat32(b, v)
case float64:
return dialect.AppendFloat64(b, v)
case string:
return dialect.AppendString(b, v)
case time.Time:
return dialect.AppendTime(b, v)
case []byte:
return dialect.AppendBytes(b, v)
case schema.QueryAppender:
return schema.AppendQueryAppender(fmter, b, v)
default:
vv := reflect.ValueOf(v)
if vv.Kind() == reflect.Ptr && vv.IsNil() {
return dialect.AppendNull(b)
}
appender := d.Appender(vv.Type())
return appender(fmter, b, vv)
}
}
func (d *Dialect) Appender(typ reflect.Type) schema.AppenderFunc {
if v, ok := d.appenderMap.Load(typ); ok {
return v.(schema.AppenderFunc)
}
fn := schema.Appender(typ, customAppender)
if v, ok := d.appenderMap.LoadOrStore(typ, fn); ok {
return v.(schema.AppenderFunc)
}
return fn
}
func (d *Dialect) FieldAppender(field *schema.Field) schema.AppenderFunc {
return schema.FieldAppender(d, field)
}
func (d *Dialect) Scanner(typ reflect.Type) schema.ScannerFunc {
if v, ok := d.scannerMap.Load(typ); ok {
return v.(schema.ScannerFunc)
}
fn := scanner(typ)
if v, ok := d.scannerMap.LoadOrStore(typ, fn); ok {
return v.(schema.ScannerFunc)
}
return fn
}

View file

@ -0,0 +1,7 @@
module github.com/uptrace/bun/dialect/pgdialect
go 1.16
replace github.com/uptrace/bun => ../..
require github.com/uptrace/bun v0.4.3

22
vendor/github.com/uptrace/bun/dialect/pgdialect/go.sum generated vendored Normal file
View file

@ -0,0 +1,22 @@
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYmW5DyG0UqvY96Bu5QYsTLvCHdrgo=
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs=
github.com/vmihailenco/msgpack/v5 v5.3.4 h1:qMKAwOV+meBw2Y8k9cVwAy7qErtYCwBzZ2ellBfvnqc=
github.com/vmihailenco/msgpack/v5 v5.3.4/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc=
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
golang.org/x/sys v0.0.0-20210616094352-59db8d763f22 h1:RqytpXGR1iVNX7psjB3ff8y7sNFinVFvkx1c8SjBkio=
golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View file

@ -0,0 +1,11 @@
// +build appengine
package pgdialect
func bytesToString(b []byte) string {
return string(b)
}
func stringToBytes(s string) []byte {
return []byte(s)
}

View file

@ -0,0 +1,28 @@
package pgdialect
import (
"fmt"
"reflect"
"github.com/uptrace/bun/schema"
)
func scanner(typ reflect.Type) schema.ScannerFunc {
if typ.Kind() == reflect.Interface {
return scanInterface
}
return schema.Scanner(typ)
}
func scanInterface(dest reflect.Value, src interface{}) error {
if dest.IsNil() {
dest.Set(reflect.ValueOf(src))
return nil
}
dest = dest.Elem()
if fn := scanner(dest.Type()); fn != nil {
return fn(dest, src)
}
return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type())
}

View file

@ -0,0 +1,104 @@
package pgdialect
import (
"encoding/json"
"net"
"reflect"
"time"
"github.com/uptrace/bun/dialect/sqltype"
"github.com/uptrace/bun/schema"
)
const (
// Date / Time
pgTypeTimestampTz = "TIMESTAMPTZ" // Timestamp with a time zone
pgTypeDate = "DATE" // Date
pgTypeTime = "TIME" // Time without a time zone
pgTypeTimeTz = "TIME WITH TIME ZONE" // Time with a time zone
pgTypeInterval = "INTERVAL" // Time Interval
// Network Addresses
pgTypeInet = "INET" // IPv4 or IPv6 hosts and networks
pgTypeCidr = "CIDR" // IPv4 or IPv6 networks
pgTypeMacaddr = "MACADDR" // MAC addresses
// Serial Types
pgTypeSmallSerial = "SMALLSERIAL" // 2 byte autoincrementing integer
pgTypeSerial = "SERIAL" // 4 byte autoincrementing integer
pgTypeBigSerial = "BIGSERIAL" // 8 byte autoincrementing integer
// Character Types
pgTypeChar = "CHAR" // fixed length string (blank padded)
pgTypeText = "TEXT" // variable length string without limit
// JSON Types
pgTypeJSON = "JSON" // text representation of json data
pgTypeJSONB = "JSONB" // binary representation of json data
// Binary Data Types
pgTypeBytea = "BYTEA" // binary string
)
var (
timeType = reflect.TypeOf((*time.Time)(nil)).Elem()
ipType = reflect.TypeOf((*net.IP)(nil)).Elem()
ipNetType = reflect.TypeOf((*net.IPNet)(nil)).Elem()
jsonRawMessageType = reflect.TypeOf((*json.RawMessage)(nil)).Elem()
)
func fieldSQLType(field *schema.Field) string {
if field.UserSQLType != "" {
return field.UserSQLType
}
if v, ok := field.Tag.Options["composite"]; ok {
return v
}
if _, ok := field.Tag.Options["hstore"]; ok {
return "hstore"
}
if _, ok := field.Tag.Options["array"]; ok {
switch field.IndirectType.Kind() {
case reflect.Slice, reflect.Array:
sqlType := sqlType(field.IndirectType.Elem())
return sqlType + "[]"
}
}
return sqlType(field.IndirectType)
}
func sqlType(typ reflect.Type) string {
switch typ {
case ipType:
return pgTypeInet
case ipNetType:
return pgTypeCidr
case jsonRawMessageType:
return pgTypeJSONB
}
sqlType := schema.DiscoverSQLType(typ)
switch sqlType {
case sqltype.Timestamp:
sqlType = pgTypeTimestampTz
}
switch typ.Kind() {
case reflect.Map, reflect.Struct:
if sqlType == sqltype.VarChar {
return pgTypeJSONB
}
return sqlType
case reflect.Array, reflect.Slice:
if typ.Elem().Kind() == reflect.Uint8 {
return pgTypeBytea
}
return pgTypeJSONB
}
return sqlType
}

View file

@ -0,0 +1,18 @@
// +build !appengine
package pgdialect
import "unsafe"
func bytesToString(b []byte) string {
return *(*string)(unsafe.Pointer(&b))
}
func stringToBytes(s string) []byte {
return *(*[]byte)(unsafe.Pointer(
&struct {
string
Cap int
}{s, len(s)},
))
}

View file

@ -0,0 +1,14 @@
package sqltype
const (
Boolean = "BOOLEAN"
SmallInt = "SMALLINT"
Integer = "INTEGER"
BigInt = "BIGINT"
Real = "REAL"
DoublePrecision = "DOUBLE PRECISION"
VarChar = "VARCHAR"
Timestamp = "TIMESTAMP"
JSON = "JSON"
JSONB = "JSONB"
)

26
vendor/github.com/uptrace/bun/extra/bunjson/json.go generated vendored Normal file
View file

@ -0,0 +1,26 @@
package bunjson
import (
"encoding/json"
"io"
)
var _ Provider = (*StdProvider)(nil)
type StdProvider struct{}
func (StdProvider) Marshal(v interface{}) ([]byte, error) {
return json.Marshal(v)
}
func (StdProvider) Unmarshal(data []byte, v interface{}) error {
return json.Unmarshal(data, v)
}
func (StdProvider) NewEncoder(w io.Writer) Encoder {
return json.NewEncoder(w)
}
func (StdProvider) NewDecoder(r io.Reader) Decoder {
return json.NewDecoder(r)
}

View file

@ -0,0 +1,43 @@
package bunjson
import (
"io"
)
var provider Provider = StdProvider{}
func SetProvider(p Provider) {
provider = p
}
type Provider interface {
Marshal(v interface{}) ([]byte, error)
Unmarshal(data []byte, v interface{}) error
NewEncoder(w io.Writer) Encoder
NewDecoder(r io.Reader) Decoder
}
type Decoder interface {
Decode(v interface{}) error
UseNumber()
}
type Encoder interface {
Encode(v interface{}) error
}
func Marshal(v interface{}) ([]byte, error) {
return provider.Marshal(v)
}
func Unmarshal(data []byte, v interface{}) error {
return provider.Unmarshal(data, v)
}
func NewEncoder(w io.Writer) Encoder {
return provider.NewEncoder(w)
}
func NewDecoder(r io.Reader) Decoder {
return provider.NewDecoder(r)
}

12
vendor/github.com/uptrace/bun/go.mod generated vendored Normal file
View file

@ -0,0 +1,12 @@
module github.com/uptrace/bun
go 1.16
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/jinzhu/inflection v1.0.0
github.com/stretchr/testify v1.7.0
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc
github.com/vmihailenco/msgpack/v5 v5.3.4
golang.org/x/sys v0.0.0-20210616094352-59db8d763f22 // indirect
)

23
vendor/github.com/uptrace/bun/go.sum generated vendored Normal file
View file

@ -0,0 +1,23 @@
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYmW5DyG0UqvY96Bu5QYsTLvCHdrgo=
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs=
github.com/vmihailenco/msgpack/v5 v5.3.4 h1:qMKAwOV+meBw2Y8k9cVwAy7qErtYCwBzZ2ellBfvnqc=
github.com/vmihailenco/msgpack/v5 v5.3.4/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc=
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
golang.org/x/sys v0.0.0-20210616094352-59db8d763f22 h1:RqytpXGR1iVNX7psjB3ff8y7sNFinVFvkx1c8SjBkio=
golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

98
vendor/github.com/uptrace/bun/hook.go generated vendored Normal file
View file

@ -0,0 +1,98 @@
package bun
import (
"context"
"database/sql"
"reflect"
"sync/atomic"
"time"
"github.com/uptrace/bun/schema"
)
type QueryEvent struct {
DB *DB
QueryAppender schema.QueryAppender
Query string
QueryArgs []interface{}
StartTime time.Time
Result sql.Result
Err error
Stash map[interface{}]interface{}
}
type QueryHook interface {
BeforeQuery(context.Context, *QueryEvent) context.Context
AfterQuery(context.Context, *QueryEvent)
}
func (db *DB) beforeQuery(
ctx context.Context,
queryApp schema.QueryAppender,
query string,
queryArgs []interface{},
) (context.Context, *QueryEvent) {
atomic.AddUint64(&db.stats.Queries, 1)
if len(db.queryHooks) == 0 {
return ctx, nil
}
event := &QueryEvent{
DB: db,
QueryAppender: queryApp,
Query: query,
QueryArgs: queryArgs,
StartTime: time.Now(),
}
for _, hook := range db.queryHooks {
ctx = hook.BeforeQuery(ctx, event)
}
return ctx, event
}
func (db *DB) afterQuery(
ctx context.Context,
event *QueryEvent,
res sql.Result,
err error,
) {
switch err {
case nil, sql.ErrNoRows:
// nothing
default:
atomic.AddUint64(&db.stats.Errors, 1)
}
if event == nil {
return
}
event.Result = res
event.Err = err
db.afterQueryFromIndex(ctx, event, len(db.queryHooks)-1)
}
func (db *DB) afterQueryFromIndex(ctx context.Context, event *QueryEvent, hookIndex int) {
for ; hookIndex >= 0; hookIndex-- {
db.queryHooks[hookIndex].AfterQuery(ctx, event)
}
}
//------------------------------------------------------------------------------
func callBeforeScanHook(ctx context.Context, v reflect.Value) error {
return v.Interface().(schema.BeforeScanHook).BeforeScan(ctx)
}
func callAfterScanHook(ctx context.Context, v reflect.Value) error {
return v.Interface().(schema.AfterScanHook).AfterScan(ctx)
}

16
vendor/github.com/uptrace/bun/internal/flag.go generated vendored Normal file
View file

@ -0,0 +1,16 @@
package internal
type Flag uint64
func (flag Flag) Has(other Flag) bool {
return flag&other == other
}
func (flag Flag) Set(other Flag) Flag {
return flag | other
}
func (flag Flag) Remove(other Flag) Flag {
flag &= ^other
return flag
}

43
vendor/github.com/uptrace/bun/internal/hex.go generated vendored Normal file
View file

@ -0,0 +1,43 @@
package internal
import (
fasthex "github.com/tmthrgd/go-hex"
)
type HexEncoder struct {
b []byte
written bool
}
func NewHexEncoder(b []byte) *HexEncoder {
return &HexEncoder{
b: b,
}
}
func (enc *HexEncoder) Bytes() []byte {
return enc.b
}
func (enc *HexEncoder) Write(b []byte) (int, error) {
if !enc.written {
enc.b = append(enc.b, '\'')
enc.b = append(enc.b, `\x`...)
enc.written = true
}
i := len(enc.b)
enc.b = append(enc.b, make([]byte, fasthex.EncodedLen(len(b)))...)
fasthex.Encode(enc.b[i:], b)
return len(b), nil
}
func (enc *HexEncoder) Close() error {
if enc.written {
enc.b = append(enc.b, '\'')
} else {
enc.b = append(enc.b, "NULL"...)
}
return nil
}

27
vendor/github.com/uptrace/bun/internal/logger.go generated vendored Normal file
View file

@ -0,0 +1,27 @@
package internal
import (
"fmt"
"log"
"os"
)
var Warn = log.New(os.Stderr, "WARN: bun: ", log.LstdFlags)
var Deprecated = log.New(os.Stderr, "DEPRECATED: bun: ", log.LstdFlags)
type Logging interface {
Printf(format string, v ...interface{})
}
type logger struct {
log *log.Logger
}
func (l *logger) Printf(format string, v ...interface{}) {
_ = l.log.Output(2, fmt.Sprintf(format, v...))
}
var Logger Logging = &logger{
log: log.New(os.Stderr, "bun: ", log.LstdFlags|log.Lshortfile),
}

67
vendor/github.com/uptrace/bun/internal/map_key.go generated vendored Normal file
View file

@ -0,0 +1,67 @@
package internal
import "reflect"
var ifaceType = reflect.TypeOf((*interface{})(nil)).Elem()
type MapKey struct {
iface interface{}
}
func NewMapKey(is []interface{}) MapKey {
return MapKey{
iface: newMapKey(is),
}
}
func newMapKey(is []interface{}) interface{} {
switch len(is) {
case 1:
ptr := new([1]interface{})
copy((*ptr)[:], is)
return *ptr
case 2:
ptr := new([2]interface{})
copy((*ptr)[:], is)
return *ptr
case 3:
ptr := new([3]interface{})
copy((*ptr)[:], is)
return *ptr
case 4:
ptr := new([4]interface{})
copy((*ptr)[:], is)
return *ptr
case 5:
ptr := new([5]interface{})
copy((*ptr)[:], is)
return *ptr
case 6:
ptr := new([6]interface{})
copy((*ptr)[:], is)
return *ptr
case 7:
ptr := new([7]interface{})
copy((*ptr)[:], is)
return *ptr
case 8:
ptr := new([8]interface{})
copy((*ptr)[:], is)
return *ptr
case 9:
ptr := new([9]interface{})
copy((*ptr)[:], is)
return *ptr
case 10:
ptr := new([10]interface{})
copy((*ptr)[:], is)
return *ptr
default:
}
at := reflect.New(reflect.ArrayOf(len(is), ifaceType)).Elem()
for i, v := range is {
*(at.Index(i).Addr().Interface().(*interface{})) = v
}
return at.Interface()
}

141
vendor/github.com/uptrace/bun/internal/parser/parser.go generated vendored Normal file
View file

@ -0,0 +1,141 @@
package parser
import (
"bytes"
"strconv"
"github.com/uptrace/bun/internal"
)
type Parser struct {
b []byte
i int
}
func New(b []byte) *Parser {
return &Parser{
b: b,
}
}
func NewString(s string) *Parser {
return New(internal.Bytes(s))
}
func (p *Parser) Valid() bool {
return p.i < len(p.b)
}
func (p *Parser) Bytes() []byte {
return p.b[p.i:]
}
func (p *Parser) Read() byte {
if p.Valid() {
c := p.b[p.i]
p.Advance()
return c
}
return 0
}
func (p *Parser) Peek() byte {
if p.Valid() {
return p.b[p.i]
}
return 0
}
func (p *Parser) Advance() {
p.i++
}
func (p *Parser) Skip(skip byte) bool {
if p.Peek() == skip {
p.Advance()
return true
}
return false
}
func (p *Parser) SkipBytes(skip []byte) bool {
if len(skip) > len(p.b[p.i:]) {
return false
}
if !bytes.Equal(p.b[p.i:p.i+len(skip)], skip) {
return false
}
p.i += len(skip)
return true
}
func (p *Parser) ReadSep(sep byte) ([]byte, bool) {
ind := bytes.IndexByte(p.b[p.i:], sep)
if ind == -1 {
b := p.b[p.i:]
p.i = len(p.b)
return b, false
}
b := p.b[p.i : p.i+ind]
p.i += ind + 1
return b, true
}
func (p *Parser) ReadIdentifier() (string, bool) {
if p.i < len(p.b) && p.b[p.i] == '(' {
s := p.i + 1
if ind := bytes.IndexByte(p.b[s:], ')'); ind != -1 {
b := p.b[s : s+ind]
p.i = s + ind + 1
return internal.String(b), false
}
}
ind := len(p.b) - p.i
var alpha bool
for i, c := range p.b[p.i:] {
if isNum(c) {
continue
}
if isAlpha(c) || (i > 0 && alpha && c == '_') {
alpha = true
continue
}
ind = i
break
}
if ind == 0 {
return "", false
}
b := p.b[p.i : p.i+ind]
p.i += ind
return internal.String(b), !alpha
}
func (p *Parser) ReadNumber() int {
ind := len(p.b) - p.i
for i, c := range p.b[p.i:] {
if !isNum(c) {
ind = i
break
}
}
if ind == 0 {
return 0
}
n, err := strconv.Atoi(string(p.b[p.i : p.i+ind]))
if err != nil {
panic(err)
}
p.i += ind
return n
}
func isNum(c byte) bool {
return c >= '0' && c <= '9'
}
func isAlpha(c byte) bool {
return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z')
}

11
vendor/github.com/uptrace/bun/internal/safe.go generated vendored Normal file
View file

@ -0,0 +1,11 @@
// +build appengine
package internal
func String(b []byte) string {
return string(b)
}
func Bytes(s string) []byte {
return []byte(s)
}

View file

@ -0,0 +1,147 @@
package tagparser
import (
"strings"
)
type Tag struct {
Name string
Options map[string]string
}
func (t Tag) HasOption(name string) bool {
_, ok := t.Options[name]
return ok
}
func Parse(s string) Tag {
p := parser{
s: s,
}
p.parse()
return p.tag
}
type parser struct {
s string
i int
tag Tag
seenName bool // for empty names
}
func (p *parser) setName(name string) {
if p.seenName {
p.addOption(name, "")
} else {
p.seenName = true
p.tag.Name = name
}
}
func (p *parser) addOption(key, value string) {
p.seenName = true
if key == "" {
return
}
if p.tag.Options == nil {
p.tag.Options = make(map[string]string)
}
p.tag.Options[key] = value
}
func (p *parser) parse() {
for p.valid() {
p.parseKeyValue()
if p.peek() == ',' {
p.i++
}
}
}
func (p *parser) parseKeyValue() {
start := p.i
for p.valid() {
switch c := p.read(); c {
case ',':
key := p.s[start : p.i-1]
p.setName(key)
return
case ':':
key := p.s[start : p.i-1]
value := p.parseValue()
p.addOption(key, value)
return
case '"':
key := p.parseQuotedValue()
p.setName(key)
return
}
}
key := p.s[start:p.i]
p.setName(key)
}
func (p *parser) parseValue() string {
start := p.i
for p.valid() {
switch c := p.read(); c {
case '"':
return p.parseQuotedValue()
case ',':
return p.s[start : p.i-1]
}
}
if p.i == start {
return ""
}
return p.s[start:p.i]
}
func (p *parser) parseQuotedValue() string {
if i := strings.IndexByte(p.s[p.i:], '"'); i >= 0 && p.s[p.i+i-1] != '\\' {
s := p.s[p.i : p.i+i]
p.i += i + 1
return s
}
b := make([]byte, 0, 16)
for p.valid() {
switch c := p.read(); c {
case '\\':
b = append(b, p.read())
case '"':
return string(b)
default:
b = append(b, c)
}
}
return ""
}
func (p *parser) valid() bool {
return p.i < len(p.s)
}
func (p *parser) read() byte {
if !p.valid() {
return 0
}
c := p.s[p.i]
p.i++
return c
}
func (p *parser) peek() byte {
if !p.valid() {
return 0
}
c := p.s[p.i]
return c
}

41
vendor/github.com/uptrace/bun/internal/time.go generated vendored Normal file
View file

@ -0,0 +1,41 @@
package internal
import (
"fmt"
"time"
)
const (
dateFormat = "2006-01-02"
timeFormat = "15:04:05.999999999"
timestampFormat = "2006-01-02 15:04:05.999999999"
timestamptzFormat = "2006-01-02 15:04:05.999999999-07:00:00"
timestamptzFormat2 = "2006-01-02 15:04:05.999999999-07:00"
timestamptzFormat3 = "2006-01-02 15:04:05.999999999-07"
)
func ParseTime(s string) (time.Time, error) {
switch l := len(s); {
case l < len("15:04:05"):
return time.Time{}, fmt.Errorf("bun: can't parse time=%q", s)
case l <= len(timeFormat):
if s[2] == ':' {
return time.ParseInLocation(timeFormat, s, time.UTC)
}
return time.ParseInLocation(dateFormat, s, time.UTC)
default:
if s[10] == 'T' {
return time.Parse(time.RFC3339Nano, s)
}
if c := s[l-9]; c == '+' || c == '-' {
return time.Parse(timestamptzFormat, s)
}
if c := s[l-6]; c == '+' || c == '-' {
return time.Parse(timestamptzFormat2, s)
}
if c := s[l-3]; c == '+' || c == '-' {
return time.Parse(timestamptzFormat3, s)
}
return time.ParseInLocation(timestampFormat, s, time.UTC)
}
}

67
vendor/github.com/uptrace/bun/internal/underscore.go generated vendored Normal file
View file

@ -0,0 +1,67 @@
package internal
func IsUpper(c byte) bool {
return c >= 'A' && c <= 'Z'
}
func IsLower(c byte) bool {
return c >= 'a' && c <= 'z'
}
func ToUpper(c byte) byte {
return c - 32
}
func ToLower(c byte) byte {
return c + 32
}
// Underscore converts "CamelCasedString" to "camel_cased_string".
func Underscore(s string) string {
r := make([]byte, 0, len(s)+5)
for i := 0; i < len(s); i++ {
c := s[i]
if IsUpper(c) {
if i > 0 && i+1 < len(s) && (IsLower(s[i-1]) || IsLower(s[i+1])) {
r = append(r, '_', ToLower(c))
} else {
r = append(r, ToLower(c))
}
} else {
r = append(r, c)
}
}
return string(r)
}
func CamelCased(s string) string {
r := make([]byte, 0, len(s))
upperNext := true
for i := 0; i < len(s); i++ {
c := s[i]
if c == '_' {
upperNext = true
continue
}
if upperNext {
if IsLower(c) {
c = ToUpper(c)
}
upperNext = false
}
r = append(r, c)
}
return string(r)
}
func ToExported(s string) string {
if len(s) == 0 {
return s
}
if c := s[0]; IsLower(c) {
b := []byte(s)
b[0] = ToUpper(c)
return string(b)
}
return s
}

20
vendor/github.com/uptrace/bun/internal/unsafe.go generated vendored Normal file
View file

@ -0,0 +1,20 @@
// +build !appengine
package internal
import "unsafe"
// String converts byte slice to string.
func String(b []byte) string {
return *(*string)(unsafe.Pointer(&b))
}
// Bytes converts string to byte slice.
func Bytes(s string) []byte {
return *(*[]byte)(unsafe.Pointer(
&struct {
string
Cap int
}{s, len(s)},
))
}

57
vendor/github.com/uptrace/bun/internal/util.go generated vendored Normal file
View file

@ -0,0 +1,57 @@
package internal
import (
"reflect"
)
func MakeSliceNextElemFunc(v reflect.Value) func() reflect.Value {
if v.Kind() == reflect.Array {
var pos int
return func() reflect.Value {
v := v.Index(pos)
pos++
return v
}
}
elemType := v.Type().Elem()
if elemType.Kind() == reflect.Ptr {
elemType = elemType.Elem()
return func() reflect.Value {
if v.Len() < v.Cap() {
v.Set(v.Slice(0, v.Len()+1))
elem := v.Index(v.Len() - 1)
if elem.IsNil() {
elem.Set(reflect.New(elemType))
}
return elem.Elem()
}
elem := reflect.New(elemType)
v.Set(reflect.Append(v, elem))
return elem.Elem()
}
}
zero := reflect.Zero(elemType)
return func() reflect.Value {
if v.Len() < v.Cap() {
v.Set(v.Slice(0, v.Len()+1))
return v.Index(v.Len() - 1)
}
v.Set(reflect.Append(v, zero))
return v.Index(v.Len() - 1)
}
}
func Unwrap(err error) error {
u, ok := err.(interface {
Unwrap() error
})
if !ok {
return nil
}
return u.Unwrap()
}

308
vendor/github.com/uptrace/bun/join.go generated vendored Normal file
View file

@ -0,0 +1,308 @@
package bun
import (
"context"
"reflect"
"github.com/uptrace/bun/internal"
"github.com/uptrace/bun/schema"
)
type join struct {
Parent *join
BaseModel tableModel
JoinModel tableModel
Relation *schema.Relation
ApplyQueryFunc func(*SelectQuery) *SelectQuery
columns []schema.QueryWithArgs
}
func (j *join) applyQuery(q *SelectQuery) {
if j.ApplyQueryFunc == nil {
return
}
var table *schema.Table
var columns []schema.QueryWithArgs
// Save state.
table, q.table = q.table, j.JoinModel.Table()
columns, q.columns = q.columns, nil
q = j.ApplyQueryFunc(q)
// Restore state.
q.table = table
j.columns, q.columns = q.columns, columns
}
func (j *join) Select(ctx context.Context, q *SelectQuery) error {
switch j.Relation.Type {
case schema.HasManyRelation:
return j.selectMany(ctx, q)
case schema.ManyToManyRelation:
return j.selectM2M(ctx, q)
}
panic("not reached")
}
func (j *join) selectMany(ctx context.Context, q *SelectQuery) error {
q = j.manyQuery(q)
if q == nil {
return nil
}
return q.Scan(ctx)
}
func (j *join) manyQuery(q *SelectQuery) *SelectQuery {
hasManyModel := newHasManyModel(j)
if hasManyModel == nil {
return nil
}
q = q.Model(hasManyModel)
var where []byte
if len(j.Relation.JoinFields) > 1 {
where = append(where, '(')
}
where = appendColumns(where, j.JoinModel.Table().SQLAlias, j.Relation.JoinFields)
if len(j.Relation.JoinFields) > 1 {
where = append(where, ')')
}
where = append(where, " IN ("...)
where = appendChildValues(
q.db.Formatter(),
where,
j.JoinModel.Root(),
j.JoinModel.ParentIndex(),
j.Relation.BaseFields,
)
where = append(where, ")"...)
q = q.Where(internal.String(where))
if j.Relation.PolymorphicField != nil {
q = q.Where("? = ?", j.Relation.PolymorphicField.SQLName, j.Relation.PolymorphicValue)
}
j.applyQuery(q)
q = q.Apply(j.hasManyColumns)
return q
}
func (j *join) hasManyColumns(q *SelectQuery) *SelectQuery {
if j.Relation.M2MTable != nil {
q = q.ColumnExpr(string(j.Relation.M2MTable.SQLAlias) + ".*")
}
b := make([]byte, 0, 32)
if len(j.columns) > 0 {
for i, col := range j.columns {
if i > 0 {
b = append(b, ", "...)
}
var err error
b, err = col.AppendQuery(q.db.fmter, b)
if err != nil {
q.err = err
return q
}
}
} else {
joinTable := j.JoinModel.Table()
b = appendColumns(b, joinTable.SQLAlias, joinTable.Fields)
}
q = q.ColumnExpr(internal.String(b))
return q
}
func (j *join) selectM2M(ctx context.Context, q *SelectQuery) error {
q = j.m2mQuery(q)
if q == nil {
return nil
}
return q.Scan(ctx)
}
func (j *join) m2mQuery(q *SelectQuery) *SelectQuery {
fmter := q.db.fmter
m2mModel := newM2MModel(j)
if m2mModel == nil {
return nil
}
q = q.Model(m2mModel)
index := j.JoinModel.ParentIndex()
baseTable := j.BaseModel.Table()
//nolint
var join []byte
join = append(join, "JOIN "...)
join = fmter.AppendQuery(join, string(j.Relation.M2MTable.Name))
join = append(join, " AS "...)
join = append(join, j.Relation.M2MTable.SQLAlias...)
join = append(join, " ON ("...)
for i, col := range j.Relation.M2MBaseFields {
if i > 0 {
join = append(join, ", "...)
}
join = append(join, j.Relation.M2MTable.SQLAlias...)
join = append(join, '.')
join = append(join, col.SQLName...)
}
join = append(join, ") IN ("...)
join = appendChildValues(fmter, join, j.BaseModel.Root(), index, baseTable.PKs)
join = append(join, ")"...)
q = q.Join(internal.String(join))
joinTable := j.JoinModel.Table()
for i, m2mJoinField := range j.Relation.M2MJoinFields {
joinField := j.Relation.JoinFields[i]
q = q.Where("?.? = ?.?",
joinTable.SQLAlias, joinField.SQLName,
j.Relation.M2MTable.SQLAlias, m2mJoinField.SQLName)
}
j.applyQuery(q)
q = q.Apply(j.hasManyColumns)
return q
}
func (j *join) hasParent() bool {
if j.Parent != nil {
switch j.Parent.Relation.Type {
case schema.HasOneRelation, schema.BelongsToRelation:
return true
}
}
return false
}
func (j *join) appendAlias(fmter schema.Formatter, b []byte) []byte {
quote := fmter.IdentQuote()
b = append(b, quote)
b = appendAlias(b, j)
b = append(b, quote)
return b
}
func (j *join) appendAliasColumn(fmter schema.Formatter, b []byte, column string) []byte {
quote := fmter.IdentQuote()
b = append(b, quote)
b = appendAlias(b, j)
b = append(b, "__"...)
b = append(b, column...)
b = append(b, quote)
return b
}
func (j *join) appendBaseAlias(fmter schema.Formatter, b []byte) []byte {
quote := fmter.IdentQuote()
if j.hasParent() {
b = append(b, quote)
b = appendAlias(b, j.Parent)
b = append(b, quote)
return b
}
return append(b, j.BaseModel.Table().SQLAlias...)
}
func (j *join) appendSoftDelete(b []byte, flags internal.Flag) []byte {
b = append(b, '.')
b = append(b, j.JoinModel.Table().SoftDeleteField.SQLName...)
if flags.Has(deletedFlag) {
b = append(b, " IS NOT NULL"...)
} else {
b = append(b, " IS NULL"...)
}
return b
}
func appendAlias(b []byte, j *join) []byte {
if j.hasParent() {
b = appendAlias(b, j.Parent)
b = append(b, "__"...)
}
b = append(b, j.Relation.Field.Name...)
return b
}
func (j *join) appendHasOneJoin(
fmter schema.Formatter, b []byte, q *SelectQuery,
) (_ []byte, err error) {
isSoftDelete := j.JoinModel.Table().SoftDeleteField != nil && !q.flags.Has(allWithDeletedFlag)
b = append(b, "LEFT JOIN "...)
b = fmter.AppendQuery(b, string(j.JoinModel.Table().SQLNameForSelects))
b = append(b, " AS "...)
b = j.appendAlias(fmter, b)
b = append(b, " ON "...)
b = append(b, '(')
for i, baseField := range j.Relation.BaseFields {
if i > 0 {
b = append(b, " AND "...)
}
b = j.appendAlias(fmter, b)
b = append(b, '.')
b = append(b, j.Relation.JoinFields[i].SQLName...)
b = append(b, " = "...)
b = j.appendBaseAlias(fmter, b)
b = append(b, '.')
b = append(b, baseField.SQLName...)
}
b = append(b, ')')
if isSoftDelete {
b = append(b, " AND "...)
b = j.appendAlias(fmter, b)
b = j.appendSoftDelete(b, q.flags)
}
return b, nil
}
func appendChildValues(
fmter schema.Formatter, b []byte, v reflect.Value, index []int, fields []*schema.Field,
) []byte {
seen := make(map[string]struct{})
walk(v, index, func(v reflect.Value) {
start := len(b)
if len(fields) > 1 {
b = append(b, '(')
}
for i, f := range fields {
if i > 0 {
b = append(b, ", "...)
}
b = f.AppendValue(fmter, b, v)
}
if len(fields) > 1 {
b = append(b, ')')
}
b = append(b, ", "...)
if _, ok := seen[string(b[start:])]; ok {
b = b[:start]
} else {
seen[string(b[start:])] = struct{}{}
}
})
if len(seen) > 0 {
b = b[:len(b)-2] // trim ", "
}
return b
}

207
vendor/github.com/uptrace/bun/model.go generated vendored Normal file
View file

@ -0,0 +1,207 @@
package bun
import (
"context"
"database/sql"
"errors"
"fmt"
"reflect"
"time"
"github.com/uptrace/bun/schema"
)
var errNilModel = errors.New("bun: Model(nil)")
var timeType = reflect.TypeOf((*time.Time)(nil)).Elem()
type Model interface {
ScanRows(ctx context.Context, rows *sql.Rows) (int, error)
Value() interface{}
}
type rowScanner interface {
ScanRow(ctx context.Context, rows *sql.Rows) error
}
type model interface {
Model
}
type tableModel interface {
model
schema.BeforeScanHook
schema.AfterScanHook
ScanColumn(column string, src interface{}) error
Table() *schema.Table
Relation() *schema.Relation
Join(string, func(*SelectQuery) *SelectQuery) *join
GetJoin(string) *join
GetJoins() []join
AddJoin(join) *join
Root() reflect.Value
ParentIndex() []int
Mount(reflect.Value)
updateSoftDeleteField() error
}
func newModel(db *DB, dest []interface{}) (model, error) {
if len(dest) == 1 {
return _newModel(db, dest[0], true)
}
values := make([]reflect.Value, len(dest))
for i, el := range dest {
v := reflect.ValueOf(el)
if v.Kind() != reflect.Ptr {
return nil, fmt.Errorf("bun: Scan(non-pointer %T)", dest)
}
v = v.Elem()
if v.Kind() != reflect.Slice {
return newScanModel(db, dest), nil
}
values[i] = v
}
return newSliceModel(db, dest, values), nil
}
func newSingleModel(db *DB, dest interface{}) (model, error) {
return _newModel(db, dest, false)
}
func _newModel(db *DB, dest interface{}, scan bool) (model, error) {
switch dest := dest.(type) {
case nil:
return nil, errNilModel
case Model:
return dest, nil
case sql.Scanner:
if !scan {
return nil, fmt.Errorf("bun: Model(unsupported %T)", dest)
}
return newScanModel(db, []interface{}{dest}), nil
}
v := reflect.ValueOf(dest)
if !v.IsValid() {
return nil, errNilModel
}
if v.Kind() != reflect.Ptr {
return nil, fmt.Errorf("bun: Model(non-pointer %T)", dest)
}
if v.IsNil() {
typ := v.Type().Elem()
if typ.Kind() == reflect.Struct {
return newStructTableModel(db, dest, db.Table(typ)), nil
}
return nil, fmt.Errorf("bun: Model(nil %T)", dest)
}
v = v.Elem()
switch v.Kind() {
case reflect.Map:
typ := v.Type()
if err := validMap(typ); err != nil {
return nil, err
}
mapPtr := v.Addr().Interface().(*map[string]interface{})
return newMapModel(db, mapPtr), nil
case reflect.Struct:
if v.Type() != timeType {
return newStructTableModelValue(db, dest, v), nil
}
case reflect.Slice:
switch elemType := sliceElemType(v); elemType.Kind() {
case reflect.Struct:
if elemType != timeType {
return newSliceTableModel(db, dest, v, elemType), nil
}
case reflect.Map:
if err := validMap(elemType); err != nil {
return nil, err
}
slicePtr := v.Addr().Interface().(*[]map[string]interface{})
return newMapSliceModel(db, slicePtr), nil
}
return newSliceModel(db, []interface{}{dest}, []reflect.Value{v}), nil
}
if scan {
return newScanModel(db, []interface{}{dest}), nil
}
return nil, fmt.Errorf("bun: Model(unsupported %T)", dest)
}
func newTableModelIndex(
db *DB,
table *schema.Table,
root reflect.Value,
index []int,
rel *schema.Relation,
) (tableModel, error) {
typ := typeByIndex(table.Type, index)
if typ.Kind() == reflect.Struct {
return &structTableModel{
db: db,
table: table.Dialect().Tables().Get(typ),
rel: rel,
root: root,
index: index,
}, nil
}
if typ.Kind() == reflect.Slice {
structType := indirectType(typ.Elem())
if structType.Kind() == reflect.Struct {
m := sliceTableModel{
structTableModel: structTableModel{
db: db,
table: table.Dialect().Tables().Get(structType),
rel: rel,
root: root,
index: index,
},
}
m.init(typ)
return &m, nil
}
}
return nil, fmt.Errorf("bun: NewModel(%s)", typ)
}
func validMap(typ reflect.Type) error {
if typ.Key().Kind() != reflect.String || typ.Elem().Kind() != reflect.Interface {
return fmt.Errorf("bun: Model(unsupported %s) (expected *map[string]interface{})",
typ)
}
return nil
}
//------------------------------------------------------------------------------
func isSingleRowModel(m model) bool {
switch m.(type) {
case *mapModel,
*structTableModel,
*scanModel:
return true
default:
return false
}
}

183
vendor/github.com/uptrace/bun/model_map.go generated vendored Normal file
View file

@ -0,0 +1,183 @@
package bun
import (
"context"
"database/sql"
"reflect"
"sort"
"github.com/uptrace/bun/schema"
)
type mapModel struct {
db *DB
dest *map[string]interface{}
m map[string]interface{}
rows *sql.Rows
columns []string
_columnTypes []*sql.ColumnType
scanIndex int
}
var _ model = (*mapModel)(nil)
func newMapModel(db *DB, dest *map[string]interface{}) *mapModel {
m := &mapModel{
db: db,
dest: dest,
}
if dest != nil {
m.m = *dest
}
return m
}
func (m *mapModel) Value() interface{} {
return m.dest
}
func (m *mapModel) ScanRows(ctx context.Context, rows *sql.Rows) (int, error) {
if !rows.Next() {
return 0, rows.Err()
}
columns, err := rows.Columns()
if err != nil {
return 0, err
}
m.rows = rows
m.columns = columns
dest := makeDest(m, len(columns))
if m.m == nil {
m.m = make(map[string]interface{}, len(m.columns))
}
m.scanIndex = 0
if err := rows.Scan(dest...); err != nil {
return 0, err
}
*m.dest = m.m
return 1, nil
}
func (m *mapModel) Scan(src interface{}) error {
if _, ok := src.([]byte); !ok {
return m.scanRaw(src)
}
columnTypes, err := m.columnTypes()
if err != nil {
return err
}
scanType := columnTypes[m.scanIndex].ScanType()
switch scanType.Kind() {
case reflect.Interface:
return m.scanRaw(src)
case reflect.Slice:
if scanType.Elem().Kind() == reflect.Uint8 {
return m.scanRaw(src)
}
}
dest := reflect.New(scanType).Elem()
if err := schema.Scanner(scanType)(dest, src); err != nil {
return err
}
return m.scanRaw(dest.Interface())
}
func (m *mapModel) columnTypes() ([]*sql.ColumnType, error) {
if m._columnTypes == nil {
columnTypes, err := m.rows.ColumnTypes()
if err != nil {
return nil, err
}
m._columnTypes = columnTypes
}
return m._columnTypes, nil
}
func (m *mapModel) scanRaw(src interface{}) error {
columnName := m.columns[m.scanIndex]
m.scanIndex++
m.m[columnName] = src
return nil
}
func (m *mapModel) appendColumnsValues(fmter schema.Formatter, b []byte) []byte {
keys := make([]string, 0, len(m.m))
for k := range m.m {
keys = append(keys, k)
}
sort.Strings(keys)
b = append(b, " ("...)
for i, k := range keys {
if i > 0 {
b = append(b, ", "...)
}
b = fmter.AppendIdent(b, k)
}
b = append(b, ") VALUES ("...)
isTemplate := fmter.IsNop()
for i, k := range keys {
if i > 0 {
b = append(b, ", "...)
}
if isTemplate {
b = append(b, '?')
} else {
b = fmter.Dialect().Append(fmter, b, m.m[k])
}
}
b = append(b, ")"...)
return b
}
func (m *mapModel) appendSet(fmter schema.Formatter, b []byte) []byte {
keys := make([]string, 0, len(m.m))
for k := range m.m {
keys = append(keys, k)
}
sort.Strings(keys)
isTemplate := fmter.IsNop()
for i, k := range keys {
if i > 0 {
b = append(b, ", "...)
}
b = fmter.AppendIdent(b, k)
b = append(b, " = "...)
if isTemplate {
b = append(b, '?')
} else {
b = fmter.Dialect().Append(fmter, b, m.m[k])
}
}
return b
}
func makeDest(v interface{}, n int) []interface{} {
dest := make([]interface{}, n)
for i := range dest {
dest[i] = v
}
return dest
}

162
vendor/github.com/uptrace/bun/model_map_slice.go generated vendored Normal file
View file

@ -0,0 +1,162 @@
package bun
import (
"context"
"database/sql"
"errors"
"sort"
"github.com/uptrace/bun/dialect/feature"
"github.com/uptrace/bun/schema"
)
type mapSliceModel struct {
mapModel
dest *[]map[string]interface{}
keys []string
}
var _ model = (*mapSliceModel)(nil)
func newMapSliceModel(db *DB, dest *[]map[string]interface{}) *mapSliceModel {
return &mapSliceModel{
mapModel: mapModel{
db: db,
},
dest: dest,
}
}
func (m *mapSliceModel) Value() interface{} {
return m.dest
}
func (m *mapSliceModel) SetCap(cap int) {
if cap > 100 {
cap = 100
}
if slice := *m.dest; len(slice) < cap {
*m.dest = make([]map[string]interface{}, 0, cap)
}
}
func (m *mapSliceModel) ScanRows(ctx context.Context, rows *sql.Rows) (int, error) {
columns, err := rows.Columns()
if err != nil {
return 0, err
}
m.rows = rows
m.columns = columns
dest := makeDest(m, len(columns))
slice := *m.dest
if len(slice) > 0 {
slice = slice[:0]
}
var n int
for rows.Next() {
m.m = make(map[string]interface{}, len(m.columns))
m.scanIndex = 0
if err := rows.Scan(dest...); err != nil {
return 0, err
}
slice = append(slice, m.m)
n++
}
if err := rows.Err(); err != nil {
return 0, err
}
*m.dest = slice
return n, nil
}
func (m *mapSliceModel) appendColumns(fmter schema.Formatter, b []byte) (_ []byte, err error) {
if err := m.initKeys(); err != nil {
return nil, err
}
for i, k := range m.keys {
if i > 0 {
b = append(b, ", "...)
}
b = fmter.AppendIdent(b, k)
}
return b, nil
}
func (m *mapSliceModel) appendValues(fmter schema.Formatter, b []byte) (_ []byte, err error) {
if err := m.initKeys(); err != nil {
return nil, err
}
slice := *m.dest
b = append(b, "VALUES "...)
if m.db.features.Has(feature.ValuesRow) {
b = append(b, "ROW("...)
} else {
b = append(b, '(')
}
if fmter.IsNop() {
for i := range m.keys {
if i > 0 {
b = append(b, ", "...)
}
b = append(b, '?')
}
return b, nil
}
for i, el := range slice {
if i > 0 {
b = append(b, "), "...)
if m.db.features.Has(feature.ValuesRow) {
b = append(b, "ROW("...)
} else {
b = append(b, '(')
}
}
for j, key := range m.keys {
if j > 0 {
b = append(b, ", "...)
}
b = fmter.Dialect().Append(fmter, b, el[key])
}
}
b = append(b, ')')
return b, nil
}
func (m *mapSliceModel) initKeys() error {
if m.keys != nil {
return nil
}
slice := *m.dest
if len(slice) == 0 {
return errors.New("bun: map slice is empty")
}
first := slice[0]
keys := make([]string, 0, len(first))
for k := range first {
keys = append(keys, k)
}
sort.Strings(keys)
m.keys = keys
return nil
}

54
vendor/github.com/uptrace/bun/model_scan.go generated vendored Normal file
View file

@ -0,0 +1,54 @@
package bun
import (
"context"
"database/sql"
"reflect"
)
type scanModel struct {
db *DB
dest []interface{}
scanIndex int
}
var _ model = (*scanModel)(nil)
func newScanModel(db *DB, dest []interface{}) *scanModel {
return &scanModel{
db: db,
dest: dest,
}
}
func (m *scanModel) Value() interface{} {
return m.dest
}
func (m *scanModel) ScanRows(ctx context.Context, rows *sql.Rows) (int, error) {
if !rows.Next() {
return 0, rows.Err()
}
dest := makeDest(m, len(m.dest))
m.scanIndex = 0
if err := rows.Scan(dest...); err != nil {
return 0, err
}
return 1, nil
}
func (m *scanModel) ScanRow(ctx context.Context, rows *sql.Rows) error {
return rows.Scan(m.dest...)
}
func (m *scanModel) Scan(src interface{}) error {
dest := reflect.ValueOf(m.dest[m.scanIndex])
m.scanIndex++
scanner := m.db.dialect.Scanner(dest.Type())
return scanner(dest, src)
}

82
vendor/github.com/uptrace/bun/model_slice.go generated vendored Normal file
View file

@ -0,0 +1,82 @@
package bun
import (
"context"
"database/sql"
"reflect"
"github.com/uptrace/bun/internal"
"github.com/uptrace/bun/schema"
)
type sliceInfo struct {
nextElem func() reflect.Value
scan schema.ScannerFunc
}
type sliceModel struct {
dest []interface{}
values []reflect.Value
scanIndex int
info []sliceInfo
}
var _ model = (*sliceModel)(nil)
func newSliceModel(db *DB, dest []interface{}, values []reflect.Value) *sliceModel {
return &sliceModel{
dest: dest,
values: values,
}
}
func (m *sliceModel) Value() interface{} {
return m.dest
}
func (m *sliceModel) ScanRows(ctx context.Context, rows *sql.Rows) (int, error) {
columns, err := rows.Columns()
if err != nil {
return 0, err
}
m.info = make([]sliceInfo, len(m.values))
for i, v := range m.values {
if v.IsValid() && v.Len() > 0 {
v.Set(v.Slice(0, 0))
}
m.info[i] = sliceInfo{
nextElem: internal.MakeSliceNextElemFunc(v),
scan: schema.Scanner(v.Type().Elem()),
}
}
if len(columns) == 0 {
return 0, nil
}
dest := makeDest(m, len(columns))
var n int
for rows.Next() {
m.scanIndex = 0
if err := rows.Scan(dest...); err != nil {
return 0, err
}
n++
}
if err := rows.Err(); err != nil {
return 0, err
}
return n, nil
}
func (m *sliceModel) Scan(src interface{}) error {
info := m.info[m.scanIndex]
m.scanIndex++
dest := info.nextElem()
return info.scan(dest, src)
}

149
vendor/github.com/uptrace/bun/model_table_has_many.go generated vendored Normal file
View file

@ -0,0 +1,149 @@
package bun
import (
"context"
"database/sql"
"fmt"
"reflect"
"github.com/uptrace/bun/internal"
"github.com/uptrace/bun/schema"
)
type hasManyModel struct {
*sliceTableModel
baseTable *schema.Table
rel *schema.Relation
baseValues map[internal.MapKey][]reflect.Value
structKey []interface{}
}
var _ tableModel = (*hasManyModel)(nil)
func newHasManyModel(j *join) *hasManyModel {
baseTable := j.BaseModel.Table()
joinModel := j.JoinModel.(*sliceTableModel)
baseValues := baseValues(joinModel, j.Relation.BaseFields)
if len(baseValues) == 0 {
return nil
}
m := hasManyModel{
sliceTableModel: joinModel,
baseTable: baseTable,
rel: j.Relation,
baseValues: baseValues,
}
if !m.sliceOfPtr {
m.strct = reflect.New(m.table.Type).Elem()
}
return &m
}
func (m *hasManyModel) ScanRows(ctx context.Context, rows *sql.Rows) (int, error) {
columns, err := rows.Columns()
if err != nil {
return 0, err
}
m.columns = columns
dest := makeDest(m, len(columns))
var n int
for rows.Next() {
if m.sliceOfPtr {
m.strct = reflect.New(m.table.Type).Elem()
} else {
m.strct.Set(m.table.ZeroValue)
}
m.structInited = false
m.scanIndex = 0
m.structKey = m.structKey[:0]
if err := rows.Scan(dest...); err != nil {
return 0, err
}
if err := m.parkStruct(); err != nil {
return 0, err
}
n++
}
if err := rows.Err(); err != nil {
return 0, err
}
return n, nil
}
func (m *hasManyModel) Scan(src interface{}) error {
column := m.columns[m.scanIndex]
m.scanIndex++
field, err := m.table.Field(column)
if err != nil {
return err
}
if err := field.ScanValue(m.strct, src); err != nil {
return err
}
for _, f := range m.rel.JoinFields {
if f.Name == field.Name {
m.structKey = append(m.structKey, field.Value(m.strct).Interface())
break
}
}
return nil
}
func (m *hasManyModel) parkStruct() error {
baseValues, ok := m.baseValues[internal.NewMapKey(m.structKey)]
if !ok {
return fmt.Errorf(
"bun: has-many relation=%s does not have base %s with id=%q (check join conditions)",
m.rel.Field.GoName, m.baseTable, m.structKey)
}
for i, v := range baseValues {
if !m.sliceOfPtr {
v.Set(reflect.Append(v, m.strct))
continue
}
if i == 0 {
v.Set(reflect.Append(v, m.strct.Addr()))
continue
}
clone := reflect.New(m.strct.Type()).Elem()
clone.Set(m.strct)
v.Set(reflect.Append(v, clone.Addr()))
}
return nil
}
func baseValues(model tableModel, fields []*schema.Field) map[internal.MapKey][]reflect.Value {
fieldIndex := model.Relation().Field.Index
m := make(map[internal.MapKey][]reflect.Value)
key := make([]interface{}, 0, len(fields))
walk(model.Root(), model.ParentIndex(), func(v reflect.Value) {
key = modelKey(key[:0], v, fields)
mapKey := internal.NewMapKey(key)
m[mapKey] = append(m[mapKey], v.FieldByIndex(fieldIndex))
})
return m
}
func modelKey(key []interface{}, strct reflect.Value, fields []*schema.Field) []interface{} {
for _, f := range fields {
key = append(key, f.Value(strct).Interface())
}
return key
}

138
vendor/github.com/uptrace/bun/model_table_m2m.go generated vendored Normal file
View file

@ -0,0 +1,138 @@
package bun
import (
"context"
"database/sql"
"fmt"
"reflect"
"github.com/uptrace/bun/internal"
"github.com/uptrace/bun/schema"
)
type m2mModel struct {
*sliceTableModel
baseTable *schema.Table
rel *schema.Relation
baseValues map[internal.MapKey][]reflect.Value
structKey []interface{}
}
var _ tableModel = (*m2mModel)(nil)
func newM2MModel(j *join) *m2mModel {
baseTable := j.BaseModel.Table()
joinModel := j.JoinModel.(*sliceTableModel)
baseValues := baseValues(joinModel, baseTable.PKs)
if len(baseValues) == 0 {
return nil
}
m := &m2mModel{
sliceTableModel: joinModel,
baseTable: baseTable,
rel: j.Relation,
baseValues: baseValues,
}
if !m.sliceOfPtr {
m.strct = reflect.New(m.table.Type).Elem()
}
return m
}
func (m *m2mModel) ScanRows(ctx context.Context, rows *sql.Rows) (int, error) {
columns, err := rows.Columns()
if err != nil {
return 0, err
}
m.columns = columns
dest := makeDest(m, len(columns))
var n int
for rows.Next() {
if m.sliceOfPtr {
m.strct = reflect.New(m.table.Type).Elem()
} else {
m.strct.Set(m.table.ZeroValue)
}
m.structInited = false
m.scanIndex = 0
m.structKey = m.structKey[:0]
if err := rows.Scan(dest...); err != nil {
return 0, err
}
if err := m.parkStruct(); err != nil {
return 0, err
}
n++
}
if err := rows.Err(); err != nil {
return 0, err
}
return n, nil
}
func (m *m2mModel) Scan(src interface{}) error {
column := m.columns[m.scanIndex]
m.scanIndex++
field, ok := m.table.FieldMap[column]
if !ok {
return m.scanM2MColumn(column, src)
}
if err := field.ScanValue(m.strct, src); err != nil {
return err
}
for _, fk := range m.rel.M2MBaseFields {
if fk.Name == field.Name {
m.structKey = append(m.structKey, field.Value(m.strct).Interface())
break
}
}
return nil
}
func (m *m2mModel) scanM2MColumn(column string, src interface{}) error {
for _, field := range m.rel.M2MBaseFields {
if field.Name == column {
dest := reflect.New(field.IndirectType).Elem()
if err := field.Scan(dest, src); err != nil {
return err
}
m.structKey = append(m.structKey, dest.Interface())
break
}
}
_, err := m.scanColumn(column, src)
return err
}
func (m *m2mModel) parkStruct() error {
baseValues, ok := m.baseValues[internal.NewMapKey(m.structKey)]
if !ok {
return fmt.Errorf(
"bun: m2m relation=%s does not have base %s with key=%q (check join conditions)",
m.rel.Field.GoName, m.baseTable, m.structKey)
}
for _, v := range baseValues {
if m.sliceOfPtr {
v.Set(reflect.Append(v, m.strct.Addr()))
} else {
v.Set(reflect.Append(v, m.strct))
}
}
return nil
}

113
vendor/github.com/uptrace/bun/model_table_slice.go generated vendored Normal file
View file

@ -0,0 +1,113 @@
package bun
import (
"context"
"database/sql"
"reflect"
"github.com/uptrace/bun/schema"
)
type sliceTableModel struct {
structTableModel
slice reflect.Value
sliceLen int
sliceOfPtr bool
nextElem func() reflect.Value
}
var _ tableModel = (*sliceTableModel)(nil)
func newSliceTableModel(
db *DB, dest interface{}, slice reflect.Value, elemType reflect.Type,
) *sliceTableModel {
m := &sliceTableModel{
structTableModel: structTableModel{
db: db,
table: db.Table(elemType),
dest: dest,
root: slice,
},
slice: slice,
sliceLen: slice.Len(),
nextElem: makeSliceNextElemFunc(slice),
}
m.init(slice.Type())
return m
}
func (m *sliceTableModel) init(sliceType reflect.Type) {
switch sliceType.Elem().Kind() {
case reflect.Ptr, reflect.Interface:
m.sliceOfPtr = true
}
}
func (m *sliceTableModel) Join(name string, apply func(*SelectQuery) *SelectQuery) *join {
return m.join(m.slice, name, apply)
}
func (m *sliceTableModel) Bind(bind reflect.Value) {
m.slice = bind.Field(m.index[len(m.index)-1])
}
func (m *sliceTableModel) SetCap(cap int) {
if cap > 100 {
cap = 100
}
if m.slice.Cap() < cap {
m.slice.Set(reflect.MakeSlice(m.slice.Type(), 0, cap))
}
}
func (m *sliceTableModel) ScanRows(ctx context.Context, rows *sql.Rows) (int, error) {
columns, err := rows.Columns()
if err != nil {
return 0, err
}
m.columns = columns
dest := makeDest(m, len(columns))
if m.slice.IsValid() && m.slice.Len() > 0 {
m.slice.Set(m.slice.Slice(0, 0))
}
var n int
for rows.Next() {
m.strct = m.nextElem()
m.structInited = false
if err := m.scanRow(ctx, rows, dest); err != nil {
return 0, err
}
n++
}
if err := rows.Err(); err != nil {
return 0, err
}
return n, nil
}
// Inherit these hooks from structTableModel.
var (
_ schema.BeforeScanHook = (*sliceTableModel)(nil)
_ schema.AfterScanHook = (*sliceTableModel)(nil)
)
func (m *sliceTableModel) updateSoftDeleteField() error {
sliceLen := m.slice.Len()
for i := 0; i < sliceLen; i++ {
strct := indirect(m.slice.Index(i))
fv := m.table.SoftDeleteField.Value(strct)
if err := m.table.UpdateSoftDeleteField(fv); err != nil {
return err
}
}
return nil
}

345
vendor/github.com/uptrace/bun/model_table_struct.go generated vendored Normal file
View file

@ -0,0 +1,345 @@
package bun
import (
"context"
"database/sql"
"fmt"
"reflect"
"strings"
"github.com/uptrace/bun/dialect"
"github.com/uptrace/bun/schema"
)
type structTableModel struct {
db *DB
table *schema.Table
rel *schema.Relation
joins []join
dest interface{}
root reflect.Value
index []int
strct reflect.Value
structInited bool
structInitErr error
columns []string
scanIndex int
}
var _ tableModel = (*structTableModel)(nil)
func newStructTableModel(db *DB, dest interface{}, table *schema.Table) *structTableModel {
return &structTableModel{
db: db,
table: table,
dest: dest,
}
}
func newStructTableModelValue(db *DB, dest interface{}, v reflect.Value) *structTableModel {
return &structTableModel{
db: db,
table: db.Table(v.Type()),
dest: dest,
root: v,
strct: v,
}
}
func (m *structTableModel) Value() interface{} {
return m.dest
}
func (m *structTableModel) Table() *schema.Table {
return m.table
}
func (m *structTableModel) Relation() *schema.Relation {
return m.rel
}
func (m *structTableModel) Root() reflect.Value {
return m.root
}
func (m *structTableModel) Index() []int {
return m.index
}
func (m *structTableModel) ParentIndex() []int {
return m.index[:len(m.index)-len(m.rel.Field.Index)]
}
func (m *structTableModel) Mount(host reflect.Value) {
m.strct = host.FieldByIndex(m.rel.Field.Index)
m.structInited = false
}
func (m *structTableModel) initStruct() error {
if m.structInited {
return m.structInitErr
}
m.structInited = true
switch m.strct.Kind() {
case reflect.Invalid:
m.structInitErr = errNilModel
return m.structInitErr
case reflect.Interface:
m.strct = m.strct.Elem()
}
if m.strct.Kind() == reflect.Ptr {
if m.strct.IsNil() {
m.strct.Set(reflect.New(m.strct.Type().Elem()))
m.strct = m.strct.Elem()
} else {
m.strct = m.strct.Elem()
}
}
m.mountJoins()
return nil
}
func (m *structTableModel) mountJoins() {
for i := range m.joins {
j := &m.joins[i]
switch j.Relation.Type {
case schema.HasOneRelation, schema.BelongsToRelation:
j.JoinModel.Mount(m.strct)
}
}
}
var _ schema.BeforeScanHook = (*structTableModel)(nil)
func (m *structTableModel) BeforeScan(ctx context.Context) error {
if !m.table.HasBeforeScanHook() {
return nil
}
return callBeforeScanHook(ctx, m.strct.Addr())
}
var _ schema.AfterScanHook = (*structTableModel)(nil)
func (m *structTableModel) AfterScan(ctx context.Context) error {
if !m.table.HasAfterScanHook() || !m.structInited {
return nil
}
var firstErr error
if err := callAfterScanHook(ctx, m.strct.Addr()); err != nil && firstErr == nil {
firstErr = err
}
for _, j := range m.joins {
switch j.Relation.Type {
case schema.HasOneRelation, schema.BelongsToRelation:
if err := j.JoinModel.AfterScan(ctx); err != nil && firstErr == nil {
firstErr = err
}
}
}
return firstErr
}
func (m *structTableModel) GetJoin(name string) *join {
for i := range m.joins {
j := &m.joins[i]
if j.Relation.Field.Name == name || j.Relation.Field.GoName == name {
return j
}
}
return nil
}
func (m *structTableModel) GetJoins() []join {
return m.joins
}
func (m *structTableModel) AddJoin(j join) *join {
m.joins = append(m.joins, j)
return &m.joins[len(m.joins)-1]
}
func (m *structTableModel) Join(name string, apply func(*SelectQuery) *SelectQuery) *join {
return m.join(m.strct, name, apply)
}
func (m *structTableModel) join(
bind reflect.Value, name string, apply func(*SelectQuery) *SelectQuery,
) *join {
path := strings.Split(name, ".")
index := make([]int, 0, len(path))
currJoin := join{
BaseModel: m,
JoinModel: m,
}
var lastJoin *join
for _, name := range path {
relation, ok := currJoin.JoinModel.Table().Relations[name]
if !ok {
return nil
}
currJoin.Relation = relation
index = append(index, relation.Field.Index...)
if j := currJoin.JoinModel.GetJoin(name); j != nil {
currJoin.BaseModel = j.BaseModel
currJoin.JoinModel = j.JoinModel
lastJoin = j
} else {
model, err := newTableModelIndex(m.db, m.table, bind, index, relation)
if err != nil {
return nil
}
currJoin.Parent = lastJoin
currJoin.BaseModel = currJoin.JoinModel
currJoin.JoinModel = model
lastJoin = currJoin.BaseModel.AddJoin(currJoin)
}
}
// No joins with such name.
if lastJoin == nil {
return nil
}
if apply != nil {
lastJoin.ApplyQueryFunc = apply
}
return lastJoin
}
func (m *structTableModel) updateSoftDeleteField() error {
fv := m.table.SoftDeleteField.Value(m.strct)
return m.table.UpdateSoftDeleteField(fv)
}
func (m *structTableModel) ScanRows(ctx context.Context, rows *sql.Rows) (int, error) {
if !rows.Next() {
return 0, rows.Err()
}
if err := m.ScanRow(ctx, rows); err != nil {
return 0, err
}
// For inserts, SQLite3 can return a row like it was inserted sucessfully and then
// an actual error for the next row. See issues/100.
if m.db.dialect.Name() == dialect.SQLite {
_ = rows.Next()
if err := rows.Err(); err != nil {
return 0, err
}
}
return 1, nil
}
func (m *structTableModel) ScanRow(ctx context.Context, rows *sql.Rows) error {
columns, err := rows.Columns()
if err != nil {
return err
}
m.columns = columns
dest := makeDest(m, len(columns))
return m.scanRow(ctx, rows, dest)
}
func (m *structTableModel) scanRow(ctx context.Context, rows *sql.Rows, dest []interface{}) error {
if err := m.BeforeScan(ctx); err != nil {
return err
}
m.scanIndex = 0
if err := rows.Scan(dest...); err != nil {
return err
}
if err := m.AfterScan(ctx); err != nil {
return err
}
return nil
}
func (m *structTableModel) Scan(src interface{}) error {
column := m.columns[m.scanIndex]
m.scanIndex++
return m.ScanColumn(unquote(column), src)
}
func (m *structTableModel) ScanColumn(column string, src interface{}) error {
if ok, err := m.scanColumn(column, src); ok {
return err
}
if column == "" || column[0] == '_' || m.db.flags.Has(discardUnknownColumns) {
return nil
}
return fmt.Errorf("bun: %s does not have column %q", m.table.TypeName, column)
}
func (m *structTableModel) scanColumn(column string, src interface{}) (bool, error) {
if src != nil {
if err := m.initStruct(); err != nil {
return true, err
}
}
if field, ok := m.table.FieldMap[column]; ok {
return true, field.ScanValue(m.strct, src)
}
if joinName, column := splitColumn(column); joinName != "" {
if join := m.GetJoin(joinName); join != nil {
return true, join.JoinModel.ScanColumn(column, src)
}
if m.table.ModelName == joinName {
return true, m.ScanColumn(column, src)
}
}
return false, nil
}
func (m *structTableModel) AppendNamedArg(
fmter schema.Formatter, b []byte, name string,
) ([]byte, bool) {
return m.table.AppendNamedArg(fmter, b, name, m.strct)
}
// sqlite3 sometimes does not unquote columns.
func unquote(s string) string {
if s == "" {
return s
}
if s[0] == '"' && s[len(s)-1] == '"' {
return s[1 : len(s)-1]
}
return s
}
func splitColumn(s string) (string, string) {
if i := strings.Index(s, "__"); i >= 0 {
return s[:i], s[i+2:]
}
return "", s
}

874
vendor/github.com/uptrace/bun/query_base.go generated vendored Normal file
View file

@ -0,0 +1,874 @@
package bun
import (
"context"
"database/sql"
"errors"
"fmt"
"github.com/uptrace/bun/dialect/feature"
"github.com/uptrace/bun/internal"
"github.com/uptrace/bun/schema"
)
const (
wherePKFlag internal.Flag = 1 << iota
forceDeleteFlag
deletedFlag
allWithDeletedFlag
)
type withQuery struct {
name string
query schema.QueryAppender
}
// IConn is a common interface for *sql.DB, *sql.Conn, and *sql.Tx.
type IConn interface {
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
}
var (
_ IConn = (*sql.DB)(nil)
_ IConn = (*sql.Conn)(nil)
_ IConn = (*sql.Tx)(nil)
_ IConn = (*DB)(nil)
_ IConn = (*Conn)(nil)
_ IConn = (*Tx)(nil)
)
// IDB is a common interface for *bun.DB, bun.Conn, and bun.Tx.
type IDB interface {
IConn
NewValues(model interface{}) *ValuesQuery
NewSelect() *SelectQuery
NewInsert() *InsertQuery
NewUpdate() *UpdateQuery
NewDelete() *DeleteQuery
NewCreateTable() *CreateTableQuery
NewDropTable() *DropTableQuery
NewCreateIndex() *CreateIndexQuery
NewDropIndex() *DropIndexQuery
NewTruncateTable() *TruncateTableQuery
NewAddColumn() *AddColumnQuery
NewDropColumn() *DropColumnQuery
}
var (
_ IConn = (*DB)(nil)
_ IConn = (*Conn)(nil)
_ IConn = (*Tx)(nil)
)
type baseQuery struct {
db *DB
conn IConn
model model
err error
tableModel tableModel
table *schema.Table
with []withQuery
modelTable schema.QueryWithArgs
tables []schema.QueryWithArgs
columns []schema.QueryWithArgs
flags internal.Flag
}
func (q *baseQuery) DB() *DB {
return q.db
}
func (q *baseQuery) GetModel() Model {
return q.model
}
func (q *baseQuery) setConn(db IConn) {
// Unwrap Bun wrappers to not call query hooks twice.
switch db := db.(type) {
case *DB:
q.conn = db.DB
case Conn:
q.conn = db.Conn
case Tx:
q.conn = db.Tx
default:
q.conn = db
}
}
// TODO: rename to setModel
func (q *baseQuery) setTableModel(modeli interface{}) {
model, err := newSingleModel(q.db, modeli)
if err != nil {
q.setErr(err)
return
}
q.model = model
if tm, ok := model.(tableModel); ok {
q.tableModel = tm
q.table = tm.Table()
}
}
func (q *baseQuery) setErr(err error) {
if q.err == nil {
q.err = err
}
}
func (q *baseQuery) getModel(dest []interface{}) (model, error) {
if len(dest) == 0 {
if q.model != nil {
return q.model, nil
}
return nil, errNilModel
}
return newModel(q.db, dest)
}
//------------------------------------------------------------------------------
func (q *baseQuery) checkSoftDelete() error {
if q.table == nil {
return errors.New("bun: can't use soft deletes without a table")
}
if q.table.SoftDeleteField == nil {
return fmt.Errorf("%s does not have a soft delete field", q.table)
}
if q.tableModel == nil {
return errors.New("bun: can't use soft deletes without a table model")
}
return nil
}
// Deleted adds `WHERE deleted_at IS NOT NULL` clause for soft deleted models.
func (q *baseQuery) whereDeleted() {
if err := q.checkSoftDelete(); err != nil {
q.setErr(err)
return
}
q.flags = q.flags.Set(deletedFlag)
q.flags = q.flags.Remove(allWithDeletedFlag)
}
// AllWithDeleted changes query to return all rows including soft deleted ones.
func (q *baseQuery) whereAllWithDeleted() {
if err := q.checkSoftDelete(); err != nil {
q.setErr(err)
return
}
q.flags = q.flags.Set(allWithDeletedFlag)
q.flags = q.flags.Remove(deletedFlag)
}
func (q *baseQuery) isSoftDelete() bool {
if q.table != nil {
return q.table.SoftDeleteField != nil && !q.flags.Has(allWithDeletedFlag)
}
return false
}
//------------------------------------------------------------------------------
func (q *baseQuery) addWith(name string, query schema.QueryAppender) {
q.with = append(q.with, withQuery{
name: name,
query: query,
})
}
func (q *baseQuery) appendWith(fmter schema.Formatter, b []byte) (_ []byte, err error) {
if len(q.with) == 0 {
return b, nil
}
b = append(b, "WITH "...)
for i, with := range q.with {
if i > 0 {
b = append(b, ", "...)
}
b = fmter.AppendIdent(b, with.name)
if q, ok := with.query.(schema.ColumnsAppender); ok {
b = append(b, " ("...)
b, err = q.AppendColumns(fmter, b)
if err != nil {
return nil, err
}
b = append(b, ")"...)
}
b = append(b, " AS ("...)
b, err = with.query.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
b = append(b, ')')
}
b = append(b, ' ')
return b, nil
}
//------------------------------------------------------------------------------
func (q *baseQuery) addTable(table schema.QueryWithArgs) {
q.tables = append(q.tables, table)
}
func (q *baseQuery) addColumn(column schema.QueryWithArgs) {
q.columns = append(q.columns, column)
}
func (q *baseQuery) excludeColumn(columns []string) {
if q.columns == nil {
for _, f := range q.table.Fields {
q.columns = append(q.columns, schema.UnsafeIdent(f.Name))
}
}
if len(columns) == 1 && columns[0] == "*" {
q.columns = make([]schema.QueryWithArgs, 0)
return
}
for _, column := range columns {
if !q._excludeColumn(column) {
q.setErr(fmt.Errorf("bun: can't find column=%q", column))
return
}
}
}
func (q *baseQuery) _excludeColumn(column string) bool {
for i, col := range q.columns {
if col.Args == nil && col.Query == column {
q.columns = append(q.columns[:i], q.columns[i+1:]...)
return true
}
}
return false
}
//------------------------------------------------------------------------------
func (q *baseQuery) modelHasTableName() bool {
return !q.modelTable.IsZero() || q.table != nil
}
func (q *baseQuery) hasTables() bool {
return q.modelHasTableName() || len(q.tables) > 0
}
func (q *baseQuery) appendTables(
fmter schema.Formatter, b []byte,
) (_ []byte, err error) {
return q._appendTables(fmter, b, false)
}
func (q *baseQuery) appendTablesWithAlias(
fmter schema.Formatter, b []byte,
) (_ []byte, err error) {
return q._appendTables(fmter, b, true)
}
func (q *baseQuery) _appendTables(
fmter schema.Formatter, b []byte, withAlias bool,
) (_ []byte, err error) {
startLen := len(b)
if q.modelHasTableName() {
if !q.modelTable.IsZero() {
b, err = q.modelTable.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
} else {
b = fmter.AppendQuery(b, string(q.table.SQLNameForSelects))
if withAlias && q.table.SQLAlias != q.table.SQLNameForSelects {
b = append(b, " AS "...)
b = append(b, q.table.SQLAlias...)
}
}
}
for _, table := range q.tables {
if len(b) > startLen {
b = append(b, ", "...)
}
b, err = table.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
}
return b, nil
}
func (q *baseQuery) appendFirstTable(fmter schema.Formatter, b []byte) ([]byte, error) {
return q._appendFirstTable(fmter, b, false)
}
func (q *baseQuery) appendFirstTableWithAlias(
fmter schema.Formatter, b []byte,
) ([]byte, error) {
return q._appendFirstTable(fmter, b, true)
}
func (q *baseQuery) _appendFirstTable(
fmter schema.Formatter, b []byte, withAlias bool,
) ([]byte, error) {
if !q.modelTable.IsZero() {
return q.modelTable.AppendQuery(fmter, b)
}
if q.table != nil {
b = fmter.AppendQuery(b, string(q.table.SQLName))
if withAlias {
b = append(b, " AS "...)
b = append(b, q.table.SQLAlias...)
}
return b, nil
}
if len(q.tables) > 0 {
return q.tables[0].AppendQuery(fmter, b)
}
return nil, errors.New("bun: query does not have a table")
}
func (q *baseQuery) hasMultiTables() bool {
if q.modelHasTableName() {
return len(q.tables) >= 1
}
return len(q.tables) >= 2
}
func (q *baseQuery) appendOtherTables(fmter schema.Formatter, b []byte) (_ []byte, err error) {
tables := q.tables
if !q.modelHasTableName() {
tables = tables[1:]
}
for i, table := range tables {
if i > 0 {
b = append(b, ", "...)
}
b, err = table.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
}
return b, nil
}
//------------------------------------------------------------------------------
func (q *baseQuery) appendColumns(fmter schema.Formatter, b []byte) (_ []byte, err error) {
for i, f := range q.columns {
if i > 0 {
b = append(b, ", "...)
}
b, err = f.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
}
return b, nil
}
func (q *baseQuery) getFields() ([]*schema.Field, error) {
table := q.tableModel.Table()
if len(q.columns) == 0 {
return table.Fields, nil
}
fields, err := q._getFields(false)
if err != nil {
return nil, err
}
return fields, nil
}
func (q *baseQuery) getDataFields() ([]*schema.Field, error) {
if len(q.columns) == 0 {
return q.table.DataFields, nil
}
return q._getFields(true)
}
func (q *baseQuery) _getFields(omitPK bool) ([]*schema.Field, error) {
fields := make([]*schema.Field, 0, len(q.columns))
for _, col := range q.columns {
if col.Args != nil {
continue
}
field, err := q.table.Field(col.Query)
if err != nil {
return nil, err
}
if omitPK && field.IsPK {
continue
}
fields = append(fields, field)
}
return fields, nil
}
func (q *baseQuery) scan(
ctx context.Context,
queryApp schema.QueryAppender,
query string,
model model,
hasDest bool,
) (res result, _ error) {
ctx, event := q.db.beforeQuery(ctx, queryApp, query, nil)
rows, err := q.conn.QueryContext(ctx, query)
if err != nil {
q.db.afterQuery(ctx, event, nil, err)
return res, err
}
defer rows.Close()
n, err := model.ScanRows(ctx, rows)
if err != nil {
q.db.afterQuery(ctx, event, nil, err)
return res, err
}
res.n = n
if n == 0 && hasDest && isSingleRowModel(model) {
err = sql.ErrNoRows
}
q.db.afterQuery(ctx, event, nil, err)
return res, err
}
func (q *baseQuery) exec(
ctx context.Context,
queryApp schema.QueryAppender,
query string,
) (res result, _ error) {
ctx, event := q.db.beforeQuery(ctx, queryApp, query, nil)
r, err := q.conn.ExecContext(ctx, query)
if err != nil {
q.db.afterQuery(ctx, event, nil, err)
return res, err
}
res.r = r
q.db.afterQuery(ctx, event, nil, err)
return res, nil
}
//------------------------------------------------------------------------------
func (q *baseQuery) AppendNamedArg(fmter schema.Formatter, b []byte, name string) ([]byte, bool) {
if q.table == nil {
return b, false
}
if m, ok := q.tableModel.(*structTableModel); ok {
if b, ok := m.AppendNamedArg(fmter, b, name); ok {
return b, ok
}
}
switch name {
case "TableName":
b = fmter.AppendQuery(b, string(q.table.SQLName))
return b, true
case "TableAlias":
b = fmter.AppendQuery(b, string(q.table.SQLAlias))
return b, true
case "PKs":
b = appendColumns(b, "", q.table.PKs)
return b, true
case "TablePKs":
b = appendColumns(b, q.table.SQLAlias, q.table.PKs)
return b, true
case "Columns":
b = appendColumns(b, "", q.table.Fields)
return b, true
case "TableColumns":
b = appendColumns(b, q.table.SQLAlias, q.table.Fields)
return b, true
}
return b, false
}
func appendColumns(b []byte, table schema.Safe, fields []*schema.Field) []byte {
for i, f := range fields {
if i > 0 {
b = append(b, ", "...)
}
if len(table) > 0 {
b = append(b, table...)
b = append(b, '.')
}
b = append(b, f.SQLName...)
}
return b
}
func formatterWithModel(fmter schema.Formatter, model schema.NamedArgAppender) schema.Formatter {
if fmter.IsNop() {
return fmter
}
return fmter.WithArg(model)
}
//------------------------------------------------------------------------------
type whereBaseQuery struct {
baseQuery
where []schema.QueryWithSep
}
func (q *whereBaseQuery) addWhere(where schema.QueryWithSep) {
q.where = append(q.where, where)
}
func (q *whereBaseQuery) addWhereGroup(sep string, where []schema.QueryWithSep) {
if len(where) == 0 {
return
}
where[0].Sep = ""
q.addWhere(schema.SafeQueryWithSep("", nil, sep+"("))
q.where = append(q.where, where...)
q.addWhere(schema.SafeQueryWithSep("", nil, ")"))
}
func (q *whereBaseQuery) mustAppendWhere(
fmter schema.Formatter, b []byte, withAlias bool,
) ([]byte, error) {
if len(q.where) == 0 && !q.flags.Has(wherePKFlag) {
err := errors.New("bun: Update and Delete queries require at least one Where")
return nil, err
}
return q.appendWhere(fmter, b, withAlias)
}
func (q *whereBaseQuery) appendWhere(
fmter schema.Formatter, b []byte, withAlias bool,
) (_ []byte, err error) {
if len(q.where) == 0 && !q.isSoftDelete() && !q.flags.Has(wherePKFlag) {
return b, nil
}
b = append(b, " WHERE "...)
startLen := len(b)
if len(q.where) > 0 {
b, err = appendWhere(fmter, b, q.where)
if err != nil {
return nil, err
}
}
if q.isSoftDelete() {
if len(b) > startLen {
b = append(b, " AND "...)
}
if withAlias {
b = append(b, q.tableModel.Table().SQLAlias...)
b = append(b, '.')
}
b = append(b, q.tableModel.Table().SoftDeleteField.SQLName...)
if q.flags.Has(deletedFlag) {
b = append(b, " IS NOT NULL"...)
} else {
b = append(b, " IS NULL"...)
}
}
if q.flags.Has(wherePKFlag) {
if len(b) > startLen {
b = append(b, " AND "...)
}
b, err = q.appendWherePK(fmter, b, withAlias)
if err != nil {
return nil, err
}
}
return b, nil
}
func appendWhere(
fmter schema.Formatter, b []byte, where []schema.QueryWithSep,
) (_ []byte, err error) {
for i, where := range where {
if i > 0 || where.Sep == "(" {
b = append(b, where.Sep...)
}
if where.Query == "" && where.Args == nil {
continue
}
b = append(b, '(')
b, err = where.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
b = append(b, ')')
}
return b, nil
}
func (q *whereBaseQuery) appendWherePK(
fmter schema.Formatter, b []byte, withAlias bool,
) (_ []byte, err error) {
if q.table == nil {
err := fmt.Errorf("bun: got %T, but WherePK requires a struct or slice-based model", q.model)
return nil, err
}
if err := q.table.CheckPKs(); err != nil {
return nil, err
}
switch model := q.tableModel.(type) {
case *structTableModel:
return q.appendWherePKStruct(fmter, b, model, withAlias)
case *sliceTableModel:
return q.appendWherePKSlice(fmter, b, model, withAlias)
}
return nil, fmt.Errorf("bun: WherePK does not support %T", q.tableModel)
}
func (q *whereBaseQuery) appendWherePKStruct(
fmter schema.Formatter, b []byte, model *structTableModel, withAlias bool,
) (_ []byte, err error) {
if !model.strct.IsValid() {
return nil, errNilModel
}
isTemplate := fmter.IsNop()
b = append(b, '(')
for i, f := range q.table.PKs {
if i > 0 {
b = append(b, " AND "...)
}
if withAlias {
b = append(b, q.table.SQLAlias...)
b = append(b, '.')
}
b = append(b, f.SQLName...)
b = append(b, " = "...)
if isTemplate {
b = append(b, '?')
} else {
b = f.AppendValue(fmter, b, model.strct)
}
}
b = append(b, ')')
return b, nil
}
func (q *whereBaseQuery) appendWherePKSlice(
fmter schema.Formatter, b []byte, model *sliceTableModel, withAlias bool,
) (_ []byte, err error) {
if len(q.table.PKs) > 1 {
b = append(b, '(')
}
if withAlias {
b = appendColumns(b, q.table.SQLAlias, q.table.PKs)
} else {
b = appendColumns(b, "", q.table.PKs)
}
if len(q.table.PKs) > 1 {
b = append(b, ')')
}
b = append(b, " IN ("...)
isTemplate := fmter.IsNop()
slice := model.slice
sliceLen := slice.Len()
for i := 0; i < sliceLen; i++ {
if i > 0 {
if isTemplate {
break
}
b = append(b, ", "...)
}
el := indirect(slice.Index(i))
if len(q.table.PKs) > 1 {
b = append(b, '(')
}
for i, f := range q.table.PKs {
if i > 0 {
b = append(b, ", "...)
}
if isTemplate {
b = append(b, '?')
} else {
b = f.AppendValue(fmter, b, el)
}
}
if len(q.table.PKs) > 1 {
b = append(b, ')')
}
}
b = append(b, ')')
return b, nil
}
//------------------------------------------------------------------------------
type returningQuery struct {
returning []schema.QueryWithArgs
returningFields []*schema.Field
}
func (q *returningQuery) addReturning(ret schema.QueryWithArgs) {
q.returning = append(q.returning, ret)
}
func (q *returningQuery) addReturningField(field *schema.Field) {
if len(q.returning) > 0 {
return
}
for _, f := range q.returningFields {
if f == field {
return
}
}
q.returningFields = append(q.returningFields, field)
}
func (q *returningQuery) hasReturning() bool {
if len(q.returning) == 1 {
switch q.returning[0].Query {
case "null", "NULL":
return false
}
}
return len(q.returning) > 0 || len(q.returningFields) > 0
}
func (q *returningQuery) appendReturning(
fmter schema.Formatter, b []byte,
) (_ []byte, err error) {
if !q.hasReturning() {
return b, nil
}
b = append(b, " RETURNING "...)
for i, f := range q.returning {
if i > 0 {
b = append(b, ", "...)
}
b, err = f.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
}
if len(q.returning) > 0 {
return b, nil
}
b = appendColumns(b, "", q.returningFields)
return b, nil
}
//------------------------------------------------------------------------------
type columnValue struct {
column string
value schema.QueryWithArgs
}
type customValueQuery struct {
modelValues map[string]schema.QueryWithArgs
extraValues []columnValue
}
func (q *customValueQuery) addValue(
table *schema.Table, column string, value string, args []interface{},
) {
if _, ok := table.FieldMap[column]; ok {
if q.modelValues == nil {
q.modelValues = make(map[string]schema.QueryWithArgs)
}
q.modelValues[column] = schema.SafeQuery(value, args)
} else {
q.extraValues = append(q.extraValues, columnValue{
column: column,
value: schema.SafeQuery(value, args),
})
}
}
//------------------------------------------------------------------------------
type setQuery struct {
set []schema.QueryWithArgs
}
func (q *setQuery) addSet(set schema.QueryWithArgs) {
q.set = append(q.set, set)
}
func (q setQuery) appendSet(fmter schema.Formatter, b []byte) (_ []byte, err error) {
for i, f := range q.set {
if i > 0 {
b = append(b, ", "...)
}
b, err = f.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
}
return b, nil
}
//------------------------------------------------------------------------------
type cascadeQuery struct {
restrict bool
}
func (q cascadeQuery) appendCascade(fmter schema.Formatter, b []byte) []byte {
if !fmter.HasFeature(feature.TableCascade) {
return b
}
if q.restrict {
b = append(b, " RESTRICT"...)
} else {
b = append(b, " CASCADE"...)
}
return b
}

105
vendor/github.com/uptrace/bun/query_column_add.go generated vendored Normal file
View file

@ -0,0 +1,105 @@
package bun
import (
"context"
"database/sql"
"fmt"
"github.com/uptrace/bun/internal"
"github.com/uptrace/bun/schema"
)
type AddColumnQuery struct {
baseQuery
}
func NewAddColumnQuery(db *DB) *AddColumnQuery {
q := &AddColumnQuery{
baseQuery: baseQuery{
db: db,
conn: db.DB,
},
}
return q
}
func (q *AddColumnQuery) Conn(db IConn) *AddColumnQuery {
q.setConn(db)
return q
}
func (q *AddColumnQuery) Model(model interface{}) *AddColumnQuery {
q.setTableModel(model)
return q
}
//------------------------------------------------------------------------------
func (q *AddColumnQuery) Table(tables ...string) *AddColumnQuery {
for _, table := range tables {
q.addTable(schema.UnsafeIdent(table))
}
return q
}
func (q *AddColumnQuery) TableExpr(query string, args ...interface{}) *AddColumnQuery {
q.addTable(schema.SafeQuery(query, args))
return q
}
func (q *AddColumnQuery) ModelTableExpr(query string, args ...interface{}) *AddColumnQuery {
q.modelTable = schema.SafeQuery(query, args)
return q
}
//------------------------------------------------------------------------------
func (q *AddColumnQuery) ColumnExpr(query string, args ...interface{}) *AddColumnQuery {
q.addColumn(schema.SafeQuery(query, args))
return q
}
//------------------------------------------------------------------------------
func (q *AddColumnQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) {
if q.err != nil {
return nil, q.err
}
if len(q.columns) != 1 {
return nil, fmt.Errorf("bun: AddColumnQuery requires exactly one column")
}
b = append(b, "ALTER TABLE "...)
b, err = q.appendFirstTable(fmter, b)
if err != nil {
return nil, err
}
b = append(b, " ADD "...)
b, err = q.columns[0].AppendQuery(fmter, b)
if err != nil {
return nil, err
}
return b, nil
}
//------------------------------------------------------------------------------
func (q *AddColumnQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) {
queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes())
if err != nil {
return nil, err
}
query := internal.String(queryBytes)
res, err := q.exec(ctx, q, query)
if err != nil {
return nil, err
}
return res, nil
}

112
vendor/github.com/uptrace/bun/query_column_drop.go generated vendored Normal file
View file

@ -0,0 +1,112 @@
package bun
import (
"context"
"database/sql"
"fmt"
"github.com/uptrace/bun/internal"
"github.com/uptrace/bun/schema"
)
type DropColumnQuery struct {
baseQuery
}
func NewDropColumnQuery(db *DB) *DropColumnQuery {
q := &DropColumnQuery{
baseQuery: baseQuery{
db: db,
conn: db.DB,
},
}
return q
}
func (q *DropColumnQuery) Conn(db IConn) *DropColumnQuery {
q.setConn(db)
return q
}
func (q *DropColumnQuery) Model(model interface{}) *DropColumnQuery {
q.setTableModel(model)
return q
}
//------------------------------------------------------------------------------
func (q *DropColumnQuery) Table(tables ...string) *DropColumnQuery {
for _, table := range tables {
q.addTable(schema.UnsafeIdent(table))
}
return q
}
func (q *DropColumnQuery) TableExpr(query string, args ...interface{}) *DropColumnQuery {
q.addTable(schema.SafeQuery(query, args))
return q
}
func (q *DropColumnQuery) ModelTableExpr(query string, args ...interface{}) *DropColumnQuery {
q.modelTable = schema.SafeQuery(query, args)
return q
}
//------------------------------------------------------------------------------
func (q *DropColumnQuery) Column(columns ...string) *DropColumnQuery {
for _, column := range columns {
q.addColumn(schema.UnsafeIdent(column))
}
return q
}
func (q *DropColumnQuery) ColumnExpr(query string, args ...interface{}) *DropColumnQuery {
q.addColumn(schema.SafeQuery(query, args))
return q
}
//------------------------------------------------------------------------------
func (q *DropColumnQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) {
if q.err != nil {
return nil, q.err
}
if len(q.columns) != 1 {
return nil, fmt.Errorf("bun: DropColumnQuery requires exactly one column")
}
b = append(b, "ALTER TABLE "...)
b, err = q.appendFirstTable(fmter, b)
if err != nil {
return nil, err
}
b = append(b, " DROP COLUMN "...)
b, err = q.columns[0].AppendQuery(fmter, b)
if err != nil {
return nil, err
}
return b, nil
}
//------------------------------------------------------------------------------
func (q *DropColumnQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) {
queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes())
if err != nil {
return nil, err
}
query := internal.String(queryBytes)
res, err := q.exec(ctx, q, query)
if err != nil {
return nil, err
}
return res, nil
}

256
vendor/github.com/uptrace/bun/query_delete.go generated vendored Normal file
View file

@ -0,0 +1,256 @@
package bun
import (
"context"
"database/sql"
"github.com/uptrace/bun/dialect/feature"
"github.com/uptrace/bun/internal"
"github.com/uptrace/bun/schema"
)
type DeleteQuery struct {
whereBaseQuery
returningQuery
}
func NewDeleteQuery(db *DB) *DeleteQuery {
q := &DeleteQuery{
whereBaseQuery: whereBaseQuery{
baseQuery: baseQuery{
db: db,
conn: db.DB,
},
},
}
return q
}
func (q *DeleteQuery) Conn(db IConn) *DeleteQuery {
q.setConn(db)
return q
}
func (q *DeleteQuery) Model(model interface{}) *DeleteQuery {
q.setTableModel(model)
return q
}
// Apply calls the fn passing the DeleteQuery as an argument.
func (q *DeleteQuery) Apply(fn func(*DeleteQuery) *DeleteQuery) *DeleteQuery {
return fn(q)
}
func (q *DeleteQuery) With(name string, query schema.QueryAppender) *DeleteQuery {
q.addWith(name, query)
return q
}
func (q *DeleteQuery) Table(tables ...string) *DeleteQuery {
for _, table := range tables {
q.addTable(schema.UnsafeIdent(table))
}
return q
}
func (q *DeleteQuery) TableExpr(query string, args ...interface{}) *DeleteQuery {
q.addTable(schema.SafeQuery(query, args))
return q
}
func (q *DeleteQuery) ModelTableExpr(query string, args ...interface{}) *DeleteQuery {
q.modelTable = schema.SafeQuery(query, args)
return q
}
//------------------------------------------------------------------------------
func (q *DeleteQuery) WherePK() *DeleteQuery {
q.flags = q.flags.Set(wherePKFlag)
return q
}
func (q *DeleteQuery) Where(query string, args ...interface{}) *DeleteQuery {
q.addWhere(schema.SafeQueryWithSep(query, args, " AND "))
return q
}
func (q *DeleteQuery) WhereOr(query string, args ...interface{}) *DeleteQuery {
q.addWhere(schema.SafeQueryWithSep(query, args, " OR "))
return q
}
func (q *DeleteQuery) WhereGroup(sep string, fn func(*DeleteQuery) *DeleteQuery) *DeleteQuery {
saved := q.where
q.where = nil
q = fn(q)
where := q.where
q.where = saved
q.addWhereGroup(sep, where)
return q
}
func (q *DeleteQuery) WhereDeleted() *DeleteQuery {
q.whereDeleted()
return q
}
func (q *DeleteQuery) WhereAllWithDeleted() *DeleteQuery {
q.whereAllWithDeleted()
return q
}
func (q *DeleteQuery) ForceDelete() *DeleteQuery {
q.flags = q.flags.Set(forceDeleteFlag)
return q
}
//------------------------------------------------------------------------------
// Returning adds a RETURNING clause to the query.
//
// To suppress the auto-generated RETURNING clause, use `Returning("NULL")`.
func (q *DeleteQuery) Returning(query string, args ...interface{}) *DeleteQuery {
q.addReturning(schema.SafeQuery(query, args))
return q
}
func (q *DeleteQuery) hasReturning() bool {
if !q.db.features.Has(feature.Returning) {
return false
}
return q.returningQuery.hasReturning()
}
//------------------------------------------------------------------------------
func (q *DeleteQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) {
if q.err != nil {
return nil, q.err
}
fmter = formatterWithModel(fmter, q)
if q.isSoftDelete() {
if err := q.tableModel.updateSoftDeleteField(); err != nil {
return nil, err
}
upd := UpdateQuery{
whereBaseQuery: q.whereBaseQuery,
returningQuery: q.returningQuery,
}
upd.Column(q.table.SoftDeleteField.Name)
return upd.AppendQuery(fmter, b)
}
q = q.WhereAllWithDeleted()
withAlias := q.db.features.Has(feature.DeleteTableAlias)
b, err = q.appendWith(fmter, b)
if err != nil {
return nil, err
}
b = append(b, "DELETE FROM "...)
if withAlias {
b, err = q.appendFirstTableWithAlias(fmter, b)
} else {
b, err = q.appendFirstTable(fmter, b)
}
if err != nil {
return nil, err
}
if q.hasMultiTables() {
b = append(b, " USING "...)
b, err = q.appendOtherTables(fmter, b)
if err != nil {
return nil, err
}
}
b, err = q.mustAppendWhere(fmter, b, withAlias)
if err != nil {
return nil, err
}
if len(q.returning) > 0 {
b, err = q.appendReturning(fmter, b)
if err != nil {
return nil, err
}
}
return b, nil
}
func (q *DeleteQuery) isSoftDelete() bool {
return q.tableModel != nil && q.table.SoftDeleteField != nil && !q.flags.Has(forceDeleteFlag)
}
//------------------------------------------------------------------------------
func (q *DeleteQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) {
if q.table != nil {
if err := q.beforeDeleteHook(ctx); err != nil {
return nil, err
}
}
queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes())
if err != nil {
return nil, err
}
query := internal.String(queryBytes)
var res sql.Result
if hasDest := len(dest) > 0; hasDest || q.hasReturning() {
model, err := q.getModel(dest)
if err != nil {
return nil, err
}
res, err = q.scan(ctx, q, query, model, hasDest)
if err != nil {
return nil, err
}
} else {
res, err = q.exec(ctx, q, query)
if err != nil {
return nil, err
}
}
if q.table != nil {
if err := q.afterDeleteHook(ctx); err != nil {
return nil, err
}
}
return res, nil
}
func (q *DeleteQuery) beforeDeleteHook(ctx context.Context) error {
if hook, ok := q.table.ZeroIface.(BeforeDeleteHook); ok {
if err := hook.BeforeDelete(ctx, q); err != nil {
return err
}
}
return nil
}
func (q *DeleteQuery) afterDeleteHook(ctx context.Context) error {
if hook, ok := q.table.ZeroIface.(AfterDeleteHook); ok {
if err := hook.AfterDelete(ctx, q); err != nil {
return err
}
}
return nil
}

242
vendor/github.com/uptrace/bun/query_index_create.go generated vendored Normal file
View file

@ -0,0 +1,242 @@
package bun
import (
"context"
"database/sql"
"github.com/uptrace/bun/internal"
"github.com/uptrace/bun/schema"
)
type CreateIndexQuery struct {
whereBaseQuery
unique bool
fulltext bool
spatial bool
concurrently bool
ifNotExists bool
index schema.QueryWithArgs
using schema.QueryWithArgs
include []schema.QueryWithArgs
}
func NewCreateIndexQuery(db *DB) *CreateIndexQuery {
q := &CreateIndexQuery{
whereBaseQuery: whereBaseQuery{
baseQuery: baseQuery{
db: db,
conn: db.DB,
},
},
}
return q
}
func (q *CreateIndexQuery) Conn(db IConn) *CreateIndexQuery {
q.setConn(db)
return q
}
func (q *CreateIndexQuery) Model(model interface{}) *CreateIndexQuery {
q.setTableModel(model)
return q
}
func (q *CreateIndexQuery) Unique() *CreateIndexQuery {
q.unique = true
return q
}
func (q *CreateIndexQuery) Concurrently() *CreateIndexQuery {
q.concurrently = true
return q
}
func (q *CreateIndexQuery) IfNotExists() *CreateIndexQuery {
q.ifNotExists = true
return q
}
//------------------------------------------------------------------------------
func (q *CreateIndexQuery) Index(query string) *CreateIndexQuery {
q.index = schema.UnsafeIdent(query)
return q
}
func (q *CreateIndexQuery) IndexExpr(query string, args ...interface{}) *CreateIndexQuery {
q.index = schema.SafeQuery(query, args)
return q
}
//------------------------------------------------------------------------------
func (q *CreateIndexQuery) Table(tables ...string) *CreateIndexQuery {
for _, table := range tables {
q.addTable(schema.UnsafeIdent(table))
}
return q
}
func (q *CreateIndexQuery) TableExpr(query string, args ...interface{}) *CreateIndexQuery {
q.addTable(schema.SafeQuery(query, args))
return q
}
func (q *CreateIndexQuery) ModelTableExpr(query string, args ...interface{}) *CreateIndexQuery {
q.modelTable = schema.SafeQuery(query, args)
return q
}
func (q *CreateIndexQuery) Using(query string, args ...interface{}) *CreateIndexQuery {
q.using = schema.SafeQuery(query, args)
return q
}
//------------------------------------------------------------------------------
func (q *CreateIndexQuery) Column(columns ...string) *CreateIndexQuery {
for _, column := range columns {
q.addColumn(schema.UnsafeIdent(column))
}
return q
}
func (q *CreateIndexQuery) ColumnExpr(query string, args ...interface{}) *CreateIndexQuery {
q.addColumn(schema.SafeQuery(query, args))
return q
}
func (q *CreateIndexQuery) ExcludeColumn(columns ...string) *CreateIndexQuery {
q.excludeColumn(columns)
return q
}
//------------------------------------------------------------------------------
func (q *CreateIndexQuery) Include(columns ...string) *CreateIndexQuery {
for _, column := range columns {
q.include = append(q.include, schema.UnsafeIdent(column))
}
return q
}
func (q *CreateIndexQuery) IncludeExpr(query string, args ...interface{}) *CreateIndexQuery {
q.include = append(q.include, schema.SafeQuery(query, args))
return q
}
//------------------------------------------------------------------------------
func (q *CreateIndexQuery) Where(query string, args ...interface{}) *CreateIndexQuery {
q.addWhere(schema.SafeQueryWithSep(query, args, " AND "))
return q
}
func (q *CreateIndexQuery) WhereOr(query string, args ...interface{}) *CreateIndexQuery {
q.addWhere(schema.SafeQueryWithSep(query, args, " OR "))
return q
}
//------------------------------------------------------------------------------
func (q *CreateIndexQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) {
if q.err != nil {
return nil, q.err
}
b = append(b, "CREATE "...)
if q.unique {
b = append(b, "UNIQUE "...)
}
if q.fulltext {
b = append(b, "FULLTEXT "...)
}
if q.spatial {
b = append(b, "SPATIAL "...)
}
b = append(b, "INDEX "...)
if q.concurrently {
b = append(b, "CONCURRENTLY "...)
}
if q.ifNotExists {
b = append(b, "IF NOT EXISTS "...)
}
b, err = q.index.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
b = append(b, " ON "...)
b, err = q.appendFirstTable(fmter, b)
if err != nil {
return nil, err
}
if !q.using.IsZero() {
b = append(b, " USING "...)
b, err = q.using.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
}
b = append(b, " ("...)
for i, col := range q.columns {
if i > 0 {
b = append(b, ", "...)
}
b, err = col.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
}
b = append(b, ')')
if len(q.include) > 0 {
b = append(b, " INCLUDE ("...)
for i, col := range q.include {
if i > 0 {
b = append(b, ", "...)
}
b, err = col.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
}
b = append(b, ')')
}
if len(q.where) > 0 {
b, err = appendWhere(fmter, b, q.where)
if err != nil {
return nil, err
}
}
return b, nil
}
//------------------------------------------------------------------------------
func (q *CreateIndexQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) {
queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes())
if err != nil {
return nil, err
}
query := internal.String(queryBytes)
res, err := q.exec(ctx, q, query)
if err != nil {
return nil, err
}
return res, nil
}

105
vendor/github.com/uptrace/bun/query_index_drop.go generated vendored Normal file
View file

@ -0,0 +1,105 @@
package bun
import (
"context"
"database/sql"
"github.com/uptrace/bun/internal"
"github.com/uptrace/bun/schema"
)
type DropIndexQuery struct {
baseQuery
cascadeQuery
concurrently bool
ifExists bool
index schema.QueryWithArgs
}
func NewDropIndexQuery(db *DB) *DropIndexQuery {
q := &DropIndexQuery{
baseQuery: baseQuery{
db: db,
conn: db.DB,
},
}
return q
}
func (q *DropIndexQuery) Conn(db IConn) *DropIndexQuery {
q.setConn(db)
return q
}
func (q *DropIndexQuery) Model(model interface{}) *DropIndexQuery {
q.setTableModel(model)
return q
}
//------------------------------------------------------------------------------
func (q *DropIndexQuery) Concurrently() *DropIndexQuery {
q.concurrently = true
return q
}
func (q *DropIndexQuery) IfExists() *DropIndexQuery {
q.ifExists = true
return q
}
func (q *DropIndexQuery) Restrict() *DropIndexQuery {
q.restrict = true
return q
}
func (q *DropIndexQuery) Index(query string, args ...interface{}) *DropIndexQuery {
q.index = schema.SafeQuery(query, args)
return q
}
//------------------------------------------------------------------------------
func (q *DropIndexQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) {
if q.err != nil {
return nil, q.err
}
b = append(b, "DROP INDEX "...)
if q.concurrently {
b = append(b, "CONCURRENTLY "...)
}
if q.ifExists {
b = append(b, "IF EXISTS "...)
}
b, err = q.index.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
b = q.appendCascade(fmter, b)
return b, nil
}
//------------------------------------------------------------------------------
func (q *DropIndexQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) {
queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes())
if err != nil {
return nil, err
}
query := internal.String(queryBytes)
res, err := q.exec(ctx, q, query)
if err != nil {
return nil, err
}
return res, nil
}

551
vendor/github.com/uptrace/bun/query_insert.go generated vendored Normal file
View file

@ -0,0 +1,551 @@
package bun
import (
"context"
"database/sql"
"fmt"
"reflect"
"github.com/uptrace/bun/dialect/feature"
"github.com/uptrace/bun/internal"
"github.com/uptrace/bun/schema"
)
type InsertQuery struct {
whereBaseQuery
returningQuery
customValueQuery
onConflict schema.QueryWithArgs
setQuery
ignore bool
replace bool
}
func NewInsertQuery(db *DB) *InsertQuery {
q := &InsertQuery{
whereBaseQuery: whereBaseQuery{
baseQuery: baseQuery{
db: db,
conn: db.DB,
},
},
}
return q
}
func (q *InsertQuery) Conn(db IConn) *InsertQuery {
q.setConn(db)
return q
}
func (q *InsertQuery) Model(model interface{}) *InsertQuery {
q.setTableModel(model)
return q
}
// Apply calls the fn passing the SelectQuery as an argument.
func (q *InsertQuery) Apply(fn func(*InsertQuery) *InsertQuery) *InsertQuery {
return fn(q)
}
func (q *InsertQuery) With(name string, query schema.QueryAppender) *InsertQuery {
q.addWith(name, query)
return q
}
//------------------------------------------------------------------------------
func (q *InsertQuery) Table(tables ...string) *InsertQuery {
for _, table := range tables {
q.addTable(schema.UnsafeIdent(table))
}
return q
}
func (q *InsertQuery) TableExpr(query string, args ...interface{}) *InsertQuery {
q.addTable(schema.SafeQuery(query, args))
return q
}
func (q *InsertQuery) ModelTableExpr(query string, args ...interface{}) *InsertQuery {
q.modelTable = schema.SafeQuery(query, args)
return q
}
//------------------------------------------------------------------------------
func (q *InsertQuery) Column(columns ...string) *InsertQuery {
for _, column := range columns {
q.addColumn(schema.UnsafeIdent(column))
}
return q
}
func (q *InsertQuery) ExcludeColumn(columns ...string) *InsertQuery {
q.excludeColumn(columns)
return q
}
// Value overwrites model value for the column in INSERT and UPDATE queries.
func (q *InsertQuery) Value(column string, value string, args ...interface{}) *InsertQuery {
if q.table == nil {
q.err = errNilModel
return q
}
q.addValue(q.table, column, value, args)
return q
}
func (q *InsertQuery) Where(query string, args ...interface{}) *InsertQuery {
q.addWhere(schema.SafeQueryWithSep(query, args, " AND "))
return q
}
func (q *InsertQuery) WhereOr(query string, args ...interface{}) *InsertQuery {
q.addWhere(schema.SafeQueryWithSep(query, args, " OR "))
return q
}
//------------------------------------------------------------------------------
// Returning adds a RETURNING clause to the query.
//
// To suppress the auto-generated RETURNING clause, use `Returning("NULL")`.
func (q *InsertQuery) Returning(query string, args ...interface{}) *InsertQuery {
q.addReturning(schema.SafeQuery(query, args))
return q
}
func (q *InsertQuery) hasReturning() bool {
if !q.db.features.Has(feature.Returning) {
return false
}
return q.returningQuery.hasReturning()
}
//------------------------------------------------------------------------------
// Ignore generates an `INSERT IGNORE INTO` query (MySQL).
func (q *InsertQuery) Ignore() *InsertQuery {
q.ignore = true
return q
}
// Replaces generates a `REPLACE INTO` query (MySQL).
func (q *InsertQuery) Replace() *InsertQuery {
q.replace = true
return q
}
//------------------------------------------------------------------------------
func (q *InsertQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) {
if q.err != nil {
return nil, q.err
}
fmter = formatterWithModel(fmter, q)
b, err = q.appendWith(fmter, b)
if err != nil {
return nil, err
}
if q.replace {
b = append(b, "REPLACE "...)
} else {
b = append(b, "INSERT "...)
if q.ignore {
b = append(b, "IGNORE "...)
}
}
b = append(b, "INTO "...)
if q.db.features.Has(feature.InsertTableAlias) && !q.onConflict.IsZero() {
b, err = q.appendFirstTableWithAlias(fmter, b)
} else {
b, err = q.appendFirstTable(fmter, b)
}
if err != nil {
return nil, err
}
b, err = q.appendColumnsValues(fmter, b)
if err != nil {
return nil, err
}
b, err = q.appendOn(fmter, b)
if err != nil {
return nil, err
}
if q.hasReturning() {
b, err = q.appendReturning(fmter, b)
if err != nil {
return nil, err
}
}
return b, nil
}
func (q *InsertQuery) appendColumnsValues(
fmter schema.Formatter, b []byte,
) (_ []byte, err error) {
if q.hasMultiTables() {
if q.columns != nil {
b = append(b, " ("...)
b, err = q.appendColumns(fmter, b)
if err != nil {
return nil, err
}
b = append(b, ")"...)
}
b = append(b, " SELECT * FROM "...)
b, err = q.appendOtherTables(fmter, b)
if err != nil {
return nil, err
}
return b, nil
}
if m, ok := q.model.(*mapModel); ok {
return m.appendColumnsValues(fmter, b), nil
}
if _, ok := q.model.(*mapSliceModel); ok {
return nil, fmt.Errorf("Insert(*[]map[string]interface{}) is not supported")
}
if q.model == nil {
return nil, errNilModel
}
fields, err := q.getFields()
if err != nil {
return nil, err
}
b = append(b, " ("...)
b = q.appendFields(fmter, b, fields)
b = append(b, ") VALUES ("...)
switch model := q.tableModel.(type) {
case *structTableModel:
b, err = q.appendStructValues(fmter, b, fields, model.strct)
if err != nil {
return nil, err
}
case *sliceTableModel:
b, err = q.appendSliceValues(fmter, b, fields, model.slice)
if err != nil {
return nil, err
}
default:
return nil, fmt.Errorf("bun: Insert does not support %T", q.tableModel)
}
b = append(b, ')')
return b, nil
}
func (q *InsertQuery) appendStructValues(
fmter schema.Formatter, b []byte, fields []*schema.Field, strct reflect.Value,
) (_ []byte, err error) {
isTemplate := fmter.IsNop()
for i, f := range fields {
if i > 0 {
b = append(b, ", "...)
}
app, ok := q.modelValues[f.Name]
if ok {
b, err = app.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
q.addReturningField(f)
continue
}
switch {
case isTemplate:
b = append(b, '?')
case f.NullZero && f.HasZeroValue(strct):
if q.db.features.Has(feature.DefaultPlaceholder) {
b = append(b, "DEFAULT"...)
} else if f.SQLDefault != "" {
b = append(b, f.SQLDefault...)
} else {
b = append(b, "NULL"...)
}
q.addReturningField(f)
default:
b = f.AppendValue(fmter, b, strct)
}
}
for i, v := range q.extraValues {
if i > 0 || len(fields) > 0 {
b = append(b, ", "...)
}
b, err = v.value.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
}
return b, nil
}
func (q *InsertQuery) appendSliceValues(
fmter schema.Formatter, b []byte, fields []*schema.Field, slice reflect.Value,
) (_ []byte, err error) {
if fmter.IsNop() {
return q.appendStructValues(fmter, b, fields, reflect.Value{})
}
sliceLen := slice.Len()
for i := 0; i < sliceLen; i++ {
if i > 0 {
b = append(b, "), ("...)
}
el := indirect(slice.Index(i))
b, err = q.appendStructValues(fmter, b, fields, el)
if err != nil {
return nil, err
}
}
for i, v := range q.extraValues {
if i > 0 || len(fields) > 0 {
b = append(b, ", "...)
}
b, err = v.value.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
}
return b, nil
}
func (q *InsertQuery) getFields() ([]*schema.Field, error) {
if q.db.features.Has(feature.DefaultPlaceholder) || len(q.columns) > 0 {
return q.baseQuery.getFields()
}
var strct reflect.Value
switch model := q.tableModel.(type) {
case *structTableModel:
strct = model.strct
case *sliceTableModel:
if model.sliceLen == 0 {
return nil, fmt.Errorf("bun: Insert(empty %T)", model.slice.Type())
}
strct = indirect(model.slice.Index(0))
}
fields := make([]*schema.Field, 0, len(q.table.Fields))
for _, f := range q.table.Fields {
if f.NotNull && f.NullZero && f.SQLDefault == "" && f.HasZeroValue(strct) {
q.addReturningField(f)
continue
}
fields = append(fields, f)
}
return fields, nil
}
func (q *InsertQuery) appendFields(
fmter schema.Formatter, b []byte, fields []*schema.Field,
) []byte {
b = appendColumns(b, "", fields)
for i, v := range q.extraValues {
if i > 0 || len(fields) > 0 {
b = append(b, ", "...)
}
b = fmter.AppendIdent(b, v.column)
}
return b
}
//------------------------------------------------------------------------------
func (q *InsertQuery) On(s string, args ...interface{}) *InsertQuery {
q.onConflict = schema.SafeQuery(s, args)
return q
}
func (q *InsertQuery) Set(query string, args ...interface{}) *InsertQuery {
q.addSet(schema.SafeQuery(query, args))
return q
}
func (q *InsertQuery) appendOn(fmter schema.Formatter, b []byte) (_ []byte, err error) {
if q.onConflict.IsZero() {
return b, nil
}
b = append(b, " ON "...)
b, err = q.onConflict.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
if len(q.set) > 0 {
if fmter.HasFeature(feature.OnDuplicateKey) {
b = append(b, ' ')
} else {
b = append(b, " SET "...)
}
b, err = q.appendSet(fmter, b)
if err != nil {
return nil, err
}
} else if len(q.columns) > 0 {
fields, err := q.getDataFields()
if err != nil {
return nil, err
}
if len(fields) == 0 {
fields = q.tableModel.Table().DataFields
}
b = q.appendSetExcluded(b, fields)
}
b, err = q.appendWhere(fmter, b, true)
if err != nil {
return nil, err
}
return b, nil
}
func (q *InsertQuery) appendSetExcluded(b []byte, fields []*schema.Field) []byte {
b = append(b, " SET "...)
for i, f := range fields {
if i > 0 {
b = append(b, ", "...)
}
b = append(b, f.SQLName...)
b = append(b, " = EXCLUDED."...)
b = append(b, f.SQLName...)
}
return b
}
//------------------------------------------------------------------------------
func (q *InsertQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) {
if q.table != nil {
if err := q.beforeInsertHook(ctx); err != nil {
return nil, err
}
}
queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes())
if err != nil {
return nil, err
}
query := internal.String(queryBytes)
var res sql.Result
if hasDest := len(dest) > 0; hasDest || q.hasReturning() {
model, err := q.getModel(dest)
if err != nil {
return nil, err
}
res, err = q.scan(ctx, q, query, model, hasDest)
if err != nil {
return nil, err
}
} else {
res, err = q.exec(ctx, q, query)
if err != nil {
return nil, err
}
if err := q.tryLastInsertID(res, dest); err != nil {
return nil, err
}
}
if q.table != nil {
if err := q.afterInsertHook(ctx); err != nil {
return nil, err
}
}
return res, nil
}
func (q *InsertQuery) beforeInsertHook(ctx context.Context) error {
if hook, ok := q.table.ZeroIface.(BeforeInsertHook); ok {
if err := hook.BeforeInsert(ctx, q); err != nil {
return err
}
}
return nil
}
func (q *InsertQuery) afterInsertHook(ctx context.Context) error {
if hook, ok := q.table.ZeroIface.(AfterInsertHook); ok {
if err := hook.AfterInsert(ctx, q); err != nil {
return err
}
}
return nil
}
func (q *InsertQuery) tryLastInsertID(res sql.Result, dest []interface{}) error {
if q.db.features.Has(feature.Returning) || q.table == nil || len(q.table.PKs) != 1 {
return nil
}
id, err := res.LastInsertId()
if err != nil {
return err
}
if id == 0 {
return nil
}
model, err := q.getModel(dest)
if err != nil {
return err
}
pk := q.table.PKs[0]
switch model := model.(type) {
case *structTableModel:
if err := pk.ScanValue(model.strct, id); err != nil {
return err
}
case *sliceTableModel:
sliceLen := model.slice.Len()
for i := 0; i < sliceLen; i++ {
strct := indirect(model.slice.Index(i))
if err := pk.ScanValue(strct, id); err != nil {
return err
}
id++
}
}
return nil
}

830
vendor/github.com/uptrace/bun/query_select.go generated vendored Normal file
View file

@ -0,0 +1,830 @@
package bun
import (
"bytes"
"context"
"database/sql"
"errors"
"fmt"
"strconv"
"strings"
"sync"
"github.com/uptrace/bun/dialect"
"github.com/uptrace/bun/internal"
"github.com/uptrace/bun/schema"
)
type union struct {
expr string
query *SelectQuery
}
type SelectQuery struct {
whereBaseQuery
distinctOn []schema.QueryWithArgs
joins []joinQuery
group []schema.QueryWithArgs
having []schema.QueryWithArgs
order []schema.QueryWithArgs
limit int32
offset int32
selFor schema.QueryWithArgs
union []union
}
func NewSelectQuery(db *DB) *SelectQuery {
return &SelectQuery{
whereBaseQuery: whereBaseQuery{
baseQuery: baseQuery{
db: db,
conn: db.DB,
},
},
}
}
func (q *SelectQuery) Conn(db IConn) *SelectQuery {
q.setConn(db)
return q
}
func (q *SelectQuery) Model(model interface{}) *SelectQuery {
q.setTableModel(model)
return q
}
// Apply calls the fn passing the SelectQuery as an argument.
func (q *SelectQuery) Apply(fn func(*SelectQuery) *SelectQuery) *SelectQuery {
return fn(q)
}
func (q *SelectQuery) With(name string, query schema.QueryAppender) *SelectQuery {
q.addWith(name, query)
return q
}
func (q *SelectQuery) Distinct() *SelectQuery {
q.distinctOn = make([]schema.QueryWithArgs, 0)
return q
}
func (q *SelectQuery) DistinctOn(query string, args ...interface{}) *SelectQuery {
q.distinctOn = append(q.distinctOn, schema.SafeQuery(query, args))
return q
}
//------------------------------------------------------------------------------
func (q *SelectQuery) Table(tables ...string) *SelectQuery {
for _, table := range tables {
q.addTable(schema.UnsafeIdent(table))
}
return q
}
func (q *SelectQuery) TableExpr(query string, args ...interface{}) *SelectQuery {
q.addTable(schema.SafeQuery(query, args))
return q
}
func (q *SelectQuery) ModelTableExpr(query string, args ...interface{}) *SelectQuery {
q.modelTable = schema.SafeQuery(query, args)
return q
}
//------------------------------------------------------------------------------
func (q *SelectQuery) Column(columns ...string) *SelectQuery {
for _, column := range columns {
q.addColumn(schema.UnsafeIdent(column))
}
return q
}
func (q *SelectQuery) ColumnExpr(query string, args ...interface{}) *SelectQuery {
q.addColumn(schema.SafeQuery(query, args))
return q
}
func (q *SelectQuery) ExcludeColumn(columns ...string) *SelectQuery {
q.excludeColumn(columns)
return q
}
//------------------------------------------------------------------------------
func (q *SelectQuery) WherePK() *SelectQuery {
q.flags = q.flags.Set(wherePKFlag)
return q
}
func (q *SelectQuery) Where(query string, args ...interface{}) *SelectQuery {
q.addWhere(schema.SafeQueryWithSep(query, args, " AND "))
return q
}
func (q *SelectQuery) WhereOr(query string, args ...interface{}) *SelectQuery {
q.addWhere(schema.SafeQueryWithSep(query, args, " OR "))
return q
}
func (q *SelectQuery) WhereGroup(sep string, fn func(*SelectQuery) *SelectQuery) *SelectQuery {
saved := q.where
q.where = nil
q = fn(q)
where := q.where
q.where = saved
q.addWhereGroup(sep, where)
return q
}
func (q *SelectQuery) WhereDeleted() *SelectQuery {
q.whereDeleted()
return q
}
func (q *SelectQuery) WhereAllWithDeleted() *SelectQuery {
q.whereAllWithDeleted()
return q
}
//------------------------------------------------------------------------------
func (q *SelectQuery) Group(columns ...string) *SelectQuery {
for _, column := range columns {
q.group = append(q.group, schema.UnsafeIdent(column))
}
return q
}
func (q *SelectQuery) GroupExpr(group string, args ...interface{}) *SelectQuery {
q.group = append(q.group, schema.SafeQuery(group, args))
return q
}
func (q *SelectQuery) Having(having string, args ...interface{}) *SelectQuery {
q.having = append(q.having, schema.SafeQuery(having, args))
return q
}
func (q *SelectQuery) Order(orders ...string) *SelectQuery {
for _, order := range orders {
if order == "" {
continue
}
index := strings.IndexByte(order, ' ')
if index == -1 {
q.order = append(q.order, schema.UnsafeIdent(order))
continue
}
field := order[:index]
sort := order[index+1:]
switch strings.ToUpper(sort) {
case "ASC", "DESC", "ASC NULLS FIRST", "DESC NULLS FIRST",
"ASC NULLS LAST", "DESC NULLS LAST":
q.order = append(q.order, schema.SafeQuery("? ?", []interface{}{
Ident(field),
Safe(sort),
}))
default:
q.order = append(q.order, schema.UnsafeIdent(order))
}
}
return q
}
func (q *SelectQuery) OrderExpr(query string, args ...interface{}) *SelectQuery {
q.order = append(q.order, schema.SafeQuery(query, args))
return q
}
func (q *SelectQuery) Limit(n int) *SelectQuery {
q.limit = int32(n)
return q
}
func (q *SelectQuery) Offset(n int) *SelectQuery {
q.offset = int32(n)
return q
}
func (q *SelectQuery) For(s string, args ...interface{}) *SelectQuery {
q.selFor = schema.SafeQuery(s, args)
return q
}
//------------------------------------------------------------------------------
func (q *SelectQuery) Union(other *SelectQuery) *SelectQuery {
return q.addUnion(" UNION ", other)
}
func (q *SelectQuery) UnionAll(other *SelectQuery) *SelectQuery {
return q.addUnion(" UNION ALL ", other)
}
func (q *SelectQuery) Intersect(other *SelectQuery) *SelectQuery {
return q.addUnion(" INTERSECT ", other)
}
func (q *SelectQuery) IntersectAll(other *SelectQuery) *SelectQuery {
return q.addUnion(" INTERSECT ALL ", other)
}
func (q *SelectQuery) Except(other *SelectQuery) *SelectQuery {
return q.addUnion(" EXCEPT ", other)
}
func (q *SelectQuery) ExceptAll(other *SelectQuery) *SelectQuery {
return q.addUnion(" EXCEPT ALL ", other)
}
func (q *SelectQuery) addUnion(expr string, other *SelectQuery) *SelectQuery {
q.union = append(q.union, union{
expr: expr,
query: other,
})
return q
}
//------------------------------------------------------------------------------
func (q *SelectQuery) Join(join string, args ...interface{}) *SelectQuery {
q.joins = append(q.joins, joinQuery{
join: schema.SafeQuery(join, args),
})
return q
}
func (q *SelectQuery) JoinOn(cond string, args ...interface{}) *SelectQuery {
return q.joinOn(cond, args, " AND ")
}
func (q *SelectQuery) JoinOnOr(cond string, args ...interface{}) *SelectQuery {
return q.joinOn(cond, args, " OR ")
}
func (q *SelectQuery) joinOn(cond string, args []interface{}, sep string) *SelectQuery {
if len(q.joins) == 0 {
q.err = errors.New("bun: query has no joins")
return q
}
j := &q.joins[len(q.joins)-1]
j.on = append(j.on, schema.SafeQueryWithSep(cond, args, sep))
return q
}
//------------------------------------------------------------------------------
// Relation adds a relation to the query. Relation name can be:
// - RelationName to select all columns,
// - RelationName.column_name,
// - RelationName._ to join relation without selecting relation columns.
func (q *SelectQuery) Relation(name string, apply ...func(*SelectQuery) *SelectQuery) *SelectQuery {
if q.tableModel == nil {
q.setErr(errNilModel)
return q
}
var fn func(*SelectQuery) *SelectQuery
if len(apply) == 1 {
fn = apply[0]
} else if len(apply) > 1 {
panic("only one apply function is supported")
}
join := q.tableModel.Join(name, fn)
if join == nil {
q.setErr(fmt.Errorf("%s does not have relation=%q", q.table, name))
return q
}
return q
}
func (q *SelectQuery) forEachHasOneJoin(fn func(*join) error) error {
if q.tableModel == nil {
return nil
}
return q._forEachHasOneJoin(fn, q.tableModel.GetJoins())
}
func (q *SelectQuery) _forEachHasOneJoin(fn func(*join) error, joins []join) error {
for i := range joins {
j := &joins[i]
switch j.Relation.Type {
case schema.HasOneRelation, schema.BelongsToRelation:
if err := fn(j); err != nil {
return err
}
if err := q._forEachHasOneJoin(fn, j.JoinModel.GetJoins()); err != nil {
return err
}
}
}
return nil
}
func (q *SelectQuery) selectJoins(ctx context.Context, joins []join) error {
var err error
for i := range joins {
j := &joins[i]
switch j.Relation.Type {
case schema.HasOneRelation, schema.BelongsToRelation:
err = q.selectJoins(ctx, j.JoinModel.GetJoins())
default:
err = j.Select(ctx, q.db.NewSelect())
}
if err != nil {
return err
}
}
return nil
}
//------------------------------------------------------------------------------
func (q *SelectQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) {
return q.appendQuery(fmter, b, false)
}
func (q *SelectQuery) appendQuery(
fmter schema.Formatter, b []byte, count bool,
) (_ []byte, err error) {
if q.err != nil {
return nil, q.err
}
fmter = formatterWithModel(fmter, q)
cteCount := count && (len(q.group) > 0 || q.distinctOn != nil)
if cteCount {
b = append(b, "WITH _count_wrapper AS ("...)
}
if len(q.union) > 0 {
b = append(b, '(')
}
b, err = q.appendWith(fmter, b)
if err != nil {
return nil, err
}
b = append(b, "SELECT "...)
if len(q.distinctOn) > 0 {
b = append(b, "DISTINCT ON ("...)
for i, app := range q.distinctOn {
if i > 0 {
b = append(b, ", "...)
}
b, err = app.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
}
b = append(b, ") "...)
} else if q.distinctOn != nil {
b = append(b, "DISTINCT "...)
}
if count && !cteCount {
b = append(b, "count(*)"...)
} else {
b, err = q.appendColumns(fmter, b)
if err != nil {
return nil, err
}
}
if q.hasTables() {
b, err = q.appendTables(fmter, b)
if err != nil {
return nil, err
}
}
if err := q.forEachHasOneJoin(func(j *join) error {
b = append(b, ' ')
b, err = j.appendHasOneJoin(fmter, b, q)
return err
}); err != nil {
return nil, err
}
for _, j := range q.joins {
b, err = j.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
}
b, err = q.appendWhere(fmter, b, true)
if err != nil {
return nil, err
}
if len(q.group) > 0 {
b = append(b, " GROUP BY "...)
for i, f := range q.group {
if i > 0 {
b = append(b, ", "...)
}
b, err = f.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
}
}
if len(q.having) > 0 {
b = append(b, " HAVING "...)
for i, f := range q.having {
if i > 0 {
b = append(b, " AND "...)
}
b = append(b, '(')
b, err = f.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
b = append(b, ')')
}
}
if !count {
b, err = q.appendOrder(fmter, b)
if err != nil {
return nil, err
}
if q.limit != 0 {
b = append(b, " LIMIT "...)
b = strconv.AppendInt(b, int64(q.limit), 10)
}
if q.offset != 0 {
b = append(b, " OFFSET "...)
b = strconv.AppendInt(b, int64(q.offset), 10)
}
if !q.selFor.IsZero() {
b = append(b, " FOR "...)
b, err = q.selFor.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
}
}
if len(q.union) > 0 {
b = append(b, ')')
for _, u := range q.union {
b = append(b, u.expr...)
b = append(b, '(')
b, err = u.query.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
b = append(b, ')')
}
}
if cteCount {
b = append(b, ") SELECT count(*) FROM _count_wrapper"...)
}
return b, nil
}
func (q *SelectQuery) appendColumns(fmter schema.Formatter, b []byte) (_ []byte, err error) {
start := len(b)
switch {
case q.columns != nil:
for i, col := range q.columns {
if i > 0 {
b = append(b, ", "...)
}
if col.Args == nil {
if field, ok := q.table.FieldMap[col.Query]; ok {
b = append(b, q.table.SQLAlias...)
b = append(b, '.')
b = append(b, field.SQLName...)
continue
}
}
b, err = col.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
}
case q.table != nil:
if len(q.table.Fields) > 10 && fmter.IsNop() {
b = append(b, q.table.SQLAlias...)
b = append(b, '.')
b = dialect.AppendString(b, fmt.Sprintf("%d columns", len(q.table.Fields)))
} else {
b = appendColumns(b, q.table.SQLAlias, q.table.Fields)
}
default:
b = append(b, '*')
}
if err := q.forEachHasOneJoin(func(j *join) error {
if len(b) != start {
b = append(b, ", "...)
start = len(b)
}
b, err = q.appendHasOneColumns(fmter, b, j)
if err != nil {
return err
}
return nil
}); err != nil {
return nil, err
}
b = bytes.TrimSuffix(b, []byte(", "))
return b, nil
}
func (q *SelectQuery) appendHasOneColumns(
fmter schema.Formatter, b []byte, join *join,
) (_ []byte, err error) {
join.applyQuery(q)
if join.columns != nil {
for i, col := range join.columns {
if i > 0 {
b = append(b, ", "...)
}
if col.Args == nil {
if field, ok := q.table.FieldMap[col.Query]; ok {
b = join.appendAlias(fmter, b)
b = append(b, '.')
b = append(b, field.SQLName...)
b = append(b, " AS "...)
b = join.appendAliasColumn(fmter, b, field.Name)
continue
}
}
b, err = col.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
}
return b, nil
}
for i, field := range join.JoinModel.Table().Fields {
if i > 0 {
b = append(b, ", "...)
}
b = join.appendAlias(fmter, b)
b = append(b, '.')
b = append(b, field.SQLName...)
b = append(b, " AS "...)
b = join.appendAliasColumn(fmter, b, field.Name)
}
return b, nil
}
func (q *SelectQuery) appendTables(fmter schema.Formatter, b []byte) (_ []byte, err error) {
b = append(b, " FROM "...)
return q.appendTablesWithAlias(fmter, b)
}
func (q *SelectQuery) appendOrder(fmter schema.Formatter, b []byte) (_ []byte, err error) {
if len(q.order) > 0 {
b = append(b, " ORDER BY "...)
for i, f := range q.order {
if i > 0 {
b = append(b, ", "...)
}
b, err = f.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
}
return b, nil
}
return b, nil
}
//------------------------------------------------------------------------------
func (q *SelectQuery) Rows(ctx context.Context) (*sql.Rows, error) {
queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes())
if err != nil {
return nil, err
}
query := internal.String(queryBytes)
return q.conn.QueryContext(ctx, query)
}
func (q *SelectQuery) Exec(ctx context.Context) (res sql.Result, err error) {
queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes())
if err != nil {
return nil, err
}
query := internal.String(queryBytes)
res, err = q.exec(ctx, q, query)
if err != nil {
return nil, err
}
return res, nil
}
func (q *SelectQuery) Scan(ctx context.Context, dest ...interface{}) error {
model, err := q.getModel(dest)
if err != nil {
return err
}
if q.limit > 1 {
if model, ok := model.(interface{ SetCap(int) }); ok {
model.SetCap(int(q.limit))
}
}
if q.table != nil {
if err := q.beforeSelectHook(ctx); err != nil {
return err
}
}
queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes())
if err != nil {
return err
}
query := internal.String(queryBytes)
res, err := q.scan(ctx, q, query, model, true)
if err != nil {
return err
}
if res.n > 0 {
if tableModel, ok := model.(tableModel); ok {
if err := q.selectJoins(ctx, tableModel.GetJoins()); err != nil {
return err
}
}
}
if q.table != nil {
if err := q.afterSelectHook(ctx); err != nil {
return err
}
}
return nil
}
func (q *SelectQuery) beforeSelectHook(ctx context.Context) error {
if hook, ok := q.table.ZeroIface.(BeforeSelectHook); ok {
if err := hook.BeforeSelect(ctx, q); err != nil {
return err
}
}
return nil
}
func (q *SelectQuery) afterSelectHook(ctx context.Context) error {
if hook, ok := q.table.ZeroIface.(AfterSelectHook); ok {
if err := hook.AfterSelect(ctx, q); err != nil {
return err
}
}
return nil
}
func (q *SelectQuery) Count(ctx context.Context) (int, error) {
qq := countQuery{q}
queryBytes, err := qq.appendQuery(q.db.fmter, nil, true)
if err != nil {
return 0, err
}
query := internal.String(queryBytes)
ctx, event := q.db.beforeQuery(ctx, qq, query, nil)
var num int
err = q.conn.QueryRowContext(ctx, query).Scan(&num)
q.db.afterQuery(ctx, event, nil, err)
return num, err
}
func (q *SelectQuery) ScanAndCount(ctx context.Context, dest ...interface{}) (int, error) {
var count int
var wg sync.WaitGroup
var mu sync.Mutex
var firstErr error
if q.limit >= 0 {
wg.Add(1)
go func() {
defer wg.Done()
if err := q.Scan(ctx, dest...); err != nil {
mu.Lock()
if firstErr == nil {
firstErr = err
}
mu.Unlock()
}
}()
}
wg.Add(1)
go func() {
defer wg.Done()
var err error
count, err = q.Count(ctx)
if err != nil {
mu.Lock()
if firstErr == nil {
firstErr = err
}
mu.Unlock()
}
}()
wg.Wait()
return count, firstErr
}
//------------------------------------------------------------------------------
type joinQuery struct {
join schema.QueryWithArgs
on []schema.QueryWithSep
}
func (j *joinQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) {
b = append(b, ' ')
b, err = j.join.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
if len(j.on) > 0 {
b = append(b, " ON "...)
for i, on := range j.on {
if i > 0 {
b = append(b, on.Sep...)
}
b = append(b, '(')
b, err = on.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
b = append(b, ')')
}
}
return b, nil
}
//------------------------------------------------------------------------------
type countQuery struct {
*SelectQuery
}
func (q countQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) {
return q.appendQuery(fmter, b, true)
}

275
vendor/github.com/uptrace/bun/query_table_create.go generated vendored Normal file
View file

@ -0,0 +1,275 @@
package bun
import (
"context"
"database/sql"
"sort"
"strconv"
"github.com/uptrace/bun/dialect/feature"
"github.com/uptrace/bun/dialect/sqltype"
"github.com/uptrace/bun/internal"
"github.com/uptrace/bun/schema"
)
type CreateTableQuery struct {
baseQuery
temp bool
ifNotExists bool
varchar int
fks []schema.QueryWithArgs
partitionBy schema.QueryWithArgs
tablespace schema.QueryWithArgs
}
func NewCreateTableQuery(db *DB) *CreateTableQuery {
q := &CreateTableQuery{
baseQuery: baseQuery{
db: db,
conn: db.DB,
},
}
return q
}
func (q *CreateTableQuery) Conn(db IConn) *CreateTableQuery {
q.setConn(db)
return q
}
func (q *CreateTableQuery) Model(model interface{}) *CreateTableQuery {
q.setTableModel(model)
return q
}
//------------------------------------------------------------------------------
func (q *CreateTableQuery) Table(tables ...string) *CreateTableQuery {
for _, table := range tables {
q.addTable(schema.UnsafeIdent(table))
}
return q
}
func (q *CreateTableQuery) TableExpr(query string, args ...interface{}) *CreateTableQuery {
q.addTable(schema.SafeQuery(query, args))
return q
}
func (q *CreateTableQuery) ModelTableExpr(query string, args ...interface{}) *CreateTableQuery {
q.modelTable = schema.SafeQuery(query, args)
return q
}
//------------------------------------------------------------------------------
func (q *CreateTableQuery) Temp() *CreateTableQuery {
q.temp = true
return q
}
func (q *CreateTableQuery) IfNotExists() *CreateTableQuery {
q.ifNotExists = true
return q
}
func (q *CreateTableQuery) Varchar(n int) *CreateTableQuery {
q.varchar = n
return q
}
func (q *CreateTableQuery) ForeignKey(query string, args ...interface{}) *CreateTableQuery {
q.fks = append(q.fks, schema.SafeQuery(query, args))
return q
}
func (q *CreateTableQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) {
if q.err != nil {
return nil, q.err
}
if q.table == nil {
return nil, errNilModel
}
b = append(b, "CREATE "...)
if q.temp {
b = append(b, "TEMP "...)
}
b = append(b, "TABLE "...)
if q.ifNotExists {
b = append(b, "IF NOT EXISTS "...)
}
b, err = q.appendFirstTable(fmter, b)
if err != nil {
return nil, err
}
b = append(b, " ("...)
for i, field := range q.table.Fields {
if i > 0 {
b = append(b, ", "...)
}
b = append(b, field.SQLName...)
b = append(b, " "...)
b = q.appendSQLType(b, field)
if field.NotNull {
b = append(b, " NOT NULL"...)
}
if q.db.features.Has(feature.AutoIncrement) && field.AutoIncrement {
b = append(b, " AUTO_INCREMENT"...)
}
if field.SQLDefault != "" {
b = append(b, " DEFAULT "...)
b = append(b, field.SQLDefault...)
}
}
b = q.appendPKConstraint(b, q.table.PKs)
b = q.appendUniqueConstraints(fmter, b)
b, err = q.appenFKConstraints(fmter, b)
if err != nil {
return nil, err
}
b = append(b, ")"...)
if !q.partitionBy.IsZero() {
b = append(b, " PARTITION BY "...)
b, err = q.partitionBy.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
}
if !q.tablespace.IsZero() {
b = append(b, " TABLESPACE "...)
b, err = q.tablespace.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
}
return b, nil
}
func (q *CreateTableQuery) appendSQLType(b []byte, field *schema.Field) []byte {
if field.CreateTableSQLType != field.DiscoveredSQLType {
return append(b, field.CreateTableSQLType...)
}
if q.varchar > 0 &&
field.CreateTableSQLType == sqltype.VarChar {
b = append(b, "varchar("...)
b = strconv.AppendInt(b, int64(q.varchar), 10)
b = append(b, ")"...)
return b
}
return append(b, field.CreateTableSQLType...)
}
func (q *CreateTableQuery) appendUniqueConstraints(fmter schema.Formatter, b []byte) []byte {
unique := q.table.Unique
keys := make([]string, 0, len(unique))
for key := range unique {
keys = append(keys, key)
}
sort.Strings(keys)
for _, key := range keys {
b = q.appendUniqueConstraint(fmter, b, key, unique[key])
}
return b
}
func (q *CreateTableQuery) appendUniqueConstraint(
fmter schema.Formatter, b []byte, name string, fields []*schema.Field,
) []byte {
if name != "" {
b = append(b, ", CONSTRAINT "...)
b = fmter.AppendIdent(b, name)
} else {
b = append(b, ","...)
}
b = append(b, " UNIQUE ("...)
b = appendColumns(b, "", fields)
b = append(b, ")"...)
return b
}
func (q *CreateTableQuery) appenFKConstraints(
fmter schema.Formatter, b []byte,
) (_ []byte, err error) {
for _, fk := range q.fks {
b = append(b, ", FOREIGN KEY "...)
b, err = fk.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
}
return b, nil
}
func (q *CreateTableQuery) appendPKConstraint(b []byte, pks []*schema.Field) []byte {
if len(pks) == 0 {
return b
}
b = append(b, ", PRIMARY KEY ("...)
b = appendColumns(b, "", pks)
b = append(b, ")"...)
return b
}
//------------------------------------------------------------------------------
func (q *CreateTableQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) {
if err := q.beforeCreateTableHook(ctx); err != nil {
return nil, err
}
queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes())
if err != nil {
return nil, err
}
query := internal.String(queryBytes)
res, err := q.exec(ctx, q, query)
if err != nil {
return nil, err
}
if q.table != nil {
if err := q.afterCreateTableHook(ctx); err != nil {
return nil, err
}
}
return res, nil
}
func (q *CreateTableQuery) beforeCreateTableHook(ctx context.Context) error {
if hook, ok := q.table.ZeroIface.(BeforeCreateTableHook); ok {
if err := hook.BeforeCreateTable(ctx, q); err != nil {
return err
}
}
return nil
}
func (q *CreateTableQuery) afterCreateTableHook(ctx context.Context) error {
if hook, ok := q.table.ZeroIface.(AfterCreateTableHook); ok {
if err := hook.AfterCreateTable(ctx, q); err != nil {
return err
}
}
return nil
}

137
vendor/github.com/uptrace/bun/query_table_drop.go generated vendored Normal file
View file

@ -0,0 +1,137 @@
package bun
import (
"context"
"database/sql"
"github.com/uptrace/bun/internal"
"github.com/uptrace/bun/schema"
)
type DropTableQuery struct {
baseQuery
cascadeQuery
ifExists bool
}
func NewDropTableQuery(db *DB) *DropTableQuery {
q := &DropTableQuery{
baseQuery: baseQuery{
db: db,
conn: db.DB,
},
}
return q
}
func (q *DropTableQuery) Conn(db IConn) *DropTableQuery {
q.setConn(db)
return q
}
func (q *DropTableQuery) Model(model interface{}) *DropTableQuery {
q.setTableModel(model)
return q
}
//------------------------------------------------------------------------------
func (q *DropTableQuery) Table(tables ...string) *DropTableQuery {
for _, table := range tables {
q.addTable(schema.UnsafeIdent(table))
}
return q
}
func (q *DropTableQuery) TableExpr(query string, args ...interface{}) *DropTableQuery {
q.addTable(schema.SafeQuery(query, args))
return q
}
func (q *DropTableQuery) ModelTableExpr(query string, args ...interface{}) *DropTableQuery {
q.modelTable = schema.SafeQuery(query, args)
return q
}
//------------------------------------------------------------------------------
func (q *DropTableQuery) IfExists() *DropTableQuery {
q.ifExists = true
return q
}
func (q *DropTableQuery) Restrict() *DropTableQuery {
q.restrict = true
return q
}
//------------------------------------------------------------------------------
func (q *DropTableQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) {
if q.err != nil {
return nil, q.err
}
b = append(b, "DROP TABLE "...)
if q.ifExists {
b = append(b, "IF EXISTS "...)
}
b, err = q.appendTables(fmter, b)
if err != nil {
return nil, err
}
b = q.appendCascade(fmter, b)
return b, nil
}
//------------------------------------------------------------------------------
func (q *DropTableQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) {
if q.table != nil {
if err := q.beforeDropTableHook(ctx); err != nil {
return nil, err
}
}
queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes())
if err != nil {
return nil, err
}
query := internal.String(queryBytes)
res, err := q.exec(ctx, q, query)
if err != nil {
return nil, err
}
if q.table != nil {
if err := q.afterDropTableHook(ctx); err != nil {
return nil, err
}
}
return res, nil
}
func (q *DropTableQuery) beforeDropTableHook(ctx context.Context) error {
if hook, ok := q.table.ZeroIface.(BeforeDropTableHook); ok {
if err := hook.BeforeDropTable(ctx, q); err != nil {
return err
}
}
return nil
}
func (q *DropTableQuery) afterDropTableHook(ctx context.Context) error {
if hook, ok := q.table.ZeroIface.(AfterDropTableHook); ok {
if err := hook.AfterDropTable(ctx, q); err != nil {
return err
}
}
return nil
}

121
vendor/github.com/uptrace/bun/query_table_truncate.go generated vendored Normal file
View file

@ -0,0 +1,121 @@
package bun
import (
"context"
"database/sql"
"github.com/uptrace/bun/dialect/feature"
"github.com/uptrace/bun/internal"
"github.com/uptrace/bun/schema"
)
type TruncateTableQuery struct {
baseQuery
cascadeQuery
continueIdentity bool
}
func NewTruncateTableQuery(db *DB) *TruncateTableQuery {
q := &TruncateTableQuery{
baseQuery: baseQuery{
db: db,
conn: db.DB,
},
}
return q
}
func (q *TruncateTableQuery) Conn(db IConn) *TruncateTableQuery {
q.setConn(db)
return q
}
func (q *TruncateTableQuery) Model(model interface{}) *TruncateTableQuery {
q.setTableModel(model)
return q
}
//------------------------------------------------------------------------------
func (q *TruncateTableQuery) Table(tables ...string) *TruncateTableQuery {
for _, table := range tables {
q.addTable(schema.UnsafeIdent(table))
}
return q
}
func (q *TruncateTableQuery) TableExpr(query string, args ...interface{}) *TruncateTableQuery {
q.addTable(schema.SafeQuery(query, args))
return q
}
//------------------------------------------------------------------------------
func (q *TruncateTableQuery) ContinueIdentity() *TruncateTableQuery {
q.continueIdentity = true
return q
}
func (q *TruncateTableQuery) Restrict() *TruncateTableQuery {
q.restrict = true
return q
}
//------------------------------------------------------------------------------
func (q *TruncateTableQuery) AppendQuery(
fmter schema.Formatter, b []byte,
) (_ []byte, err error) {
if q.err != nil {
return nil, q.err
}
if !fmter.HasFeature(feature.TableTruncate) {
b = append(b, "DELETE FROM "...)
b, err = q.appendTables(fmter, b)
if err != nil {
return nil, err
}
return b, nil
}
b = append(b, "TRUNCATE TABLE "...)
b, err = q.appendTables(fmter, b)
if err != nil {
return nil, err
}
if q.db.features.Has(feature.TableIdentity) {
if q.continueIdentity {
b = append(b, " CONTINUE IDENTITY"...)
} else {
b = append(b, " RESTART IDENTITY"...)
}
}
b = q.appendCascade(fmter, b)
return b, nil
}
//------------------------------------------------------------------------------
func (q *TruncateTableQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) {
queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes())
if err != nil {
return nil, err
}
query := internal.String(queryBytes)
res, err := q.exec(ctx, q, query)
if err != nil {
return nil, err
}
return res, nil
}

432
vendor/github.com/uptrace/bun/query_update.go generated vendored Normal file
View file

@ -0,0 +1,432 @@
package bun
import (
"context"
"database/sql"
"errors"
"fmt"
"github.com/uptrace/bun/dialect/feature"
"github.com/uptrace/bun/internal"
"github.com/uptrace/bun/schema"
)
type UpdateQuery struct {
whereBaseQuery
returningQuery
customValueQuery
setQuery
omitZero bool
}
func NewUpdateQuery(db *DB) *UpdateQuery {
q := &UpdateQuery{
whereBaseQuery: whereBaseQuery{
baseQuery: baseQuery{
db: db,
conn: db.DB,
},
},
}
return q
}
func (q *UpdateQuery) Conn(db IConn) *UpdateQuery {
q.setConn(db)
return q
}
func (q *UpdateQuery) Model(model interface{}) *UpdateQuery {
q.setTableModel(model)
return q
}
// Apply calls the fn passing the SelectQuery as an argument.
func (q *UpdateQuery) Apply(fn func(*UpdateQuery) *UpdateQuery) *UpdateQuery {
return fn(q)
}
func (q *UpdateQuery) With(name string, query schema.QueryAppender) *UpdateQuery {
q.addWith(name, query)
return q
}
//------------------------------------------------------------------------------
func (q *UpdateQuery) Table(tables ...string) *UpdateQuery {
for _, table := range tables {
q.addTable(schema.UnsafeIdent(table))
}
return q
}
func (q *UpdateQuery) TableExpr(query string, args ...interface{}) *UpdateQuery {
q.addTable(schema.SafeQuery(query, args))
return q
}
func (q *UpdateQuery) ModelTableExpr(query string, args ...interface{}) *UpdateQuery {
q.modelTable = schema.SafeQuery(query, args)
return q
}
//------------------------------------------------------------------------------
func (q *UpdateQuery) Column(columns ...string) *UpdateQuery {
for _, column := range columns {
q.addColumn(schema.UnsafeIdent(column))
}
return q
}
func (q *UpdateQuery) ExcludeColumn(columns ...string) *UpdateQuery {
q.excludeColumn(columns)
return q
}
func (q *UpdateQuery) Set(query string, args ...interface{}) *UpdateQuery {
q.addSet(schema.SafeQuery(query, args))
return q
}
// Value overwrites model value for the column in INSERT and UPDATE queries.
func (q *UpdateQuery) Value(column string, value string, args ...interface{}) *UpdateQuery {
if q.table == nil {
q.err = errNilModel
return q
}
q.addValue(q.table, column, value, args)
return q
}
//------------------------------------------------------------------------------
func (q *UpdateQuery) WherePK() *UpdateQuery {
q.flags = q.flags.Set(wherePKFlag)
return q
}
func (q *UpdateQuery) Where(query string, args ...interface{}) *UpdateQuery {
q.addWhere(schema.SafeQueryWithSep(query, args, " AND "))
return q
}
func (q *UpdateQuery) WhereOr(query string, args ...interface{}) *UpdateQuery {
q.addWhere(schema.SafeQueryWithSep(query, args, " OR "))
return q
}
func (q *UpdateQuery) WhereGroup(sep string, fn func(*UpdateQuery) *UpdateQuery) *UpdateQuery {
saved := q.where
q.where = nil
q = fn(q)
where := q.where
q.where = saved
q.addWhereGroup(sep, where)
return q
}
func (q *UpdateQuery) WhereDeleted() *UpdateQuery {
q.whereDeleted()
return q
}
func (q *UpdateQuery) WhereAllWithDeleted() *UpdateQuery {
q.whereAllWithDeleted()
return q
}
//------------------------------------------------------------------------------
// Returning adds a RETURNING clause to the query.
//
// To suppress the auto-generated RETURNING clause, use `Returning("NULL")`.
func (q *UpdateQuery) Returning(query string, args ...interface{}) *UpdateQuery {
q.addReturning(schema.SafeQuery(query, args))
return q
}
func (q *UpdateQuery) hasReturning() bool {
if !q.db.features.Has(feature.Returning) {
return false
}
return q.returningQuery.hasReturning()
}
//------------------------------------------------------------------------------
func (q *UpdateQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) {
if q.err != nil {
return nil, q.err
}
fmter = formatterWithModel(fmter, q)
withAlias := fmter.HasFeature(feature.UpdateMultiTable)
b, err = q.appendWith(fmter, b)
if err != nil {
return nil, err
}
b = append(b, "UPDATE "...)
if withAlias {
b, err = q.appendTablesWithAlias(fmter, b)
} else {
b, err = q.appendFirstTableWithAlias(fmter, b)
}
if err != nil {
return nil, err
}
b, err = q.mustAppendSet(fmter, b)
if err != nil {
return nil, err
}
if !fmter.HasFeature(feature.UpdateMultiTable) {
b, err = q.appendOtherTables(fmter, b)
if err != nil {
return nil, err
}
}
b, err = q.mustAppendWhere(fmter, b, withAlias)
if err != nil {
return nil, err
}
if len(q.returning) > 0 {
b, err = q.appendReturning(fmter, b)
if err != nil {
return nil, err
}
}
return b, nil
}
func (q *UpdateQuery) mustAppendSet(fmter schema.Formatter, b []byte) (_ []byte, err error) {
b = append(b, " SET "...)
if len(q.set) > 0 {
return q.appendSet(fmter, b)
}
if m, ok := q.model.(*mapModel); ok {
return m.appendSet(fmter, b), nil
}
if q.tableModel == nil {
return nil, errNilModel
}
switch model := q.tableModel.(type) {
case *structTableModel:
b, err = q.appendSetStruct(fmter, b, model)
if err != nil {
return nil, err
}
case *sliceTableModel:
return nil, errors.New("bun: to bulk Update, use CTE and VALUES")
default:
return nil, fmt.Errorf("bun: Update does not support %T", q.tableModel)
}
return b, nil
}
func (q *UpdateQuery) appendSetStruct(
fmter schema.Formatter, b []byte, model *structTableModel,
) ([]byte, error) {
fields, err := q.getDataFields()
if err != nil {
return nil, err
}
isTemplate := fmter.IsNop()
pos := len(b)
for _, f := range fields {
if q.omitZero && f.NullZero && f.HasZeroValue(model.strct) {
continue
}
if len(b) != pos {
b = append(b, ", "...)
pos = len(b)
}
b = append(b, f.SQLName...)
b = append(b, " = "...)
if isTemplate {
b = append(b, '?')
continue
}
app, ok := q.modelValues[f.Name]
if ok {
b, err = app.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
} else {
b = f.AppendValue(fmter, b, model.strct)
}
}
for i, v := range q.extraValues {
if i > 0 || len(fields) > 0 {
b = append(b, ", "...)
}
b = append(b, v.column...)
b = append(b, " = "...)
b, err = v.value.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
}
return b, nil
}
func (q *UpdateQuery) appendOtherTables(fmter schema.Formatter, b []byte) (_ []byte, err error) {
if !q.hasMultiTables() {
return b, nil
}
b = append(b, " FROM "...)
b, err = q.whereBaseQuery.appendOtherTables(fmter, b)
if err != nil {
return nil, err
}
return b, nil
}
//------------------------------------------------------------------------------
func (q *UpdateQuery) Bulk() *UpdateQuery {
model, ok := q.model.(*sliceTableModel)
if !ok {
q.setErr(fmt.Errorf("bun: Bulk requires a slice, got %T", q.model))
return q
}
return q.With("_data", q.db.NewValues(model)).
Model(model).
TableExpr("_data").
Set(q.updateSliceSet(model)).
Where(q.updateSliceWhere(model))
}
func (q *UpdateQuery) updateSliceSet(model *sliceTableModel) string {
var b []byte
for i, field := range model.table.DataFields {
if i > 0 {
b = append(b, ", "...)
}
if q.db.fmter.HasFeature(feature.UpdateMultiTable) {
b = append(b, model.table.SQLAlias...)
b = append(b, '.')
}
b = append(b, field.SQLName...)
b = append(b, " = _data."...)
b = append(b, field.SQLName...)
}
return internal.String(b)
}
func (db *UpdateQuery) updateSliceWhere(model *sliceTableModel) string {
var b []byte
for i, pk := range model.table.PKs {
if i > 0 {
b = append(b, " AND "...)
}
b = append(b, model.table.SQLAlias...)
b = append(b, '.')
b = append(b, pk.SQLName...)
b = append(b, " = _data."...)
b = append(b, pk.SQLName...)
}
return internal.String(b)
}
//------------------------------------------------------------------------------
func (q *UpdateQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) {
if q.table != nil {
if err := q.beforeUpdateHook(ctx); err != nil {
return nil, err
}
}
queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes())
if err != nil {
return nil, err
}
query := internal.String(queryBytes)
var res sql.Result
if hasDest := len(dest) > 0; hasDest || q.hasReturning() {
model, err := q.getModel(dest)
if err != nil {
return nil, err
}
res, err = q.scan(ctx, q, query, model, hasDest)
if err != nil {
return nil, err
}
} else {
res, err = q.exec(ctx, q, query)
if err != nil {
return nil, err
}
}
if q.table != nil {
if err := q.afterUpdateHook(ctx); err != nil {
return nil, err
}
}
return res, nil
}
func (q *UpdateQuery) beforeUpdateHook(ctx context.Context) error {
if hook, ok := q.table.ZeroIface.(BeforeUpdateHook); ok {
if err := hook.BeforeUpdate(ctx, q); err != nil {
return err
}
}
return nil
}
func (q *UpdateQuery) afterUpdateHook(ctx context.Context) error {
if hook, ok := q.table.ZeroIface.(AfterUpdateHook); ok {
if err := hook.AfterUpdate(ctx, q); err != nil {
return err
}
}
return nil
}
// FQN returns a fully qualified column name. For MySQL, it returns the column name with
// the table alias. For other RDBMS, it returns just the column name.
func (q *UpdateQuery) FQN(name string) Ident {
if q.db.fmter.HasFeature(feature.UpdateMultiTable) {
return Ident(q.table.Alias + "." + name)
}
return Ident(name)
}

198
vendor/github.com/uptrace/bun/query_values.go generated vendored Normal file
View file

@ -0,0 +1,198 @@
package bun
import (
"fmt"
"reflect"
"strconv"
"github.com/uptrace/bun/dialect/feature"
"github.com/uptrace/bun/schema"
)
type ValuesQuery struct {
baseQuery
customValueQuery
withOrder bool
}
var _ schema.NamedArgAppender = (*ValuesQuery)(nil)
func NewValuesQuery(db *DB, model interface{}) *ValuesQuery {
q := &ValuesQuery{
baseQuery: baseQuery{
db: db,
conn: db.DB,
},
}
q.setTableModel(model)
return q
}
func (q *ValuesQuery) Conn(db IConn) *ValuesQuery {
q.setConn(db)
return q
}
func (q *ValuesQuery) WithOrder() *ValuesQuery {
q.withOrder = true
return q
}
func (q *ValuesQuery) AppendNamedArg(fmter schema.Formatter, b []byte, name string) ([]byte, bool) {
switch name {
case "Columns":
bb, err := q.AppendColumns(fmter, b)
if err != nil {
q.setErr(err)
return b, true
}
return bb, true
}
return b, false
}
// AppendColumns appends the table columns. It is used by CTE.
func (q *ValuesQuery) AppendColumns(fmter schema.Formatter, b []byte) (_ []byte, err error) {
if q.err != nil {
return nil, q.err
}
if q.model == nil {
return nil, errNilModel
}
if q.tableModel != nil {
fields, err := q.getFields()
if err != nil {
return nil, err
}
b = appendColumns(b, "", fields)
if q.withOrder {
b = append(b, ", _order"...)
}
return b, nil
}
switch model := q.model.(type) {
case *mapSliceModel:
return model.appendColumns(fmter, b)
}
return nil, fmt.Errorf("bun: Values does not support %T", q.model)
}
func (q *ValuesQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) {
if q.err != nil {
return nil, q.err
}
if q.model == nil {
return nil, errNilModel
}
fmter = formatterWithModel(fmter, q)
if q.tableModel != nil {
fields, err := q.getFields()
if err != nil {
return nil, err
}
return q.appendQuery(fmter, b, fields)
}
switch model := q.model.(type) {
case *mapSliceModel:
return model.appendValues(fmter, b)
}
return nil, fmt.Errorf("bun: Values does not support %T", q.model)
}
func (q *ValuesQuery) appendQuery(
fmter schema.Formatter,
b []byte,
fields []*schema.Field,
) (_ []byte, err error) {
b = append(b, "VALUES "...)
if q.db.features.Has(feature.ValuesRow) {
b = append(b, "ROW("...)
} else {
b = append(b, '(')
}
switch model := q.tableModel.(type) {
case *structTableModel:
b, err = q.appendValues(fmter, b, fields, model.strct)
if err != nil {
return nil, err
}
if q.withOrder {
b = append(b, ", "...)
b = strconv.AppendInt(b, 0, 10)
}
case *sliceTableModel:
slice := model.slice
sliceLen := slice.Len()
for i := 0; i < sliceLen; i++ {
if i > 0 {
b = append(b, "), "...)
if q.db.features.Has(feature.ValuesRow) {
b = append(b, "ROW("...)
} else {
b = append(b, '(')
}
}
b, err = q.appendValues(fmter, b, fields, slice.Index(i))
if err != nil {
return nil, err
}
if q.withOrder {
b = append(b, ", "...)
b = strconv.AppendInt(b, int64(i), 10)
}
}
default:
return nil, fmt.Errorf("bun: Values does not support %T", q.model)
}
b = append(b, ')')
return b, nil
}
func (q *ValuesQuery) appendValues(
fmter schema.Formatter, b []byte, fields []*schema.Field, strct reflect.Value,
) (_ []byte, err error) {
isTemplate := fmter.IsNop()
for i, f := range fields {
if i > 0 {
b = append(b, ", "...)
}
app, ok := q.modelValues[f.Name]
if ok {
b, err = app.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
continue
}
if isTemplate {
b = append(b, '?')
} else {
b = f.AppendValue(fmter, b, indirect(strct))
}
if fmter.HasFeature(feature.DoubleColonCast) {
b = append(b, "::"...)
b = append(b, f.UserSQLType...)
}
}
return b, nil
}

93
vendor/github.com/uptrace/bun/schema/append.go generated vendored Normal file
View file

@ -0,0 +1,93 @@
package schema
import (
"reflect"
"strconv"
"strings"
"time"
"github.com/vmihailenco/msgpack/v5"
"github.com/uptrace/bun/dialect"
"github.com/uptrace/bun/dialect/sqltype"
"github.com/uptrace/bun/internal"
)
func FieldAppender(dialect Dialect, field *Field) AppenderFunc {
if field.Tag.HasOption("msgpack") {
return appendMsgpack
}
switch strings.ToUpper(field.UserSQLType) {
case sqltype.JSON, sqltype.JSONB:
return AppendJSONValue
}
return dialect.Appender(field.StructField.Type)
}
func Append(fmter Formatter, b []byte, v interface{}, custom CustomAppender) []byte {
switch v := v.(type) {
case nil:
return dialect.AppendNull(b)
case bool:
return dialect.AppendBool(b, v)
case int:
return strconv.AppendInt(b, int64(v), 10)
case int32:
return strconv.AppendInt(b, int64(v), 10)
case int64:
return strconv.AppendInt(b, v, 10)
case uint:
return strconv.AppendUint(b, uint64(v), 10)
case uint32:
return strconv.AppendUint(b, uint64(v), 10)
case uint64:
return strconv.AppendUint(b, v, 10)
case float32:
return dialect.AppendFloat32(b, v)
case float64:
return dialect.AppendFloat64(b, v)
case string:
return dialect.AppendString(b, v)
case time.Time:
return dialect.AppendTime(b, v)
case []byte:
return dialect.AppendBytes(b, v)
case QueryAppender:
return AppendQueryAppender(fmter, b, v)
default:
vv := reflect.ValueOf(v)
if vv.Kind() == reflect.Ptr && vv.IsNil() {
return dialect.AppendNull(b)
}
appender := Appender(vv.Type(), custom)
return appender(fmter, b, vv)
}
}
func appendMsgpack(fmter Formatter, b []byte, v reflect.Value) []byte {
hexEnc := internal.NewHexEncoder(b)
enc := msgpack.GetEncoder()
defer msgpack.PutEncoder(enc)
enc.Reset(hexEnc)
if err := enc.EncodeValue(v); err != nil {
return dialect.AppendError(b, err)
}
if err := hexEnc.Close(); err != nil {
return dialect.AppendError(b, err)
}
return hexEnc.Bytes()
}
func AppendQueryAppender(fmter Formatter, b []byte, app QueryAppender) []byte {
bb, err := app.AppendQuery(fmter, b)
if err != nil {
return dialect.AppendError(b, err)
}
return bb
}

237
vendor/github.com/uptrace/bun/schema/append_value.go generated vendored Normal file
View file

@ -0,0 +1,237 @@
package schema
import (
"database/sql/driver"
"encoding/json"
"fmt"
"net"
"reflect"
"strconv"
"time"
"github.com/uptrace/bun/dialect"
"github.com/uptrace/bun/extra/bunjson"
"github.com/uptrace/bun/internal"
)
var (
timeType = reflect.TypeOf((*time.Time)(nil)).Elem()
ipType = reflect.TypeOf((*net.IP)(nil)).Elem()
ipNetType = reflect.TypeOf((*net.IPNet)(nil)).Elem()
jsonRawMessageType = reflect.TypeOf((*json.RawMessage)(nil)).Elem()
driverValuerType = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
queryAppenderType = reflect.TypeOf((*QueryAppender)(nil)).Elem()
)
type (
AppenderFunc func(fmter Formatter, b []byte, v reflect.Value) []byte
CustomAppender func(typ reflect.Type) AppenderFunc
)
var appenders = []AppenderFunc{
reflect.Bool: AppendBoolValue,
reflect.Int: AppendIntValue,
reflect.Int8: AppendIntValue,
reflect.Int16: AppendIntValue,
reflect.Int32: AppendIntValue,
reflect.Int64: AppendIntValue,
reflect.Uint: AppendUintValue,
reflect.Uint8: AppendUintValue,
reflect.Uint16: AppendUintValue,
reflect.Uint32: AppendUintValue,
reflect.Uint64: AppendUintValue,
reflect.Uintptr: nil,
reflect.Float32: AppendFloat32Value,
reflect.Float64: AppendFloat64Value,
reflect.Complex64: nil,
reflect.Complex128: nil,
reflect.Array: AppendJSONValue,
reflect.Chan: nil,
reflect.Func: nil,
reflect.Interface: nil,
reflect.Map: AppendJSONValue,
reflect.Ptr: nil,
reflect.Slice: AppendJSONValue,
reflect.String: AppendStringValue,
reflect.Struct: AppendJSONValue,
reflect.UnsafePointer: nil,
}
func Appender(typ reflect.Type, custom CustomAppender) AppenderFunc {
switch typ {
case timeType:
return appendTimeValue
case ipType:
return appendIPValue
case ipNetType:
return appendIPNetValue
case jsonRawMessageType:
return appendJSONRawMessageValue
}
if typ.Implements(queryAppenderType) {
return appendQueryAppenderValue
}
if typ.Implements(driverValuerType) {
return driverValueAppender(custom)
}
kind := typ.Kind()
if kind != reflect.Ptr {
ptr := reflect.PtrTo(typ)
if ptr.Implements(queryAppenderType) {
return addrAppender(appendQueryAppenderValue, custom)
}
if ptr.Implements(driverValuerType) {
return addrAppender(driverValueAppender(custom), custom)
}
}
switch kind {
case reflect.Interface:
return ifaceAppenderFunc(typ, custom)
case reflect.Ptr:
return ptrAppenderFunc(typ, custom)
case reflect.Slice:
if typ.Elem().Kind() == reflect.Uint8 {
return appendBytesValue
}
case reflect.Array:
if typ.Elem().Kind() == reflect.Uint8 {
return appendArrayBytesValue
}
}
if custom != nil {
if fn := custom(typ); fn != nil {
return fn
}
}
return appenders[typ.Kind()]
}
func ifaceAppenderFunc(typ reflect.Type, custom func(reflect.Type) AppenderFunc) AppenderFunc {
return func(fmter Formatter, b []byte, v reflect.Value) []byte {
if v.IsNil() {
return dialect.AppendNull(b)
}
elem := v.Elem()
appender := Appender(elem.Type(), custom)
return appender(fmter, b, elem)
}
}
func ptrAppenderFunc(typ reflect.Type, custom func(reflect.Type) AppenderFunc) AppenderFunc {
appender := Appender(typ.Elem(), custom)
return func(fmter Formatter, b []byte, v reflect.Value) []byte {
if v.IsNil() {
return dialect.AppendNull(b)
}
return appender(fmter, b, v.Elem())
}
}
func AppendBoolValue(fmter Formatter, b []byte, v reflect.Value) []byte {
return dialect.AppendBool(b, v.Bool())
}
func AppendIntValue(fmter Formatter, b []byte, v reflect.Value) []byte {
return strconv.AppendInt(b, v.Int(), 10)
}
func AppendUintValue(fmter Formatter, b []byte, v reflect.Value) []byte {
return strconv.AppendUint(b, v.Uint(), 10)
}
func AppendFloat32Value(fmter Formatter, b []byte, v reflect.Value) []byte {
return dialect.AppendFloat32(b, float32(v.Float()))
}
func AppendFloat64Value(fmter Formatter, b []byte, v reflect.Value) []byte {
return dialect.AppendFloat64(b, float64(v.Float()))
}
func appendBytesValue(fmter Formatter, b []byte, v reflect.Value) []byte {
return dialect.AppendBytes(b, v.Bytes())
}
func appendArrayBytesValue(fmter Formatter, b []byte, v reflect.Value) []byte {
if v.CanAddr() {
return dialect.AppendBytes(b, v.Slice(0, v.Len()).Bytes())
}
tmp := make([]byte, v.Len())
reflect.Copy(reflect.ValueOf(tmp), v)
b = dialect.AppendBytes(b, tmp)
return b
}
func AppendStringValue(fmter Formatter, b []byte, v reflect.Value) []byte {
return dialect.AppendString(b, v.String())
}
func AppendJSONValue(fmter Formatter, b []byte, v reflect.Value) []byte {
bb, err := bunjson.Marshal(v.Interface())
if err != nil {
return dialect.AppendError(b, err)
}
if len(bb) > 0 && bb[len(bb)-1] == '\n' {
bb = bb[:len(bb)-1]
}
return dialect.AppendJSON(b, bb)
}
func appendTimeValue(fmter Formatter, b []byte, v reflect.Value) []byte {
tm := v.Interface().(time.Time)
return dialect.AppendTime(b, tm)
}
func appendIPValue(fmter Formatter, b []byte, v reflect.Value) []byte {
ip := v.Interface().(net.IP)
return dialect.AppendString(b, ip.String())
}
func appendIPNetValue(fmter Formatter, b []byte, v reflect.Value) []byte {
ipnet := v.Interface().(net.IPNet)
return dialect.AppendString(b, ipnet.String())
}
func appendJSONRawMessageValue(fmter Formatter, b []byte, v reflect.Value) []byte {
bytes := v.Bytes()
if bytes == nil {
return dialect.AppendNull(b)
}
return dialect.AppendString(b, internal.String(bytes))
}
func appendQueryAppenderValue(fmter Formatter, b []byte, v reflect.Value) []byte {
return AppendQueryAppender(fmter, b, v.Interface().(QueryAppender))
}
func driverValueAppender(custom CustomAppender) AppenderFunc {
return func(fmter Formatter, b []byte, v reflect.Value) []byte {
return appendDriverValue(fmter, b, v.Interface().(driver.Valuer), custom)
}
}
func appendDriverValue(fmter Formatter, b []byte, v driver.Valuer, custom CustomAppender) []byte {
value, err := v.Value()
if err != nil {
return dialect.AppendError(b, err)
}
return Append(fmter, b, value, custom)
}
func addrAppender(fn AppenderFunc, custom CustomAppender) AppenderFunc {
return func(fmter Formatter, b []byte, v reflect.Value) []byte {
if !v.CanAddr() {
err := fmt.Errorf("bun: Append(nonaddressable %T)", v.Interface())
return dialect.AppendError(b, err)
}
return fn(fmter, b, v.Addr())
}
}

99
vendor/github.com/uptrace/bun/schema/dialect.go generated vendored Normal file
View file

@ -0,0 +1,99 @@
package schema
import (
"database/sql"
"reflect"
"sync"
"github.com/uptrace/bun/dialect"
"github.com/uptrace/bun/dialect/feature"
)
type Dialect interface {
Init(db *sql.DB)
Name() dialect.Name
Features() feature.Feature
Tables() *Tables
OnTable(table *Table)
IdentQuote() byte
Append(fmter Formatter, b []byte, v interface{}) []byte
Appender(typ reflect.Type) AppenderFunc
FieldAppender(field *Field) AppenderFunc
Scanner(typ reflect.Type) ScannerFunc
}
//------------------------------------------------------------------------------
type nopDialect struct {
tables *Tables
features feature.Feature
appenderMap sync.Map
scannerMap sync.Map
}
func newNopDialect() *nopDialect {
d := new(nopDialect)
d.tables = NewTables(d)
d.features = feature.Returning
return d
}
func (d *nopDialect) Init(*sql.DB) {}
func (d *nopDialect) Name() dialect.Name {
return dialect.Invalid
}
func (d *nopDialect) Features() feature.Feature {
return d.features
}
func (d *nopDialect) Tables() *Tables {
return d.tables
}
func (d *nopDialect) OnField(field *Field) {}
func (d *nopDialect) OnTable(table *Table) {}
func (d *nopDialect) IdentQuote() byte {
return '"'
}
func (d *nopDialect) Append(fmter Formatter, b []byte, v interface{}) []byte {
return Append(fmter, b, v, nil)
}
func (d *nopDialect) Appender(typ reflect.Type) AppenderFunc {
if v, ok := d.appenderMap.Load(typ); ok {
return v.(AppenderFunc)
}
fn := Appender(typ, nil)
if v, ok := d.appenderMap.LoadOrStore(typ, fn); ok {
return v.(AppenderFunc)
}
return fn
}
func (d *nopDialect) FieldAppender(field *Field) AppenderFunc {
return FieldAppender(d, field)
}
func (d *nopDialect) Scanner(typ reflect.Type) ScannerFunc {
if v, ok := d.scannerMap.Load(typ); ok {
return v.(ScannerFunc)
}
fn := Scanner(typ)
if v, ok := d.scannerMap.LoadOrStore(typ, fn); ok {
return v.(ScannerFunc)
}
return fn
}

117
vendor/github.com/uptrace/bun/schema/field.go generated vendored Normal file
View file

@ -0,0 +1,117 @@
package schema
import (
"fmt"
"reflect"
"github.com/uptrace/bun/dialect"
"github.com/uptrace/bun/internal/tagparser"
)
type Field struct {
StructField reflect.StructField
Tag tagparser.Tag
IndirectType reflect.Type
Index []int
Name string // SQL name, .e.g. id
SQLName Safe // escaped SQL name, e.g. "id"
GoName string // struct field name, e.g. Id
DiscoveredSQLType string
UserSQLType string
CreateTableSQLType string
SQLDefault string
OnDelete string
OnUpdate string
IsPK bool
NotNull bool
NullZero bool
AutoIncrement bool
Append AppenderFunc
Scan ScannerFunc
IsZero IsZeroerFunc
}
func (f *Field) String() string {
return f.Name
}
func (f *Field) Clone() *Field {
cp := *f
cp.Index = cp.Index[:len(f.Index):len(f.Index)]
return &cp
}
func (f *Field) Value(strct reflect.Value) reflect.Value {
return fieldByIndexAlloc(strct, f.Index)
}
func (f *Field) HasZeroValue(v reflect.Value) bool {
for _, idx := range f.Index {
if v.Kind() == reflect.Ptr {
if v.IsNil() {
return true
}
v = v.Elem()
}
v = v.Field(idx)
}
return f.IsZero(v)
}
func (f *Field) AppendValue(fmter Formatter, b []byte, strct reflect.Value) []byte {
fv, ok := fieldByIndex(strct, f.Index)
if !ok {
return dialect.AppendNull(b)
}
if f.NullZero && f.IsZero(fv) {
return dialect.AppendNull(b)
}
if f.Append == nil {
panic(fmt.Errorf("bun: AppendValue(unsupported %s)", fv.Type()))
}
return f.Append(fmter, b, fv)
}
func (f *Field) ScanWithCheck(fv reflect.Value, src interface{}) error {
if f.Scan == nil {
return fmt.Errorf("bun: Scan(unsupported %s)", f.IndirectType)
}
return f.Scan(fv, src)
}
func (f *Field) ScanValue(strct reflect.Value, src interface{}) error {
if src == nil {
if fv, ok := fieldByIndex(strct, f.Index); ok {
return f.ScanWithCheck(fv, src)
}
return nil
}
fv := fieldByIndexAlloc(strct, f.Index)
return f.ScanWithCheck(fv, src)
}
func (f *Field) markAsPK() {
f.IsPK = true
f.NotNull = true
f.NullZero = true
}
func indexEqual(ind1, ind2 []int) bool {
if len(ind1) != len(ind2) {
return false
}
for i, ind := range ind1 {
if ind != ind2[i] {
return false
}
}
return true
}

248
vendor/github.com/uptrace/bun/schema/formatter.go generated vendored Normal file
View file

@ -0,0 +1,248 @@
package schema
import (
"reflect"
"strconv"
"strings"
"github.com/uptrace/bun/dialect"
"github.com/uptrace/bun/dialect/feature"
"github.com/uptrace/bun/internal"
"github.com/uptrace/bun/internal/parser"
)
var nopFormatter = Formatter{
dialect: newNopDialect(),
}
type Formatter struct {
dialect Dialect
args *namedArgList
}
func NewFormatter(dialect Dialect) Formatter {
return Formatter{
dialect: dialect,
}
}
func NewNopFormatter() Formatter {
return nopFormatter
}
func (f Formatter) IsNop() bool {
return f.dialect.Name() == dialect.Invalid
}
func (f Formatter) Dialect() Dialect {
return f.dialect
}
func (f Formatter) IdentQuote() byte {
return f.dialect.IdentQuote()
}
func (f Formatter) AppendIdent(b []byte, ident string) []byte {
return dialect.AppendIdent(b, ident, f.IdentQuote())
}
func (f Formatter) AppendValue(b []byte, v reflect.Value) []byte {
if v.Kind() == reflect.Ptr && v.IsNil() {
return dialect.AppendNull(b)
}
appender := f.dialect.Appender(v.Type())
return appender(f, b, v)
}
func (f Formatter) HasFeature(feature feature.Feature) bool {
return f.dialect.Features().Has(feature)
}
func (f Formatter) WithArg(arg NamedArgAppender) Formatter {
return Formatter{
dialect: f.dialect,
args: f.args.WithArg(arg),
}
}
func (f Formatter) WithNamedArg(name string, value interface{}) Formatter {
return Formatter{
dialect: f.dialect,
args: f.args.WithArg(&namedArg{name: name, value: value}),
}
}
func (f Formatter) FormatQuery(query string, args ...interface{}) string {
if f.IsNop() || (args == nil && f.args == nil) || strings.IndexByte(query, '?') == -1 {
return query
}
return internal.String(f.AppendQuery(nil, query, args...))
}
func (f Formatter) AppendQuery(dst []byte, query string, args ...interface{}) []byte {
if f.IsNop() || (args == nil && f.args == nil) || strings.IndexByte(query, '?') == -1 {
return append(dst, query...)
}
return f.append(dst, parser.NewString(query), args)
}
func (f Formatter) append(dst []byte, p *parser.Parser, args []interface{}) []byte {
var namedArgs NamedArgAppender
if len(args) == 1 {
var ok bool
namedArgs, ok = args[0].(NamedArgAppender)
if !ok {
namedArgs, _ = newStructArgs(f, args[0])
}
}
var argIndex int
for p.Valid() {
b, ok := p.ReadSep('?')
if !ok {
dst = append(dst, b...)
continue
}
if len(b) > 0 && b[len(b)-1] == '\\' {
dst = append(dst, b[:len(b)-1]...)
dst = append(dst, '?')
continue
}
dst = append(dst, b...)
name, numeric := p.ReadIdentifier()
if name != "" {
if numeric {
idx, err := strconv.Atoi(name)
if err != nil {
goto restore_arg
}
if idx >= len(args) {
goto restore_arg
}
dst = f.appendArg(dst, args[idx])
continue
}
if namedArgs != nil {
dst, ok = namedArgs.AppendNamedArg(f, dst, name)
if ok {
continue
}
}
dst, ok = f.args.AppendNamedArg(f, dst, name)
if ok {
continue
}
restore_arg:
dst = append(dst, '?')
dst = append(dst, name...)
continue
}
if argIndex >= len(args) {
dst = append(dst, '?')
continue
}
arg := args[argIndex]
argIndex++
dst = f.appendArg(dst, arg)
}
return dst
}
func (f Formatter) appendArg(b []byte, arg interface{}) []byte {
switch arg := arg.(type) {
case QueryAppender:
bb, err := arg.AppendQuery(f, b)
if err != nil {
return dialect.AppendError(b, err)
}
return bb
default:
return f.dialect.Append(f, b, arg)
}
}
//------------------------------------------------------------------------------
type NamedArgAppender interface {
AppendNamedArg(fmter Formatter, b []byte, name string) ([]byte, bool)
}
//------------------------------------------------------------------------------
type namedArgList struct {
arg NamedArgAppender
next *namedArgList
}
func (l *namedArgList) WithArg(arg NamedArgAppender) *namedArgList {
return &namedArgList{
arg: arg,
next: l,
}
}
func (l *namedArgList) AppendNamedArg(fmter Formatter, b []byte, name string) ([]byte, bool) {
for l != nil && l.arg != nil {
if b, ok := l.arg.AppendNamedArg(fmter, b, name); ok {
return b, true
}
l = l.next
}
return b, false
}
//------------------------------------------------------------------------------
type namedArg struct {
name string
value interface{}
}
var _ NamedArgAppender = (*namedArg)(nil)
func (a *namedArg) AppendNamedArg(fmter Formatter, b []byte, name string) ([]byte, bool) {
if a.name == name {
return fmter.appendArg(b, a.value), true
}
return b, false
}
//------------------------------------------------------------------------------
var _ NamedArgAppender = (*structArgs)(nil)
type structArgs struct {
table *Table
strct reflect.Value
}
func newStructArgs(fmter Formatter, strct interface{}) (*structArgs, bool) {
v := reflect.ValueOf(strct)
if !v.IsValid() {
return nil, false
}
v = reflect.Indirect(v)
if v.Kind() != reflect.Struct {
return nil, false
}
return &structArgs{
table: fmter.Dialect().Tables().Get(v.Type()),
strct: v,
}, true
}
func (m *structArgs) AppendNamedArg(fmter Formatter, b []byte, name string) ([]byte, bool) {
return m.table.AppendNamedArg(fmter, b, name, m.strct)
}

20
vendor/github.com/uptrace/bun/schema/hook.go generated vendored Normal file
View file

@ -0,0 +1,20 @@
package schema
import (
"context"
"reflect"
)
type BeforeScanHook interface {
BeforeScan(context.Context) error
}
var beforeScanHookType = reflect.TypeOf((*BeforeScanHook)(nil)).Elem()
//------------------------------------------------------------------------------
type AfterScanHook interface {
AfterScan(context.Context) error
}
var afterScanHookType = reflect.TypeOf((*AfterScanHook)(nil)).Elem()

32
vendor/github.com/uptrace/bun/schema/relation.go generated vendored Normal file
View file

@ -0,0 +1,32 @@
package schema
import (
"fmt"
)
const (
InvalidRelation = iota
HasOneRelation
BelongsToRelation
HasManyRelation
ManyToManyRelation
)
type Relation struct {
Type int
Field *Field
JoinTable *Table
BaseFields []*Field
JoinFields []*Field
PolymorphicField *Field
PolymorphicValue string
M2MTable *Table
M2MBaseFields []*Field
M2MJoinFields []*Field
}
func (r *Relation) String() string {
return fmt.Sprintf("relation=%s", r.Field.GoName)
}

392
vendor/github.com/uptrace/bun/schema/scan.go generated vendored Normal file
View file

@ -0,0 +1,392 @@
package schema
import (
"bytes"
"database/sql"
"fmt"
"net"
"reflect"
"strconv"
"time"
"github.com/vmihailenco/msgpack/v5"
"github.com/uptrace/bun/extra/bunjson"
"github.com/uptrace/bun/internal"
)
var scannerType = reflect.TypeOf((*sql.Scanner)(nil)).Elem()
type ScannerFunc func(dest reflect.Value, src interface{}) error
var scanners = []ScannerFunc{
reflect.Bool: scanBool,
reflect.Int: scanInt64,
reflect.Int8: scanInt64,
reflect.Int16: scanInt64,
reflect.Int32: scanInt64,
reflect.Int64: scanInt64,
reflect.Uint: scanUint64,
reflect.Uint8: scanUint64,
reflect.Uint16: scanUint64,
reflect.Uint32: scanUint64,
reflect.Uint64: scanUint64,
reflect.Uintptr: scanUint64,
reflect.Float32: scanFloat64,
reflect.Float64: scanFloat64,
reflect.Complex64: nil,
reflect.Complex128: nil,
reflect.Array: nil,
reflect.Chan: nil,
reflect.Func: nil,
reflect.Map: scanJSON,
reflect.Ptr: nil,
reflect.Slice: scanJSON,
reflect.String: scanString,
reflect.Struct: scanJSON,
reflect.UnsafePointer: nil,
}
func FieldScanner(dialect Dialect, field *Field) ScannerFunc {
if field.Tag.HasOption("msgpack") {
return scanMsgpack
}
if field.Tag.HasOption("json_use_number") {
return scanJSONUseNumber
}
return dialect.Scanner(field.StructField.Type)
}
func Scanner(typ reflect.Type) ScannerFunc {
kind := typ.Kind()
if kind == reflect.Ptr {
if fn := Scanner(typ.Elem()); fn != nil {
return ptrScanner(fn)
}
}
if typ.Implements(scannerType) {
return scanScanner
}
if kind != reflect.Ptr {
ptr := reflect.PtrTo(typ)
if ptr.Implements(scannerType) {
return addrScanner(scanScanner)
}
}
switch typ {
case timeType:
return scanTime
case ipType:
return scanIP
case ipNetType:
return scanIPNet
case jsonRawMessageType:
return scanJSONRawMessage
}
return scanners[kind]
}
func scanBool(dest reflect.Value, src interface{}) error {
switch src := src.(type) {
case nil:
dest.SetBool(false)
return nil
case bool:
dest.SetBool(src)
return nil
case int64:
dest.SetBool(src != 0)
return nil
case []byte:
if len(src) == 1 {
dest.SetBool(src[0] != '0')
return nil
}
}
return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type())
}
func scanInt64(dest reflect.Value, src interface{}) error {
switch src := src.(type) {
case nil:
dest.SetInt(0)
return nil
case int64:
dest.SetInt(src)
return nil
case uint64:
dest.SetInt(int64(src))
return nil
case []byte:
n, err := strconv.ParseInt(internal.String(src), 10, 64)
if err != nil {
return err
}
dest.SetInt(n)
return nil
case string:
n, err := strconv.ParseInt(src, 10, 64)
if err != nil {
return err
}
dest.SetInt(n)
return nil
}
return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type())
}
func scanUint64(dest reflect.Value, src interface{}) error {
switch src := src.(type) {
case nil:
dest.SetUint(0)
return nil
case uint64:
dest.SetUint(src)
return nil
case int64:
dest.SetUint(uint64(src))
return nil
case []byte:
n, err := strconv.ParseUint(internal.String(src), 10, 64)
if err != nil {
return err
}
dest.SetUint(n)
return nil
}
return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type())
}
func scanFloat64(dest reflect.Value, src interface{}) error {
switch src := src.(type) {
case nil:
dest.SetFloat(0)
return nil
case float64:
dest.SetFloat(src)
return nil
case []byte:
f, err := strconv.ParseFloat(internal.String(src), 64)
if err != nil {
return err
}
dest.SetFloat(f)
return nil
}
return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type())
}
func scanString(dest reflect.Value, src interface{}) error {
switch src := src.(type) {
case nil:
dest.SetString("")
return nil
case string:
dest.SetString(src)
return nil
case []byte:
dest.SetString(string(src))
return nil
}
return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type())
}
func scanTime(dest reflect.Value, src interface{}) error {
switch src := src.(type) {
case nil:
destTime := dest.Addr().Interface().(*time.Time)
*destTime = time.Time{}
return nil
case time.Time:
destTime := dest.Addr().Interface().(*time.Time)
*destTime = src
return nil
case string:
srcTime, err := internal.ParseTime(src)
if err != nil {
return err
}
destTime := dest.Addr().Interface().(*time.Time)
*destTime = srcTime
return nil
case []byte:
srcTime, err := internal.ParseTime(internal.String(src))
if err != nil {
return err
}
destTime := dest.Addr().Interface().(*time.Time)
*destTime = srcTime
return nil
}
return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type())
}
func scanScanner(dest reflect.Value, src interface{}) error {
return dest.Interface().(sql.Scanner).Scan(src)
}
func scanMsgpack(dest reflect.Value, src interface{}) error {
if src == nil {
return scanNull(dest)
}
b, err := toBytes(src)
if err != nil {
return err
}
dec := msgpack.GetDecoder()
defer msgpack.PutDecoder(dec)
dec.Reset(bytes.NewReader(b))
return dec.DecodeValue(dest)
}
func scanJSON(dest reflect.Value, src interface{}) error {
if src == nil {
return scanNull(dest)
}
b, err := toBytes(src)
if err != nil {
return err
}
return bunjson.Unmarshal(b, dest.Addr().Interface())
}
func scanJSONUseNumber(dest reflect.Value, src interface{}) error {
if src == nil {
return scanNull(dest)
}
b, err := toBytes(src)
if err != nil {
return err
}
dec := bunjson.NewDecoder(bytes.NewReader(b))
dec.UseNumber()
return dec.Decode(dest.Addr().Interface())
}
func scanIP(dest reflect.Value, src interface{}) error {
if src == nil {
return scanNull(dest)
}
b, err := toBytes(src)
if err != nil {
return err
}
ip := net.ParseIP(internal.String(b))
if ip == nil {
return fmt.Errorf("bun: invalid ip: %q", b)
}
ptr := dest.Addr().Interface().(*net.IP)
*ptr = ip
return nil
}
func scanIPNet(dest reflect.Value, src interface{}) error {
if src == nil {
return scanNull(dest)
}
b, err := toBytes(src)
if err != nil {
return err
}
_, ipnet, err := net.ParseCIDR(internal.String(b))
if err != nil {
return err
}
ptr := dest.Addr().Interface().(*net.IPNet)
*ptr = *ipnet
return nil
}
func scanJSONRawMessage(dest reflect.Value, src interface{}) error {
if src == nil {
dest.SetBytes(nil)
return nil
}
b, err := toBytes(src)
if err != nil {
return err
}
dest.SetBytes(b)
return nil
}
func addrScanner(fn ScannerFunc) ScannerFunc {
return func(dest reflect.Value, src interface{}) error {
if !dest.CanAddr() {
return fmt.Errorf("bun: Scan(nonaddressable %T)", dest.Interface())
}
return fn(dest.Addr(), src)
}
}
func toBytes(src interface{}) ([]byte, error) {
switch src := src.(type) {
case string:
return internal.Bytes(src), nil
case []byte:
return src, nil
default:
return nil, fmt.Errorf("bun: got %T, wanted []byte or string", src)
}
}
func ptrScanner(fn ScannerFunc) ScannerFunc {
return func(dest reflect.Value, src interface{}) error {
if src == nil {
if !dest.CanAddr() {
if dest.IsNil() {
return nil
}
return fn(dest.Elem(), src)
}
if !dest.IsNil() {
dest.Set(reflect.New(dest.Type().Elem()))
}
return nil
}
if dest.IsNil() {
dest.Set(reflect.New(dest.Type().Elem()))
}
return fn(dest.Elem(), src)
}
}
func scanNull(dest reflect.Value) error {
if nilable(dest.Kind()) && dest.IsNil() {
return nil
}
dest.Set(reflect.New(dest.Type()).Elem())
return nil
}
func nilable(kind reflect.Kind) bool {
switch kind {
case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice:
return true
}
return false
}

76
vendor/github.com/uptrace/bun/schema/sqlfmt.go generated vendored Normal file
View file

@ -0,0 +1,76 @@
package schema
type QueryAppender interface {
AppendQuery(fmter Formatter, b []byte) ([]byte, error)
}
type ColumnsAppender interface {
AppendColumns(fmter Formatter, b []byte) ([]byte, error)
}
//------------------------------------------------------------------------------
// Safe represents a safe SQL query.
type Safe string
var _ QueryAppender = (*Safe)(nil)
func (s Safe) AppendQuery(fmter Formatter, b []byte) ([]byte, error) {
return append(b, s...), nil
}
//------------------------------------------------------------------------------
// Ident represents a SQL identifier, for example, table or column name.
type Ident string
var _ QueryAppender = (*Ident)(nil)
func (s Ident) AppendQuery(fmter Formatter, b []byte) ([]byte, error) {
return fmter.AppendIdent(b, string(s)), nil
}
//------------------------------------------------------------------------------
type QueryWithArgs struct {
Query string
Args []interface{}
}
var _ QueryAppender = QueryWithArgs{}
func SafeQuery(query string, args []interface{}) QueryWithArgs {
if query != "" && args == nil {
args = make([]interface{}, 0)
}
return QueryWithArgs{Query: query, Args: args}
}
func UnsafeIdent(ident string) QueryWithArgs {
return QueryWithArgs{Query: ident}
}
func (q QueryWithArgs) IsZero() bool {
return q.Query == "" && q.Args == nil
}
func (q QueryWithArgs) AppendQuery(fmter Formatter, b []byte) ([]byte, error) {
if q.Args == nil {
return fmter.AppendIdent(b, q.Query), nil
}
return fmter.AppendQuery(b, q.Query, q.Args...), nil
}
//------------------------------------------------------------------------------
type QueryWithSep struct {
QueryWithArgs
Sep string
}
func SafeQueryWithSep(query string, args []interface{}, sep string) QueryWithSep {
return QueryWithSep{
QueryWithArgs: SafeQuery(query, args),
Sep: sep,
}
}

129
vendor/github.com/uptrace/bun/schema/sqltype.go generated vendored Normal file
View file

@ -0,0 +1,129 @@
package schema
import (
"bytes"
"database/sql"
"encoding/json"
"fmt"
"reflect"
"time"
"github.com/uptrace/bun/dialect"
"github.com/uptrace/bun/dialect/sqltype"
"github.com/uptrace/bun/internal"
)
var (
bunNullTimeType = reflect.TypeOf((*NullTime)(nil)).Elem()
nullTimeType = reflect.TypeOf((*sql.NullTime)(nil)).Elem()
nullBoolType = reflect.TypeOf((*sql.NullBool)(nil)).Elem()
nullFloatType = reflect.TypeOf((*sql.NullFloat64)(nil)).Elem()
nullIntType = reflect.TypeOf((*sql.NullInt64)(nil)).Elem()
nullStringType = reflect.TypeOf((*sql.NullString)(nil)).Elem()
)
var sqlTypes = []string{
reflect.Bool: sqltype.Boolean,
reflect.Int: sqltype.BigInt,
reflect.Int8: sqltype.SmallInt,
reflect.Int16: sqltype.SmallInt,
reflect.Int32: sqltype.Integer,
reflect.Int64: sqltype.BigInt,
reflect.Uint: sqltype.BigInt,
reflect.Uint8: sqltype.SmallInt,
reflect.Uint16: sqltype.SmallInt,
reflect.Uint32: sqltype.Integer,
reflect.Uint64: sqltype.BigInt,
reflect.Uintptr: sqltype.BigInt,
reflect.Float32: sqltype.Real,
reflect.Float64: sqltype.DoublePrecision,
reflect.Complex64: "",
reflect.Complex128: "",
reflect.Array: "",
reflect.Chan: "",
reflect.Func: "",
reflect.Interface: "",
reflect.Map: sqltype.VarChar,
reflect.Ptr: "",
reflect.Slice: sqltype.VarChar,
reflect.String: sqltype.VarChar,
reflect.Struct: sqltype.VarChar,
reflect.UnsafePointer: "",
}
func DiscoverSQLType(typ reflect.Type) string {
switch typ {
case timeType, nullTimeType, bunNullTimeType:
return sqltype.Timestamp
case nullBoolType:
return sqltype.Boolean
case nullFloatType:
return sqltype.DoublePrecision
case nullIntType:
return sqltype.BigInt
case nullStringType:
return sqltype.VarChar
}
return sqlTypes[typ.Kind()]
}
//------------------------------------------------------------------------------
var jsonNull = []byte("null")
// NullTime is a time.Time wrapper that marshals zero time as JSON null and SQL NULL.
type NullTime struct {
time.Time
}
var (
_ json.Marshaler = (*NullTime)(nil)
_ json.Unmarshaler = (*NullTime)(nil)
_ sql.Scanner = (*NullTime)(nil)
_ QueryAppender = (*NullTime)(nil)
)
func (tm NullTime) MarshalJSON() ([]byte, error) {
if tm.IsZero() {
return jsonNull, nil
}
return tm.Time.MarshalJSON()
}
func (tm *NullTime) UnmarshalJSON(b []byte) error {
if bytes.Equal(b, jsonNull) {
tm.Time = time.Time{}
return nil
}
return tm.Time.UnmarshalJSON(b)
}
func (tm NullTime) AppendQuery(fmter Formatter, b []byte) ([]byte, error) {
if tm.IsZero() {
return dialect.AppendNull(b), nil
}
return dialect.AppendTime(b, tm.Time), nil
}
func (tm *NullTime) Scan(src interface{}) error {
if src == nil {
tm.Time = time.Time{}
return nil
}
switch src := src.(type) {
case []byte:
newtm, err := internal.ParseTime(internal.String(src))
if err != nil {
return err
}
tm.Time = newtm
return nil
case time.Time:
tm.Time = src
return nil
default:
return fmt.Errorf("bun: can't scan %#v into NullTime", src)
}
}

948
vendor/github.com/uptrace/bun/schema/table.go generated vendored Normal file
View file

@ -0,0 +1,948 @@
package schema
import (
"database/sql"
"fmt"
"reflect"
"strings"
"sync"
"time"
"github.com/jinzhu/inflection"
"github.com/uptrace/bun/internal"
"github.com/uptrace/bun/internal/tagparser"
)
const (
beforeScanHookFlag internal.Flag = 1 << iota
afterScanHookFlag
)
var (
baseModelType = reflect.TypeOf((*BaseModel)(nil)).Elem()
tableNameInflector = inflection.Plural
)
type BaseModel struct{}
// SetTableNameInflector overrides the default func that pluralizes
// model name to get table name, e.g. my_article becomes my_articles.
func SetTableNameInflector(fn func(string) string) {
tableNameInflector = fn
}
// Table represents a SQL table created from Go struct.
type Table struct {
dialect Dialect
Type reflect.Type
ZeroValue reflect.Value // reflect.Struct
ZeroIface interface{} // struct pointer
TypeName string
ModelName string
Name string
SQLName Safe
SQLNameForSelects Safe
Alias string
SQLAlias Safe
Fields []*Field // PKs + DataFields
PKs []*Field
DataFields []*Field
fieldsMapMu sync.RWMutex
FieldMap map[string]*Field
Relations map[string]*Relation
Unique map[string][]*Field
SoftDeleteField *Field
UpdateSoftDeleteField func(fv reflect.Value) error
allFields []*Field // read only
skippedFields []*Field
flags internal.Flag
}
func newTable(dialect Dialect, typ reflect.Type) *Table {
t := new(Table)
t.dialect = dialect
t.Type = typ
t.ZeroValue = reflect.New(t.Type).Elem()
t.ZeroIface = reflect.New(t.Type).Interface()
t.TypeName = internal.ToExported(t.Type.Name())
t.ModelName = internal.Underscore(t.Type.Name())
tableName := tableNameInflector(t.ModelName)
t.setName(tableName)
t.Alias = t.ModelName
t.SQLAlias = t.quoteIdent(t.ModelName)
hooks := []struct {
typ reflect.Type
flag internal.Flag
}{
{beforeScanHookType, beforeScanHookFlag},
{afterScanHookType, afterScanHookFlag},
}
typ = reflect.PtrTo(t.Type)
for _, hook := range hooks {
if typ.Implements(hook.typ) {
t.flags = t.flags.Set(hook.flag)
}
}
return t
}
func (t *Table) init1() {
t.initFields()
}
func (t *Table) init2() {
t.initInlines()
t.initRelations()
t.skippedFields = nil
}
func (t *Table) setName(name string) {
t.Name = name
t.SQLName = t.quoteIdent(name)
t.SQLNameForSelects = t.quoteIdent(name)
if t.SQLAlias == "" {
t.Alias = name
t.SQLAlias = t.quoteIdent(name)
}
}
func (t *Table) String() string {
return "model=" + t.TypeName
}
func (t *Table) CheckPKs() error {
if len(t.PKs) == 0 {
return fmt.Errorf("bun: %s does not have primary keys", t)
}
return nil
}
func (t *Table) addField(field *Field) {
t.Fields = append(t.Fields, field)
if field.IsPK {
t.PKs = append(t.PKs, field)
} else {
t.DataFields = append(t.DataFields, field)
}
t.FieldMap[field.Name] = field
}
func (t *Table) removeField(field *Field) {
t.Fields = removeField(t.Fields, field)
if field.IsPK {
t.PKs = removeField(t.PKs, field)
} else {
t.DataFields = removeField(t.DataFields, field)
}
delete(t.FieldMap, field.Name)
}
func (t *Table) fieldWithLock(name string) *Field {
t.fieldsMapMu.RLock()
field := t.FieldMap[name]
t.fieldsMapMu.RUnlock()
return field
}
func (t *Table) HasField(name string) bool {
_, ok := t.FieldMap[name]
return ok
}
func (t *Table) Field(name string) (*Field, error) {
field, ok := t.FieldMap[name]
if !ok {
return nil, fmt.Errorf("bun: %s does not have column=%s", t, name)
}
return field, nil
}
func (t *Table) fieldByGoName(name string) *Field {
for _, f := range t.allFields {
if f.GoName == name {
return f
}
}
return nil
}
func (t *Table) initFields() {
t.Fields = make([]*Field, 0, t.Type.NumField())
t.FieldMap = make(map[string]*Field, t.Type.NumField())
t.addFields(t.Type, nil)
if len(t.PKs) > 0 {
return
}
for _, name := range []string{"id", "uuid", "pk_" + t.ModelName} {
if field, ok := t.FieldMap[name]; ok {
field.markAsPK()
t.PKs = []*Field{field}
t.DataFields = removeField(t.DataFields, field)
break
}
}
if len(t.PKs) == 1 {
switch t.PKs[0].IndirectType.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
t.PKs[0].AutoIncrement = true
}
}
}
func (t *Table) addFields(typ reflect.Type, baseIndex []int) {
for i := 0; i < typ.NumField(); i++ {
f := typ.Field(i)
// Make a copy so slice is not shared between fields.
index := make([]int, len(baseIndex))
copy(index, baseIndex)
if f.Anonymous {
if f.Tag.Get("bun") == "-" {
continue
}
if f.Name == "BaseModel" && f.Type == baseModelType {
if len(index) == 0 {
t.processBaseModelField(f)
}
continue
}
fieldType := indirectType(f.Type)
if fieldType.Kind() != reflect.Struct {
continue
}
t.addFields(fieldType, append(index, f.Index...))
tag := tagparser.Parse(f.Tag.Get("bun"))
if _, inherit := tag.Options["inherit"]; inherit {
embeddedTable := t.dialect.Tables().Ref(fieldType)
t.TypeName = embeddedTable.TypeName
t.SQLName = embeddedTable.SQLName
t.SQLNameForSelects = embeddedTable.SQLNameForSelects
t.Alias = embeddedTable.Alias
t.SQLAlias = embeddedTable.SQLAlias
t.ModelName = embeddedTable.ModelName
}
continue
}
field := t.newField(f, index)
if field != nil {
t.addField(field)
}
}
}
func (t *Table) processBaseModelField(f reflect.StructField) {
tag := tagparser.Parse(f.Tag.Get("bun"))
if isKnownTableOption(tag.Name) {
internal.Warn.Printf(
"%s.%s tag name %q is also an option name; is it a mistake?",
t.TypeName, f.Name, tag.Name,
)
}
for name := range tag.Options {
if !isKnownTableOption(name) {
internal.Warn.Printf("%s.%s has unknown tag option: %q", t.TypeName, f.Name, name)
}
}
if tag.Name != "" {
t.setName(tag.Name)
}
if s, ok := tag.Options["select"]; ok {
t.SQLNameForSelects = t.quoteTableName(s)
}
if s, ok := tag.Options["alias"]; ok {
t.Alias = s
t.SQLAlias = t.quoteIdent(s)
}
}
//nolint
func (t *Table) newField(f reflect.StructField, index []int) *Field {
tag := tagparser.Parse(f.Tag.Get("bun"))
if f.PkgPath != "" {
return nil
}
sqlName := internal.Underscore(f.Name)
if tag.Name != sqlName && isKnownFieldOption(tag.Name) {
internal.Warn.Printf(
"%s.%s tag name %q is also an option name; is it a mistake?",
t.TypeName, f.Name, tag.Name,
)
}
for name := range tag.Options {
if !isKnownFieldOption(name) {
internal.Warn.Printf("%s.%s has unknown tag option: %q", t.TypeName, f.Name, name)
}
}
skip := tag.Name == "-"
if !skip && tag.Name != "" {
sqlName = tag.Name
}
index = append(index, f.Index...)
if field := t.fieldWithLock(sqlName); field != nil {
if indexEqual(field.Index, index) {
return field
}
t.removeField(field)
}
field := &Field{
StructField: f,
Tag: tag,
IndirectType: indirectType(f.Type),
Index: index,
Name: sqlName,
GoName: f.Name,
SQLName: t.quoteIdent(sqlName),
}
field.NotNull = tag.HasOption("notnull")
field.NullZero = tag.HasOption("nullzero")
field.AutoIncrement = tag.HasOption("autoincrement")
if tag.HasOption("pk") {
field.markAsPK()
}
if tag.HasOption("allowzero") {
if tag.HasOption("nullzero") {
internal.Warn.Printf(
"%s.%s: nullzero and allowzero options are mutually exclusive",
t.TypeName, f.Name,
)
}
field.NullZero = false
}
if v, ok := tag.Options["unique"]; ok {
// Split the value by comma, this will allow multiple names to be specified.
// We can use this to create multiple named unique constraints where a single column
// might be included in multiple constraints.
for _, uniqueName := range strings.Split(v, ",") {
if t.Unique == nil {
t.Unique = make(map[string][]*Field)
}
t.Unique[uniqueName] = append(t.Unique[uniqueName], field)
}
}
if s, ok := tag.Options["default"]; ok {
field.SQLDefault = s
}
if s, ok := field.Tag.Options["type"]; ok {
field.UserSQLType = s
}
field.DiscoveredSQLType = DiscoverSQLType(field.IndirectType)
field.Append = t.dialect.FieldAppender(field)
field.Scan = FieldScanner(t.dialect, field)
field.IsZero = FieldZeroChecker(field)
if v, ok := tag.Options["alt"]; ok {
t.FieldMap[v] = field
}
t.allFields = append(t.allFields, field)
if skip {
t.skippedFields = append(t.skippedFields, field)
t.FieldMap[field.Name] = field
return nil
}
if _, ok := tag.Options["soft_delete"]; ok {
field.NullZero = true
t.SoftDeleteField = field
t.UpdateSoftDeleteField = softDeleteFieldUpdater(field)
}
return field
}
func (t *Table) initInlines() {
for _, f := range t.skippedFields {
if f.IndirectType.Kind() == reflect.Struct {
t.inlineFields(f, nil)
}
}
}
//---------------------------------------------------------------------------------------
func (t *Table) initRelations() {
for i := 0; i < len(t.Fields); {
f := t.Fields[i]
if t.tryRelation(f) {
t.Fields = removeField(t.Fields, f)
t.DataFields = removeField(t.DataFields, f)
} else {
i++
}
if f.IndirectType.Kind() == reflect.Struct {
t.inlineFields(f, nil)
}
}
}
func (t *Table) tryRelation(field *Field) bool {
if rel, ok := field.Tag.Options["rel"]; ok {
t.initRelation(field, rel)
return true
}
if field.Tag.HasOption("m2m") {
t.addRelation(t.m2mRelation(field))
return true
}
if field.Tag.HasOption("join") {
internal.Warn.Printf(
`%s.%s option "join" requires a relation type`,
t.TypeName, field.GoName,
)
}
return false
}
func (t *Table) initRelation(field *Field, rel string) {
switch rel {
case "belongs-to":
t.addRelation(t.belongsToRelation(field))
case "has-one":
t.addRelation(t.hasOneRelation(field))
case "has-many":
t.addRelation(t.hasManyRelation(field))
default:
panic(fmt.Errorf("bun: unknown relation=%s on field=%s", rel, field.GoName))
}
}
func (t *Table) addRelation(rel *Relation) {
if t.Relations == nil {
t.Relations = make(map[string]*Relation)
}
_, ok := t.Relations[rel.Field.GoName]
if ok {
panic(fmt.Errorf("%s already has %s", t, rel))
}
t.Relations[rel.Field.GoName] = rel
}
func (t *Table) belongsToRelation(field *Field) *Relation {
joinTable := t.dialect.Tables().Ref(field.IndirectType)
if err := joinTable.CheckPKs(); err != nil {
panic(err)
}
rel := &Relation{
Type: HasOneRelation,
Field: field,
JoinTable: joinTable,
}
if join, ok := field.Tag.Options["join"]; ok {
baseColumns, joinColumns := parseRelationJoin(join)
for i, baseColumn := range baseColumns {
joinColumn := joinColumns[i]
if f := t.fieldWithLock(baseColumn); f != nil {
rel.BaseFields = append(rel.BaseFields, f)
} else {
panic(fmt.Errorf(
"bun: %s belongs-to %s: %s must have column %s",
t.TypeName, field.GoName, t.TypeName, baseColumn,
))
}
if f := joinTable.fieldWithLock(joinColumn); f != nil {
rel.JoinFields = append(rel.JoinFields, f)
} else {
panic(fmt.Errorf(
"bun: %s belongs-to %s: %s must have column %s",
t.TypeName, field.GoName, t.TypeName, baseColumn,
))
}
}
return rel
}
rel.JoinFields = joinTable.PKs
fkPrefix := internal.Underscore(field.GoName) + "_"
for _, joinPK := range joinTable.PKs {
fkName := fkPrefix + joinPK.Name
if fk := t.fieldWithLock(fkName); fk != nil {
rel.BaseFields = append(rel.BaseFields, fk)
continue
}
if fk := t.fieldWithLock(joinPK.Name); fk != nil {
rel.BaseFields = append(rel.BaseFields, fk)
continue
}
panic(fmt.Errorf(
"bun: %s belongs-to %s: %s must have column %s "+
"(to override, use join:base_column=join_column tag on %s field)",
t.TypeName, field.GoName, t.TypeName, fkName, field.GoName,
))
}
return rel
}
func (t *Table) hasOneRelation(field *Field) *Relation {
if err := t.CheckPKs(); err != nil {
panic(err)
}
joinTable := t.dialect.Tables().Ref(field.IndirectType)
rel := &Relation{
Type: BelongsToRelation,
Field: field,
JoinTable: joinTable,
}
if join, ok := field.Tag.Options["join"]; ok {
baseColumns, joinColumns := parseRelationJoin(join)
for i, baseColumn := range baseColumns {
if f := t.fieldWithLock(baseColumn); f != nil {
rel.BaseFields = append(rel.BaseFields, f)
} else {
panic(fmt.Errorf(
"bun: %s has-one %s: %s must have column %s",
field.GoName, t.TypeName, joinTable.TypeName, baseColumn,
))
}
joinColumn := joinColumns[i]
if f := joinTable.fieldWithLock(joinColumn); f != nil {
rel.JoinFields = append(rel.JoinFields, f)
} else {
panic(fmt.Errorf(
"bun: %s has-one %s: %s must have column %s",
field.GoName, t.TypeName, joinTable.TypeName, baseColumn,
))
}
}
return rel
}
rel.BaseFields = t.PKs
fkPrefix := internal.Underscore(t.ModelName) + "_"
for _, pk := range t.PKs {
fkName := fkPrefix + pk.Name
if f := joinTable.fieldWithLock(fkName); f != nil {
rel.JoinFields = append(rel.JoinFields, f)
continue
}
if f := joinTable.fieldWithLock(pk.Name); f != nil {
rel.JoinFields = append(rel.JoinFields, f)
continue
}
panic(fmt.Errorf(
"bun: %s has-one %s: %s must have column %s "+
"(to override, use join:base_column=join_column tag on %s field)",
field.GoName, t.TypeName, joinTable.TypeName, fkName, field.GoName,
))
}
return rel
}
func (t *Table) hasManyRelation(field *Field) *Relation {
if err := t.CheckPKs(); err != nil {
panic(err)
}
if field.IndirectType.Kind() != reflect.Slice {
panic(fmt.Errorf(
"bun: %s.%s has-many relation requires slice, got %q",
t.TypeName, field.GoName, field.IndirectType.Kind(),
))
}
joinTable := t.dialect.Tables().Ref(indirectType(field.IndirectType.Elem()))
polymorphicValue, isPolymorphic := field.Tag.Options["polymorphic"]
rel := &Relation{
Type: HasManyRelation,
Field: field,
JoinTable: joinTable,
}
var polymorphicColumn string
if join, ok := field.Tag.Options["join"]; ok {
baseColumns, joinColumns := parseRelationJoin(join)
for i, baseColumn := range baseColumns {
joinColumn := joinColumns[i]
if isPolymorphic && baseColumn == "type" {
polymorphicColumn = joinColumn
continue
}
if f := t.fieldWithLock(baseColumn); f != nil {
rel.BaseFields = append(rel.BaseFields, f)
} else {
panic(fmt.Errorf(
"bun: %s has-one %s: %s must have column %s",
t.TypeName, field.GoName, t.TypeName, baseColumn,
))
}
if f := joinTable.fieldWithLock(joinColumn); f != nil {
rel.JoinFields = append(rel.JoinFields, f)
} else {
panic(fmt.Errorf(
"bun: %s has-one %s: %s must have column %s",
t.TypeName, field.GoName, t.TypeName, baseColumn,
))
}
}
} else {
rel.BaseFields = t.PKs
fkPrefix := internal.Underscore(t.ModelName) + "_"
if isPolymorphic {
polymorphicColumn = fkPrefix + "type"
}
for _, pk := range t.PKs {
joinColumn := fkPrefix + pk.Name
if fk := joinTable.fieldWithLock(joinColumn); fk != nil {
rel.JoinFields = append(rel.JoinFields, fk)
continue
}
if fk := joinTable.fieldWithLock(pk.Name); fk != nil {
rel.JoinFields = append(rel.JoinFields, fk)
continue
}
panic(fmt.Errorf(
"bun: %s has-many %s: %s must have column %s "+
"(to override, use join:base_column=join_column tag on the field %s)",
t.TypeName, field.GoName, joinTable.TypeName, joinColumn, field.GoName,
))
}
}
if isPolymorphic {
rel.PolymorphicField = joinTable.fieldWithLock(polymorphicColumn)
if rel.PolymorphicField == nil {
panic(fmt.Errorf(
"bun: %s has-many %s: %s must have polymorphic column %s",
t.TypeName, field.GoName, joinTable.TypeName, polymorphicColumn,
))
}
if polymorphicValue == "" {
polymorphicValue = t.ModelName
}
rel.PolymorphicValue = polymorphicValue
}
return rel
}
func (t *Table) m2mRelation(field *Field) *Relation {
if field.IndirectType.Kind() != reflect.Slice {
panic(fmt.Errorf(
"bun: %s.%s m2m relation requires slice, got %q",
t.TypeName, field.GoName, field.IndirectType.Kind(),
))
}
joinTable := t.dialect.Tables().Ref(indirectType(field.IndirectType.Elem()))
if err := t.CheckPKs(); err != nil {
panic(err)
}
if err := joinTable.CheckPKs(); err != nil {
panic(err)
}
m2mTableName, ok := field.Tag.Options["m2m"]
if !ok {
panic(fmt.Errorf("bun: %s must have m2m tag option", field.GoName))
}
m2mTable := t.dialect.Tables().ByName(m2mTableName)
if m2mTable == nil {
panic(fmt.Errorf(
"bun: can't find m2m %s table (use db.RegisterModel)",
m2mTableName,
))
}
rel := &Relation{
Type: ManyToManyRelation,
Field: field,
JoinTable: joinTable,
M2MTable: m2mTable,
}
var leftColumn, rightColumn string
if join, ok := field.Tag.Options["join"]; ok {
left, right := parseRelationJoin(join)
leftColumn = left[0]
rightColumn = right[0]
} else {
leftColumn = t.TypeName
rightColumn = joinTable.TypeName
}
leftField := m2mTable.fieldByGoName(leftColumn)
if leftField == nil {
panic(fmt.Errorf(
"bun: %s many-to-many %s: %s must have field %s "+
"(to override, use tag join:LeftField=RightField on field %s.%s",
t.TypeName, field.GoName, m2mTable.TypeName, leftColumn, t.TypeName, field.GoName,
))
}
rightField := m2mTable.fieldByGoName(rightColumn)
if rightField == nil {
panic(fmt.Errorf(
"bun: %s many-to-many %s: %s must have field %s "+
"(to override, use tag join:LeftField=RightField on field %s.%s",
t.TypeName, field.GoName, m2mTable.TypeName, rightColumn, t.TypeName, field.GoName,
))
}
leftRel := m2mTable.belongsToRelation(leftField)
rel.BaseFields = leftRel.JoinFields
rel.M2MBaseFields = leftRel.BaseFields
rightRel := m2mTable.belongsToRelation(rightField)
rel.JoinFields = rightRel.JoinFields
rel.M2MJoinFields = rightRel.BaseFields
return rel
}
func (t *Table) inlineFields(field *Field, path map[reflect.Type]struct{}) {
if path == nil {
path = map[reflect.Type]struct{}{
t.Type: {},
}
}
if _, ok := path[field.IndirectType]; ok {
return
}
path[field.IndirectType] = struct{}{}
joinTable := t.dialect.Tables().Ref(field.IndirectType)
for _, f := range joinTable.allFields {
f = f.Clone()
f.GoName = field.GoName + "_" + f.GoName
f.Name = field.Name + "__" + f.Name
f.SQLName = t.quoteIdent(f.Name)
f.Index = appendNew(field.Index, f.Index...)
t.fieldsMapMu.Lock()
if _, ok := t.FieldMap[f.Name]; !ok {
t.FieldMap[f.Name] = f
}
t.fieldsMapMu.Unlock()
if f.IndirectType.Kind() != reflect.Struct {
continue
}
if _, ok := path[f.IndirectType]; !ok {
t.inlineFields(f, path)
}
}
}
//------------------------------------------------------------------------------
func (t *Table) Dialect() Dialect { return t.dialect }
//------------------------------------------------------------------------------
func (t *Table) HasBeforeScanHook() bool { return t.flags.Has(beforeScanHookFlag) }
func (t *Table) HasAfterScanHook() bool { return t.flags.Has(afterScanHookFlag) }
//------------------------------------------------------------------------------
func (t *Table) AppendNamedArg(
fmter Formatter, b []byte, name string, strct reflect.Value,
) ([]byte, bool) {
if field, ok := t.FieldMap[name]; ok {
return fmter.appendArg(b, field.Value(strct).Interface()), true
}
return b, false
}
func (t *Table) quoteTableName(s string) Safe {
// Don't quote if table name contains placeholder (?) or parentheses.
if strings.IndexByte(s, '?') >= 0 ||
strings.IndexByte(s, '(') >= 0 ||
strings.IndexByte(s, ')') >= 0 {
return Safe(s)
}
return t.quoteIdent(s)
}
func (t *Table) quoteIdent(s string) Safe {
return Safe(NewFormatter(t.dialect).AppendIdent(nil, s))
}
func appendNew(dst []int, src ...int) []int {
cp := make([]int, len(dst)+len(src))
copy(cp, dst)
copy(cp[len(dst):], src)
return cp
}
func isKnownTableOption(name string) bool {
switch name {
case "alias", "select":
return true
}
return false
}
func isKnownFieldOption(name string) bool {
switch name {
case "alias",
"type",
"array",
"hstore",
"composite",
"json_use_number",
"msgpack",
"notnull",
"nullzero",
"allowzero",
"default",
"unique",
"soft_delete",
"pk",
"autoincrement",
"rel",
"join",
"m2m",
"polymorphic":
return true
}
return false
}
func removeField(fields []*Field, field *Field) []*Field {
for i, f := range fields {
if f == field {
return append(fields[:i], fields[i+1:]...)
}
}
return fields
}
func parseRelationJoin(join string) ([]string, []string) {
ss := strings.Split(join, ",")
baseColumns := make([]string, len(ss))
joinColumns := make([]string, len(ss))
for i, s := range ss {
ss := strings.Split(strings.TrimSpace(s), "=")
if len(ss) != 2 {
panic(fmt.Errorf("can't parse relation join: %q", join))
}
baseColumns[i] = ss[0]
joinColumns[i] = ss[1]
}
return baseColumns, joinColumns
}
//------------------------------------------------------------------------------
func softDeleteFieldUpdater(field *Field) func(fv reflect.Value) error {
typ := field.StructField.Type
switch typ {
case timeType:
return func(fv reflect.Value) error {
ptr := fv.Addr().Interface().(*time.Time)
*ptr = time.Now()
return nil
}
case nullTimeType:
return func(fv reflect.Value) error {
ptr := fv.Addr().Interface().(*sql.NullTime)
*ptr = sql.NullTime{Time: time.Now()}
return nil
}
case nullIntType:
return func(fv reflect.Value) error {
ptr := fv.Addr().Interface().(*sql.NullInt64)
*ptr = sql.NullInt64{Int64: time.Now().UnixNano()}
return nil
}
}
switch field.IndirectType.Kind() {
case reflect.Int64:
return func(fv reflect.Value) error {
ptr := fv.Addr().Interface().(*int64)
*ptr = time.Now().UnixNano()
return nil
}
case reflect.Ptr:
typ = typ.Elem()
default:
return softDeleteFieldUpdaterFallback(field)
}
switch typ { //nolint:gocritic
case timeType:
return func(fv reflect.Value) error {
now := time.Now()
fv.Set(reflect.ValueOf(&now))
return nil
}
}
switch typ.Kind() { //nolint:gocritic
case reflect.Int64:
return func(fv reflect.Value) error {
utime := time.Now().UnixNano()
fv.Set(reflect.ValueOf(&utime))
return nil
}
}
return softDeleteFieldUpdaterFallback(field)
}
func softDeleteFieldUpdaterFallback(field *Field) func(fv reflect.Value) error {
return func(fv reflect.Value) error {
return field.ScanWithCheck(fv, time.Now())
}
}

148
vendor/github.com/uptrace/bun/schema/tables.go generated vendored Normal file
View file

@ -0,0 +1,148 @@
package schema
import (
"fmt"
"reflect"
"sync"
)
type tableInProgress struct {
table *Table
init1Once sync.Once
init2Once sync.Once
}
func newTableInProgress(table *Table) *tableInProgress {
return &tableInProgress{
table: table,
}
}
func (inp *tableInProgress) init1() bool {
var inited bool
inp.init1Once.Do(func() {
inp.table.init1()
inited = true
})
return inited
}
func (inp *tableInProgress) init2() bool {
var inited bool
inp.init2Once.Do(func() {
inp.table.init2()
inited = true
})
return inited
}
type Tables struct {
dialect Dialect
tables sync.Map
mu sync.RWMutex
inProgress map[reflect.Type]*tableInProgress
}
func NewTables(dialect Dialect) *Tables {
return &Tables{
dialect: dialect,
inProgress: make(map[reflect.Type]*tableInProgress),
}
}
func (t *Tables) Register(models ...interface{}) {
for _, model := range models {
_ = t.Get(reflect.TypeOf(model).Elem())
}
}
func (t *Tables) Get(typ reflect.Type) *Table {
return t.table(typ, false)
}
func (t *Tables) Ref(typ reflect.Type) *Table {
return t.table(typ, true)
}
func (t *Tables) table(typ reflect.Type, allowInProgress bool) *Table {
if typ.Kind() != reflect.Struct {
panic(fmt.Errorf("got %s, wanted %s", typ.Kind(), reflect.Struct))
}
if v, ok := t.tables.Load(typ); ok {
return v.(*Table)
}
t.mu.Lock()
if v, ok := t.tables.Load(typ); ok {
t.mu.Unlock()
return v.(*Table)
}
var table *Table
inProgress := t.inProgress[typ]
if inProgress == nil {
table = newTable(t.dialect, typ)
inProgress = newTableInProgress(table)
t.inProgress[typ] = inProgress
} else {
table = inProgress.table
}
t.mu.Unlock()
inProgress.init1()
if allowInProgress {
return table
}
if inProgress.init2() {
t.mu.Lock()
delete(t.inProgress, typ)
t.tables.Store(typ, table)
t.mu.Unlock()
}
t.dialect.OnTable(table)
for _, field := range table.FieldMap {
if field.UserSQLType == "" {
field.UserSQLType = field.DiscoveredSQLType
}
if field.CreateTableSQLType == "" {
field.CreateTableSQLType = field.UserSQLType
}
}
return table
}
func (t *Tables) ByModel(name string) *Table {
var found *Table
t.tables.Range(func(key, value interface{}) bool {
t := value.(*Table)
if t.TypeName == name {
found = t
return false
}
return true
})
return found
}
func (t *Tables) ByName(name string) *Table {
var found *Table
t.tables.Range(func(key, value interface{}) bool {
t := value.(*Table)
if t.Name == name {
found = t
return false
}
return true
})
return found
}

53
vendor/github.com/uptrace/bun/schema/util.go generated vendored Normal file
View file

@ -0,0 +1,53 @@
package schema
import "reflect"
func indirectType(t reflect.Type) reflect.Type {
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
return t
}
func fieldByIndex(v reflect.Value, index []int) (_ reflect.Value, ok bool) {
if len(index) == 1 {
return v.Field(index[0]), true
}
for i, idx := range index {
if i > 0 {
if v.Kind() == reflect.Ptr {
if v.IsNil() {
return v, false
}
v = v.Elem()
}
}
v = v.Field(idx)
}
return v, true
}
func fieldByIndexAlloc(v reflect.Value, index []int) reflect.Value {
if len(index) == 1 {
return v.Field(index[0])
}
for i, idx := range index {
if i > 0 {
v = indirectNil(v)
}
v = v.Field(idx)
}
return v
}
func indirectNil(v reflect.Value) reflect.Value {
if v.Kind() == reflect.Ptr {
if v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
}
v = v.Elem()
}
return v
}

126
vendor/github.com/uptrace/bun/schema/zerochecker.go generated vendored Normal file
View file

@ -0,0 +1,126 @@
package schema
import (
"database/sql/driver"
"reflect"
)
var isZeroerType = reflect.TypeOf((*isZeroer)(nil)).Elem()
type isZeroer interface {
IsZero() bool
}
type IsZeroerFunc func(reflect.Value) bool
func FieldZeroChecker(field *Field) IsZeroerFunc {
return zeroChecker(field.IndirectType)
}
func zeroChecker(typ reflect.Type) IsZeroerFunc {
if typ.Implements(isZeroerType) {
return isZeroInterface
}
kind := typ.Kind()
if kind != reflect.Ptr {
ptr := reflect.PtrTo(typ)
if ptr.Implements(isZeroerType) {
return addrChecker(isZeroInterface)
}
}
switch kind {
case reflect.Array:
if typ.Elem().Kind() == reflect.Uint8 {
return isZeroBytes
}
return isZeroLen
case reflect.String:
return isZeroLen
case reflect.Bool:
return isZeroBool
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return isZeroInt
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return isZeroUint
case reflect.Float32, reflect.Float64:
return isZeroFloat
case reflect.Interface, reflect.Ptr, reflect.Slice, reflect.Map:
return isNil
}
if typ.Implements(driverValuerType) {
return isZeroDriverValue
}
return notZero
}
func addrChecker(fn IsZeroerFunc) IsZeroerFunc {
return func(v reflect.Value) bool {
if !v.CanAddr() {
return false
}
return fn(v.Addr())
}
}
func isZeroInterface(v reflect.Value) bool {
if v.Kind() == reflect.Ptr && v.IsNil() {
return true
}
return v.Interface().(isZeroer).IsZero()
}
func isZeroDriverValue(v reflect.Value) bool {
if v.Kind() == reflect.Ptr {
return v.IsNil()
}
valuer := v.Interface().(driver.Valuer)
value, err := valuer.Value()
if err != nil {
return false
}
return value == nil
}
func isZeroLen(v reflect.Value) bool {
return v.Len() == 0
}
func isNil(v reflect.Value) bool {
return v.IsNil()
}
func isZeroBool(v reflect.Value) bool {
return !v.Bool()
}
func isZeroInt(v reflect.Value) bool {
return v.Int() == 0
}
func isZeroUint(v reflect.Value) bool {
return v.Uint() == 0
}
func isZeroFloat(v reflect.Value) bool {
return v.Float() == 0
}
func isZeroBytes(v reflect.Value) bool {
b := v.Slice(0, v.Len()).Bytes()
for _, c := range b {
if c != 0 {
return false
}
}
return true
}
func notZero(v reflect.Value) bool {
return false
}

114
vendor/github.com/uptrace/bun/util.go generated vendored Normal file
View file

@ -0,0 +1,114 @@
package bun
import "reflect"
func indirect(v reflect.Value) reflect.Value {
switch v.Kind() {
case reflect.Interface:
return indirect(v.Elem())
case reflect.Ptr:
return v.Elem()
default:
return v
}
}
func walk(v reflect.Value, index []int, fn func(reflect.Value)) {
v = reflect.Indirect(v)
switch v.Kind() {
case reflect.Slice:
sliceLen := v.Len()
for i := 0; i < sliceLen; i++ {
visitField(v.Index(i), index, fn)
}
default:
visitField(v, index, fn)
}
}
func visitField(v reflect.Value, index []int, fn func(reflect.Value)) {
v = reflect.Indirect(v)
if len(index) > 0 {
v = v.Field(index[0])
if v.Kind() == reflect.Ptr && v.IsNil() {
return
}
walk(v, index[1:], fn)
} else {
fn(v)
}
}
func typeByIndex(t reflect.Type, index []int) reflect.Type {
for _, x := range index {
switch t.Kind() {
case reflect.Ptr:
t = t.Elem()
case reflect.Slice:
t = indirectType(t.Elem())
}
t = t.Field(x).Type
}
return indirectType(t)
}
func indirectType(t reflect.Type) reflect.Type {
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
return t
}
func sliceElemType(v reflect.Value) reflect.Type {
elemType := v.Type().Elem()
if elemType.Kind() == reflect.Interface && v.Len() > 0 {
return indirect(v.Index(0).Elem()).Type()
}
return indirectType(elemType)
}
func makeSliceNextElemFunc(v reflect.Value) func() reflect.Value {
if v.Kind() == reflect.Array {
var pos int
return func() reflect.Value {
v := v.Index(pos)
pos++
return v
}
}
sliceType := v.Type()
elemType := sliceType.Elem()
if elemType.Kind() == reflect.Ptr {
elemType = elemType.Elem()
return func() reflect.Value {
if v.Len() < v.Cap() {
v.Set(v.Slice(0, v.Len()+1))
elem := v.Index(v.Len() - 1)
if elem.IsNil() {
elem.Set(reflect.New(elemType))
}
return elem.Elem()
}
elem := reflect.New(elemType)
v.Set(reflect.Append(v, elem))
return elem.Elem()
}
}
zero := reflect.Zero(elemType)
return func() reflect.Value {
l := v.Len()
c := v.Cap()
if l < c {
v.Set(v.Slice(0, l+1))
return v.Index(l)
}
v.Set(reflect.Append(v, zero))
return v.Index(l)
}
}

6
vendor/github.com/uptrace/bun/version.go generated vendored Normal file
View file

@ -0,0 +1,6 @@
package bun
// Version is the current release version.
func Version() string {
return "0.4.3"
}