From 29b4adfb69439ef627b9dffa63b9fd3909940368 Mon Sep 17 00:00:00 2001 From: tsmethurst Date: Mon, 23 Aug 2021 16:54:26 +0200 Subject: [PATCH] start moving to bun --- go.mod | 6 +- go.sum | 18 +- internal/db/account.go | 23 +- internal/db/admin.go | 11 +- internal/db/basic.go | 26 +- internal/db/db.go | 8 +- internal/db/domain.go | 13 +- internal/db/instance.go | 14 +- internal/db/media.go | 8 +- internal/db/mention.go | 10 +- internal/db/notification.go | 10 +- internal/db/pg/account.go | 100 +- internal/db/pg/account_test.go | 3 +- internal/db/pg/admin.go | 12 +- internal/db/pg/basic.go | 3 +- internal/db/pg/domain.go | 4 +- internal/db/pg/instance.go | 3 +- internal/db/pg/media.go | 4 +- internal/db/pg/mention.go | 4 +- internal/db/pg/notification.go | 4 +- internal/db/pg/pg.go | 59 +- internal/db/pg/relationship.go | 3 +- internal/db/pg/status.go | 3 +- internal/db/pg/timeline.go | 3 +- internal/db/pg/util.go | 5 +- internal/db/relationship.go | 30 +- internal/db/status.go | 36 +- internal/db/timeline.go | 12 +- internal/federation/dereferencing/account.go | 40 +- internal/federation/dereferencing/announce.go | 9 +- internal/federation/dereferencing/blocked.go | 41 - .../dereferencing/collectionpage.go | 6 +- .../federation/dereferencing/dereferencer.go | 17 +- .../federation/dereferencing/handshake.go | 7 +- internal/federation/dereferencing/instance.go | 6 +- internal/federation/dereferencing/status.go | 48 +- internal/federation/dereferencing/thread.go | 31 +- internal/processing/account.go | 46 +- internal/processing/account/account.go | 29 +- internal/processing/account/create.go | 9 +- internal/processing/account/createblock.go | 25 +- internal/processing/account/createfollow.go | 23 +- internal/processing/account/delete.go | 47 +- internal/processing/account/get.go | 7 +- internal/processing/account/getfollowers.go | 11 +- internal/processing/account/getfollowing.go | 11 +- .../processing/account/getrelationship.go | 5 +- internal/processing/account/getstatuses.go | 7 +- internal/processing/account/removeblock.go | 11 +- internal/processing/account/removefollow.go | 17 +- internal/processing/account/update.go | 29 +- internal/processing/admin.go | 26 +- internal/processing/admin/admin.go | 13 +- .../processing/admin/createdomainblock.go | 19 +- .../processing/admin/deletedomainblock.go | 15 +- internal/processing/admin/emoji.go | 5 +- internal/processing/admin/getdomainblock.go | 5 +- internal/processing/admin/getdomainblocks.go | 6 +- .../processing/admin/importdomainblocks.go | 5 +- internal/processing/app.go | 8 +- internal/processing/blocks.go | 5 +- internal/processing/federation.go | 32 +- internal/processing/followrequest.go | 20 +- internal/processing/fromclientapi.go | 106 +- internal/processing/fromcommon.go | 47 +- internal/processing/fromfederator.go | 23 +- internal/processing/instance.go | 21 +- internal/processing/media.go | 18 +- internal/processing/processor.go | 98 +- internal/transport/controller.go | 7 +- internal/transport/deliver.go | 26 +- internal/transport/dereference.go | 22 +- internal/transport/derefinstance.go | 41 +- internal/transport/derefmedia.go | 23 +- internal/transport/finger.go | 24 +- internal/transport/transport.go | 24 +- testrig/db.go | 39 +- .../github.com/go-pg/pg/extra/pgdebug/go.mod | 7 - .../github.com/go-pg/pg/extra/pgdebug/go.sum | 161 --- .../go-pg/pg/extra/pgdebug/pgdebug.go | 42 - vendor/github.com/uptrace/bun/.gitignore | 3 + .../github.com/uptrace/bun/.prettierrc.yaml | 6 + vendor/github.com/uptrace/bun/CHANGELOG.md | 99 ++ .../pg/extra/pgdebug => uptrace/bun}/LICENSE | 2 +- vendor/github.com/uptrace/bun/Makefile | 21 + vendor/github.com/uptrace/bun/README.md | 267 ++++ vendor/github.com/uptrace/bun/RELEASING.md | 21 + vendor/github.com/uptrace/bun/bun.go | 122 ++ vendor/github.com/uptrace/bun/db.go | 502 ++++++++ .../github.com/uptrace/bun/dialect/append.go | 178 +++ .../github.com/uptrace/bun/dialect/dialect.go | 26 + .../uptrace/bun/dialect/feature/feature.go | 22 + .../uptrace/bun/dialect/pgdialect/LICENSE | 24 + .../uptrace/bun/dialect/pgdialect/append.go | 303 +++++ .../uptrace/bun/dialect/pgdialect/array.go | 65 + .../bun/dialect/pgdialect/array_parser.go | 146 +++ .../bun/dialect/pgdialect/array_scan.go | 302 +++++ .../uptrace/bun/dialect/pgdialect/dialect.go | 150 +++ .../uptrace/bun/dialect/pgdialect/go.mod | 7 + .../uptrace/bun/dialect/pgdialect/go.sum | 22 + .../uptrace/bun/dialect/pgdialect/safe.go | 11 + .../uptrace/bun/dialect/pgdialect/scan.go | 28 + .../uptrace/bun/dialect/pgdialect/sqltype.go | 104 ++ .../uptrace/bun/dialect/pgdialect/unsafe.go | 18 + .../uptrace/bun/dialect/sqltype/sqltype.go | 14 + .../uptrace/bun/driver/pgdriver/LICENSE | 24 + .../uptrace/bun/driver/pgdriver/README.md | 36 + .../uptrace/bun/driver/pgdriver/column.go | 192 +++ .../uptrace/bun/driver/pgdriver/config.go | 233 ++++ .../uptrace/bun/driver/pgdriver/driver.go | 606 +++++++++ .../uptrace/bun/driver/pgdriver/error.go | 66 + .../uptrace/bun/driver/pgdriver/format.go | 188 +++ .../uptrace/bun/driver/pgdriver/go.mod | 11 + .../uptrace/bun/driver/pgdriver/go.sum | 27 + .../uptrace/bun/driver/pgdriver/listener.go | 392 ++++++ .../uptrace/bun/driver/pgdriver/proto.go | 1127 +++++++++++++++++ .../uptrace/bun/driver/pgdriver/safe.go | 11 + .../uptrace/bun/driver/pgdriver/unsafe.go | 19 + .../bun/driver/pgdriver/write_buffer.go | 112 ++ .../uptrace/bun/extra/bunjson/json.go | 26 + .../uptrace/bun/extra/bunjson/provider.go | 43 + vendor/github.com/uptrace/bun/go.mod | 12 + vendor/github.com/uptrace/bun/go.sum | 23 + vendor/github.com/uptrace/bun/hook.go | 98 ++ .../github.com/uptrace/bun/internal/flag.go | 16 + vendor/github.com/uptrace/bun/internal/hex.go | 43 + .../github.com/uptrace/bun/internal/logger.go | 27 + .../uptrace/bun/internal/map_key.go | 67 + .../uptrace/bun/internal/parser/parser.go | 141 +++ .../github.com/uptrace/bun/internal/safe.go | 11 + .../uptrace/bun/internal/tagparser/parser.go | 147 +++ .../github.com/uptrace/bun/internal/time.go | 41 + .../uptrace/bun/internal/underscore.go | 67 + .../github.com/uptrace/bun/internal/unsafe.go | 20 + .../github.com/uptrace/bun/internal/util.go | 57 + vendor/github.com/uptrace/bun/join.go | 308 +++++ vendor/github.com/uptrace/bun/model.go | 207 +++ vendor/github.com/uptrace/bun/model_map.go | 183 +++ .../github.com/uptrace/bun/model_map_slice.go | 162 +++ vendor/github.com/uptrace/bun/model_scan.go | 54 + vendor/github.com/uptrace/bun/model_slice.go | 82 ++ .../uptrace/bun/model_table_has_many.go | 149 +++ .../github.com/uptrace/bun/model_table_m2m.go | 138 ++ .../uptrace/bun/model_table_slice.go | 113 ++ .../uptrace/bun/model_table_struct.go | 345 +++++ vendor/github.com/uptrace/bun/query_base.go | 874 +++++++++++++ .../uptrace/bun/query_column_add.go | 105 ++ .../uptrace/bun/query_column_drop.go | 112 ++ vendor/github.com/uptrace/bun/query_delete.go | 256 ++++ .../uptrace/bun/query_index_create.go | 242 ++++ .../uptrace/bun/query_index_drop.go | 105 ++ vendor/github.com/uptrace/bun/query_insert.go | 551 ++++++++ vendor/github.com/uptrace/bun/query_select.go | 830 ++++++++++++ .../uptrace/bun/query_table_create.go | 275 ++++ .../uptrace/bun/query_table_drop.go | 137 ++ .../uptrace/bun/query_table_truncate.go | 121 ++ vendor/github.com/uptrace/bun/query_update.go | 432 +++++++ vendor/github.com/uptrace/bun/query_values.go | 198 +++ .../github.com/uptrace/bun/schema/append.go | 93 ++ .../uptrace/bun/schema/append_value.go | 237 ++++ .../github.com/uptrace/bun/schema/dialect.go | 99 ++ vendor/github.com/uptrace/bun/schema/field.go | 117 ++ .../uptrace/bun/schema/formatter.go | 248 ++++ vendor/github.com/uptrace/bun/schema/hook.go | 20 + .../github.com/uptrace/bun/schema/relation.go | 32 + vendor/github.com/uptrace/bun/schema/scan.go | 392 ++++++ .../github.com/uptrace/bun/schema/sqlfmt.go | 76 ++ .../github.com/uptrace/bun/schema/sqltype.go | 129 ++ vendor/github.com/uptrace/bun/schema/table.go | 948 ++++++++++++++ .../github.com/uptrace/bun/schema/tables.go | 148 +++ vendor/github.com/uptrace/bun/schema/util.go | 53 + .../uptrace/bun/schema/zerochecker.go | 126 ++ vendor/github.com/uptrace/bun/util.go | 114 ++ vendor/github.com/uptrace/bun/version.go | 6 + vendor/modules.txt | 22 +- 175 files changed, 16045 insertions(+), 945 deletions(-) delete mode 100644 internal/federation/dereferencing/blocked.go delete mode 100644 vendor/github.com/go-pg/pg/extra/pgdebug/go.mod delete mode 100644 vendor/github.com/go-pg/pg/extra/pgdebug/go.sum delete mode 100644 vendor/github.com/go-pg/pg/extra/pgdebug/pgdebug.go create mode 100644 vendor/github.com/uptrace/bun/.gitignore create mode 100644 vendor/github.com/uptrace/bun/.prettierrc.yaml create mode 100644 vendor/github.com/uptrace/bun/CHANGELOG.md rename vendor/github.com/{go-pg/pg/extra/pgdebug => uptrace/bun}/LICENSE (94%) create mode 100644 vendor/github.com/uptrace/bun/Makefile create mode 100644 vendor/github.com/uptrace/bun/README.md create mode 100644 vendor/github.com/uptrace/bun/RELEASING.md create mode 100644 vendor/github.com/uptrace/bun/bun.go create mode 100644 vendor/github.com/uptrace/bun/db.go create mode 100644 vendor/github.com/uptrace/bun/dialect/append.go create mode 100644 vendor/github.com/uptrace/bun/dialect/dialect.go create mode 100644 vendor/github.com/uptrace/bun/dialect/feature/feature.go create mode 100644 vendor/github.com/uptrace/bun/dialect/pgdialect/LICENSE create mode 100644 vendor/github.com/uptrace/bun/dialect/pgdialect/append.go create mode 100644 vendor/github.com/uptrace/bun/dialect/pgdialect/array.go create mode 100644 vendor/github.com/uptrace/bun/dialect/pgdialect/array_parser.go create mode 100644 vendor/github.com/uptrace/bun/dialect/pgdialect/array_scan.go create mode 100644 vendor/github.com/uptrace/bun/dialect/pgdialect/dialect.go create mode 100644 vendor/github.com/uptrace/bun/dialect/pgdialect/go.mod create mode 100644 vendor/github.com/uptrace/bun/dialect/pgdialect/go.sum create mode 100644 vendor/github.com/uptrace/bun/dialect/pgdialect/safe.go create mode 100644 vendor/github.com/uptrace/bun/dialect/pgdialect/scan.go create mode 100644 vendor/github.com/uptrace/bun/dialect/pgdialect/sqltype.go create mode 100644 vendor/github.com/uptrace/bun/dialect/pgdialect/unsafe.go create mode 100644 vendor/github.com/uptrace/bun/dialect/sqltype/sqltype.go create mode 100644 vendor/github.com/uptrace/bun/driver/pgdriver/LICENSE create mode 100644 vendor/github.com/uptrace/bun/driver/pgdriver/README.md create mode 100644 vendor/github.com/uptrace/bun/driver/pgdriver/column.go create mode 100644 vendor/github.com/uptrace/bun/driver/pgdriver/config.go create mode 100644 vendor/github.com/uptrace/bun/driver/pgdriver/driver.go create mode 100644 vendor/github.com/uptrace/bun/driver/pgdriver/error.go create mode 100644 vendor/github.com/uptrace/bun/driver/pgdriver/format.go create mode 100644 vendor/github.com/uptrace/bun/driver/pgdriver/go.mod create mode 100644 vendor/github.com/uptrace/bun/driver/pgdriver/go.sum create mode 100644 vendor/github.com/uptrace/bun/driver/pgdriver/listener.go create mode 100644 vendor/github.com/uptrace/bun/driver/pgdriver/proto.go create mode 100644 vendor/github.com/uptrace/bun/driver/pgdriver/safe.go create mode 100644 vendor/github.com/uptrace/bun/driver/pgdriver/unsafe.go create mode 100644 vendor/github.com/uptrace/bun/driver/pgdriver/write_buffer.go create mode 100644 vendor/github.com/uptrace/bun/extra/bunjson/json.go create mode 100644 vendor/github.com/uptrace/bun/extra/bunjson/provider.go create mode 100644 vendor/github.com/uptrace/bun/go.mod create mode 100644 vendor/github.com/uptrace/bun/go.sum create mode 100644 vendor/github.com/uptrace/bun/hook.go create mode 100644 vendor/github.com/uptrace/bun/internal/flag.go create mode 100644 vendor/github.com/uptrace/bun/internal/hex.go create mode 100644 vendor/github.com/uptrace/bun/internal/logger.go create mode 100644 vendor/github.com/uptrace/bun/internal/map_key.go create mode 100644 vendor/github.com/uptrace/bun/internal/parser/parser.go create mode 100644 vendor/github.com/uptrace/bun/internal/safe.go create mode 100644 vendor/github.com/uptrace/bun/internal/tagparser/parser.go create mode 100644 vendor/github.com/uptrace/bun/internal/time.go create mode 100644 vendor/github.com/uptrace/bun/internal/underscore.go create mode 100644 vendor/github.com/uptrace/bun/internal/unsafe.go create mode 100644 vendor/github.com/uptrace/bun/internal/util.go create mode 100644 vendor/github.com/uptrace/bun/join.go create mode 100644 vendor/github.com/uptrace/bun/model.go create mode 100644 vendor/github.com/uptrace/bun/model_map.go create mode 100644 vendor/github.com/uptrace/bun/model_map_slice.go create mode 100644 vendor/github.com/uptrace/bun/model_scan.go create mode 100644 vendor/github.com/uptrace/bun/model_slice.go create mode 100644 vendor/github.com/uptrace/bun/model_table_has_many.go create mode 100644 vendor/github.com/uptrace/bun/model_table_m2m.go create mode 100644 vendor/github.com/uptrace/bun/model_table_slice.go create mode 100644 vendor/github.com/uptrace/bun/model_table_struct.go create mode 100644 vendor/github.com/uptrace/bun/query_base.go create mode 100644 vendor/github.com/uptrace/bun/query_column_add.go create mode 100644 vendor/github.com/uptrace/bun/query_column_drop.go create mode 100644 vendor/github.com/uptrace/bun/query_delete.go create mode 100644 vendor/github.com/uptrace/bun/query_index_create.go create mode 100644 vendor/github.com/uptrace/bun/query_index_drop.go create mode 100644 vendor/github.com/uptrace/bun/query_insert.go create mode 100644 vendor/github.com/uptrace/bun/query_select.go create mode 100644 vendor/github.com/uptrace/bun/query_table_create.go create mode 100644 vendor/github.com/uptrace/bun/query_table_drop.go create mode 100644 vendor/github.com/uptrace/bun/query_table_truncate.go create mode 100644 vendor/github.com/uptrace/bun/query_update.go create mode 100644 vendor/github.com/uptrace/bun/query_values.go create mode 100644 vendor/github.com/uptrace/bun/schema/append.go create mode 100644 vendor/github.com/uptrace/bun/schema/append_value.go create mode 100644 vendor/github.com/uptrace/bun/schema/dialect.go create mode 100644 vendor/github.com/uptrace/bun/schema/field.go create mode 100644 vendor/github.com/uptrace/bun/schema/formatter.go create mode 100644 vendor/github.com/uptrace/bun/schema/hook.go create mode 100644 vendor/github.com/uptrace/bun/schema/relation.go create mode 100644 vendor/github.com/uptrace/bun/schema/scan.go create mode 100644 vendor/github.com/uptrace/bun/schema/sqlfmt.go create mode 100644 vendor/github.com/uptrace/bun/schema/sqltype.go create mode 100644 vendor/github.com/uptrace/bun/schema/table.go create mode 100644 vendor/github.com/uptrace/bun/schema/tables.go create mode 100644 vendor/github.com/uptrace/bun/schema/util.go create mode 100644 vendor/github.com/uptrace/bun/schema/zerochecker.go create mode 100644 vendor/github.com/uptrace/bun/util.go create mode 100644 vendor/github.com/uptrace/bun/version.go diff --git a/go.mod b/go.mod index 01b5338fa..b8752992c 100644 --- a/go.mod +++ b/go.mod @@ -21,7 +21,6 @@ require ( github.com/go-errors/errors v1.4.0 // indirect github.com/go-fed/activity v1.0.1-0.20210803212804-d866ba75dd0f github.com/go-fed/httpsig v1.1.0 - github.com/go-pg/pg/extra/pgdebug v0.2.0 github.com/go-pg/pg/v10 v10.10.3 github.com/go-playground/validator/v10 v10.7.0 // indirect github.com/golang/geo v0.0.0-20210211234256-740aa86cb551 // indirect @@ -47,13 +46,16 @@ require ( github.com/superseriousbusiness/oauth2/v4 v4.3.0-SSB github.com/tdewolff/minify/v2 v2.9.21 github.com/tidwall/buntdb v1.2.4 // indirect + github.com/uptrace/bun v0.4.3 + github.com/uptrace/bun/dialect/pgdialect v0.4.3 + github.com/uptrace/bun/driver/pgdriver v0.4.3 github.com/urfave/cli/v2 v2.3.0 - github.com/vmihailenco/msgpack/v5 v5.3.4 // indirect github.com/wagslane/go-password-validator v0.3.0 golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 golang.org/x/oauth2 v0.0.0-20210628180205-a41e5a781914 golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c // indirect golang.org/x/text v0.3.6 + google.golang.org/appengine v1.6.7 // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/yaml.v2 v2.4.0 gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect diff --git a/go.sum b/go.sum index d074f2c62..9acb6dc7a 100644 --- a/go.sum +++ b/go.sum @@ -131,9 +131,6 @@ github.com/go-fed/httpsig v1.1.0/go.mod h1:RCMrTZvN1bJYtofsG4rd5NaO5obxQ5xBkdiS7 github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= -github.com/go-pg/pg/extra/pgdebug v0.2.0 h1:t62UhMiV6KYAxSWojwIJiyX06TdepkzCeIzdeb00184= -github.com/go-pg/pg/extra/pgdebug v0.2.0/go.mod h1:KmW//PLshMAQunfInLv9mFIbYXuGplOY9bc6qo3CaY0= -github.com/go-pg/pg/v10 v10.6.2/go.mod h1:BfgPoQnD2wXNd986RYEHzikqv9iE875PrFaZ9vXvtNM= github.com/go-pg/pg/v10 v10.10.3 h1:WobSfk5I+v7XwD1h9x2B7n4slDzjdBIonJ5PID95Aag= github.com/go-pg/pg/v10 v10.10.3/go.mod h1:EmoJGYErc+stNN/1Jf+o4csXuprjxcRztBnn6cHe38E= github.com/go-pg/zerochecker v0.2.0 h1:pp7f72c3DobMWOb2ErtZsnrPaSvHd2W4o9//8HtF4mU= @@ -200,7 +197,6 @@ github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.4.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-querystring v1.0.0 h1:Xkwi/a1rcvNg1PPYe5vI8GbeBY/jrVuDX5ASuANWTrk= @@ -382,6 +378,12 @@ github.com/ugorji/go v1.2.6/go.mod h1:anCg0y61KIhDlPZmnH+so+RQbysYVyDko0IMgJv0Nn github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY= github.com/ugorji/go/codec v1.2.6 h1:7kbGefxLoDBuYXOms4yD7223OpNMMPNPZxXk5TvFcyQ= github.com/ugorji/go/codec v1.2.6/go.mod h1:V6TCNZ4PHqoHGFZuSG1W8nrCzzdgA2DozYxWFFpvxTw= +github.com/uptrace/bun v0.4.3 h1:x6bjDqwjxwM/9Q1eauhkznuvTrz/rLiCK2p4tT63sAE= +github.com/uptrace/bun v0.4.3/go.mod h1:aL6D9vPw8DXaTQTwGrEPtUderBYXx7ShUmPfnxnqscw= +github.com/uptrace/bun/dialect/pgdialect v0.4.3 h1:lM2IUKpU99110chKkupw3oTfXiOKpB0hTJIe6frqQDo= +github.com/uptrace/bun/dialect/pgdialect v0.4.3/go.mod h1:BaNvWejl32oKUhwpFkw/eNcWldzIlVY4nfw/sNul0s8= +github.com/uptrace/bun/driver/pgdriver v0.4.3 h1:WLtUL3xtnZuryRcXII8PV8dm6UfEMWQniHFmV5T4tEw= +github.com/uptrace/bun/driver/pgdriver v0.4.3/go.mod h1:CQsGmzHK9Sq70avzRy7aFYAomoT3XihjGPtRDSToV+0= github.com/urfave/cli/v2 v2.3.0 h1:qph92Y649prgesehzOrQjdWyxFOp/QVM+6imKHad91M= github.com/urfave/cli/v2 v2.3.0/go.mod h1:LJmUH05zAU44vOAcrfzZQKsZbVcdbOG8rtL3/XcUArI= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= @@ -391,12 +393,9 @@ github.com/valyala/fasthttp v1.14.0/go.mod h1:ol1PCaL0dX20wC0htZ7sYCsvCYmrouYra0 github.com/valyala/tcplisten v0.0.0-20161114210144-ceec8f93295a/go.mod h1:v3UYOV9WzVtRmSR+PDvWpU/qWl4Wa5LApYYX4ZtKbio= github.com/vmihailenco/bufpool v0.1.11 h1:gOq2WmBrq0i2yW5QJ16ykccQ4wH9UyEsgLm6czKAd94= github.com/vmihailenco/bufpool v0.1.11/go.mod h1:AFf/MOy3l2CFTKbxwt0mp2MwnqjNEs5H/UxrkA5jxTQ= -github.com/vmihailenco/msgpack/v4 v4.3.11/go.mod h1:gborTTJjAo/GWTqqRjrLCn9pgNN+NXzzngzBKDPIqw4= -github.com/vmihailenco/msgpack/v5 v5.0.0-beta.1/go.mod h1:xlngVLeyQ/Qi05oQxhQ+oTuqa03RjMwMfk/7/TCs+QI= github.com/vmihailenco/msgpack/v5 v5.3.1/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc= github.com/vmihailenco/msgpack/v5 v5.3.4 h1:qMKAwOV+meBw2Y8k9cVwAy7qErtYCwBzZ2ellBfvnqc= github.com/vmihailenco/msgpack/v5 v5.3.4/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc= -github.com/vmihailenco/tagparser v0.1.1/go.mod h1:OeAg3pn3UbLjkWt+rN9oFYB6u/cQgqMEUPoW2WPyhdI= github.com/vmihailenco/tagparser v0.1.2 h1:gnjoVuB/kljJ5wICEEOpx98oXMWPLj22G67Vbd1qPqc= github.com/vmihailenco/tagparser v0.1.2/go.mod h1:OeAg3pn3UbLjkWt+rN9oFYB6u/cQgqMEUPoW2WPyhdI= github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= @@ -426,7 +425,6 @@ go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8= go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= -go.opentelemetry.io/otel v0.13.0/go.mod h1:dlSNewoRYikTkotEnxdmuBHgzT+k/idJSfDv/FxEnOY= go.uber.org/goleak v0.10.0 h1:G3eWbSNIskeRqtsN/1uI5B+eP73y3JUuBsv9AZjehb4= go.uber.org/goleak v0.10.0/go.mod h1:VCZuO8V8mFPlL0F5J5GK1rtHV3DrFcQ1R8ryq7FK0aI= golang.org/x/crypto v0.0.0-20180527072434-ab813273cd59/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= @@ -436,7 +434,6 @@ golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8U golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.0.0-20201012173705-84dcc777aaee/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 h1:/UOmuWzQfxxo9UtlXMwuQU8CMgg1eZXqTRwkSQJWKOI= golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= @@ -502,7 +499,6 @@ golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81R golang.org/x/net v0.0.0-20200707034311-ab3426394381/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20201006153459-a7d1128ccaa0/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= -golang.org/x/net v0.0.0-20201010224723-4f7140c49acb/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= @@ -562,7 +558,6 @@ golang.org/x/sys v0.0.0-20200523222454-059865788121/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200803210538-64077c9b5642/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20201015000850-e3ed0017c211/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210112080510-489259a85091/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -570,6 +565,7 @@ golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c h1:F1jZWGFhYfh0Ci55sIpILtKKK8p3i2/krTr0H1rg74I= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1 h1:v+OssWQX+hTHEmOBgwxdZxK4zHq3yOs8F9J7mk0PY8E= diff --git a/internal/db/account.go b/internal/db/account.go index 0e1575f9b..61d97bf8c 100644 --- a/internal/db/account.go +++ b/internal/db/account.go @@ -19,6 +19,7 @@ package db import ( + "context" "time" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" @@ -27,40 +28,40 @@ import ( // Account contains functions related to account getting/setting/creation. type Account interface { // GetAccountByID returns one account with the given ID, or an error if something goes wrong. - GetAccountByID(id string) (*gtsmodel.Account, Error) + GetAccountByID(ctx context.Context, id string) (*gtsmodel.Account, Error) // GetAccountByURI returns one account with the given URI, or an error if something goes wrong. - GetAccountByURI(uri string) (*gtsmodel.Account, Error) + GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.Account, Error) // GetAccountByURL returns one account with the given URL, or an error if something goes wrong. - GetAccountByURL(uri string) (*gtsmodel.Account, Error) + GetAccountByURL(ctx context.Context, uri string) (*gtsmodel.Account, Error) // GetLocalAccountByUsername returns an account on this instance by its username. - GetLocalAccountByUsername(username string) (*gtsmodel.Account, Error) + GetLocalAccountByUsername(ctx context.Context, username string) (*gtsmodel.Account, Error) // GetAccountFaves fetches faves/likes created by the target accountID. - GetAccountFaves(accountID string) ([]*gtsmodel.StatusFave, Error) + GetAccountFaves(ctx context.Context, accountID string) ([]*gtsmodel.StatusFave, Error) // GetAccountStatusesCount is a shortcut for the common action of counting statuses produced by accountID. - CountAccountStatuses(accountID string) (int, Error) + CountAccountStatuses(ctx context.Context, accountID string) (int, Error) // GetAccountStatuses is a shortcut for getting the most recent statuses. accountID is optional, if not provided // then all statuses will be returned. If limit is set to 0, the size of the returned slice will not be limited. This can // be very memory intensive so you probably shouldn't do this! // In case of no entries, a 'no entries' error will be returned - GetAccountStatuses(accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, Error) + GetAccountStatuses(ctx context.Context, accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, Error) - GetAccountBlocks(accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, Error) + GetAccountBlocks(ctx context.Context, accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, Error) // GetAccountLastPosted simply gets the timestamp of the most recent post by the account. // // The returned time will be zero if account has never posted anything. - GetAccountLastPosted(accountID string) (time.Time, Error) + GetAccountLastPosted(ctx context.Context, accountID string) (time.Time, Error) // SetAccountHeaderOrAvatar sets the header or avatar for the given accountID to the given media attachment. - SetAccountHeaderOrAvatar(mediaAttachment *gtsmodel.MediaAttachment, accountID string) Error + SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachment *gtsmodel.MediaAttachment, accountID string) Error // GetInstanceAccount returns the instance account for the given domain. // If domain is empty, this instance account will be returned. - GetInstanceAccount(domain string) (*gtsmodel.Account, Error) + GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, Error) } diff --git a/internal/db/admin.go b/internal/db/admin.go index aa2b22f47..f9b4f821e 100644 --- a/internal/db/admin.go +++ b/internal/db/admin.go @@ -19,6 +19,7 @@ package db import ( + "context" "net" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" @@ -28,26 +29,26 @@ import ( type Admin interface { // IsUsernameAvailable checks whether a given username is available on our domain. // Returns an error if the username is already taken, or something went wrong in the db. - IsUsernameAvailable(username string) Error + IsUsernameAvailable(ctx context.Context, username string) Error // IsEmailAvailable checks whether a given email address for a new account is available to be used on our domain. // Return an error if: // A) the email is already associated with an account // B) we block signups from this email domain // C) something went wrong in the db - IsEmailAvailable(email string) Error + IsEmailAvailable(ctx context.Context, email string) Error // NewSignup creates a new user in the database with the given parameters. // By the time this function is called, it should be assumed that all the parameters have passed validation! - NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, Error) + NewSignup(ctx context.Context, username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, Error) // CreateInstanceAccount creates an account in the database with the same username as the instance host value. // Ie., if the instance is hosted at 'example.org' the instance user will have a username of 'example.org'. // This is needed for things like serving files that belong to the instance and not an individual user/account. - CreateInstanceAccount() Error + CreateInstanceAccount(ctx context.Context) Error // CreateInstanceInstance creates an instance in the database with the same domain as the instance host value. // Ie., if the instance is hosted at 'example.org' the instance will have a domain of 'example.org'. // This is needed for things like serving instance information through /api/v1/instance - CreateInstanceInstance() Error + CreateInstanceInstance(ctx context.Context) Error } diff --git a/internal/db/basic.go b/internal/db/basic.go index 729920bba..d48ffb018 100644 --- a/internal/db/basic.go +++ b/internal/db/basic.go @@ -24,15 +24,15 @@ import "context" type Basic interface { // CreateTable creates a table for the given interface. // For implementations that don't use tables, this can just return nil. - CreateTable(i interface{}) Error + CreateTable(ctx context.Context, i interface{}) Error // DropTable drops the table for the given interface. // For implementations that don't use tables, this can just return nil. - DropTable(i interface{}) Error + DropTable(ctx context.Context, i interface{}) Error // RegisterTable registers a table for use in many2many relations. // For implementations that don't use tables, or many2many relations, this can just return nil. - RegisterTable(i interface{}) Error + RegisterTable(ctx context.Context, i interface{}) Error // Stop should stop and close the database connection cleanly, returning an error if this is not possible. // If the database implementation doesn't need to be stopped, this can just return nil. @@ -45,43 +45,43 @@ type Basic interface { // for other implementations (for example, in-memory) it might just be the key of a map. // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice. // In case of no entries, a 'no entries' error will be returned - GetByID(id string, i interface{}) Error + GetByID(ctx context.Context, id string, i interface{}) Error // GetWhere gets one entry where key = value. This is similar to GetByID but allows the caller to specify the // name of the key to select from. // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice. // In case of no entries, a 'no entries' error will be returned - GetWhere(where []Where, i interface{}) Error + GetWhere(ctx context.Context, where []Where, i interface{}) Error // GetAll will try to get all entries of type i. // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice. // In case of no entries, a 'no entries' error will be returned - GetAll(i interface{}) Error + GetAll(ctx context.Context, i interface{}) Error // Put simply stores i. It is up to the implementation to figure out how to store it, and using what key. // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice. - Put(i interface{}) Error + Put(ctx context.Context, i interface{}) Error // Upsert stores or updates i based on the given conflict column, as in https://www.postgresqltutorial.com/postgresql-upsert/ // It is up to the implementation to figure out how to store it, and using what key. // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice. - Upsert(i interface{}, conflictColumn string) Error + Upsert(ctx context.Context, i interface{}, conflictColumn string) Error // UpdateByID updates i with id id. // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice. - UpdateByID(id string, i interface{}) Error + UpdateByID(ctx context.Context, id string, i interface{}) Error // UpdateOneByID updates interface i with database the given database id. It will update one field of key key and value value. - UpdateOneByID(id string, key string, value interface{}, i interface{}) Error + UpdateOneByID(ctx context.Context, id string, key string, value interface{}, i interface{}) Error // UpdateWhere updates column key of interface i with the given value, where the given parameters apply. - UpdateWhere(where []Where, key string, value interface{}, i interface{}) Error + UpdateWhere(ctx context.Context, where []Where, key string, value interface{}, i interface{}) Error // DeleteByID removes i with id id. // If i didn't exist anyway, then no error should be returned. - DeleteByID(id string, i interface{}) Error + DeleteByID(ctx context.Context, id string, i interface{}) Error // DeleteWhere deletes i where key = value // If i didn't exist anyway, then no error should be returned. - DeleteWhere(where []Where, i interface{}) Error + DeleteWhere(ctx context.Context, where []Where, i interface{}) Error } diff --git a/internal/db/db.go b/internal/db/db.go index d6ac883e4..71ac887bb 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -19,6 +19,8 @@ package db import ( + "context" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) @@ -52,7 +54,7 @@ type DB interface { // // Note: this func doesn't/shouldn't do any manipulation of the accounts in the DB, it's just for checking // if they exist in the db and conveniently returning them if they do. - MentionStringsToMentions(targetAccounts []string, originAccountID string, statusID string) ([]*gtsmodel.Mention, error) + MentionStringsToMentions(ctx context.Context, targetAccounts []string, originAccountID string, statusID string) ([]*gtsmodel.Mention, error) // TagStringsToTags takes a slice of deduplicated, lowercase tags in the form "somehashtag", which have been // used in a status. It takes the id of the account that wrote the status, and the id of the status itself, and then @@ -61,7 +63,7 @@ type DB interface { // // Note: this func doesn't/shouldn't do any manipulation of the tags in the DB, it's just for checking // if they exist in the db already, and conveniently returning them, or creating new tag structs. - TagStringsToTags(tags []string, originAccountID string, statusID string) ([]*gtsmodel.Tag, error) + TagStringsToTags(ctx context.Context, tags []string, originAccountID string, statusID string) ([]*gtsmodel.Tag, error) // EmojiStringsToEmojis takes a slice of deduplicated, lowercase emojis in the form ":emojiname:", which have been // used in a status. It takes the id of the account that wrote the status, and the id of the status itself, and then @@ -69,5 +71,5 @@ type DB interface { // // Note: this func doesn't/shouldn't do any manipulation of the emoji in the DB, it's just for checking // if they exist in the db and conveniently returning them if they do. - EmojiStringsToEmojis(emojis []string, originAccountID string, statusID string) ([]*gtsmodel.Emoji, error) + EmojiStringsToEmojis(ctx context.Context, emojis []string, originAccountID string, statusID string) ([]*gtsmodel.Emoji, error) } diff --git a/internal/db/domain.go b/internal/db/domain.go index a6583c80c..df50a6770 100644 --- a/internal/db/domain.go +++ b/internal/db/domain.go @@ -18,19 +18,22 @@ package db -import "net/url" +import ( + "context" + "net/url" +) // Domain contains DB functions related to domains and domain blocks. type Domain interface { // IsDomainBlocked checks if an instance-level domain block exists for the given domain string (eg., `example.org`). - IsDomainBlocked(domain string) (bool, Error) + IsDomainBlocked(ctx context.Context, domain string) (bool, Error) // AreDomainsBlocked checks if an instance-level domain block exists for any of the given domains strings, and returns true if even one is found. - AreDomainsBlocked(domains []string) (bool, Error) + AreDomainsBlocked(ctx context.Context, domains []string) (bool, Error) // IsURIBlocked checks if an instance-level domain block exists for the `host` in the given URI (eg., `https://example.org/users/whatever`). - IsURIBlocked(uri *url.URL) (bool, Error) + IsURIBlocked(ctx context.Context, uri *url.URL) (bool, Error) // AreURIsBlocked checks if an instance-level domain block exists for any `host` in the given URI slice, and returns true if even one is found. - AreURIsBlocked(uris []*url.URL) (bool, Error) + AreURIsBlocked(ctx context.Context, uris []*url.URL) (bool, Error) } diff --git a/internal/db/instance.go b/internal/db/instance.go index 1f7c83e4f..dcd978a81 100644 --- a/internal/db/instance.go +++ b/internal/db/instance.go @@ -18,19 +18,23 @@ package db -import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +import ( + "context" + + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) // Instance contains functions for instance-level actions (counting instance users etc.). type Instance interface { // CountInstanceUsers returns the number of known accounts registered with the given domain. - CountInstanceUsers(domain string) (int, Error) + CountInstanceUsers(ctx context.Context, domain string) (int, Error) // CountInstanceStatuses returns the number of known statuses posted from the given domain. - CountInstanceStatuses(domain string) (int, Error) + CountInstanceStatuses(ctx context.Context, domain string) (int, Error) // CountInstanceDomains returns the number of known instances known that the given domain federates with. - CountInstanceDomains(domain string) (int, Error) + CountInstanceDomains(ctx context.Context, domain string) (int, Error) // GetInstanceAccounts returns a slice of accounts from the given instance, arranged by ID. - GetInstanceAccounts(domain string, maxID string, limit int) ([]*gtsmodel.Account, Error) + GetInstanceAccounts(ctx context.Context, domain string, maxID string, limit int) ([]*gtsmodel.Account, Error) } diff --git a/internal/db/media.go b/internal/db/media.go index db4db3411..b779dd276 100644 --- a/internal/db/media.go +++ b/internal/db/media.go @@ -18,10 +18,14 @@ package db -import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +import ( + "context" + + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) // Media contains functions related to creating/getting/removing media attachments. type Media interface { // GetAttachmentByID gets a single attachment by its ID - GetAttachmentByID(id string) (*gtsmodel.MediaAttachment, Error) + GetAttachmentByID(ctx context.Context, id string) (*gtsmodel.MediaAttachment, Error) } diff --git a/internal/db/mention.go b/internal/db/mention.go index cb1c56dc1..b9b45546a 100644 --- a/internal/db/mention.go +++ b/internal/db/mention.go @@ -18,13 +18,17 @@ package db -import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +import ( + "context" + + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) // Mention contains functions for getting/creating mentions in the database. type Mention interface { // GetMention gets a single mention by ID - GetMention(id string) (*gtsmodel.Mention, Error) + GetMention(ctx context.Context, id string) (*gtsmodel.Mention, Error) // GetMentions gets multiple mentions. - GetMentions(ids []string) ([]*gtsmodel.Mention, Error) + GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.Mention, Error) } diff --git a/internal/db/notification.go b/internal/db/notification.go index 326f0f149..09c17f031 100644 --- a/internal/db/notification.go +++ b/internal/db/notification.go @@ -18,14 +18,18 @@ package db -import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +import ( + "context" + + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) // Notification contains functions for creating and getting notifications. type Notification interface { // GetNotifications returns a slice of notifications that pertain to the given accountID. // // Returned notifications will be ordered ID descending (ie., highest/newest to lowest/oldest). - GetNotifications(accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, Error) + GetNotifications(ctx context.Context, accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, Error) // GetNotification returns one notification according to its id. - GetNotification(id string) (*gtsmodel.Notification, Error) + GetNotification(ctx context.Context, id string) (*gtsmodel.Notification, Error) } diff --git a/internal/db/pg/account.go b/internal/db/pg/account.go index 3889c6601..4cf69df18 100644 --- a/internal/db/pg/account.go +++ b/internal/db/pg/account.go @@ -24,61 +24,62 @@ import ( "fmt" "time" - "github.com/go-pg/pg/v10" - "github.com/go-pg/pg/v10/orm" "github.com/sirupsen/logrus" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/uptrace/bun" ) type accountDB struct { config *config.Config - conn *pg.DB + conn *bun.DB log *logrus.Logger cancel context.CancelFunc } -func (a *accountDB) newAccountQ(account *gtsmodel.Account) *orm.Query { - return a.conn.Model(account). +func (a *accountDB) newAccountQ(account *gtsmodel.Account) *bun.SelectQuery { + return a.conn. + NewSelect(). + Model(account). Relation("AvatarMediaAttachment"). Relation("HeaderMediaAttachment") } -func (a *accountDB) GetAccountByID(id string) (*gtsmodel.Account, db.Error) { +func (a *accountDB) GetAccountByID(ctx context.Context, id string) (*gtsmodel.Account, db.Error) { account := >smodel.Account{} q := a.newAccountQ(account). Where("account.id = ?", id) - err := processErrorResponse(q.Select()) + err := processErrorResponse(q.Scan(ctx)) return account, err } -func (a *accountDB) GetAccountByURI(uri string) (*gtsmodel.Account, db.Error) { +func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) { account := >smodel.Account{} q := a.newAccountQ(account). Where("account.uri = ?", uri) - err := processErrorResponse(q.Select()) + err := processErrorResponse(q.Scan(ctx)) return account, err } -func (a *accountDB) GetAccountByURL(uri string) (*gtsmodel.Account, db.Error) { +func (a *accountDB) GetAccountByURL(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) { account := >smodel.Account{} q := a.newAccountQ(account). Where("account.url = ?", uri) - err := processErrorResponse(q.Select()) + err := processErrorResponse(q.Scan(ctx)) return account, err } -func (a *accountDB) GetInstanceAccount(domain string) (*gtsmodel.Account, db.Error) { +func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, db.Error) { account := >smodel.Account{} q := a.newAccountQ(account) @@ -90,29 +91,31 @@ func (a *accountDB) GetInstanceAccount(domain string) (*gtsmodel.Account, db.Err } else { q = q. Where("account.username = ?", domain). - Where("? IS NULL", pg.Ident("domain")) + Where("? IS NULL", bun.Ident("domain")) } - err := processErrorResponse(q.Select()) + err := processErrorResponse(q.Scan(ctx)) return account, err } -func (a *accountDB) GetAccountLastPosted(accountID string) (time.Time, db.Error) { +func (a *accountDB) GetAccountLastPosted(ctx context.Context, accountID string) (time.Time, db.Error) { status := >smodel.Status{} - q := a.conn.Model(status). + q := a.conn. + NewSelect(). + Model(status). Order("id DESC"). Limit(1). Where("account_id = ?", accountID). Column("created_at") - err := processErrorResponse(q.Select()) + err := processErrorResponse(q.Scan(ctx)) return status.CreatedAt, err } -func (a *accountDB) SetAccountHeaderOrAvatar(mediaAttachment *gtsmodel.MediaAttachment, accountID string) db.Error { +func (a *accountDB) SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachment *gtsmodel.MediaAttachment, accountID string) db.Error { if mediaAttachment.Avatar && mediaAttachment.Header { return errors.New("one media attachment cannot be both header and avatar") } @@ -127,51 +130,58 @@ func (a *accountDB) SetAccountHeaderOrAvatar(mediaAttachment *gtsmodel.MediaAtta } // TODO: there are probably more side effects here that need to be handled - if _, err := a.conn.Model(mediaAttachment).OnConflict("(id) DO UPDATE").Insert(); err != nil { + if _, err := a.conn.NewInsert().Model(mediaAttachment).On("CONFLICT (id) DO UPDATE").Exec(ctx); err != nil { return err } - if _, err := a.conn.Model(>smodel.Account{}).Set(fmt.Sprintf("%s_media_attachment_id = ?", headerOrAVI), mediaAttachment.ID).Where("id = ?", accountID).Update(); err != nil { + if _, err := a.conn.NewInsert().Model(>smodel.Account{}).Set(fmt.Sprintf("%s_media_attachment_id = ?", headerOrAVI), mediaAttachment.ID).Where("id = ?", accountID).Exec(ctx); err != nil { return err } return nil } -func (a *accountDB) GetLocalAccountByUsername(username string) (*gtsmodel.Account, db.Error) { +func (a *accountDB) GetLocalAccountByUsername(ctx context.Context, username string) (*gtsmodel.Account, db.Error) { account := >smodel.Account{} q := a.newAccountQ(account). Where("username = ?", username). - Where("? IS NULL", pg.Ident("domain")) + Where("? IS NULL", bun.Ident("domain")) - err := processErrorResponse(q.Select()) + err := processErrorResponse(q.Scan(ctx)) return account, err } -func (a *accountDB) GetAccountFaves(accountID string) ([]*gtsmodel.StatusFave, db.Error) { +func (a *accountDB) GetAccountFaves(ctx context.Context, accountID string) ([]*gtsmodel.StatusFave, db.Error) { faves := []*gtsmodel.StatusFave{} - if err := a.conn.Model(&faves). + if err := a.conn. + NewSelect(). + Model(&faves). Where("account_id = ?", accountID). - Select(); err != nil { - if err == pg.ErrNoRows { - return faves, nil - } + Scan(ctx); err != nil { return nil, err } return faves, nil } -func (a *accountDB) CountAccountStatuses(accountID string) (int, db.Error) { - return a.conn.Model(>smodel.Status{}).Where("account_id = ?", accountID).Count() +func (a *accountDB) CountAccountStatuses(ctx context.Context, accountID string) (int, db.Error) { + return a.conn. + NewSelect(). + Model(>smodel.Status{}). + Where("account_id = ?", accountID). + Count(ctx) } -func (a *accountDB) GetAccountStatuses(accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, db.Error) { +func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, db.Error) { a.log.Debugf("getting statuses for account %s", accountID) statuses := []*gtsmodel.Status{} - q := a.conn.Model(&statuses).Order("id DESC") + q := a.conn. + NewSelect(). + Model(&statuses). + Order("id DESC") + if accountID != "" { q = q.Where("account_id = ?", accountID) } @@ -181,7 +191,7 @@ func (a *accountDB) GetAccountStatuses(accountID string, limit int, excludeRepli } if excludeReplies { - q = q.Where("? IS NULL", pg.Ident("in_reply_to_id")) + q = q.Where("? IS NULL", bun.Ident("in_reply_to_id")) } if pinnedOnly { @@ -189,8 +199,10 @@ func (a *accountDB) GetAccountStatuses(accountID string, limit int, excludeRepli } if mediaOnly { - q = q.WhereGroup(func(q *pg.Query) (*pg.Query, error) { - return q.Where("? IS NOT NULL", pg.Ident("attachments")).Where("attachments != '{}'"), nil + q = q.WhereGroup(" AND ", func(q *bun.SelectQuery) *bun.SelectQuery { + return q. + WhereOr("? IS NOT NULL", bun.Ident("attachments")). + WhereOr("attachments != '{}'") }) } @@ -198,10 +210,7 @@ func (a *accountDB) GetAccountStatuses(accountID string, limit int, excludeRepli q = q.Where("id < ?", maxID) } - if err := q.Select(); err != nil { - if err == pg.ErrNoRows { - return nil, db.ErrNoEntries - } + if err := q.Scan(ctx); err != nil { return nil, err } @@ -213,10 +222,12 @@ func (a *accountDB) GetAccountStatuses(accountID string, limit int, excludeRepli return statuses, nil } -func (a *accountDB) GetAccountBlocks(accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, db.Error) { +func (a *accountDB) GetAccountBlocks(ctx context.Context, accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, db.Error) { blocks := []*gtsmodel.Block{} - fq := a.conn.Model(&blocks). + fq := a.conn. + NewSelect(). + Model(&blocks). Where("block.account_id = ?", accountID). Relation("TargetAccount"). Order("block.id DESC") @@ -233,11 +244,8 @@ func (a *accountDB) GetAccountBlocks(accountID string, maxID string, sinceID str fq = fq.Limit(limit) } - err := fq.Select() + err := fq.Scan(ctx) if err != nil { - if err == pg.ErrNoRows { - return nil, "", "", db.ErrNoEntries - } return nil, "", "", err } diff --git a/internal/db/pg/account_test.go b/internal/db/pg/account_test.go index 7ea5ff39a..df4d244bf 100644 --- a/internal/db/pg/account_test.go +++ b/internal/db/pg/account_test.go @@ -19,6 +19,7 @@ package pg_test import ( + "context" "testing" "github.com/stretchr/testify/suite" @@ -54,7 +55,7 @@ func (suite *AccountTestSuite) TearDownTest() { } func (suite *AccountTestSuite) TestGetAccountByIDWithExtras() { - account, err := suite.db.GetAccountByID(suite.testAccounts["local_account_1"].ID) + account, err := suite.db.GetAccountByID(context.Background(), suite.testAccounts["local_account_1"].ID) if err != nil { suite.FailNow(err.Error()) } diff --git a/internal/db/pg/admin.go b/internal/db/pg/admin.go index 854f56ef0..6319ba273 100644 --- a/internal/db/pg/admin.go +++ b/internal/db/pg/admin.go @@ -35,21 +35,27 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/id" "github.com/superseriousbusiness/gotosocial/internal/util" + "github.com/uptrace/bun" "golang.org/x/crypto/bcrypt" ) type adminDB struct { config *config.Config - conn *pg.DB + conn *bun.DB log *logrus.Logger cancel context.CancelFunc } -func (a *adminDB) IsUsernameAvailable(username string) db.Error { +func (a *adminDB) IsUsernameAvailable(ctx context.Context, username string) db.Error { // if no error we fail because it means we found something // if error but it's not pg.ErrNoRows then we fail // if err is pg.ErrNoRows we're good, we found nothing so continue - if err := a.conn.Model(>smodel.Account{}).Where("username = ?", username).Where("domain = ?", nil).Select(); err == nil { + if err := a.conn. + NewSelect(). + Model(>smodel.Account{}). + Where("username = ?", username). + Where("domain = ?", nil). + Scan(ctx); err == nil { return fmt.Errorf("username %s already in use", username) } else if err != pg.ErrNoRows { return fmt.Errorf("db error: %s", err) diff --git a/internal/db/pg/basic.go b/internal/db/pg/basic.go index 6e76b4450..f43b80afe 100644 --- a/internal/db/pg/basic.go +++ b/internal/db/pg/basic.go @@ -29,11 +29,12 @@ import ( "github.com/sirupsen/logrus" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/uptrace/bun" ) type basicDB struct { config *config.Config - conn *pg.DB + conn *bun.DB log *logrus.Logger cancel context.CancelFunc } diff --git a/internal/db/pg/domain.go b/internal/db/pg/domain.go index 4e9b2ab48..1666d8b11 100644 --- a/internal/db/pg/domain.go +++ b/internal/db/pg/domain.go @@ -22,17 +22,17 @@ import ( "context" "net/url" - "github.com/go-pg/pg/v10" "github.com/sirupsen/logrus" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/util" + "github.com/uptrace/bun" ) type domainDB struct { config *config.Config - conn *pg.DB + conn *bun.DB log *logrus.Logger cancel context.CancelFunc } diff --git a/internal/db/pg/instance.go b/internal/db/pg/instance.go index 968832ca5..2f0c326ca 100644 --- a/internal/db/pg/instance.go +++ b/internal/db/pg/instance.go @@ -26,11 +26,12 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/uptrace/bun" ) type instanceDB struct { config *config.Config - conn *pg.DB + conn *bun.DB log *logrus.Logger cancel context.CancelFunc } diff --git a/internal/db/pg/media.go b/internal/db/pg/media.go index 618030af3..4b421cca4 100644 --- a/internal/db/pg/media.go +++ b/internal/db/pg/media.go @@ -21,17 +21,17 @@ package pg import ( "context" - "github.com/go-pg/pg/v10" "github.com/go-pg/pg/v10/orm" "github.com/sirupsen/logrus" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/uptrace/bun" ) type mediaDB struct { config *config.Config - conn *pg.DB + conn *bun.DB log *logrus.Logger cancel context.CancelFunc } diff --git a/internal/db/pg/mention.go b/internal/db/pg/mention.go index b31f07b67..fac7ca5ad 100644 --- a/internal/db/pg/mention.go +++ b/internal/db/pg/mention.go @@ -21,18 +21,18 @@ package pg import ( "context" - "github.com/go-pg/pg/v10" "github.com/go-pg/pg/v10/orm" "github.com/sirupsen/logrus" "github.com/superseriousbusiness/gotosocial/internal/cache" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/uptrace/bun" ) type mentionDB struct { config *config.Config - conn *pg.DB + conn *bun.DB log *logrus.Logger cancel context.CancelFunc cache cache.Cache diff --git a/internal/db/pg/notification.go b/internal/db/pg/notification.go index 281a76d85..27fe00b1e 100644 --- a/internal/db/pg/notification.go +++ b/internal/db/pg/notification.go @@ -21,18 +21,18 @@ package pg import ( "context" - "github.com/go-pg/pg/v10" "github.com/go-pg/pg/v10/orm" "github.com/sirupsen/logrus" "github.com/superseriousbusiness/gotosocial/internal/cache" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/uptrace/bun" ) type notificationDB struct { config *config.Config - conn *pg.DB + conn *bun.DB log *logrus.Logger cancel context.CancelFunc cache cache.Cache diff --git a/internal/db/pg/pg.go b/internal/db/pg/pg.go index 0437baf02..7626386ee 100644 --- a/internal/db/pg/pg.go +++ b/internal/db/pg/pg.go @@ -22,6 +22,7 @@ import ( "context" "crypto/tls" "crypto/x509" + "database/sql" "encoding/pem" "errors" "fmt" @@ -29,7 +30,6 @@ import ( "strings" "time" - "github.com/go-pg/pg/extra/pgdebug" "github.com/go-pg/pg/v10" "github.com/go-pg/pg/v10/orm" "github.com/sirupsen/logrus" @@ -37,6 +37,9 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/id" + "github.com/uptrace/bun" + "github.com/uptrace/bun/dialect/pgdialect" + "github.com/uptrace/bun/driver/pgdriver" ) var registerTables []interface{} = []interface{}{ @@ -77,105 +80,75 @@ func NewPostgresService(ctx context.Context, c *config.Config, log *logrus.Logge } log.Debugf("using pg options: %+v", opts) - // create a connection - pgCtx, cancel := context.WithCancel(ctx) - conn := pg.Connect(opts).WithContext(pgCtx) + sqldb := sql.OpenDB(pgdriver.NewConnector(opts...)) - // this will break the logfmt format we normally log in, - // since we can't choose where pg outputs to and it defaults to - // stdout. So use this option with care! - if log.GetLevel() >= logrus.TraceLevel { - conn.AddQueryHook(pgdebug.DebugHook{ - // Print all queries. - Verbose: true, - }) - } + conn := bun.NewDB(sqldb, pgdialect.New()) // actually *begin* the connection so that we can tell if the db is there and listening - if err := conn.Ping(ctx); err != nil { - cancel() + if err := conn.Ping(); err != nil { return nil, fmt.Errorf("db connection error: %s", err) } - - // print out discovered postgres version - var version string - if _, err = conn.QueryOneContext(ctx, pg.Scan(&version), "SELECT version()"); err != nil { - cancel() - return nil, fmt.Errorf("db connection error: %s", err) - } - log.Infof("connected to postgres version: %s", version) + log.Info("connected to postgres") ps := &postgresService{ Account: &accountDB{ config: c, conn: conn, log: log, - cancel: cancel, }, Admin: &adminDB{ config: c, conn: conn, log: log, - cancel: cancel, }, Basic: &basicDB{ config: c, conn: conn, log: log, - cancel: cancel, }, Domain: &domainDB{ config: c, conn: conn, log: log, - cancel: cancel, }, Instance: &instanceDB{ config: c, conn: conn, log: log, - cancel: cancel, }, Media: &mediaDB{ config: c, conn: conn, log: log, - cancel: cancel, }, Mention: &mentionDB{ config: c, conn: conn, log: log, - cancel: cancel, }, Notification: ¬ificationDB{ config: c, conn: conn, log: log, - cancel: cancel, }, Relationship: &relationshipDB{ config: c, conn: conn, log: log, - cancel: cancel, }, Status: &statusDB{ config: c, conn: conn, log: log, - cancel: cancel, }, Timeline: &timelineDB{ config: c, conn: conn, log: log, - cancel: cancel, }, config: c, conn: conn, log: log, - cancel: cancel, } // we can confidently return this useable postgres service now @@ -188,7 +161,7 @@ func NewPostgresService(ctx context.Context, c *config.Config, log *logrus.Logge // derivePGOptions takes an application config and returns either a ready-to-use *pg.Options // with sensible defaults, or an error if it's not satisfied by the provided config. -func derivePGOptions(c *config.Config) (*pg.Options, error) { +func derivePGOptions(c *config.Config) ([]pgdriver.DriverOption, error) { if strings.ToUpper(c.DBConfig.Type) != db.DBTypePostgres { return nil, fmt.Errorf("expected db type of %s but got %s", db.DBTypePostgres, c.DBConfig.Type) } @@ -268,13 +241,13 @@ func derivePGOptions(c *config.Config) (*pg.Options, error) { // We can rely on the pg library we're using to set // sensible defaults for everything we don't set here. - options := &pg.Options{ - Addr: fmt.Sprintf("%s:%d", c.DBConfig.Address, c.DBConfig.Port), - User: c.DBConfig.User, - Password: c.DBConfig.Password, - Database: c.DBConfig.Database, - ApplicationName: c.ApplicationName, - TLSConfig: tlsConfig, + options := []pgdriver.DriverOption{ + pgdriver.WithAddr(fmt.Sprintf("%s:%d", c.DBConfig.Address, c.DBConfig.Port)), + pgdriver.WithUser(c.DBConfig.User), + pgdriver.WithPassword(c.DBConfig.Password), + pgdriver.WithDatabase(c.DBConfig.Database), + pgdriver.WithApplicationName(c.ApplicationName), + pgdriver.WithTLSConfig(tlsConfig), } return options, nil diff --git a/internal/db/pg/relationship.go b/internal/db/pg/relationship.go index 76bd50c76..35f5a7eab 100644 --- a/internal/db/pg/relationship.go +++ b/internal/db/pg/relationship.go @@ -28,11 +28,12 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/uptrace/bun" ) type relationshipDB struct { config *config.Config - conn *pg.DB + conn *bun.DB log *logrus.Logger cancel context.CancelFunc } diff --git a/internal/db/pg/status.go b/internal/db/pg/status.go index 99790428e..312845765 100644 --- a/internal/db/pg/status.go +++ b/internal/db/pg/status.go @@ -31,11 +31,12 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/uptrace/bun" ) type statusDB struct { config *config.Config - conn *pg.DB + conn *bun.DB log *logrus.Logger cancel context.CancelFunc cache cache.Cache diff --git a/internal/db/pg/timeline.go b/internal/db/pg/timeline.go index fa8b07aab..a00e39942 100644 --- a/internal/db/pg/timeline.go +++ b/internal/db/pg/timeline.go @@ -27,11 +27,12 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/uptrace/bun" ) type timelineDB struct { config *config.Config - conn *pg.DB + conn *bun.DB log *logrus.Logger cancel context.CancelFunc } diff --git a/internal/db/pg/util.go b/internal/db/pg/util.go index 17c09b720..e6d31780a 100644 --- a/internal/db/pg/util.go +++ b/internal/db/pg/util.go @@ -3,7 +3,6 @@ package pg import ( "strings" - "github.com/go-pg/pg/v10" "github.com/superseriousbusiness/gotosocial/internal/db" ) @@ -12,9 +11,9 @@ func processErrorResponse(err error) db.Error { switch err { case nil: return nil - case pg.ErrNoRows: + case bun.ErrNoRows: return db.ErrNoEntries - case pg.ErrMultiRows: + case bun.ErrMultiRows: return db.ErrMultipleEntries default: if strings.Contains(err.Error(), "duplicate key value violates unique constraint") { diff --git a/internal/db/relationship.go b/internal/db/relationship.go index 85f64d72b..804526425 100644 --- a/internal/db/relationship.go +++ b/internal/db/relationship.go @@ -18,54 +18,58 @@ package db -import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +import ( + "context" + + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) // Relationship contains functions for getting or modifying the relationship between two accounts. type Relationship interface { // IsBlocked checks whether account 1 has a block in place against block2. // If eitherDirection is true, then the function returns true if account1 blocks account2, OR if account2 blocks account1. - IsBlocked(account1 string, account2 string, eitherDirection bool) (bool, Error) + IsBlocked(ctx context.Context, account1 string, account2 string, eitherDirection bool) (bool, Error) // GetBlock returns the block from account1 targeting account2, if it exists, or an error if it doesn't. // // Because this is slower than Blocked, only use it if you need the actual Block struct for some reason, // not if you're just checking for the existence of a block. - GetBlock(account1 string, account2 string) (*gtsmodel.Block, Error) + GetBlock(ctx context.Context, account1 string, account2 string) (*gtsmodel.Block, Error) // GetRelationship retrieves the relationship of the targetAccount to the requestingAccount. - GetRelationship(requestingAccount string, targetAccount string) (*gtsmodel.Relationship, Error) + GetRelationship(ctx context.Context, requestingAccount string, targetAccount string) (*gtsmodel.Relationship, Error) // IsFollowing returns true if sourceAccount follows target account, or an error if something goes wrong while finding out. - IsFollowing(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, Error) + IsFollowing(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, Error) // IsFollowRequested returns true if sourceAccount has requested to follow target account, or an error if something goes wrong while finding out. - IsFollowRequested(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, Error) + IsFollowRequested(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, Error) // IsMutualFollowing returns true if account1 and account2 both follow each other, or an error if something goes wrong while finding out. - IsMutualFollowing(account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, Error) + IsMutualFollowing(ctx context.Context, account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, Error) // AcceptFollowRequest moves a follow request in the database from the follow_requests table to the follows table. // In other words, it should create the follow, and delete the existing follow request. // // It will return the newly created follow for further processing. - AcceptFollowRequest(originAccountID string, targetAccountID string) (*gtsmodel.Follow, Error) + AcceptFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.Follow, Error) // GetAccountFollowRequests returns all follow requests targeting the given account. - GetAccountFollowRequests(accountID string) ([]*gtsmodel.FollowRequest, Error) + GetAccountFollowRequests(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, Error) // GetAccountFollows returns a slice of follows owned by the given accountID. - GetAccountFollows(accountID string) ([]*gtsmodel.Follow, Error) + GetAccountFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, Error) // CountAccountFollows returns the amount of accounts that the given accountID is following. // // If localOnly is set to true, then only follows from *this instance* will be returned. - CountAccountFollows(accountID string, localOnly bool) (int, Error) + CountAccountFollows(ctx context.Context, accountID string, localOnly bool) (int, Error) // GetAccountFollowedBy fetches follows that target given accountID. // // If localOnly is set to true, then only follows from *this instance* will be returned. - GetAccountFollowedBy(accountID string, localOnly bool) ([]*gtsmodel.Follow, Error) + GetAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) ([]*gtsmodel.Follow, Error) // CountAccountFollowedBy returns the amounts that the given ID is followed by. - CountAccountFollowedBy(accountID string, localOnly bool) (int, Error) + CountAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) (int, Error) } diff --git a/internal/db/status.go b/internal/db/status.go index 9d206c198..7430433c4 100644 --- a/internal/db/status.go +++ b/internal/db/status.go @@ -18,58 +18,62 @@ package db -import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +import ( + "context" + + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) // Status contains functions for getting statuses, creating statuses, and checking various other fields on statuses. type Status interface { // GetStatusByID returns one status from the database, with all rel fields populated (if possible). - GetStatusByID(id string) (*gtsmodel.Status, Error) + GetStatusByID(ctx context.Context, id string) (*gtsmodel.Status, Error) // GetStatusByURI returns one status from the database, with all rel fields populated (if possible). - GetStatusByURI(uri string) (*gtsmodel.Status, Error) + GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.Status, Error) // GetStatusByURL returns one status from the database, with all rel fields populated (if possible). - GetStatusByURL(uri string) (*gtsmodel.Status, Error) + GetStatusByURL(ctx context.Context, uri string) (*gtsmodel.Status, Error) // PutStatus stores one status in the database. - PutStatus(status *gtsmodel.Status) Error + PutStatus(ctx context.Context, status *gtsmodel.Status) Error // CountStatusReplies returns the amount of replies recorded for a status, or an error if something goes wrong - CountStatusReplies(status *gtsmodel.Status) (int, Error) + CountStatusReplies(ctx context.Context, status *gtsmodel.Status) (int, Error) // CountStatusReblogs returns the amount of reblogs/boosts recorded for a status, or an error if something goes wrong - CountStatusReblogs(status *gtsmodel.Status) (int, Error) + CountStatusReblogs(ctx context.Context, status *gtsmodel.Status) (int, Error) // CountStatusFaves returns the amount of faves/likes recorded for a status, or an error if something goes wrong - CountStatusFaves(status *gtsmodel.Status) (int, Error) + CountStatusFaves(ctx context.Context, status *gtsmodel.Status) (int, Error) // GetStatusParents gets the parent statuses of a given status. // // If onlyDirect is true, only the immediate parent will be returned. - GetStatusParents(status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, Error) + GetStatusParents(ctx context.Context, status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, Error) // GetStatusChildren gets the child statuses of a given status. // // If onlyDirect is true, only the immediate children will be returned. - GetStatusChildren(status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, Error) + GetStatusChildren(ctx context.Context, status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, Error) // IsStatusFavedBy checks if a given status has been faved by a given account ID - IsStatusFavedBy(status *gtsmodel.Status, accountID string) (bool, Error) + IsStatusFavedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, Error) // IsStatusRebloggedBy checks if a given status has been reblogged/boosted by a given account ID - IsStatusRebloggedBy(status *gtsmodel.Status, accountID string) (bool, Error) + IsStatusRebloggedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, Error) // IsStatusMutedBy checks if a given status has been muted by a given account ID - IsStatusMutedBy(status *gtsmodel.Status, accountID string) (bool, Error) + IsStatusMutedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, Error) // IsStatusBookmarkedBy checks if a given status has been bookmarked by a given account ID - IsStatusBookmarkedBy(status *gtsmodel.Status, accountID string) (bool, Error) + IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, Error) // GetStatusFaves returns a slice of faves/likes of the given status. // This slice will be unfiltered, not taking account of blocks and whatnot, so filter it before serving it back to a user. - GetStatusFaves(status *gtsmodel.Status) ([]*gtsmodel.StatusFave, Error) + GetStatusFaves(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.StatusFave, Error) // GetStatusReblogs returns a slice of statuses that are a boost/reblog of the given status. // This slice will be unfiltered, not taking account of blocks and whatnot, so filter it before serving it back to a user. - GetStatusReblogs(status *gtsmodel.Status) ([]*gtsmodel.Status, Error) + GetStatusReblogs(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.Status, Error) } diff --git a/internal/db/timeline.go b/internal/db/timeline.go index 74aa5c781..83fb3a959 100644 --- a/internal/db/timeline.go +++ b/internal/db/timeline.go @@ -18,20 +18,24 @@ package db -import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +import ( + "context" + + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) // Timeline contains functionality for retrieving home/public/faved etc timelines for an account. type Timeline interface { // GetHomeTimeline returns a slice of statuses from accounts that are followed by the given account id. // // Statuses should be returned in descending order of when they were created (newest first). - GetHomeTimeline(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, Error) + GetHomeTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, Error) // GetPublicTimeline fetches the account's PUBLIC timeline -- ie., posts and replies that are public. // It will use the given filters and try to return as many statuses as possible up to the limit. // // Statuses should be returned in descending order of when they were created (newest first). - GetPublicTimeline(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, Error) + GetPublicTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, Error) // GetFavedTimeline fetches the account's FAVED timeline -- ie., posts and replies that the requesting account has faved. // It will use the given filters and try to return as many statuses as possible up to the limit. @@ -40,5 +44,5 @@ type Timeline interface { // In other words, they'll be returned in descending order of when they were faved by the requesting user, not when they were created. // // Also note the extra return values, which correspond to the nextMaxID and prevMinID for building Link headers. - GetFavedTimeline(accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, Error) + GetFavedTimeline(ctx context.Context, accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, Error) } diff --git a/internal/federation/dereferencing/account.go b/internal/federation/dereferencing/account.go index ba6766061..50ec04f51 100644 --- a/internal/federation/dereferencing/account.go +++ b/internal/federation/dereferencing/account.go @@ -39,12 +39,12 @@ import ( // // EnrichRemoteAccount is mostly useful for calling after an account has been initially created by // the federatingDB's Create function, or during the federated authorization flow. -func (d *deref) EnrichRemoteAccount(username string, account *gtsmodel.Account) (*gtsmodel.Account, error) { - if err := d.PopulateAccountFields(account, username, false); err != nil { +func (d *deref) EnrichRemoteAccount(ctx context.Context, username string, account *gtsmodel.Account) (*gtsmodel.Account, error) { + if err := d.PopulateAccountFields(ctx, account, username, false); err != nil { return nil, err } - if err := d.db.UpdateByID(account.ID, account); err != nil { + if err := d.db.UpdateByID(ctx, account.ID, account); err != nil { return nil, fmt.Errorf("EnrichRemoteAccount: error updating account: %s", err) } @@ -60,22 +60,22 @@ func (d *deref) EnrichRemoteAccount(username string, account *gtsmodel.Account) // the remote instance again. // // SIDE EFFECTS: remote account will be stored in the database, or updated if it already exists (and refresh is true). -func (d *deref) GetRemoteAccount(username string, remoteAccountID *url.URL, refresh bool) (*gtsmodel.Account, bool, error) { +func (d *deref) GetRemoteAccount(ctx context.Context, username string, remoteAccountID *url.URL, refresh bool) (*gtsmodel.Account, bool, error) { new := true // check if we already have the account in our db - maybeAccount, err := d.db.GetAccountByURI(remoteAccountID.String()) + maybeAccount, err := d.db.GetAccountByURI(ctx, remoteAccountID.String()) if err == nil { // we've seen this account before so it's not new new = false if !refresh { // we're not being asked to refresh, but just in case we don't have the avatar/header cached yet.... - maybeAccount, err = d.EnrichRemoteAccount(username, maybeAccount) + maybeAccount, err = d.EnrichRemoteAccount(ctx, username, maybeAccount) return maybeAccount, new, err } } - accountable, err := d.dereferenceAccountable(username, remoteAccountID) + accountable, err := d.dereferenceAccountable(ctx, username, remoteAccountID) if err != nil { return nil, new, fmt.Errorf("FullyDereferenceAccount: error dereferencing accountable: %s", err) } @@ -93,22 +93,22 @@ func (d *deref) GetRemoteAccount(username string, remoteAccountID *url.URL, refr } gtsAccount.ID = ulid - if err := d.PopulateAccountFields(gtsAccount, username, refresh); err != nil { + if err := d.PopulateAccountFields(ctx, gtsAccount, username, refresh); err != nil { return nil, new, fmt.Errorf("FullyDereferenceAccount: error populating further account fields: %s", err) } - if err := d.db.Put(gtsAccount); err != nil { + if err := d.db.Put(ctx, gtsAccount); err != nil { return nil, new, fmt.Errorf("FullyDereferenceAccount: error putting new account: %s", err) } } else { // take the id we already have and do an update gtsAccount.ID = maybeAccount.ID - if err := d.PopulateAccountFields(gtsAccount, username, refresh); err != nil { + if err := d.PopulateAccountFields(ctx, gtsAccount, username, refresh); err != nil { return nil, new, fmt.Errorf("FullyDereferenceAccount: error populating further account fields: %s", err) } - if err := d.db.UpdateByID(gtsAccount.ID, gtsAccount); err != nil { + if err := d.db.UpdateByID(ctx, gtsAccount.ID, gtsAccount); err != nil { return nil, new, fmt.Errorf("FullyDereferenceAccount: error updating existing account: %s", err) } } @@ -120,15 +120,15 @@ func (d *deref) GetRemoteAccount(username string, remoteAccountID *url.URL, refr // it finds as something that an account model can be constructed out of. // // Will work for Person, Application, or Service models. -func (d *deref) dereferenceAccountable(username string, remoteAccountID *url.URL) (ap.Accountable, error) { +func (d *deref) dereferenceAccountable(ctx context.Context, username string, remoteAccountID *url.URL) (ap.Accountable, error) { d.startHandshake(username, remoteAccountID) defer d.stopHandshake(username, remoteAccountID) - if blocked, err := d.blockedDomain(remoteAccountID.Host); blocked || err != nil { + if blocked, err := d.db.IsDomainBlocked(ctx, remoteAccountID.Host); blocked || err != nil { return nil, fmt.Errorf("DereferenceAccountable: domain %s is blocked", remoteAccountID.Host) } - transport, err := d.transportController.NewTransportForUsername(username) + transport, err := d.transportController.NewTransportForUsername(ctx, username) if err != nil { return nil, fmt.Errorf("DereferenceAccountable: transport err: %s", err) } @@ -174,7 +174,7 @@ func (d *deref) dereferenceAccountable(username string, remoteAccountID *url.URL // PopulateAccountFields populates any fields on the given account that weren't populated by the initial // dereferencing. This includes things like header and avatar etc. -func (d *deref) PopulateAccountFields(account *gtsmodel.Account, requestingUsername string, refresh bool) error { +func (d *deref) PopulateAccountFields(ctx context.Context, account *gtsmodel.Account, requestingUsername string, refresh bool) error { l := d.log.WithFields(logrus.Fields{ "func": "PopulateAccountFields", "requestingUsername": requestingUsername, @@ -184,17 +184,17 @@ func (d *deref) PopulateAccountFields(account *gtsmodel.Account, requestingUsern if err != nil { return fmt.Errorf("PopulateAccountFields: couldn't parse account URI %s: %s", account.URI, err) } - if blocked, err := d.blockedDomain(accountURI.Host); blocked || err != nil { + if blocked, err := d.db.IsDomainBlocked(ctx, accountURI.Host); blocked || err != nil { return fmt.Errorf("PopulateAccountFields: domain %s is blocked", accountURI.Host) } - t, err := d.transportController.NewTransportForUsername(requestingUsername) + t, err := d.transportController.NewTransportForUsername(ctx, requestingUsername) if err != nil { return fmt.Errorf("PopulateAccountFields: error getting transport for user: %s", err) } // fetch the header and avatar - if err := d.fetchHeaderAndAviForAccount(account, t, refresh); err != nil { + if err := d.fetchHeaderAndAviForAccount(ctx, account, t, refresh); err != nil { // if this doesn't work, just skip it -- we can do it later l.Debugf("error fetching header/avi for account: %s", err) } @@ -208,12 +208,12 @@ func (d *deref) PopulateAccountFields(account *gtsmodel.Account, requestingUsern // targetAccount's AvatarMediaAttachmentID and HeaderMediaAttachmentID will be updated as necessary. // // SIDE EFFECTS: remote header and avatar will be stored in local storage. -func (d *deref) fetchHeaderAndAviForAccount(targetAccount *gtsmodel.Account, t transport.Transport, refresh bool) error { +func (d *deref) fetchHeaderAndAviForAccount(ctx context.Context, targetAccount *gtsmodel.Account, t transport.Transport, refresh bool) error { accountURI, err := url.Parse(targetAccount.URI) if err != nil { return fmt.Errorf("fetchHeaderAndAviForAccount: couldn't parse account URI %s: %s", targetAccount.URI, err) } - if blocked, err := d.blockedDomain(accountURI.Host); blocked || err != nil { + if blocked, err := d.db.IsDomainBlocked(ctx, accountURI.Host); blocked || err != nil { return fmt.Errorf("fetchHeaderAndAviForAccount: domain %s is blocked", accountURI.Host) } diff --git a/internal/federation/dereferencing/announce.go b/internal/federation/dereferencing/announce.go index 6773db425..33af74ebe 100644 --- a/internal/federation/dereferencing/announce.go +++ b/internal/federation/dereferencing/announce.go @@ -19,6 +19,7 @@ package dereferencing import ( + "context" "errors" "fmt" "net/url" @@ -26,7 +27,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) -func (d *deref) DereferenceAnnounce(announce *gtsmodel.Status, requestingUsername string) error { +func (d *deref) DereferenceAnnounce(ctx context.Context, announce *gtsmodel.Status, requestingUsername string) error { if announce.BoostOf == nil || announce.BoostOf.URI == "" { // we can't do anything unfortunately return errors.New("DereferenceAnnounce: no URI to dereference") @@ -36,16 +37,16 @@ func (d *deref) DereferenceAnnounce(announce *gtsmodel.Status, requestingUsernam if err != nil { return fmt.Errorf("DereferenceAnnounce: couldn't parse boosted status URI %s: %s", announce.BoostOf.URI, err) } - if blocked, err := d.blockedDomain(boostedStatusURI.Host); blocked || err != nil { + if blocked, err := d.db.IsDomainBlocked(ctx, boostedStatusURI.Host); blocked || err != nil { return fmt.Errorf("DereferenceAnnounce: domain %s is blocked", boostedStatusURI.Host) } // dereference statuses in the thread of the boosted status - if err := d.DereferenceThread(requestingUsername, boostedStatusURI); err != nil { + if err := d.DereferenceThread(ctx, requestingUsername, boostedStatusURI); err != nil { return fmt.Errorf("DereferenceAnnounce: error dereferencing thread of boosted status: %s", err) } - boostedStatus, _, _, err := d.GetRemoteStatus(requestingUsername, boostedStatusURI, false) + boostedStatus, _, _, err := d.GetRemoteStatus(ctx, requestingUsername, boostedStatusURI, false) if err != nil { return fmt.Errorf("DereferenceAnnounce: error dereferencing remote status with id %s: %s", announce.BoostOf.URI, err) } diff --git a/internal/federation/dereferencing/blocked.go b/internal/federation/dereferencing/blocked.go deleted file mode 100644 index c8a4c6ade..000000000 --- a/internal/federation/dereferencing/blocked.go +++ /dev/null @@ -1,41 +0,0 @@ -/* - GoToSocial - Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org - - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU Affero General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU Affero General Public License for more details. - - You should have received a copy of the GNU Affero General Public License - along with this program. If not, see . -*/ - -package dereferencing - -import ( - "github.com/superseriousbusiness/gotosocial/internal/db" - "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" -) - -func (d *deref) blockedDomain(host string) (bool, error) { - b := >smodel.DomainBlock{} - err := d.db.GetWhere([]db.Where{{Key: "domain", Value: host, CaseInsensitive: true}}, b) - if err == nil { - // block exists - return true, nil - } - - if err == db.ErrNoEntries { - // there are no entries so there's no block - return false, nil - } - - // there's an actual error - return false, err -} diff --git a/internal/federation/dereferencing/collectionpage.go b/internal/federation/dereferencing/collectionpage.go index 5feadc1ad..6f0beeaf6 100644 --- a/internal/federation/dereferencing/collectionpage.go +++ b/internal/federation/dereferencing/collectionpage.go @@ -32,12 +32,12 @@ import ( ) // DereferenceCollectionPage returns the activitystreams CollectionPage at the specified IRI, or an error if something goes wrong. -func (d *deref) DereferenceCollectionPage(username string, pageIRI *url.URL) (ap.CollectionPageable, error) { - if blocked, err := d.blockedDomain(pageIRI.Host); blocked || err != nil { +func (d *deref) DereferenceCollectionPage(ctx context.Context, username string, pageIRI *url.URL) (ap.CollectionPageable, error) { + if blocked, err := d.db.IsDomainBlocked(ctx, pageIRI.Host); blocked || err != nil { return nil, fmt.Errorf("DereferenceCollectionPage: domain %s is blocked", pageIRI.Host) } - transport, err := d.transportController.NewTransportForUsername(username) + transport, err := d.transportController.NewTransportForUsername(ctx, username) if err != nil { return nil, fmt.Errorf("DereferenceCollectionPage: error creating transport: %s", err) } diff --git a/internal/federation/dereferencing/dereferencer.go b/internal/federation/dereferencing/dereferencer.go index 03b90569a..71625ed88 100644 --- a/internal/federation/dereferencing/dereferencer.go +++ b/internal/federation/dereferencing/dereferencer.go @@ -19,6 +19,7 @@ package dereferencing import ( + "context" "net/url" "sync" @@ -34,18 +35,18 @@ import ( // Dereferencer wraps logic and functionality for doing dereferencing of remote accounts, statuses, etc, from federated instances. type Dereferencer interface { - GetRemoteAccount(username string, remoteAccountID *url.URL, refresh bool) (*gtsmodel.Account, bool, error) - EnrichRemoteAccount(username string, account *gtsmodel.Account) (*gtsmodel.Account, error) + GetRemoteAccount(ctx context.Context, username string, remoteAccountID *url.URL, refresh bool) (*gtsmodel.Account, bool, error) + EnrichRemoteAccount(ctx context.Context, username string, account *gtsmodel.Account) (*gtsmodel.Account, error) - GetRemoteStatus(username string, remoteStatusID *url.URL, refresh bool) (*gtsmodel.Status, ap.Statusable, bool, error) - EnrichRemoteStatus(username string, status *gtsmodel.Status) (*gtsmodel.Status, error) + GetRemoteStatus(ctx context.Context, username string, remoteStatusID *url.URL, refresh bool) (*gtsmodel.Status, ap.Statusable, bool, error) + EnrichRemoteStatus(ctx context.Context, username string, status *gtsmodel.Status) (*gtsmodel.Status, error) - GetRemoteInstance(username string, remoteInstanceURI *url.URL) (*gtsmodel.Instance, error) + GetRemoteInstance(ctx context.Context, username string, remoteInstanceURI *url.URL) (*gtsmodel.Instance, error) - DereferenceAnnounce(announce *gtsmodel.Status, requestingUsername string) error - DereferenceThread(username string, statusIRI *url.URL) error + DereferenceAnnounce(ctx context.Context, announce *gtsmodel.Status, requestingUsername string) error + DereferenceThread(ctx context.Context, username string, statusIRI *url.URL) error - Handshaking(username string, remoteAccountID *url.URL) bool + Handshaking(ctx context.Context, username string, remoteAccountID *url.URL) bool } type deref struct { diff --git a/internal/federation/dereferencing/handshake.go b/internal/federation/dereferencing/handshake.go index cda8eafd0..17003be84 100644 --- a/internal/federation/dereferencing/handshake.go +++ b/internal/federation/dereferencing/handshake.go @@ -18,9 +18,12 @@ package dereferencing -import "net/url" +import ( + "context" + "net/url" +) -func (d *deref) Handshaking(username string, remoteAccountID *url.URL) bool { +func (d *deref) Handshaking(ctx context.Context, username string, remoteAccountID *url.URL) bool { d.handshakeSync.Lock() defer d.handshakeSync.Unlock() diff --git a/internal/federation/dereferencing/instance.go b/internal/federation/dereferencing/instance.go index 80f626662..ec3c3f13d 100644 --- a/internal/federation/dereferencing/instance.go +++ b/internal/federation/dereferencing/instance.go @@ -26,12 +26,12 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) -func (d *deref) GetRemoteInstance(username string, remoteInstanceURI *url.URL) (*gtsmodel.Instance, error) { - if blocked, err := d.blockedDomain(remoteInstanceURI.Host); blocked || err != nil { +func (d *deref) GetRemoteInstance(ctx context.Context, username string, remoteInstanceURI *url.URL) (*gtsmodel.Instance, error) { + if blocked, err := d.db.IsDomainBlocked(ctx, remoteInstanceURI.Host); blocked || err != nil { return nil, fmt.Errorf("GetRemoteInstance: domain %s is blocked", remoteInstanceURI.Host) } - transport, err := d.transportController.NewTransportForUsername(username) + transport, err := d.transportController.NewTransportForUsername(ctx, username) if err != nil { return nil, fmt.Errorf("transport err: %s", err) } diff --git a/internal/federation/dereferencing/status.go b/internal/federation/dereferencing/status.go index 68693c021..ac42be14e 100644 --- a/internal/federation/dereferencing/status.go +++ b/internal/federation/dereferencing/status.go @@ -39,12 +39,12 @@ import ( // // EnrichRemoteStatus is mostly useful for calling after a status has been initially created by // the federatingDB's Create function, but additional dereferencing is needed on it. -func (d *deref) EnrichRemoteStatus(username string, status *gtsmodel.Status) (*gtsmodel.Status, error) { - if err := d.populateStatusFields(status, username); err != nil { +func (d *deref) EnrichRemoteStatus(ctx context.Context, username string, status *gtsmodel.Status) (*gtsmodel.Status, error) { + if err := d.populateStatusFields(ctx, status, username); err != nil { return nil, err } - if err := d.db.UpdateByID(status.ID, status); err != nil { + if err := d.db.UpdateByID(ctx, status.ID, status); err != nil { return nil, fmt.Errorf("EnrichRemoteStatus: error updating status: %s", err) } @@ -62,11 +62,11 @@ func (d *deref) EnrichRemoteStatus(username string, status *gtsmodel.Status) (*g // If a dereference was performed, then the function also returns the ap.Statusable representation for further processing. // // SIDE EFFECTS: remote status will be stored in the database, and the remote status owner will also be stored. -func (d *deref) GetRemoteStatus(username string, remoteStatusID *url.URL, refresh bool) (*gtsmodel.Status, ap.Statusable, bool, error) { +func (d *deref) GetRemoteStatus(ctx context.Context, username string, remoteStatusID *url.URL, refresh bool) (*gtsmodel.Status, ap.Statusable, bool, error) { new := true // check if we already have the status in our db - maybeStatus, err := d.db.GetStatusByURI(remoteStatusID.String()) + maybeStatus, err := d.db.GetStatusByURI(ctx, remoteStatusID.String()) if err == nil { // we've seen this status before so it's not new new = false @@ -77,7 +77,7 @@ func (d *deref) GetRemoteStatus(username string, remoteStatusID *url.URL, refres } } - statusable, err := d.dereferenceStatusable(username, remoteStatusID) + statusable, err := d.dereferenceStatusable(ctx, username, remoteStatusID) if err != nil { return nil, statusable, new, fmt.Errorf("GetRemoteStatus: error dereferencing statusable: %s", err) } @@ -88,7 +88,7 @@ func (d *deref) GetRemoteStatus(username string, remoteStatusID *url.URL, refres } // do this so we know we have the remote account of the status in the db - _, _, err = d.GetRemoteAccount(username, accountURI, false) + _, _, err = d.GetRemoteAccount(ctx, username, accountURI, false) if err != nil { return nil, statusable, new, fmt.Errorf("GetRemoteStatus: couldn't derive status author: %s", err) } @@ -105,21 +105,21 @@ func (d *deref) GetRemoteStatus(username string, remoteStatusID *url.URL, refres } gtsStatus.ID = ulid - if err := d.populateStatusFields(gtsStatus, username); err != nil { + if err := d.populateStatusFields(ctx, gtsStatus, username); err != nil { return nil, statusable, new, fmt.Errorf("GetRemoteStatus: error populating status fields: %s", err) } - if err := d.db.PutStatus(gtsStatus); err != nil { + if err := d.db.PutStatus(ctx, gtsStatus); err != nil { return nil, statusable, new, fmt.Errorf("GetRemoteStatus: error putting new status: %s", err) } } else { gtsStatus.ID = maybeStatus.ID - if err := d.populateStatusFields(gtsStatus, username); err != nil { + if err := d.populateStatusFields(ctx, gtsStatus, username); err != nil { return nil, statusable, new, fmt.Errorf("GetRemoteStatus: error populating status fields: %s", err) } - if err := d.db.UpdateByID(gtsStatus.ID, gtsStatus); err != nil { + if err := d.db.UpdateByID(ctx, gtsStatus.ID, gtsStatus); err != nil { return nil, statusable, new, fmt.Errorf("GetRemoteStatus: error updating status: %s", err) } } @@ -127,12 +127,12 @@ func (d *deref) GetRemoteStatus(username string, remoteStatusID *url.URL, refres return gtsStatus, statusable, new, nil } -func (d *deref) dereferenceStatusable(username string, remoteStatusID *url.URL) (ap.Statusable, error) { - if blocked, err := d.blockedDomain(remoteStatusID.Host); blocked || err != nil { +func (d *deref) dereferenceStatusable(ctx context.Context, username string, remoteStatusID *url.URL) (ap.Statusable, error) { + if blocked, err := d.db.IsDomainBlocked(ctx, remoteStatusID.Host); blocked || err != nil { return nil, fmt.Errorf("DereferenceStatusable: domain %s is blocked", remoteStatusID.Host) } - transport, err := d.transportController.NewTransportForUsername(username) + transport, err := d.transportController.NewTransportForUsername(ctx, username) if err != nil { return nil, fmt.Errorf("DereferenceStatusable: transport err: %s", err) } @@ -236,7 +236,7 @@ func (d *deref) dereferenceStatusable(username string, remoteStatusID *url.URL) // This function will deference all of the above, insert them in the database as necessary, // and attach them to the status. The status itself will not be added to the database yet, // that's up the caller to do. -func (d *deref) populateStatusFields(status *gtsmodel.Status, requestingUsername string) error { +func (d *deref) populateStatusFields(ctx context.Context, status *gtsmodel.Status, requestingUsername string) error { l := d.log.WithFields(logrus.Fields{ "func": "dereferenceStatusFields", "status": fmt.Sprintf("%+v", status), @@ -248,12 +248,12 @@ func (d *deref) populateStatusFields(status *gtsmodel.Status, requestingUsername if err != nil { return fmt.Errorf("DereferenceStatusFields: couldn't parse status URI %s: %s", status.URI, err) } - if blocked, err := d.blockedDomain(statusURI.Host); blocked || err != nil { + if blocked, err := d.db.IsDomainBlocked(ctx, statusURI.Host); blocked || err != nil { return fmt.Errorf("DereferenceStatusFields: domain %s is blocked", statusURI.Host) } // we can continue -- create a new transport here because we'll probably need it - t, err := d.transportController.NewTransportForUsername(requestingUsername) + t, err := d.transportController.NewTransportForUsername(ctx, requestingUsername) if err != nil { return fmt.Errorf("error creating transport: %s", err) } @@ -281,7 +281,7 @@ func (d *deref) populateStatusFields(status *gtsmodel.Status, requestingUsername // it might have been processed elsewhere so check first if it's already in the database or not maybeAttachment := >smodel.MediaAttachment{} - err := d.db.GetWhere([]db.Where{{Key: "remote_url", Value: a.RemoteURL}}, maybeAttachment) + err := d.db.GetWhere(ctx, []db.Where{{Key: "remote_url", Value: a.RemoteURL}}, maybeAttachment) if err == nil { // we already have it in the db, dereferenced, no need to do it again l.Tracef("attachment already exists with id %s", maybeAttachment.ID) @@ -302,7 +302,7 @@ func (d *deref) populateStatusFields(status *gtsmodel.Status, requestingUsername l.Debugf("dereferenced attachment: %+v", deferencedAttachment) deferencedAttachment.StatusID = status.ID deferencedAttachment.Description = a.Description - if err := d.db.Put(deferencedAttachment); err != nil { + if err := d.db.Put(ctx, deferencedAttachment); err != nil { return fmt.Errorf("error inserting dereferenced attachment with remote url %s: %s", a.RemoteURL, err) } attachmentIDs = append(attachmentIDs, deferencedAttachment.ID) @@ -338,9 +338,9 @@ func (d *deref) populateStatusFields(status *gtsmodel.Status, requestingUsername } var targetAccount *gtsmodel.Account - if a, err := d.db.GetAccountByURL(targetAccountURI.String()); err == nil { + if a, err := d.db.GetAccountByURL(ctx, targetAccountURI.String()); err == nil { targetAccount = a - } else if a, _, err := d.GetRemoteAccount(requestingUsername, targetAccountURI, false); err == nil { + } else if a, _, err := d.GetRemoteAccount(ctx, requestingUsername, targetAccountURI, false); err == nil { targetAccount = a } else { // we can't find the target account so bail @@ -369,7 +369,7 @@ func (d *deref) populateStatusFields(status *gtsmodel.Status, requestingUsername TargetAccountURL: targetAccount.URL, } - if err := d.db.Put(m); err != nil { + if err := d.db.Put(ctx, m); err != nil { return fmt.Errorf("error creating mention: %s", err) } mentionIDs = append(mentionIDs, m.ID) @@ -382,13 +382,13 @@ func (d *deref) populateStatusFields(status *gtsmodel.Status, requestingUsername if err != nil { return err } - if replyToStatus, err := d.db.GetStatusByURI(status.InReplyToURI); err == nil { + if replyToStatus, err := d.db.GetStatusByURI(ctx, status.InReplyToURI); err == nil { // we have the status status.InReplyToID = replyToStatus.ID status.InReplyTo = replyToStatus status.InReplyToAccountID = replyToStatus.AccountID status.InReplyToAccount = replyToStatus.Account - } else if replyToStatus, _, _, err := d.GetRemoteStatus(requestingUsername, statusURI, false); err == nil { + } else if replyToStatus, _, _, err := d.GetRemoteStatus(ctx, requestingUsername, statusURI, false); err == nil { // we got the status status.InReplyToID = replyToStatus.ID status.InReplyTo = replyToStatus diff --git a/internal/federation/dereferencing/thread.go b/internal/federation/dereferencing/thread.go index 2a407f923..328a1c4ee 100644 --- a/internal/federation/dereferencing/thread.go +++ b/internal/federation/dereferencing/thread.go @@ -19,6 +19,7 @@ package dereferencing import ( + "context" "fmt" "net/url" @@ -34,7 +35,7 @@ import ( // This process involves working up and down the chain of replies, and parsing through the collections of IDs // presented by remote instances as part of their replies collections, and will likely involve making several calls to // multiple different hosts. -func (d *deref) DereferenceThread(username string, statusIRI *url.URL) error { +func (d *deref) DereferenceThread(ctx context.Context, username string, statusIRI *url.URL) error { l := d.log.WithFields(logrus.Fields{ "func": "DereferenceThread", "username": username, @@ -49,18 +50,18 @@ func (d *deref) DereferenceThread(username string, statusIRI *url.URL) error { } // first make sure we have this status in our db - _, statusable, _, err := d.GetRemoteStatus(username, statusIRI, true) + _, statusable, _, err := d.GetRemoteStatus(ctx, username, statusIRI, true) if err != nil { return fmt.Errorf("DereferenceThread: error getting status with id %s: %s", statusIRI.String(), err) } // first iterate up through ancestors, dereferencing if necessary as we go - if err := d.iterateAncestors(username, *statusIRI); err != nil { + if err := d.iterateAncestors(ctx, username, *statusIRI); err != nil { return fmt.Errorf("error iterating ancestors of status %s: %s", statusIRI.String(), err) } // now iterate down through descendants, again dereferencing as we go - if err := d.iterateDescendants(username, *statusIRI, statusable); err != nil { + if err := d.iterateDescendants(ctx, username, *statusIRI, statusable); err != nil { return fmt.Errorf("error iterating descendants of status %s: %s", statusIRI.String(), err) } @@ -68,7 +69,7 @@ func (d *deref) DereferenceThread(username string, statusIRI *url.URL) error { } // iterateAncestors has the goal of reaching the oldest ancestor of a given status, and stashing all statuses along the way. -func (d *deref) iterateAncestors(username string, statusIRI url.URL) error { +func (d *deref) iterateAncestors(ctx context.Context, username string, statusIRI url.URL) error { l := d.log.WithFields(logrus.Fields{ "func": "iterateAncestors", "username": username, @@ -87,7 +88,7 @@ func (d *deref) iterateAncestors(username string, statusIRI url.URL) error { } status := >smodel.Status{} - if err := d.db.GetByID(id, status); err != nil { + if err := d.db.GetByID(ctx, id, status); err != nil { return err } @@ -99,12 +100,12 @@ func (d *deref) iterateAncestors(username string, statusIRI url.URL) error { if err != nil { return err } - return d.iterateAncestors(username, *nextIRI) + return d.iterateAncestors(ctx, username, *nextIRI) } // If we reach here, we're looking at a remote status -- make sure we have it in our db by calling GetRemoteStatus // We call it with refresh to true because we want the statusable representation to parse inReplyTo from. - status, statusable, _, err := d.GetRemoteStatus(username, &statusIRI, true) + status, statusable, _, err := d.GetRemoteStatus(ctx, username, &statusIRI, true) if err != nil { l.Debugf("error getting remote status: %s", err) return nil @@ -117,22 +118,22 @@ func (d *deref) iterateAncestors(username string, statusIRI url.URL) error { } // get the ancestor status into our database if we don't have it yet - if _, _, _, err := d.GetRemoteStatus(username, inReplyTo, false); err != nil { + if _, _, _, err := d.GetRemoteStatus(ctx, username, inReplyTo, false); err != nil { l.Debugf("error getting remote status: %s", err) return nil } // now enrich the current status, since we should have the ancestor in the db - if _, err := d.EnrichRemoteStatus(username, status); err != nil { + if _, err := d.EnrichRemoteStatus(ctx, username, status); err != nil { l.Debugf("error enriching remote status: %s", err) return nil } // now move up to the next ancestor - return d.iterateAncestors(username, *inReplyTo) + return d.iterateAncestors(ctx, username, *inReplyTo) } -func (d *deref) iterateDescendants(username string, statusIRI url.URL, statusable ap.Statusable) error { +func (d *deref) iterateDescendants(ctx context.Context, username string, statusIRI url.URL, statusable ap.Statusable) error { l := d.log.WithFields(logrus.Fields{ "func": "iterateDescendants", "username": username, @@ -182,7 +183,7 @@ func (d *deref) iterateDescendants(username string, statusIRI url.URL, statusabl pageLoop: for { l.Debugf("dereferencing page %s", currentPageIRI) - nextPage, err := d.DereferenceCollectionPage(username, currentPageIRI) + nextPage, err := d.DereferenceCollectionPage(ctx, username, currentPageIRI) if err != nil { return nil } @@ -226,10 +227,10 @@ pageLoop: foundReplies = foundReplies + 1 // get the remote statusable and put it in the db - _, statusable, new, err := d.GetRemoteStatus(username, itemURI, false) + _, statusable, new, err := d.GetRemoteStatus(ctx, username, itemURI, false) if new && err == nil && statusable != nil { // now iterate descendants of *that* status - if err := d.iterateDescendants(username, *itemURI, statusable); err != nil { + if err := d.iterateDescendants(ctx, username, *itemURI, statusable); err != nil { continue } } diff --git a/internal/processing/account.go b/internal/processing/account.go index f722c88eb..94ba596ac 100644 --- a/internal/processing/account.go +++ b/internal/processing/account.go @@ -19,51 +19,53 @@ package processing import ( + "context" + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/oauth" ) -func (p *processor) AccountCreate(authed *oauth.Auth, form *apimodel.AccountCreateRequest) (*apimodel.Token, error) { - return p.accountProcessor.Create(authed.Token, authed.Application, form) +func (p *processor) AccountCreate(ctx context.Context, authed *oauth.Auth, form *apimodel.AccountCreateRequest) (*apimodel.Token, error) { + return p.accountProcessor.Create(ctx, authed.Token, authed.Application, form) } -func (p *processor) AccountGet(authed *oauth.Auth, targetAccountID string) (*apimodel.Account, error) { - return p.accountProcessor.Get(authed.Account, targetAccountID) +func (p *processor) AccountGet(ctx context.Context, authed *oauth.Auth, targetAccountID string) (*apimodel.Account, error) { + return p.accountProcessor.Get(ctx, authed.Account, targetAccountID) } -func (p *processor) AccountUpdate(authed *oauth.Auth, form *apimodel.UpdateCredentialsRequest) (*apimodel.Account, error) { - return p.accountProcessor.Update(authed.Account, form) +func (p *processor) AccountUpdate(ctx context.Context, authed *oauth.Auth, form *apimodel.UpdateCredentialsRequest) (*apimodel.Account, error) { + return p.accountProcessor.Update(ctx, authed.Account, form) } -func (p *processor) AccountStatusesGet(authed *oauth.Auth, targetAccountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]apimodel.Status, gtserror.WithCode) { - return p.accountProcessor.StatusesGet(authed.Account, targetAccountID, limit, excludeReplies, maxID, pinnedOnly, mediaOnly) +func (p *processor) AccountStatusesGet(ctx context.Context, authed *oauth.Auth, targetAccountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]apimodel.Status, gtserror.WithCode) { + return p.accountProcessor.StatusesGet(ctx, authed.Account, targetAccountID, limit, excludeReplies, maxID, pinnedOnly, mediaOnly) } -func (p *processor) AccountFollowersGet(authed *oauth.Auth, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) { - return p.accountProcessor.FollowersGet(authed.Account, targetAccountID) +func (p *processor) AccountFollowersGet(ctx context.Context, authed *oauth.Auth, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) { + return p.accountProcessor.FollowersGet(ctx, authed.Account, targetAccountID) } -func (p *processor) AccountFollowingGet(authed *oauth.Auth, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) { - return p.accountProcessor.FollowingGet(authed.Account, targetAccountID) +func (p *processor) AccountFollowingGet(ctx context.Context, authed *oauth.Auth, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) { + return p.accountProcessor.FollowingGet(ctx, authed.Account, targetAccountID) } -func (p *processor) AccountRelationshipGet(authed *oauth.Auth, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) { - return p.accountProcessor.RelationshipGet(authed.Account, targetAccountID) +func (p *processor) AccountRelationshipGet(ctx context.Context, authed *oauth.Auth, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) { + return p.accountProcessor.RelationshipGet(ctx, authed.Account, targetAccountID) } -func (p *processor) AccountFollowCreate(authed *oauth.Auth, form *apimodel.AccountFollowRequest) (*apimodel.Relationship, gtserror.WithCode) { - return p.accountProcessor.FollowCreate(authed.Account, form) +func (p *processor) AccountFollowCreate(ctx context.Context, authed *oauth.Auth, form *apimodel.AccountFollowRequest) (*apimodel.Relationship, gtserror.WithCode) { + return p.accountProcessor.FollowCreate(ctx, authed.Account, form) } -func (p *processor) AccountFollowRemove(authed *oauth.Auth, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) { - return p.accountProcessor.FollowRemove(authed.Account, targetAccountID) +func (p *processor) AccountFollowRemove(ctx context.Context, authed *oauth.Auth, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) { + return p.accountProcessor.FollowRemove(ctx, authed.Account, targetAccountID) } -func (p *processor) AccountBlockCreate(authed *oauth.Auth, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) { - return p.accountProcessor.BlockCreate(authed.Account, targetAccountID) +func (p *processor) AccountBlockCreate(ctx context.Context, authed *oauth.Auth, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) { + return p.accountProcessor.BlockCreate(ctx, authed.Account, targetAccountID) } -func (p *processor) AccountBlockRemove(authed *oauth.Auth, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) { - return p.accountProcessor.BlockRemove(authed.Account, targetAccountID) +func (p *processor) AccountBlockRemove(ctx context.Context, authed *oauth.Auth, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) { + return p.accountProcessor.BlockRemove(ctx, authed.Account, targetAccountID) } diff --git a/internal/processing/account/account.go b/internal/processing/account/account.go index 7b8910149..81701fd7c 100644 --- a/internal/processing/account/account.go +++ b/internal/processing/account/account.go @@ -19,6 +19,7 @@ package account import ( + "context" "mime/multipart" "github.com/sirupsen/logrus" @@ -38,40 +39,40 @@ import ( // Processor wraps a bunch of functions for processing account actions. type Processor interface { // Create processes the given form for creating a new account, returning an oauth token for that account if successful. - Create(applicationToken oauth2.TokenInfo, application *gtsmodel.Application, form *apimodel.AccountCreateRequest) (*apimodel.Token, error) + Create(ctx context.Context, applicationToken oauth2.TokenInfo, application *gtsmodel.Application, form *apimodel.AccountCreateRequest) (*apimodel.Token, error) // Delete deletes an account, and all of that account's statuses, media, follows, notifications, etc etc etc. // The origin passed here should be either the ID of the account doing the delete (can be itself), or the ID of a domain block. - Delete(account *gtsmodel.Account, origin string) error + Delete(ctx context.Context, account *gtsmodel.Account, origin string) error // Get processes the given request for account information. - Get(requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Account, error) + Get(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Account, error) // Update processes the update of an account with the given form - Update(account *gtsmodel.Account, form *apimodel.UpdateCredentialsRequest) (*apimodel.Account, error) + Update(ctx context.Context, account *gtsmodel.Account, form *apimodel.UpdateCredentialsRequest) (*apimodel.Account, error) // StatusesGet fetches a number of statuses (in time descending order) from the given account, filtered by visibility for // the account given in authed. - StatusesGet(requestingAccount *gtsmodel.Account, targetAccountID string, limit int, excludeReplies bool, maxID string, pinned bool, mediaOnly bool) ([]apimodel.Status, gtserror.WithCode) + StatusesGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string, limit int, excludeReplies bool, maxID string, pinned bool, mediaOnly bool) ([]apimodel.Status, gtserror.WithCode) // FollowersGet fetches a list of the target account's followers. - FollowersGet(requestingAccount *gtsmodel.Account, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) + FollowersGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) // FollowingGet fetches a list of the accounts that target account is following. - FollowingGet(requestingAccount *gtsmodel.Account, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) + FollowingGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) // RelationshipGet returns a relationship model describing the relationship of the targetAccount to the Authed account. - RelationshipGet(requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) + RelationshipGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) // FollowCreate handles a follow request to an account, either remote or local. - FollowCreate(requestingAccount *gtsmodel.Account, form *apimodel.AccountFollowRequest) (*apimodel.Relationship, gtserror.WithCode) + FollowCreate(ctx context.Context, requestingAccount *gtsmodel.Account, form *apimodel.AccountFollowRequest) (*apimodel.Relationship, gtserror.WithCode) // FollowRemove handles the removal of a follow/follow request to an account, either remote or local. - FollowRemove(requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) + FollowRemove(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) // BlockCreate handles the creation of a block from requestingAccount to targetAccountID, either remote or local. - BlockCreate(requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) + BlockCreate(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) // BlockRemove handles the removal of a block from requestingAccount to targetAccountID, either remote or local. - BlockRemove(requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) + BlockRemove(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) // UpdateHeader does the dirty work of checking the header part of an account update form, // parsing and checking the image, and doing the necessary updates in the database for this to become // the account's new header image. - UpdateAvatar(avatar *multipart.FileHeader, accountID string) (*gtsmodel.MediaAttachment, error) + UpdateAvatar(ctx context.Context, avatar *multipart.FileHeader, accountID string) (*gtsmodel.MediaAttachment, error) // UpdateAvatar does the dirty work of checking the avatar part of an account update form, // parsing and checking the image, and doing the necessary updates in the database for this to become // the account's new avatar image. - UpdateHeader(header *multipart.FileHeader, accountID string) (*gtsmodel.MediaAttachment, error) + UpdateHeader(ctx context.Context, header *multipart.FileHeader, accountID string) (*gtsmodel.MediaAttachment, error) } type processor struct { diff --git a/internal/processing/account/create.go b/internal/processing/account/create.go index 83e76973d..1eae90d03 100644 --- a/internal/processing/account/create.go +++ b/internal/processing/account/create.go @@ -19,6 +19,7 @@ package account import ( + "context" "fmt" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" @@ -27,14 +28,14 @@ import ( "github.com/superseriousbusiness/oauth2/v4" ) -func (p *processor) Create(applicationToken oauth2.TokenInfo, application *gtsmodel.Application, form *apimodel.AccountCreateRequest) (*apimodel.Token, error) { +func (p *processor) Create(ctx context.Context, applicationToken oauth2.TokenInfo, application *gtsmodel.Application, form *apimodel.AccountCreateRequest) (*apimodel.Token, error) { l := p.log.WithField("func", "accountCreate") - if err := p.db.IsEmailAvailable(form.Email); err != nil { + if err := p.db.IsEmailAvailable(ctx, form.Email); err != nil { return nil, err } - if err := p.db.IsUsernameAvailable(form.Username); err != nil { + if err := p.db.IsUsernameAvailable(ctx, form.Username); err != nil { return nil, err } @@ -45,7 +46,7 @@ func (p *processor) Create(applicationToken oauth2.TokenInfo, application *gtsmo } l.Trace("creating new username and account") - user, err := p.db.NewSignup(form.Username, text.RemoveHTML(reason), p.config.AccountsConfig.RequireApproval, form.Email, form.Password, form.IP, form.Locale, application.ID, false, false) + user, err := p.db.NewSignup(ctx, form.Username, text.RemoveHTML(reason), p.config.AccountsConfig.RequireApproval, form.Email, form.Password, form.IP, form.Locale, application.ID, false, false) if err != nil { return nil, fmt.Errorf("error creating new signup in the database: %s", err) } diff --git a/internal/processing/account/createblock.go b/internal/processing/account/createblock.go index f10a2efa3..06f82b37d 100644 --- a/internal/processing/account/createblock.go +++ b/internal/processing/account/createblock.go @@ -19,6 +19,7 @@ package account import ( + "context" "fmt" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" @@ -29,18 +30,18 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/util" ) -func (p *processor) BlockCreate(requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) { +func (p *processor) BlockCreate(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) { // make sure the target account actually exists in our db - targetAccount, err := p.db.GetAccountByID(targetAccountID) + targetAccount, err := p.db.GetAccountByID(ctx, targetAccountID) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("BlockCreate: error getting account %s from the db: %s", targetAccountID, err)) } // if requestingAccount already blocks target account, we don't need to do anything - if blocked, err := p.db.IsBlocked(requestingAccount.ID, targetAccountID, false); err != nil { + if blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, targetAccountID, false); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("BlockCreate: error checking existence of block: %s", err)) } else if blocked { - return p.RelationshipGet(requestingAccount, targetAccountID) + return p.RelationshipGet(ctx, requestingAccount, targetAccountID) } // make the block @@ -57,18 +58,18 @@ func (p *processor) BlockCreate(requestingAccount *gtsmodel.Account, targetAccou block.URI = util.GenerateURIForBlock(requestingAccount.Username, p.config.Protocol, p.config.Host, newBlockID) // whack it in the database - if err := p.db.Put(block); err != nil { + if err := p.db.Put(ctx, block); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("BlockCreate: error creating block in db: %s", err)) } // clear any follows or follow requests from the blocked account to the target account -- this is a simple delete - if err := p.db.DeleteWhere([]db.Where{ + if err := p.db.DeleteWhere(ctx, []db.Where{ {Key: "account_id", Value: targetAccountID}, {Key: "target_account_id", Value: requestingAccount.ID}, }, >smodel.Follow{}); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("BlockCreate: error removing follow in db: %s", err)) } - if err := p.db.DeleteWhere([]db.Where{ + if err := p.db.DeleteWhere(ctx, []db.Where{ {Key: "account_id", Value: targetAccountID}, {Key: "target_account_id", Value: requestingAccount.ID}, }, >smodel.FollowRequest{}); err != nil { @@ -82,12 +83,12 @@ func (p *processor) BlockCreate(requestingAccount *gtsmodel.Account, targetAccou var frChanged bool var frURI string fr := >smodel.FollowRequest{} - if err := p.db.GetWhere([]db.Where{ + if err := p.db.GetWhere(ctx, []db.Where{ {Key: "account_id", Value: requestingAccount.ID}, {Key: "target_account_id", Value: targetAccountID}, }, fr); err == nil { frURI = fr.URI - if err := p.db.DeleteByID(fr.ID, fr); err != nil { + if err := p.db.DeleteByID(ctx, fr.ID, fr); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("BlockCreate: error removing follow request from db: %s", err)) } frChanged = true @@ -97,12 +98,12 @@ func (p *processor) BlockCreate(requestingAccount *gtsmodel.Account, targetAccou var fChanged bool var fURI string f := >smodel.Follow{} - if err := p.db.GetWhere([]db.Where{ + if err := p.db.GetWhere(ctx, []db.Where{ {Key: "account_id", Value: requestingAccount.ID}, {Key: "target_account_id", Value: targetAccountID}, }, f); err == nil { fURI = f.URI - if err := p.db.DeleteByID(f.ID, f); err != nil { + if err := p.db.DeleteByID(ctx, f.ID, f); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("BlockCreate: error removing follow from db: %s", err)) } fChanged = true @@ -147,5 +148,5 @@ func (p *processor) BlockCreate(requestingAccount *gtsmodel.Account, targetAccou TargetAccount: targetAccount, } - return p.RelationshipGet(requestingAccount, targetAccountID) + return p.RelationshipGet(ctx, requestingAccount, targetAccountID) } diff --git a/internal/processing/account/createfollow.go b/internal/processing/account/createfollow.go index 8c856a50e..a7767afea 100644 --- a/internal/processing/account/createfollow.go +++ b/internal/processing/account/createfollow.go @@ -19,6 +19,7 @@ package account import ( + "context" "fmt" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" @@ -29,16 +30,16 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/util" ) -func (p *processor) FollowCreate(requestingAccount *gtsmodel.Account, form *apimodel.AccountFollowRequest) (*apimodel.Relationship, gtserror.WithCode) { +func (p *processor) FollowCreate(ctx context.Context, requestingAccount *gtsmodel.Account, form *apimodel.AccountFollowRequest) (*apimodel.Relationship, gtserror.WithCode) { // if there's a block between the accounts we shouldn't create the request ofc - if blocked, err := p.db.IsBlocked(requestingAccount.ID, form.ID, true); err != nil { + if blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, form.ID, true); err != nil { return nil, gtserror.NewErrorInternalError(err) } else if blocked { return nil, gtserror.NewErrorNotFound(fmt.Errorf("block exists between accounts")) } // make sure the target account actually exists in our db - targetAcct, err := p.db.GetAccountByID(form.ID) + targetAcct, err := p.db.GetAccountByID(ctx, form.ID) if err != nil { if err == db.ErrNoEntries { return nil, gtserror.NewErrorNotFound(fmt.Errorf("accountfollowcreate: account %s not found in the db: %s", form.ID, err)) @@ -47,19 +48,19 @@ func (p *processor) FollowCreate(requestingAccount *gtsmodel.Account, form *apim } // check if a follow exists already - if follows, err := p.db.IsFollowing(requestingAccount, targetAcct); err != nil { + if follows, err := p.db.IsFollowing(ctx, requestingAccount, targetAcct); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("accountfollowcreate: error checking follow in db: %s", err)) } else if follows { // already follows so just return the relationship - return p.RelationshipGet(requestingAccount, form.ID) + return p.RelationshipGet(ctx, requestingAccount, form.ID) } // check if a follow request exists already - if followRequested, err := p.db.IsFollowRequested(requestingAccount, targetAcct); err != nil { + if followRequested, err := p.db.IsFollowRequested(ctx, requestingAccount, targetAcct); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("accountfollowcreate: error checking follow request in db: %s", err)) } else if followRequested { // already follow requested so just return the relationship - return p.RelationshipGet(requestingAccount, form.ID) + return p.RelationshipGet(ctx, requestingAccount, form.ID) } // make the follow request @@ -84,17 +85,17 @@ func (p *processor) FollowCreate(requestingAccount *gtsmodel.Account, form *apim } // whack it in the database - if err := p.db.Put(fr); err != nil { + if err := p.db.Put(ctx, fr); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("accountfollowcreate: error creating follow request in db: %s", err)) } // if it's a local account that's not locked we can just straight up accept the follow request if !targetAcct.Locked && targetAcct.Domain == "" { - if _, err := p.db.AcceptFollowRequest(requestingAccount.ID, form.ID); err != nil { + if _, err := p.db.AcceptFollowRequest(ctx, requestingAccount.ID, form.ID); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("accountfollowcreate: error accepting folow request for local unlocked account: %s", err)) } // return the new relationship - return p.RelationshipGet(requestingAccount, form.ID) + return p.RelationshipGet(ctx, requestingAccount, form.ID) } // otherwise we leave the follow request as it is and we handle the rest of the process asynchronously @@ -107,5 +108,5 @@ func (p *processor) FollowCreate(requestingAccount *gtsmodel.Account, form *apim } // return whatever relationship results from this - return p.RelationshipGet(requestingAccount, form.ID) + return p.RelationshipGet(ctx, requestingAccount, form.ID) } diff --git a/internal/processing/account/delete.go b/internal/processing/account/delete.go index e8840abae..a0758c846 100644 --- a/internal/processing/account/delete.go +++ b/internal/processing/account/delete.go @@ -19,6 +19,7 @@ package account import ( + "context" "time" "github.com/sirupsen/logrus" @@ -48,7 +49,7 @@ import ( // 16. Delete account's user // 17. Delete account's timeline // 18. Delete account itself -func (p *processor) Delete(account *gtsmodel.Account, origin string) error { +func (p *processor) Delete(ctx context.Context, account *gtsmodel.Account, origin string) error { l := p.log.WithFields(logrus.Fields{ "func": "Delete", "username": account.Username, @@ -61,22 +62,22 @@ func (p *processor) Delete(account *gtsmodel.Account, origin string) error { if account.Domain == "" { // see if we can get a user for this account u := >smodel.User{} - if err := p.db.GetWhere([]db.Where{{Key: "account_id", Value: account.ID}}, u); err == nil { + if err := p.db.GetWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, u); err == nil { // we got one! select all tokens with the user's ID tokens := []*oauth.Token{} - if err := p.db.GetWhere([]db.Where{{Key: "user_id", Value: u.ID}}, &tokens); err == nil { + if err := p.db.GetWhere(ctx, []db.Where{{Key: "user_id", Value: u.ID}}, &tokens); err == nil { // we have some tokens to delete for _, t := range tokens { // delete client(s) associated with this token - if err := p.db.DeleteByID(t.ClientID, &oauth.Client{}); err != nil { + if err := p.db.DeleteByID(ctx, t.ClientID, &oauth.Client{}); err != nil { l.Errorf("error deleting oauth client: %s", err) } // delete application(s) associated with this token - if err := p.db.DeleteWhere([]db.Where{{Key: "client_id", Value: t.ClientID}}, >smodel.Application{}); err != nil { + if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "client_id", Value: t.ClientID}}, >smodel.Application{}); err != nil { l.Errorf("error deleting application: %s", err) } // delete the token itself - if err := p.db.DeleteByID(t.ID, t); err != nil { + if err := p.db.DeleteByID(ctx, t.ID, t); err != nil { l.Errorf("error deleting oauth token: %s", err) } } @@ -87,12 +88,12 @@ func (p *processor) Delete(account *gtsmodel.Account, origin string) error { // 2. Delete account's blocks l.Debug("deleting account blocks") // first delete any blocks that this account created - if err := p.db.DeleteWhere([]db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.Block{}); err != nil { + if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.Block{}); err != nil { l.Errorf("error deleting blocks created by account: %s", err) } // now delete any blocks that target this account - if err := p.db.DeleteWhere([]db.Where{{Key: "target_account_id", Value: account.ID}}, &[]*gtsmodel.Block{}); err != nil { + if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "target_account_id", Value: account.ID}}, &[]*gtsmodel.Block{}); err != nil { l.Errorf("error deleting blocks targeting account: %s", err) } @@ -103,12 +104,12 @@ func (p *processor) Delete(account *gtsmodel.Account, origin string) error { // TODO: federate these if necessary l.Debug("deleting account follow requests") // first delete any follow requests that this account created - if err := p.db.DeleteWhere([]db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.FollowRequest{}); err != nil { + if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.FollowRequest{}); err != nil { l.Errorf("error deleting follow requests created by account: %s", err) } // now delete any follow requests that target this account - if err := p.db.DeleteWhere([]db.Where{{Key: "target_account_id", Value: account.ID}}, &[]*gtsmodel.FollowRequest{}); err != nil { + if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "target_account_id", Value: account.ID}}, &[]*gtsmodel.FollowRequest{}); err != nil { l.Errorf("error deleting follow requests targeting account: %s", err) } @@ -116,12 +117,12 @@ func (p *processor) Delete(account *gtsmodel.Account, origin string) error { // TODO: federate these if necessary l.Debug("deleting account follows") // first delete any follows that this account created - if err := p.db.DeleteWhere([]db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.Follow{}); err != nil { + if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.Follow{}); err != nil { l.Errorf("error deleting follows created by account: %s", err) } // now delete any follows that target this account - if err := p.db.DeleteWhere([]db.Where{{Key: "target_account_id", Value: account.ID}}, &[]*gtsmodel.Follow{}); err != nil { + if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "target_account_id", Value: account.ID}}, &[]*gtsmodel.Follow{}); err != nil { l.Errorf("error deleting follows targeting account: %s", err) } @@ -133,7 +134,7 @@ func (p *processor) Delete(account *gtsmodel.Account, origin string) error { var maxID string selectStatusesLoop: for { - statuses, err := p.db.GetAccountStatuses(account.ID, 20, false, maxID, false, false) + statuses, err := p.db.GetAccountStatuses(ctx, account.ID, 20, false, maxID, false, false) if err != nil { if err == db.ErrNoEntries { // no statuses left for this instance so we're done @@ -157,7 +158,7 @@ selectStatusesLoop: TargetAccount: account, } - if err := p.db.DeleteByID(s.ID, s); err != nil { + if err := p.db.DeleteByID(ctx, s.ID, s); err != nil { if err != db.ErrNoEntries { // actual error has occurred l.Errorf("Delete: db error status %s for account %s: %s", s.ID, account.Username, err) @@ -167,7 +168,7 @@ selectStatusesLoop: // if there are any boosts of this status, delete them as well boosts := []*gtsmodel.Status{} - if err := p.db.GetWhere([]db.Where{{Key: "boost_of_id", Value: s.ID}}, &boosts); err != nil { + if err := p.db.GetWhere(ctx, []db.Where{{Key: "boost_of_id", Value: s.ID}}, &boosts); err != nil { if err != db.ErrNoEntries { // an actual error has occurred l.Errorf("Delete: db error selecting boosts of status %s for account %s: %s", s.ID, account.Username, err) @@ -177,7 +178,7 @@ selectStatusesLoop: for _, b := range boosts { oa := >smodel.Account{} - if err := p.db.GetByID(b.AccountID, oa); err == nil { + if err := p.db.GetByID(ctx, b.AccountID, oa); err == nil { l.Debug("putting boost undo in the client api channel") p.fromClientAPI <- gtsmodel.FromClientAPI{ @@ -189,7 +190,7 @@ selectStatusesLoop: } } - if err := p.db.DeleteByID(b.ID, b); err != nil { + if err := p.db.DeleteByID(ctx, b.ID, b); err != nil { if err != db.ErrNoEntries { // actual error has occurred l.Errorf("Delete: db error deleting boost with id %s: %s", b.ID, err) @@ -208,26 +209,26 @@ selectStatusesLoop: // 10. Delete account's notifications l.Debug("deleting account notifications") - if err := p.db.DeleteWhere([]db.Where{{Key: "origin_account_id", Value: account.ID}}, &[]*gtsmodel.Notification{}); err != nil { + if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "origin_account_id", Value: account.ID}}, &[]*gtsmodel.Notification{}); err != nil { l.Errorf("error deleting notifications created by account: %s", err) } // 11. Delete account's bookmarks l.Debug("deleting account bookmarks") - if err := p.db.DeleteWhere([]db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.StatusBookmark{}); err != nil { + if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.StatusBookmark{}); err != nil { l.Errorf("error deleting bookmarks created by account: %s", err) } // 12. Delete account's faves // TODO: federate these if necessary l.Debug("deleting account faves") - if err := p.db.DeleteWhere([]db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.StatusFave{}); err != nil { + if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.StatusFave{}); err != nil { l.Errorf("error deleting faves created by account: %s", err) } // 13. Delete account's mutes l.Debug("deleting account mutes") - if err := p.db.DeleteWhere([]db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.StatusMute{}); err != nil { + if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.StatusMute{}); err != nil { l.Errorf("error deleting status mutes created by account: %s", err) } @@ -239,7 +240,7 @@ selectStatusesLoop: // 16. Delete account's user l.Debug("deleting account user") - if err := p.db.DeleteWhere([]db.Where{{Key: "account_id", Value: account.ID}}, >smodel.User{}); err != nil { + if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, >smodel.User{}); err != nil { return err } @@ -266,7 +267,7 @@ selectStatusesLoop: account.SuspendedAt = time.Now() account.SuspensionOrigin = origin - if err := p.db.UpdateByID(account.ID, account); err != nil { + if err := p.db.UpdateByID(ctx, account.ID, account); err != nil { return err } diff --git a/internal/processing/account/get.go b/internal/processing/account/get.go index 3dfc54b51..01a0fb51a 100644 --- a/internal/processing/account/get.go +++ b/internal/processing/account/get.go @@ -19,6 +19,7 @@ package account import ( + "context" "errors" "fmt" @@ -27,9 +28,9 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) -func (p *processor) Get(requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Account, error) { +func (p *processor) Get(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Account, error) { targetAccount := >smodel.Account{} - if err := p.db.GetByID(targetAccountID, targetAccount); err != nil { + if err := p.db.GetByID(ctx, targetAccountID, targetAccount); err != nil { if err == db.ErrNoEntries { return nil, errors.New("account not found") } @@ -39,7 +40,7 @@ func (p *processor) Get(requestingAccount *gtsmodel.Account, targetAccountID str var blocked bool var err error if requestingAccount != nil { - blocked, err = p.db.IsBlocked(requestingAccount.ID, targetAccountID, true) + blocked, err = p.db.IsBlocked(ctx, requestingAccount.ID, targetAccountID, true) if err != nil { return nil, fmt.Errorf("error checking account block: %s", err) } diff --git a/internal/processing/account/getfollowers.go b/internal/processing/account/getfollowers.go index 4f66b40ee..f90b0f767 100644 --- a/internal/processing/account/getfollowers.go +++ b/internal/processing/account/getfollowers.go @@ -19,6 +19,7 @@ package account import ( + "context" "fmt" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" @@ -27,15 +28,15 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) -func (p *processor) FollowersGet(requestingAccount *gtsmodel.Account, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) { - if blocked, err := p.db.IsBlocked(requestingAccount.ID, targetAccountID, true); err != nil { +func (p *processor) FollowersGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) { + if blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, targetAccountID, true); err != nil { return nil, gtserror.NewErrorInternalError(err) } else if blocked { return nil, gtserror.NewErrorNotFound(fmt.Errorf("block exists between accounts")) } accounts := []apimodel.Account{} - follows, err := p.db.GetAccountFollowedBy(targetAccountID, false) + follows, err := p.db.GetAccountFollowedBy(ctx, targetAccountID, false) if err != nil { if err == db.ErrNoEntries { return accounts, nil @@ -44,7 +45,7 @@ func (p *processor) FollowersGet(requestingAccount *gtsmodel.Account, targetAcco } for _, f := range follows { - blocked, err := p.db.IsBlocked(requestingAccount.ID, f.AccountID, true) + blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, f.AccountID, true) if err != nil { return nil, gtserror.NewErrorInternalError(err) } @@ -53,7 +54,7 @@ func (p *processor) FollowersGet(requestingAccount *gtsmodel.Account, targetAcco } if f.Account == nil { - a, err := p.db.GetAccountByID(f.AccountID) + a, err := p.db.GetAccountByID(ctx, f.AccountID) if err != nil { if err == db.ErrNoEntries { continue diff --git a/internal/processing/account/getfollowing.go b/internal/processing/account/getfollowing.go index c7fb426f9..4082e89c1 100644 --- a/internal/processing/account/getfollowing.go +++ b/internal/processing/account/getfollowing.go @@ -19,6 +19,7 @@ package account import ( + "context" "fmt" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" @@ -27,15 +28,15 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) -func (p *processor) FollowingGet(requestingAccount *gtsmodel.Account, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) { - if blocked, err := p.db.IsBlocked(requestingAccount.ID, targetAccountID, true); err != nil { +func (p *processor) FollowingGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) { + if blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, targetAccountID, true); err != nil { return nil, gtserror.NewErrorInternalError(err) } else if blocked { return nil, gtserror.NewErrorNotFound(fmt.Errorf("block exists between accounts")) } accounts := []apimodel.Account{} - follows, err := p.db.GetAccountFollows(targetAccountID) + follows, err := p.db.GetAccountFollows(ctx, targetAccountID) if err != nil { if err == db.ErrNoEntries { return accounts, nil @@ -44,7 +45,7 @@ func (p *processor) FollowingGet(requestingAccount *gtsmodel.Account, targetAcco } for _, f := range follows { - blocked, err := p.db.IsBlocked(requestingAccount.ID, f.AccountID, true) + blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, f.AccountID, true) if err != nil { return nil, gtserror.NewErrorInternalError(err) } @@ -53,7 +54,7 @@ func (p *processor) FollowingGet(requestingAccount *gtsmodel.Account, targetAcco } if f.TargetAccount == nil { - a, err := p.db.GetAccountByID(f.TargetAccountID) + a, err := p.db.GetAccountByID(ctx, f.TargetAccountID) if err != nil { if err == db.ErrNoEntries { continue diff --git a/internal/processing/account/getrelationship.go b/internal/processing/account/getrelationship.go index a0a93a4c2..615f30d5a 100644 --- a/internal/processing/account/getrelationship.go +++ b/internal/processing/account/getrelationship.go @@ -19,6 +19,7 @@ package account import ( + "context" "errors" "fmt" @@ -27,12 +28,12 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) -func (p *processor) RelationshipGet(requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) { +func (p *processor) RelationshipGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) { if requestingAccount == nil { return nil, gtserror.NewErrorForbidden(errors.New("not authed")) } - gtsR, err := p.db.GetRelationship(requestingAccount.ID, targetAccountID) + gtsR, err := p.db.GetRelationship(ctx, requestingAccount.ID, targetAccountID) if err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("error getting relationship: %s", err)) } diff --git a/internal/processing/account/getstatuses.go b/internal/processing/account/getstatuses.go index dc21e7006..dae9fee33 100644 --- a/internal/processing/account/getstatuses.go +++ b/internal/processing/account/getstatuses.go @@ -19,6 +19,7 @@ package account import ( + "context" "fmt" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" @@ -27,8 +28,8 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) -func (p *processor) StatusesGet(requestingAccount *gtsmodel.Account, targetAccountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]apimodel.Status, gtserror.WithCode) { - if blocked, err := p.db.IsBlocked(requestingAccount.ID, targetAccountID, true); err != nil { +func (p *processor) StatusesGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]apimodel.Status, gtserror.WithCode) { + if blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, targetAccountID, true); err != nil { return nil, gtserror.NewErrorInternalError(err) } else if blocked { return nil, gtserror.NewErrorNotFound(fmt.Errorf("block exists between accounts")) @@ -36,7 +37,7 @@ func (p *processor) StatusesGet(requestingAccount *gtsmodel.Account, targetAccou apiStatuses := []apimodel.Status{} - statuses, err := p.db.GetAccountStatuses(targetAccountID, limit, excludeReplies, maxID, pinnedOnly, mediaOnly) + statuses, err := p.db.GetAccountStatuses(ctx, targetAccountID, limit, excludeReplies, maxID, pinnedOnly, mediaOnly) if err != nil { if err == db.ErrNoEntries { return apiStatuses, nil diff --git a/internal/processing/account/removeblock.go b/internal/processing/account/removeblock.go index 7c1f2bc17..7e3d78076 100644 --- a/internal/processing/account/removeblock.go +++ b/internal/processing/account/removeblock.go @@ -19,6 +19,7 @@ package account import ( + "context" "fmt" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" @@ -27,9 +28,9 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) -func (p *processor) BlockRemove(requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) { +func (p *processor) BlockRemove(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) { // make sure the target account actually exists in our db - targetAccount, err := p.db.GetAccountByID(targetAccountID) + targetAccount, err := p.db.GetAccountByID(ctx, targetAccountID) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("BlockCreate: error getting account %s from the db: %s", targetAccountID, err)) } @@ -37,13 +38,13 @@ func (p *processor) BlockRemove(requestingAccount *gtsmodel.Account, targetAccou // check if a block exists, and remove it if it does (storing the URI for later) var blockChanged bool block := >smodel.Block{} - if err := p.db.GetWhere([]db.Where{ + if err := p.db.GetWhere(ctx, []db.Where{ {Key: "account_id", Value: requestingAccount.ID}, {Key: "target_account_id", Value: targetAccountID}, }, block); err == nil { block.Account = requestingAccount block.TargetAccount = targetAccount - if err := p.db.DeleteByID(block.ID, >smodel.Block{}); err != nil { + if err := p.db.DeleteByID(ctx, block.ID, >smodel.Block{}); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("BlockRemove: error removing block from db: %s", err)) } blockChanged = true @@ -61,5 +62,5 @@ func (p *processor) BlockRemove(requestingAccount *gtsmodel.Account, targetAccou } // return whatever relationship results from all this - return p.RelationshipGet(requestingAccount, targetAccountID) + return p.RelationshipGet(ctx, requestingAccount, targetAccountID) } diff --git a/internal/processing/account/removefollow.go b/internal/processing/account/removefollow.go index 6646d694e..c271de79f 100644 --- a/internal/processing/account/removefollow.go +++ b/internal/processing/account/removefollow.go @@ -19,6 +19,7 @@ package account import ( + "context" "fmt" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" @@ -27,9 +28,9 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) -func (p *processor) FollowRemove(requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) { +func (p *processor) FollowRemove(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) { // if there's a block between the accounts we shouldn't do anything - blocked, err := p.db.IsBlocked(requestingAccount.ID, targetAccountID, true) + blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, targetAccountID, true) if err != nil { return nil, gtserror.NewErrorInternalError(err) } @@ -39,7 +40,7 @@ func (p *processor) FollowRemove(requestingAccount *gtsmodel.Account, targetAcco // make sure the target account actually exists in our db targetAcct := >smodel.Account{} - if err := p.db.GetByID(targetAccountID, targetAcct); err != nil { + if err := p.db.GetByID(ctx, targetAccountID, targetAcct); err != nil { if err == db.ErrNoEntries { return nil, gtserror.NewErrorNotFound(fmt.Errorf("AccountFollowRemove: account %s not found in the db: %s", targetAccountID, err)) } @@ -49,12 +50,12 @@ func (p *processor) FollowRemove(requestingAccount *gtsmodel.Account, targetAcco var frChanged bool var frURI string fr := >smodel.FollowRequest{} - if err := p.db.GetWhere([]db.Where{ + if err := p.db.GetWhere(ctx, []db.Where{ {Key: "account_id", Value: requestingAccount.ID}, {Key: "target_account_id", Value: targetAccountID}, }, fr); err == nil { frURI = fr.URI - if err := p.db.DeleteByID(fr.ID, fr); err != nil { + if err := p.db.DeleteByID(ctx, fr.ID, fr); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("AccountFollowRemove: error removing follow request from db: %s", err)) } frChanged = true @@ -64,12 +65,12 @@ func (p *processor) FollowRemove(requestingAccount *gtsmodel.Account, targetAcco var fChanged bool var fURI string f := >smodel.Follow{} - if err := p.db.GetWhere([]db.Where{ + if err := p.db.GetWhere(ctx, []db.Where{ {Key: "account_id", Value: requestingAccount.ID}, {Key: "target_account_id", Value: targetAccountID}, }, f); err == nil { fURI = f.URI - if err := p.db.DeleteByID(f.ID, f); err != nil { + if err := p.db.DeleteByID(ctx, f.ID, f); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("AccountFollowRemove: error removing follow from db: %s", err)) } fChanged = true @@ -106,5 +107,5 @@ func (p *processor) FollowRemove(requestingAccount *gtsmodel.Account, targetAcco } // return whatever relationship results from all this - return p.RelationshipGet(requestingAccount, targetAccountID) + return p.RelationshipGet(ctx, requestingAccount, targetAccountID) } diff --git a/internal/processing/account/update.go b/internal/processing/account/update.go index df842bacd..46aa10ce1 100644 --- a/internal/processing/account/update.go +++ b/internal/processing/account/update.go @@ -20,6 +20,7 @@ package account import ( "bytes" + "context" "errors" "fmt" "io" @@ -32,17 +33,17 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/util" ) -func (p *processor) Update(account *gtsmodel.Account, form *apimodel.UpdateCredentialsRequest) (*apimodel.Account, error) { +func (p *processor) Update(ctx context.Context, account *gtsmodel.Account, form *apimodel.UpdateCredentialsRequest) (*apimodel.Account, error) { l := p.log.WithField("func", "AccountUpdate") if form.Discoverable != nil { - if err := p.db.UpdateOneByID(account.ID, "discoverable", *form.Discoverable, >smodel.Account{}); err != nil { + if err := p.db.UpdateOneByID(ctx, account.ID, "discoverable", *form.Discoverable, >smodel.Account{}); err != nil { return nil, fmt.Errorf("error updating discoverable: %s", err) } } if form.Bot != nil { - if err := p.db.UpdateOneByID(account.ID, "bot", *form.Bot, >smodel.Account{}); err != nil { + if err := p.db.UpdateOneByID(ctx, account.ID, "bot", *form.Bot, >smodel.Account{}); err != nil { return nil, fmt.Errorf("error updating bot: %s", err) } } @@ -52,7 +53,7 @@ func (p *processor) Update(account *gtsmodel.Account, form *apimodel.UpdateCrede return nil, err } displayName := text.RemoveHTML(*form.DisplayName) // no html allowed in display name - if err := p.db.UpdateOneByID(account.ID, "display_name", displayName, >smodel.Account{}); err != nil { + if err := p.db.UpdateOneByID(ctx, account.ID, "display_name", displayName, >smodel.Account{}); err != nil { return nil, err } } @@ -62,13 +63,13 @@ func (p *processor) Update(account *gtsmodel.Account, form *apimodel.UpdateCrede return nil, err } note := text.SanitizeHTML(*form.Note) // html OK in note but sanitize it - if err := p.db.UpdateOneByID(account.ID, "note", note, >smodel.Account{}); err != nil { + if err := p.db.UpdateOneByID(ctx, account.ID, "note", note, >smodel.Account{}); err != nil { return nil, err } } if form.Avatar != nil && form.Avatar.Size != 0 { - avatarInfo, err := p.UpdateAvatar(form.Avatar, account.ID) + avatarInfo, err := p.UpdateAvatar(ctx, form.Avatar, account.ID) if err != nil { return nil, err } @@ -76,7 +77,7 @@ func (p *processor) Update(account *gtsmodel.Account, form *apimodel.UpdateCrede } if form.Header != nil && form.Header.Size != 0 { - headerInfo, err := p.UpdateHeader(form.Header, account.ID) + headerInfo, err := p.UpdateHeader(ctx, form.Header, account.ID) if err != nil { return nil, err } @@ -84,7 +85,7 @@ func (p *processor) Update(account *gtsmodel.Account, form *apimodel.UpdateCrede } if form.Locked != nil { - if err := p.db.UpdateOneByID(account.ID, "locked", *form.Locked, >smodel.Account{}); err != nil { + if err := p.db.UpdateOneByID(ctx, account.ID, "locked", *form.Locked, >smodel.Account{}); err != nil { return nil, err } } @@ -94,13 +95,13 @@ func (p *processor) Update(account *gtsmodel.Account, form *apimodel.UpdateCrede if err := util.ValidateLanguage(*form.Source.Language); err != nil { return nil, err } - if err := p.db.UpdateOneByID(account.ID, "language", *form.Source.Language, >smodel.Account{}); err != nil { + if err := p.db.UpdateOneByID(ctx, account.ID, "language", *form.Source.Language, >smodel.Account{}); err != nil { return nil, err } } if form.Source.Sensitive != nil { - if err := p.db.UpdateOneByID(account.ID, "locked", *form.Locked, >smodel.Account{}); err != nil { + if err := p.db.UpdateOneByID(ctx, account.ID, "locked", *form.Locked, >smodel.Account{}); err != nil { return nil, err } } @@ -109,7 +110,7 @@ func (p *processor) Update(account *gtsmodel.Account, form *apimodel.UpdateCrede if err := util.ValidatePrivacy(*form.Source.Privacy); err != nil { return nil, err } - if err := p.db.UpdateOneByID(account.ID, "privacy", *form.Source.Privacy, >smodel.Account{}); err != nil { + if err := p.db.UpdateOneByID(ctx, account.ID, "privacy", *form.Source.Privacy, >smodel.Account{}); err != nil { return nil, err } } @@ -117,7 +118,7 @@ func (p *processor) Update(account *gtsmodel.Account, form *apimodel.UpdateCrede // fetch the account with all updated values set updatedAccount := >smodel.Account{} - if err := p.db.GetByID(account.ID, updatedAccount); err != nil { + if err := p.db.GetByID(ctx, account.ID, updatedAccount); err != nil { return nil, fmt.Errorf("could not fetch updated account %s: %s", account.ID, err) } @@ -138,7 +139,7 @@ func (p *processor) Update(account *gtsmodel.Account, form *apimodel.UpdateCrede // UpdateAvatar does the dirty work of checking the avatar part of an account update form, // parsing and checking the image, and doing the necessary updates in the database for this to become // the account's new avatar image. -func (p *processor) UpdateAvatar(avatar *multipart.FileHeader, accountID string) (*gtsmodel.MediaAttachment, error) { +func (p *processor) UpdateAvatar(ctx context.Context, avatar *multipart.FileHeader, accountID string) (*gtsmodel.MediaAttachment, error) { var err error if int(avatar.Size) > p.config.MediaConfig.MaxImageSize { err = fmt.Errorf("avatar with size %d exceeded max image size of %d bytes", avatar.Size, p.config.MediaConfig.MaxImageSize) @@ -171,7 +172,7 @@ func (p *processor) UpdateAvatar(avatar *multipart.FileHeader, accountID string) // UpdateHeader does the dirty work of checking the header part of an account update form, // parsing and checking the image, and doing the necessary updates in the database for this to become // the account's new header image. -func (p *processor) UpdateHeader(header *multipart.FileHeader, accountID string) (*gtsmodel.MediaAttachment, error) { +func (p *processor) UpdateHeader(ctx context.Context, header *multipart.FileHeader, accountID string) (*gtsmodel.MediaAttachment, error) { var err error if int(header.Size) > p.config.MediaConfig.MaxImageSize { err = fmt.Errorf("header with size %d exceeded max image size of %d bytes", header.Size, p.config.MediaConfig.MaxImageSize) diff --git a/internal/processing/admin.go b/internal/processing/admin.go index 9a38f5ec1..48faee986 100644 --- a/internal/processing/admin.go +++ b/internal/processing/admin.go @@ -19,31 +19,33 @@ package processing import ( + "context" + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/oauth" ) -func (p *processor) AdminEmojiCreate(authed *oauth.Auth, form *apimodel.EmojiCreateRequest) (*apimodel.Emoji, error) { - return p.adminProcessor.EmojiCreate(authed.Account, authed.User, form) +func (p *processor) AdminEmojiCreate(ctx context.Context, authed *oauth.Auth, form *apimodel.EmojiCreateRequest) (*apimodel.Emoji, error) { + return p.adminProcessor.EmojiCreate(ctx, authed.Account, authed.User, form) } -func (p *processor) AdminDomainBlockCreate(authed *oauth.Auth, form *apimodel.DomainBlockCreateRequest) (*apimodel.DomainBlock, gtserror.WithCode) { - return p.adminProcessor.DomainBlockCreate(authed.Account, form.Domain, form.Obfuscate, form.PublicComment, form.PrivateComment, "") +func (p *processor) AdminDomainBlockCreate(ctx context.Context, authed *oauth.Auth, form *apimodel.DomainBlockCreateRequest) (*apimodel.DomainBlock, gtserror.WithCode) { + return p.adminProcessor.DomainBlockCreate(ctx, authed.Account, form.Domain, form.Obfuscate, form.PublicComment, form.PrivateComment, "") } -func (p *processor) AdminDomainBlocksImport(authed *oauth.Auth, form *apimodel.DomainBlockCreateRequest) ([]*apimodel.DomainBlock, gtserror.WithCode) { - return p.adminProcessor.DomainBlocksImport(authed.Account, form.Domains) +func (p *processor) AdminDomainBlocksImport(ctx context.Context, authed *oauth.Auth, form *apimodel.DomainBlockCreateRequest) ([]*apimodel.DomainBlock, gtserror.WithCode) { + return p.adminProcessor.DomainBlocksImport(ctx, authed.Account, form.Domains) } -func (p *processor) AdminDomainBlocksGet(authed *oauth.Auth, export bool) ([]*apimodel.DomainBlock, gtserror.WithCode) { - return p.adminProcessor.DomainBlocksGet(authed.Account, export) +func (p *processor) AdminDomainBlocksGet(ctx context.Context, authed *oauth.Auth, export bool) ([]*apimodel.DomainBlock, gtserror.WithCode) { + return p.adminProcessor.DomainBlocksGet(ctx, authed.Account, export) } -func (p *processor) AdminDomainBlockGet(authed *oauth.Auth, id string, export bool) (*apimodel.DomainBlock, gtserror.WithCode) { - return p.adminProcessor.DomainBlockGet(authed.Account, id, export) +func (p *processor) AdminDomainBlockGet(ctx context.Context, authed *oauth.Auth, id string, export bool) (*apimodel.DomainBlock, gtserror.WithCode) { + return p.adminProcessor.DomainBlockGet(ctx, authed.Account, id, export) } -func (p *processor) AdminDomainBlockDelete(authed *oauth.Auth, id string) (*apimodel.DomainBlock, gtserror.WithCode) { - return p.adminProcessor.DomainBlockDelete(authed.Account, id) +func (p *processor) AdminDomainBlockDelete(ctx context.Context, authed *oauth.Auth, id string) (*apimodel.DomainBlock, gtserror.WithCode) { + return p.adminProcessor.DomainBlockDelete(ctx, authed.Account, id) } diff --git a/internal/processing/admin/admin.go b/internal/processing/admin/admin.go index fd63d8a10..de288811b 100644 --- a/internal/processing/admin/admin.go +++ b/internal/processing/admin/admin.go @@ -19,6 +19,7 @@ package admin import ( + "context" "mime/multipart" "github.com/sirupsen/logrus" @@ -33,12 +34,12 @@ import ( // Processor wraps a bunch of functions for processing admin actions. type Processor interface { - DomainBlockCreate(account *gtsmodel.Account, domain string, obfuscate bool, publicComment string, privateComment string, subscriptionID string) (*apimodel.DomainBlock, gtserror.WithCode) - DomainBlocksImport(account *gtsmodel.Account, domains *multipart.FileHeader) ([]*apimodel.DomainBlock, gtserror.WithCode) - DomainBlocksGet(account *gtsmodel.Account, export bool) ([]*apimodel.DomainBlock, gtserror.WithCode) - DomainBlockGet(account *gtsmodel.Account, id string, export bool) (*apimodel.DomainBlock, gtserror.WithCode) - DomainBlockDelete(account *gtsmodel.Account, id string) (*apimodel.DomainBlock, gtserror.WithCode) - EmojiCreate(account *gtsmodel.Account, user *gtsmodel.User, form *apimodel.EmojiCreateRequest) (*apimodel.Emoji, error) + DomainBlockCreate(ctx context.Context, account *gtsmodel.Account, domain string, obfuscate bool, publicComment string, privateComment string, subscriptionID string) (*apimodel.DomainBlock, gtserror.WithCode) + DomainBlocksImport(ctx context.Context, account *gtsmodel.Account, domains *multipart.FileHeader) ([]*apimodel.DomainBlock, gtserror.WithCode) + DomainBlocksGet(ctx context.Context, account *gtsmodel.Account, export bool) ([]*apimodel.DomainBlock, gtserror.WithCode) + DomainBlockGet(ctx context.Context, account *gtsmodel.Account, id string, export bool) (*apimodel.DomainBlock, gtserror.WithCode) + DomainBlockDelete(ctx context.Context, account *gtsmodel.Account, id string) (*apimodel.DomainBlock, gtserror.WithCode) + EmojiCreate(ctx context.Context, account *gtsmodel.Account, user *gtsmodel.User, form *apimodel.EmojiCreateRequest) (*apimodel.Emoji, error) } type processor struct { diff --git a/internal/processing/admin/createdomainblock.go b/internal/processing/admin/createdomainblock.go index 624f632dc..d845fe4b9 100644 --- a/internal/processing/admin/createdomainblock.go +++ b/internal/processing/admin/createdomainblock.go @@ -19,6 +19,7 @@ package admin import ( + "context" "fmt" "time" @@ -31,10 +32,10 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/text" ) -func (p *processor) DomainBlockCreate(account *gtsmodel.Account, domain string, obfuscate bool, publicComment string, privateComment string, subscriptionID string) (*apimodel.DomainBlock, gtserror.WithCode) { +func (p *processor) DomainBlockCreate(ctx context.Context, account *gtsmodel.Account, domain string, obfuscate bool, publicComment string, privateComment string, subscriptionID string) (*apimodel.DomainBlock, gtserror.WithCode) { // first check if we already have a block -- if err == nil we already had a block so we can skip a whole lot of work domainBlock := >smodel.DomainBlock{} - err := p.db.GetWhere([]db.Where{{Key: "domain", Value: domain, CaseInsensitive: true}}, domainBlock) + err := p.db.GetWhere(ctx, []db.Where{{Key: "domain", Value: domain, CaseInsensitive: true}}, domainBlock) if err != nil { if err != db.ErrNoEntries { // something went wrong in the DB @@ -59,7 +60,7 @@ func (p *processor) DomainBlockCreate(account *gtsmodel.Account, domain string, } // put the new block in the database - if err := p.db.Put(domainBlock); err != nil { + if err := p.db.Put(ctx, domainBlock); err != nil { if err != db.ErrNoEntries { // there's a real error creating the block return nil, gtserror.NewErrorInternalError(fmt.Errorf("DomainBlockCreate: db error putting new domain block %s: %s", domain, err)) @@ -67,7 +68,7 @@ func (p *processor) DomainBlockCreate(account *gtsmodel.Account, domain string, } // process the side effects of the domain block asynchronously since it might take a while - go p.initiateDomainBlockSideEffects(account, domainBlock) // TODO: add this to a queuing system so it can retry/resume + go p.initiateDomainBlockSideEffects(ctx, account, domainBlock) // TODO: add this to a queuing system so it can retry/resume } mastoDomainBlock, err := p.tc.DomainBlockToMasto(domainBlock, false) @@ -83,7 +84,7 @@ func (p *processor) DomainBlockCreate(account *gtsmodel.Account, domain string, // 1. Strip most info away from the instance entry for the domain. // 2. Delete the instance account for that instance if it exists. // 3. Select all accounts from this instance and pass them through the delete functionality of the processor. -func (p *processor) initiateDomainBlockSideEffects(account *gtsmodel.Account, block *gtsmodel.DomainBlock) { +func (p *processor) initiateDomainBlockSideEffects(ctx context.Context, account *gtsmodel.Account, block *gtsmodel.DomainBlock) { l := p.log.WithFields(logrus.Fields{ "func": "domainBlockProcessSideEffects", "domain": block.Domain, @@ -93,7 +94,7 @@ func (p *processor) initiateDomainBlockSideEffects(account *gtsmodel.Account, bl // if we have an instance entry for this domain, update it with the new block ID and clear all fields instance := >smodel.Instance{} - if err := p.db.GetWhere([]db.Where{{Key: "domain", Value: block.Domain, CaseInsensitive: true}}, instance); err == nil { + if err := p.db.GetWhere(ctx, []db.Where{{Key: "domain", Value: block.Domain, CaseInsensitive: true}}, instance); err == nil { instance.Title = "" instance.UpdatedAt = time.Now() instance.SuspendedAt = time.Now() @@ -105,14 +106,14 @@ func (p *processor) initiateDomainBlockSideEffects(account *gtsmodel.Account, bl instance.ContactAccountUsername = "" instance.ContactAccountID = "" instance.Version = "" - if err := p.db.UpdateByID(instance.ID, instance); err != nil { + if err := p.db.UpdateByID(ctx, instance.ID, instance); err != nil { l.Errorf("domainBlockProcessSideEffects: db error updating instance: %s", err) } l.Debug("domainBlockProcessSideEffects: instance entry updated") } // if we have an instance account for this instance, delete it - if err := p.db.DeleteWhere([]db.Where{{Key: "username", Value: block.Domain, CaseInsensitive: true}}, >smodel.Account{}); err != nil { + if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "username", Value: block.Domain, CaseInsensitive: true}}, >smodel.Account{}); err != nil { l.Errorf("domainBlockProcessSideEffects: db error removing instance account: %s", err) } @@ -123,7 +124,7 @@ func (p *processor) initiateDomainBlockSideEffects(account *gtsmodel.Account, bl selectAccountsLoop: for { - accounts, err := p.db.GetInstanceAccounts(block.Domain, maxID, limit) + accounts, err := p.db.GetInstanceAccounts(ctx, block.Domain, maxID, limit) if err != nil { if err == db.ErrNoEntries { // no accounts left for this instance so we're done diff --git a/internal/processing/admin/deletedomainblock.go b/internal/processing/admin/deletedomainblock.go index edb0a58f9..c57554772 100644 --- a/internal/processing/admin/deletedomainblock.go +++ b/internal/processing/admin/deletedomainblock.go @@ -19,6 +19,7 @@ package admin import ( + "context" "fmt" "time" @@ -28,10 +29,10 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) -func (p *processor) DomainBlockDelete(account *gtsmodel.Account, id string) (*apimodel.DomainBlock, gtserror.WithCode) { +func (p *processor) DomainBlockDelete(ctx context.Context, account *gtsmodel.Account, id string) (*apimodel.DomainBlock, gtserror.WithCode) { domainBlock := >smodel.DomainBlock{} - if err := p.db.GetByID(id, domainBlock); err != nil { + if err := p.db.GetByID(ctx, id, domainBlock); err != nil { if err != db.ErrNoEntries { // something has gone really wrong return nil, gtserror.NewErrorInternalError(err) @@ -47,33 +48,33 @@ func (p *processor) DomainBlockDelete(account *gtsmodel.Account, id string) (*ap } // delete the domain block - if err := p.db.DeleteByID(id, domainBlock); err != nil { + if err := p.db.DeleteByID(ctx, id, domainBlock); err != nil { return nil, gtserror.NewErrorInternalError(err) } // remove the domain block reference from the instance, if we have an entry for it i := >smodel.Instance{} - if err := p.db.GetWhere([]db.Where{ + if err := p.db.GetWhere(ctx, []db.Where{ {Key: "domain", Value: domainBlock.Domain, CaseInsensitive: true}, {Key: "domain_block_id", Value: id}, }, i); err == nil { i.SuspendedAt = time.Time{} i.DomainBlockID = "" - if err := p.db.UpdateByID(i.ID, i); err != nil { + if err := p.db.UpdateByID(ctx, i.ID, i); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("couldn't update database entry for instance %s: %s", domainBlock.Domain, err)) } } // unsuspend all accounts whose suspension origin was this domain block // 1. remove the 'suspended_at' entry from their accounts - if err := p.db.UpdateWhere([]db.Where{ + if err := p.db.UpdateWhere(ctx, []db.Where{ {Key: "suspension_origin", Value: domainBlock.ID}, }, "suspended_at", nil, &[]*gtsmodel.Account{}); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("database error removing suspended_at from accounts: %s", err)) } // 2. remove the 'suspension_origin' entry from their accounts - if err := p.db.UpdateWhere([]db.Where{ + if err := p.db.UpdateWhere(ctx, []db.Where{ {Key: "suspension_origin", Value: domainBlock.ID}, }, "suspension_origin", nil, &[]*gtsmodel.Account{}); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("database error removing suspension_origin from accounts: %s", err)) diff --git a/internal/processing/admin/emoji.go b/internal/processing/admin/emoji.go index f19e173b5..23ed1b6c2 100644 --- a/internal/processing/admin/emoji.go +++ b/internal/processing/admin/emoji.go @@ -20,6 +20,7 @@ package admin import ( "bytes" + "context" "errors" "fmt" "io" @@ -29,7 +30,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/id" ) -func (p *processor) EmojiCreate(account *gtsmodel.Account, user *gtsmodel.User, form *apimodel.EmojiCreateRequest) (*apimodel.Emoji, error) { +func (p *processor) EmojiCreate(ctx context.Context, account *gtsmodel.Account, user *gtsmodel.User, form *apimodel.EmojiCreateRequest) (*apimodel.Emoji, error) { if user.Admin { return nil, fmt.Errorf("user %s not an admin", user.ID) } @@ -65,7 +66,7 @@ func (p *processor) EmojiCreate(account *gtsmodel.Account, user *gtsmodel.User, return nil, fmt.Errorf("error converting emoji to mastotype: %s", err) } - if err := p.db.Put(emoji); err != nil { + if err := p.db.Put(ctx, emoji); err != nil { return nil, fmt.Errorf("database error while processing emoji: %s", err) } diff --git a/internal/processing/admin/getdomainblock.go b/internal/processing/admin/getdomainblock.go index f74010627..676aff1ae 100644 --- a/internal/processing/admin/getdomainblock.go +++ b/internal/processing/admin/getdomainblock.go @@ -19,6 +19,7 @@ package admin import ( + "context" "fmt" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" @@ -27,10 +28,10 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) -func (p *processor) DomainBlockGet(account *gtsmodel.Account, id string, export bool) (*apimodel.DomainBlock, gtserror.WithCode) { +func (p *processor) DomainBlockGet(ctx context.Context, account *gtsmodel.Account, id string, export bool) (*apimodel.DomainBlock, gtserror.WithCode) { domainBlock := >smodel.DomainBlock{} - if err := p.db.GetByID(id, domainBlock); err != nil { + if err := p.db.GetByID(ctx, id, domainBlock); err != nil { if err != db.ErrNoEntries { // something has gone really wrong return nil, gtserror.NewErrorInternalError(err) diff --git a/internal/processing/admin/getdomainblocks.go b/internal/processing/admin/getdomainblocks.go index f827d03fc..9f0940aef 100644 --- a/internal/processing/admin/getdomainblocks.go +++ b/internal/processing/admin/getdomainblocks.go @@ -19,16 +19,18 @@ package admin import ( + "context" + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) -func (p *processor) DomainBlocksGet(account *gtsmodel.Account, export bool) ([]*apimodel.DomainBlock, gtserror.WithCode) { +func (p *processor) DomainBlocksGet(ctx context.Context, account *gtsmodel.Account, export bool) ([]*apimodel.DomainBlock, gtserror.WithCode) { domainBlocks := []*gtsmodel.DomainBlock{} - if err := p.db.GetAll(&domainBlocks); err != nil { + if err := p.db.GetAll(ctx, &domainBlocks); err != nil { if err != db.ErrNoEntries { // something has gone really wrong return nil, gtserror.NewErrorInternalError(err) diff --git a/internal/processing/admin/importdomainblocks.go b/internal/processing/admin/importdomainblocks.go index ab171b712..66326bd62 100644 --- a/internal/processing/admin/importdomainblocks.go +++ b/internal/processing/admin/importdomainblocks.go @@ -20,6 +20,7 @@ package admin import ( "bytes" + "context" "encoding/json" "errors" "fmt" @@ -32,7 +33,7 @@ import ( ) // DomainBlocksImport handles the import of a bunch of domain blocks at once, by calling the DomainBlockCreate function for each domain in the provided file. -func (p *processor) DomainBlocksImport(account *gtsmodel.Account, domains *multipart.FileHeader) ([]*apimodel.DomainBlock, gtserror.WithCode) { +func (p *processor) DomainBlocksImport(ctx context.Context, account *gtsmodel.Account, domains *multipart.FileHeader) ([]*apimodel.DomainBlock, gtserror.WithCode) { f, err := domains.Open() if err != nil { @@ -54,7 +55,7 @@ func (p *processor) DomainBlocksImport(account *gtsmodel.Account, domains *multi blocks := []*apimodel.DomainBlock{} for _, d := range d { - block, err := p.DomainBlockCreate(account, d.Domain, false, d.PublicComment, "", "") + block, err := p.DomainBlockCreate(ctx, account, d.Domain, false, d.PublicComment, "", "") if err != nil { return nil, err diff --git a/internal/processing/app.go b/internal/processing/app.go index 7da5344ac..444c4dda2 100644 --- a/internal/processing/app.go +++ b/internal/processing/app.go @@ -19,6 +19,8 @@ package processing import ( + "context" + "github.com/google/uuid" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" @@ -26,7 +28,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/oauth" ) -func (p *processor) AppCreate(authed *oauth.Auth, form *apimodel.ApplicationCreateRequest) (*apimodel.Application, error) { +func (p *processor) AppCreate(ctx context.Context, authed *oauth.Auth, form *apimodel.ApplicationCreateRequest) (*apimodel.Application, error) { // set default 'read' for scopes if it's not set, this follows the default of the mastodon api https://docs.joinmastodon.org/methods/apps/ var scopes string if form.Scopes == "" { @@ -61,7 +63,7 @@ func (p *processor) AppCreate(authed *oauth.Auth, form *apimodel.ApplicationCrea } // chuck it in the db - if err := p.db.Put(app); err != nil { + if err := p.db.Put(ctx, app); err != nil { return nil, err } @@ -74,7 +76,7 @@ func (p *processor) AppCreate(authed *oauth.Auth, form *apimodel.ApplicationCrea } // chuck it in the db - if err := p.db.Put(oc); err != nil { + if err := p.db.Put(ctx, oc); err != nil { return nil, err } diff --git a/internal/processing/blocks.go b/internal/processing/blocks.go index 809cbde8e..7451485f9 100644 --- a/internal/processing/blocks.go +++ b/internal/processing/blocks.go @@ -19,6 +19,7 @@ package processing import ( + "context" "fmt" "net/url" @@ -28,8 +29,8 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/oauth" ) -func (p *processor) BlocksGet(authed *oauth.Auth, maxID string, sinceID string, limit int) (*apimodel.BlocksResponse, gtserror.WithCode) { - accounts, nextMaxID, prevMinID, err := p.db.GetAccountBlocks(authed.Account.ID, maxID, sinceID, limit) +func (p *processor) BlocksGet(ctx context.Context, authed *oauth.Auth, maxID string, sinceID string, limit int) (*apimodel.BlocksResponse, gtserror.WithCode) { + accounts, nextMaxID, prevMinID, err := p.db.GetAccountBlocks(ctx, authed.Account.ID, maxID, sinceID, limit) if err != nil { if err == db.ErrNoEntries { // there are just no entries diff --git a/internal/processing/federation.go b/internal/processing/federation.go index cea14b4de..2f8b15d16 100644 --- a/internal/processing/federation.go +++ b/internal/processing/federation.go @@ -36,7 +36,7 @@ import ( func (p *processor) GetFediUser(ctx context.Context, requestedUsername string, requestURL *url.URL) (interface{}, gtserror.WithCode) { // get the account the request is referring to - requestedAccount, err := p.db.GetLocalAccountByUsername(requestedUsername) + requestedAccount, err := p.db.GetLocalAccountByUsername(ctx, requestedUsername) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("database error getting account with username %s: %s", requestedUsername, err)) } @@ -62,7 +62,7 @@ func (p *processor) GetFediUser(ctx context.Context, requestedUsername string, r return nil, gtserror.NewErrorNotAuthorized(err) } - blocked, err := p.db.IsBlocked(requestedAccount.ID, requestingAccount.ID, true) + blocked, err := p.db.IsBlocked(ctx, requestedAccount.ID, requestingAccount.ID, true) if err != nil { return nil, gtserror.NewErrorInternalError(err) } @@ -90,7 +90,7 @@ func (p *processor) GetFediUser(ctx context.Context, requestedUsername string, r func (p *processor) GetFediFollowers(ctx context.Context, requestedUsername string, requestURL *url.URL) (interface{}, gtserror.WithCode) { // get the account the request is referring to - requestedAccount, err := p.db.GetLocalAccountByUsername(requestedUsername) + requestedAccount, err := p.db.GetLocalAccountByUsername(ctx, requestedUsername) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("database error getting account with username %s: %s", requestedUsername, err)) } @@ -106,7 +106,7 @@ func (p *processor) GetFediFollowers(ctx context.Context, requestedUsername stri return nil, gtserror.NewErrorNotAuthorized(err) } - blocked, err := p.db.IsBlocked(requestedAccount.ID, requestingAccount.ID, true) + blocked, err := p.db.IsBlocked(ctx, requestedAccount.ID, requestingAccount.ID, true) if err != nil { return nil, gtserror.NewErrorInternalError(err) } @@ -135,7 +135,7 @@ func (p *processor) GetFediFollowers(ctx context.Context, requestedUsername stri func (p *processor) GetFediFollowing(ctx context.Context, requestedUsername string, requestURL *url.URL) (interface{}, gtserror.WithCode) { // get the account the request is referring to - requestedAccount, err := p.db.GetLocalAccountByUsername(requestedUsername) + requestedAccount, err := p.db.GetLocalAccountByUsername(ctx, requestedUsername) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("database error getting account with username %s: %s", requestedUsername, err)) } @@ -151,7 +151,7 @@ func (p *processor) GetFediFollowing(ctx context.Context, requestedUsername stri return nil, gtserror.NewErrorNotAuthorized(err) } - blocked, err := p.db.IsBlocked(requestedAccount.ID, requestingAccount.ID, true) + blocked, err := p.db.IsBlocked(ctx, requestedAccount.ID, requestingAccount.ID, true) if err != nil { return nil, gtserror.NewErrorInternalError(err) } @@ -180,7 +180,7 @@ func (p *processor) GetFediFollowing(ctx context.Context, requestedUsername stri func (p *processor) GetFediStatus(ctx context.Context, requestedUsername string, requestedStatusID string, requestURL *url.URL) (interface{}, gtserror.WithCode) { // get the account the request is referring to - requestedAccount, err := p.db.GetLocalAccountByUsername(requestedUsername) + requestedAccount, err := p.db.GetLocalAccountByUsername(ctx, requestedUsername) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("database error getting account with username %s: %s", requestedUsername, err)) } @@ -198,7 +198,7 @@ func (p *processor) GetFediStatus(ctx context.Context, requestedUsername string, // authorize the request: // 1. check if a block exists between the requester and the requestee - blocked, err := p.db.IsBlocked(requestedAccount.ID, requestingAccount.ID, true) + blocked, err := p.db.IsBlocked(ctx, requestedAccount.ID, requestingAccount.ID, true) if err != nil { return nil, gtserror.NewErrorInternalError(err) } @@ -209,7 +209,7 @@ func (p *processor) GetFediStatus(ctx context.Context, requestedUsername string, // get the status out of the database here s := >smodel.Status{} - if err := p.db.GetWhere([]db.Where{ + if err := p.db.GetWhere(ctx, []db.Where{ {Key: "id", Value: requestedStatusID}, {Key: "account_id", Value: requestedAccount.ID}, }, s); err != nil { @@ -240,7 +240,7 @@ func (p *processor) GetFediStatus(ctx context.Context, requestedUsername string, func (p *processor) GetFediStatusReplies(ctx context.Context, requestedUsername string, requestedStatusID string, page bool, onlyOtherAccounts bool, minID string, requestURL *url.URL) (interface{}, gtserror.WithCode) { // get the account the request is referring to - requestedAccount, err := p.db.GetLocalAccountByUsername(requestedUsername) + requestedAccount, err := p.db.GetLocalAccountByUsername(ctx, requestedUsername) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("database error getting account with username %s: %s", requestedUsername, err)) } @@ -258,7 +258,7 @@ func (p *processor) GetFediStatusReplies(ctx context.Context, requestedUsername // authorize the request: // 1. check if a block exists between the requester and the requestee - blocked, err := p.db.IsBlocked(requestedAccount.ID, requestingAccount.ID, true) + blocked, err := p.db.IsBlocked(ctx, requestedAccount.ID, requestingAccount.ID, true) if err != nil { return nil, gtserror.NewErrorInternalError(err) } @@ -269,7 +269,7 @@ func (p *processor) GetFediStatusReplies(ctx context.Context, requestedUsername // get the status out of the database here s := >smodel.Status{} - if err := p.db.GetWhere([]db.Where{ + if err := p.db.GetWhere(ctx, []db.Where{ {Key: "id", Value: requestedStatusID}, {Key: "account_id", Value: requestedAccount.ID}, }, s); err != nil { @@ -320,7 +320,7 @@ func (p *processor) GetFediStatusReplies(ctx context.Context, requestedUsername } else { // scenario 3 // get immediate children - replies, err := p.db.GetStatusChildren(s, true, minID) + replies, err := p.db.GetStatusChildren(ctx, s, true, minID) if err != nil { return nil, gtserror.NewErrorInternalError(err) } @@ -373,7 +373,7 @@ func (p *processor) GetFediStatusReplies(ctx context.Context, requestedUsername func (p *processor) GetWebfingerAccount(ctx context.Context, requestedUsername string, requestURL *url.URL) (*apimodel.WellKnownResponse, gtserror.WithCode) { // get the account the request is referring to - requestedAccount, err := p.db.GetLocalAccountByUsername(requestedUsername) + requestedAccount, err := p.db.GetLocalAccountByUsername(ctx, requestedUsername) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("database error getting account with username %s: %s", requestedUsername, err)) } @@ -400,7 +400,7 @@ func (p *processor) GetWebfingerAccount(ctx context.Context, requestedUsername s }, nil } -func (p *processor) GetNodeInfoRel(request *http.Request) (*apimodel.WellKnownResponse, gtserror.WithCode) { +func (p *processor) GetNodeInfoRel(ctx context.Context, request *http.Request) (*apimodel.WellKnownResponse, gtserror.WithCode) { return &apimodel.WellKnownResponse{ Links: []apimodel.Link{ { @@ -411,7 +411,7 @@ func (p *processor) GetNodeInfoRel(request *http.Request) (*apimodel.WellKnownRe }, nil } -func (p *processor) GetNodeInfo(request *http.Request) (*apimodel.Nodeinfo, gtserror.WithCode) { +func (p *processor) GetNodeInfo(ctx context.Context, request *http.Request) (*apimodel.Nodeinfo, gtserror.WithCode) { return &apimodel.Nodeinfo{ Version: "2.0", Software: apimodel.NodeInfoSoftware{ diff --git a/internal/processing/followrequest.go b/internal/processing/followrequest.go index 867725023..783a69a96 100644 --- a/internal/processing/followrequest.go +++ b/internal/processing/followrequest.go @@ -19,6 +19,8 @@ package processing import ( + "context" + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtserror" @@ -26,8 +28,8 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/oauth" ) -func (p *processor) FollowRequestsGet(auth *oauth.Auth) ([]apimodel.Account, gtserror.WithCode) { - frs, err := p.db.GetAccountFollowRequests(auth.Account.ID) +func (p *processor) FollowRequestsGet(ctx context.Context, auth *oauth.Auth) ([]apimodel.Account, gtserror.WithCode) { + frs, err := p.db.GetAccountFollowRequests(ctx, auth.Account.ID) if err != nil { if err != db.ErrNoEntries { return nil, gtserror.NewErrorInternalError(err) @@ -37,7 +39,7 @@ func (p *processor) FollowRequestsGet(auth *oauth.Auth) ([]apimodel.Account, gts accts := []apimodel.Account{} for _, fr := range frs { acct := >smodel.Account{} - if err := p.db.GetByID(fr.AccountID, acct); err != nil { + if err := p.db.GetByID(ctx, fr.AccountID, acct); err != nil { return nil, gtserror.NewErrorInternalError(err) } mastoAcct, err := p.tc.AccountToMastoPublic(acct) @@ -49,19 +51,19 @@ func (p *processor) FollowRequestsGet(auth *oauth.Auth) ([]apimodel.Account, gts return accts, nil } -func (p *processor) FollowRequestAccept(auth *oauth.Auth, accountID string) (*apimodel.Relationship, gtserror.WithCode) { - follow, err := p.db.AcceptFollowRequest(accountID, auth.Account.ID) +func (p *processor) FollowRequestAccept(ctx context.Context, auth *oauth.Auth, accountID string) (*apimodel.Relationship, gtserror.WithCode) { + follow, err := p.db.AcceptFollowRequest(ctx, accountID, auth.Account.ID) if err != nil { return nil, gtserror.NewErrorNotFound(err) } originAccount := >smodel.Account{} - if err := p.db.GetByID(follow.AccountID, originAccount); err != nil { + if err := p.db.GetByID(ctx, follow.AccountID, originAccount); err != nil { return nil, gtserror.NewErrorInternalError(err) } targetAccount := >smodel.Account{} - if err := p.db.GetByID(follow.TargetAccountID, targetAccount); err != nil { + if err := p.db.GetByID(ctx, follow.TargetAccountID, targetAccount); err != nil { return nil, gtserror.NewErrorInternalError(err) } @@ -73,7 +75,7 @@ func (p *processor) FollowRequestAccept(auth *oauth.Auth, accountID string) (*ap TargetAccount: targetAccount, } - gtsR, err := p.db.GetRelationship(auth.Account.ID, accountID) + gtsR, err := p.db.GetRelationship(ctx, auth.Account.ID, accountID) if err != nil { return nil, gtserror.NewErrorInternalError(err) } @@ -86,6 +88,6 @@ func (p *processor) FollowRequestAccept(auth *oauth.Auth, accountID string) (*ap return r, nil } -func (p *processor) FollowRequestDeny(auth *oauth.Auth) gtserror.WithCode { +func (p *processor) FollowRequestDeny(ctx context.Context, auth *oauth.Auth) gtserror.WithCode { return nil } diff --git a/internal/processing/fromclientapi.go b/internal/processing/fromclientapi.go index beed283c1..7c8743005 100644 --- a/internal/processing/fromclientapi.go +++ b/internal/processing/fromclientapi.go @@ -29,7 +29,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) -func (p *processor) processFromClientAPI(clientMsg gtsmodel.FromClientAPI) error { +func (p *processor) processFromClientAPI(ctx context.Context, clientMsg gtsmodel.FromClientAPI) error { switch clientMsg.APActivityType { case gtsmodel.ActivityStreamsCreate: // CREATE @@ -41,16 +41,16 @@ func (p *processor) processFromClientAPI(clientMsg gtsmodel.FromClientAPI) error return errors.New("note was not parseable as *gtsmodel.Status") } - if err := p.timelineStatus(status); err != nil { + if err := p.timelineStatus(ctx, status); err != nil { return err } - if err := p.notifyStatus(status); err != nil { + if err := p.notifyStatus(ctx, status); err != nil { return err } if status.VisibilityAdvanced != nil && status.VisibilityAdvanced.Federated { - return p.federateStatus(status) + return p.federateStatus(ctx, status) } case gtsmodel.ActivityStreamsFollow: // CREATE FOLLOW REQUEST @@ -59,11 +59,11 @@ func (p *processor) processFromClientAPI(clientMsg gtsmodel.FromClientAPI) error return errors.New("followrequest was not parseable as *gtsmodel.FollowRequest") } - if err := p.notifyFollowRequest(followRequest, clientMsg.TargetAccount); err != nil { + if err := p.notifyFollowRequest(ctx, followRequest, clientMsg.TargetAccount); err != nil { return err } - return p.federateFollow(followRequest, clientMsg.OriginAccount, clientMsg.TargetAccount) + return p.federateFollow(ctx, followRequest, clientMsg.OriginAccount, clientMsg.TargetAccount) case gtsmodel.ActivityStreamsLike: // CREATE LIKE/FAVE fave, ok := clientMsg.GTSModel.(*gtsmodel.StatusFave) @@ -71,11 +71,11 @@ func (p *processor) processFromClientAPI(clientMsg gtsmodel.FromClientAPI) error return errors.New("fave was not parseable as *gtsmodel.StatusFave") } - if err := p.notifyFave(fave, clientMsg.TargetAccount); err != nil { + if err := p.notifyFave(ctx, fave, clientMsg.TargetAccount); err != nil { return err } - return p.federateFave(fave, clientMsg.OriginAccount, clientMsg.TargetAccount) + return p.federateFave(ctx, fave, clientMsg.OriginAccount, clientMsg.TargetAccount) case gtsmodel.ActivityStreamsAnnounce: // CREATE BOOST/ANNOUNCE boostWrapperStatus, ok := clientMsg.GTSModel.(*gtsmodel.Status) @@ -83,15 +83,15 @@ func (p *processor) processFromClientAPI(clientMsg gtsmodel.FromClientAPI) error return errors.New("boost was not parseable as *gtsmodel.Status") } - if err := p.timelineStatus(boostWrapperStatus); err != nil { + if err := p.timelineStatus(ctx, boostWrapperStatus); err != nil { return err } - if err := p.notifyAnnounce(boostWrapperStatus); err != nil { + if err := p.notifyAnnounce(ctx, boostWrapperStatus); err != nil { return err } - return p.federateAnnounce(boostWrapperStatus, clientMsg.OriginAccount, clientMsg.TargetAccount) + return p.federateAnnounce(ctx, boostWrapperStatus, clientMsg.OriginAccount, clientMsg.TargetAccount) case gtsmodel.ActivityStreamsBlock: // CREATE BLOCK block, ok := clientMsg.GTSModel.(*gtsmodel.Block) @@ -110,7 +110,7 @@ func (p *processor) processFromClientAPI(clientMsg gtsmodel.FromClientAPI) error // TODO: same with notifications // TODO: same with bookmarks - return p.federateBlock(block) + return p.federateBlock(ctx, block) } case gtsmodel.ActivityStreamsUpdate: // UPDATE @@ -122,7 +122,7 @@ func (p *processor) processFromClientAPI(clientMsg gtsmodel.FromClientAPI) error return errors.New("account was not parseable as *gtsmodel.Account") } - return p.federateAccountUpdate(account, clientMsg.OriginAccount) + return p.federateAccountUpdate(ctx, account, clientMsg.OriginAccount) } case gtsmodel.ActivityStreamsAccept: // ACCEPT @@ -134,11 +134,11 @@ func (p *processor) processFromClientAPI(clientMsg gtsmodel.FromClientAPI) error return errors.New("accept was not parseable as *gtsmodel.Follow") } - if err := p.notifyFollow(follow, clientMsg.TargetAccount); err != nil { + if err := p.notifyFollow(ctx, follow, clientMsg.TargetAccount); err != nil { return err } - return p.federateAcceptFollowRequest(follow, clientMsg.OriginAccount, clientMsg.TargetAccount) + return p.federateAcceptFollowRequest(ctx, follow, clientMsg.OriginAccount, clientMsg.TargetAccount) } case gtsmodel.ActivityStreamsUndo: // UNDO @@ -149,21 +149,21 @@ func (p *processor) processFromClientAPI(clientMsg gtsmodel.FromClientAPI) error if !ok { return errors.New("undo was not parseable as *gtsmodel.Follow") } - return p.federateUnfollow(follow, clientMsg.OriginAccount, clientMsg.TargetAccount) + return p.federateUnfollow(ctx, follow, clientMsg.OriginAccount, clientMsg.TargetAccount) case gtsmodel.ActivityStreamsBlock: // UNDO BLOCK block, ok := clientMsg.GTSModel.(*gtsmodel.Block) if !ok { return errors.New("undo was not parseable as *gtsmodel.Block") } - return p.federateUnblock(block) + return p.federateUnblock(ctx, block) case gtsmodel.ActivityStreamsLike: // UNDO LIKE/FAVE fave, ok := clientMsg.GTSModel.(*gtsmodel.StatusFave) if !ok { return errors.New("undo was not parseable as *gtsmodel.StatusFave") } - return p.federateUnfave(fave, clientMsg.OriginAccount, clientMsg.TargetAccount) + return p.federateUnfave(ctx, fave, clientMsg.OriginAccount, clientMsg.TargetAccount) case gtsmodel.ActivityStreamsAnnounce: // UNDO ANNOUNCE/BOOST boost, ok := clientMsg.GTSModel.(*gtsmodel.Status) @@ -175,7 +175,7 @@ func (p *processor) processFromClientAPI(clientMsg gtsmodel.FromClientAPI) error return err } - return p.federateUnannounce(boost, clientMsg.OriginAccount, clientMsg.TargetAccount) + return p.federateUnannounce(ctx, boost, clientMsg.OriginAccount, clientMsg.TargetAccount) } case gtsmodel.ActivityStreamsDelete: // DELETE @@ -200,13 +200,13 @@ func (p *processor) processFromClientAPI(clientMsg gtsmodel.FromClientAPI) error // delete all mentions for this status for _, m := range statusToDelete.MentionIDs { - if err := p.db.DeleteByID(m, >smodel.Mention{}); err != nil { + if err := p.db.DeleteByID(ctx, m, >smodel.Mention{}); err != nil { return err } } // delete all notifications for this status - if err := p.db.DeleteWhere([]db.Where{{Key: "status_id", Value: statusToDelete.ID}}, &[]*gtsmodel.Notification{}); err != nil { + if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "status_id", Value: statusToDelete.ID}}, &[]*gtsmodel.Notification{}); err != nil { return err } @@ -215,7 +215,7 @@ func (p *processor) processFromClientAPI(clientMsg gtsmodel.FromClientAPI) error return err } - return p.federateStatusDelete(statusToDelete) + return p.federateStatusDelete(ctx, statusToDelete) case gtsmodel.ActivityStreamsProfile, gtsmodel.ActivityStreamsPerson: // DELETE ACCOUNT/PROFILE @@ -228,7 +228,7 @@ func (p *processor) processFromClientAPI(clientMsg gtsmodel.FromClientAPI) error // origin is whichever account caused this message origin = clientMsg.OriginAccount.ID } - return p.accountProcessor.Delete(clientMsg.TargetAccount, origin) + return p.accountProcessor.Delete(ctx, clientMsg.TargetAccount, origin) } } return nil @@ -236,10 +236,10 @@ func (p *processor) processFromClientAPI(clientMsg gtsmodel.FromClientAPI) error // TODO: move all the below functions into federation.Federator -func (p *processor) federateStatus(status *gtsmodel.Status) error { +func (p *processor) federateStatus(ctx context.Context, status *gtsmodel.Status) error { if status.Account == nil { a := >smodel.Account{} - if err := p.db.GetByID(status.AccountID, a); err != nil { + if err := p.db.GetByID(ctx, status.AccountID, a); err != nil { return fmt.Errorf("federateStatus: error fetching status author account: %s", err) } status.Account = a @@ -260,14 +260,14 @@ func (p *processor) federateStatus(status *gtsmodel.Status) error { return fmt.Errorf("federateStatus: error parsing outboxURI %s: %s", status.Account.OutboxURI, err) } - _, err = p.federator.FederatingActor().Send(context.Background(), outboxIRI, asStatus) + _, err = p.federator.FederatingActor().Send(ctx, outboxIRI, asStatus) return err } -func (p *processor) federateStatusDelete(status *gtsmodel.Status) error { +func (p *processor) federateStatusDelete(ctx context.Context, status *gtsmodel.Status) error { if status.Account == nil { a := >smodel.Account{} - if err := p.db.GetByID(status.AccountID, a); err != nil { + if err := p.db.GetByID(ctx, status.AccountID, a); err != nil { return fmt.Errorf("federateStatus: error fetching status author account: %s", err) } status.Account = a @@ -310,11 +310,11 @@ func (p *processor) federateStatusDelete(status *gtsmodel.Status) error { delete.SetActivityStreamsTo(asStatus.GetActivityStreamsTo()) delete.SetActivityStreamsCc(asStatus.GetActivityStreamsCc()) - _, err = p.federator.FederatingActor().Send(context.Background(), outboxIRI, delete) + _, err = p.federator.FederatingActor().Send(ctx, outboxIRI, delete) return err } -func (p *processor) federateFollow(followRequest *gtsmodel.FollowRequest, originAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) error { +func (p *processor) federateFollow(ctx context.Context, followRequest *gtsmodel.FollowRequest, originAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) error { // if both accounts are local there's nothing to do here if originAccount.Domain == "" && targetAccount.Domain == "" { return nil @@ -332,11 +332,11 @@ func (p *processor) federateFollow(followRequest *gtsmodel.FollowRequest, origin return fmt.Errorf("federateFollow: error parsing outboxURI %s: %s", originAccount.OutboxURI, err) } - _, err = p.federator.FederatingActor().Send(context.Background(), outboxIRI, asFollow) + _, err = p.federator.FederatingActor().Send(ctx, outboxIRI, asFollow) return err } -func (p *processor) federateUnfollow(follow *gtsmodel.Follow, originAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) error { +func (p *processor) federateUnfollow(ctx context.Context, follow *gtsmodel.Follow, originAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) error { // if both accounts are local there's nothing to do here if originAccount.Domain == "" && targetAccount.Domain == "" { return nil @@ -373,11 +373,11 @@ func (p *processor) federateUnfollow(follow *gtsmodel.Follow, originAccount *gts } // send off the Undo - _, err = p.federator.FederatingActor().Send(context.Background(), outboxIRI, undo) + _, err = p.federator.FederatingActor().Send(ctx, outboxIRI, undo) return err } -func (p *processor) federateUnfave(fave *gtsmodel.StatusFave, originAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) error { +func (p *processor) federateUnfave(ctx context.Context, fave *gtsmodel.StatusFave, originAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) error { // if both accounts are local there's nothing to do here if originAccount.Domain == "" && targetAccount.Domain == "" { return nil @@ -412,11 +412,11 @@ func (p *processor) federateUnfave(fave *gtsmodel.StatusFave, originAccount *gts if err != nil { return fmt.Errorf("federateFave: error parsing outboxURI %s: %s", originAccount.OutboxURI, err) } - _, err = p.federator.FederatingActor().Send(context.Background(), outboxIRI, undo) + _, err = p.federator.FederatingActor().Send(ctx, outboxIRI, undo) return err } -func (p *processor) federateUnannounce(boost *gtsmodel.Status, originAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) error { +func (p *processor) federateUnannounce(ctx context.Context, boost *gtsmodel.Status, originAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) error { if originAccount.Domain != "" { // nothing to do here return nil @@ -447,11 +447,11 @@ func (p *processor) federateUnannounce(boost *gtsmodel.Status, originAccount *gt return fmt.Errorf("federateUnannounce: error parsing outboxURI %s: %s", originAccount.OutboxURI, err) } - _, err = p.federator.FederatingActor().Send(context.Background(), outboxIRI, undo) + _, err = p.federator.FederatingActor().Send(ctx, outboxIRI, undo) return err } -func (p *processor) federateAcceptFollowRequest(follow *gtsmodel.Follow, originAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) error { +func (p *processor) federateAcceptFollowRequest(ctx context.Context, follow *gtsmodel.Follow, originAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) error { // if both accounts are local there's nothing to do here if originAccount.Domain == "" && targetAccount.Domain == "" { return nil @@ -497,11 +497,11 @@ func (p *processor) federateAcceptFollowRequest(follow *gtsmodel.Follow, originA } // send off the accept using the accepter's outbox - _, err = p.federator.FederatingActor().Send(context.Background(), outboxIRI, accept) + _, err = p.federator.FederatingActor().Send(ctx, outboxIRI, accept) return err } -func (p *processor) federateFave(fave *gtsmodel.StatusFave, originAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) error { +func (p *processor) federateFave(ctx context.Context, fave *gtsmodel.StatusFave, originAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) error { // if both accounts are local there's nothing to do here if originAccount.Domain == "" && targetAccount.Domain == "" { return nil @@ -517,11 +517,11 @@ func (p *processor) federateFave(fave *gtsmodel.StatusFave, originAccount *gtsmo if err != nil { return fmt.Errorf("federateFave: error parsing outboxURI %s: %s", originAccount.OutboxURI, err) } - _, err = p.federator.FederatingActor().Send(context.Background(), outboxIRI, asFave) + _, err = p.federator.FederatingActor().Send(ctx, outboxIRI, asFave) return err } -func (p *processor) federateAnnounce(boostWrapperStatus *gtsmodel.Status, boostingAccount *gtsmodel.Account, boostedAccount *gtsmodel.Account) error { +func (p *processor) federateAnnounce(ctx context.Context, boostWrapperStatus *gtsmodel.Status, boostingAccount *gtsmodel.Account, boostedAccount *gtsmodel.Account) error { announce, err := p.tc.BoostToAS(boostWrapperStatus, boostingAccount, boostedAccount) if err != nil { return fmt.Errorf("federateAnnounce: error converting status to announce: %s", err) @@ -532,11 +532,11 @@ func (p *processor) federateAnnounce(boostWrapperStatus *gtsmodel.Status, boosti return fmt.Errorf("federateAnnounce: error parsing outboxURI %s: %s", boostingAccount.OutboxURI, err) } - _, err = p.federator.FederatingActor().Send(context.Background(), outboxIRI, announce) + _, err = p.federator.FederatingActor().Send(ctx, outboxIRI, announce) return err } -func (p *processor) federateAccountUpdate(updatedAccount *gtsmodel.Account, originAccount *gtsmodel.Account) error { +func (p *processor) federateAccountUpdate(ctx context.Context, updatedAccount *gtsmodel.Account, originAccount *gtsmodel.Account) error { person, err := p.tc.AccountToAS(updatedAccount) if err != nil { return fmt.Errorf("federateAccountUpdate: error converting account to person: %s", err) @@ -552,14 +552,14 @@ func (p *processor) federateAccountUpdate(updatedAccount *gtsmodel.Account, orig return fmt.Errorf("federateAnnounce: error parsing outboxURI %s: %s", originAccount.OutboxURI, err) } - _, err = p.federator.FederatingActor().Send(context.Background(), outboxIRI, update) + _, err = p.federator.FederatingActor().Send(ctx, outboxIRI, update) return err } -func (p *processor) federateBlock(block *gtsmodel.Block) error { +func (p *processor) federateBlock(ctx context.Context, block *gtsmodel.Block) error { if block.Account == nil { a := >smodel.Account{} - if err := p.db.GetByID(block.AccountID, a); err != nil { + if err := p.db.GetByID(ctx, block.AccountID, a); err != nil { return fmt.Errorf("federateBlock: error getting block account from database: %s", err) } block.Account = a @@ -567,7 +567,7 @@ func (p *processor) federateBlock(block *gtsmodel.Block) error { if block.TargetAccount == nil { a := >smodel.Account{} - if err := p.db.GetByID(block.TargetAccountID, a); err != nil { + if err := p.db.GetByID(ctx, block.TargetAccountID, a); err != nil { return fmt.Errorf("federateBlock: error getting block target account from database: %s", err) } block.TargetAccount = a @@ -588,14 +588,14 @@ func (p *processor) federateBlock(block *gtsmodel.Block) error { return fmt.Errorf("federateBlock: error parsing outboxURI %s: %s", block.Account.OutboxURI, err) } - _, err = p.federator.FederatingActor().Send(context.Background(), outboxIRI, asBlock) + _, err = p.federator.FederatingActor().Send(ctx, outboxIRI, asBlock) return err } -func (p *processor) federateUnblock(block *gtsmodel.Block) error { +func (p *processor) federateUnblock(ctx context.Context, block *gtsmodel.Block) error { if block.Account == nil { a := >smodel.Account{} - if err := p.db.GetByID(block.AccountID, a); err != nil { + if err := p.db.GetByID(ctx, block.AccountID, a); err != nil { return fmt.Errorf("federateUnblock: error getting block account from database: %s", err) } block.Account = a @@ -603,7 +603,7 @@ func (p *processor) federateUnblock(block *gtsmodel.Block) error { if block.TargetAccount == nil { a := >smodel.Account{} - if err := p.db.GetByID(block.TargetAccountID, a); err != nil { + if err := p.db.GetByID(ctx, block.TargetAccountID, a); err != nil { return fmt.Errorf("federateUnblock: error getting block target account from database: %s", err) } block.TargetAccount = a @@ -642,6 +642,6 @@ func (p *processor) federateUnblock(block *gtsmodel.Block) error { if err != nil { return fmt.Errorf("federateUnblock: error parsing outboxURI %s: %s", block.Account.OutboxURI, err) } - _, err = p.federator.FederatingActor().Send(context.Background(), outboxIRI, undo) + _, err = p.federator.FederatingActor().Send(ctx, outboxIRI, undo) return err } diff --git a/internal/processing/fromcommon.go b/internal/processing/fromcommon.go index 2c2635175..47dd907ee 100644 --- a/internal/processing/fromcommon.go +++ b/internal/processing/fromcommon.go @@ -19,6 +19,7 @@ package processing import ( + "context" "fmt" "strings" "sync" @@ -28,7 +29,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/id" ) -func (p *processor) notifyStatus(status *gtsmodel.Status) error { +func (p *processor) notifyStatus(ctx context.Context, status *gtsmodel.Status) error { // if there are no mentions in this status then just bail if len(status.MentionIDs) == 0 { return nil @@ -36,7 +37,7 @@ func (p *processor) notifyStatus(status *gtsmodel.Status) error { if status.Mentions == nil { // there are mentions but they're not fully populated on the status yet so do this - menchies, err := p.db.GetMentions(status.MentionIDs) + menchies, err := p.db.GetMentions(ctx, status.MentionIDs) if err != nil { return fmt.Errorf("notifyStatus: error getting mentions for status %s from the db: %s", status.ID, err) } @@ -47,7 +48,7 @@ func (p *processor) notifyStatus(status *gtsmodel.Status) error { for _, m := range status.Mentions { // make sure this is a local account, otherwise we don't need to create a notification for it if m.TargetAccount == nil { - a, err := p.db.GetAccountByID(m.TargetAccountID) + a, err := p.db.GetAccountByID(ctx, m.TargetAccountID) if err != nil { // we don't have the account or there's been an error return fmt.Errorf("notifyStatus: error getting account with id %s from the db: %s", m.TargetAccountID, err) @@ -60,7 +61,7 @@ func (p *processor) notifyStatus(status *gtsmodel.Status) error { } // make sure a notif doesn't already exist for this mention - err := p.db.GetWhere([]db.Where{ + err := p.db.GetWhere(ctx, []db.Where{ {Key: "notification_type", Value: gtsmodel.NotificationMention}, {Key: "target_account_id", Value: m.TargetAccountID}, {Key: "origin_account_id", Value: status.AccountID}, @@ -92,7 +93,7 @@ func (p *processor) notifyStatus(status *gtsmodel.Status) error { Status: status, } - if err := p.db.Put(notif); err != nil { + if err := p.db.Put(ctx, notif); err != nil { return fmt.Errorf("notifyStatus: error putting notification in database: %s", err) } @@ -110,7 +111,7 @@ func (p *processor) notifyStatus(status *gtsmodel.Status) error { return nil } -func (p *processor) notifyFollowRequest(followRequest *gtsmodel.FollowRequest, receivingAccount *gtsmodel.Account) error { +func (p *processor) notifyFollowRequest(ctx context.Context, followRequest *gtsmodel.FollowRequest, receivingAccount *gtsmodel.Account) error { // return if this isn't a local account if receivingAccount.Domain != "" { return nil @@ -128,7 +129,7 @@ func (p *processor) notifyFollowRequest(followRequest *gtsmodel.FollowRequest, r OriginAccountID: followRequest.AccountID, } - if err := p.db.Put(notif); err != nil { + if err := p.db.Put(ctx, notif); err != nil { return fmt.Errorf("notifyFollowRequest: error putting notification in database: %s", err) } @@ -145,14 +146,14 @@ func (p *processor) notifyFollowRequest(followRequest *gtsmodel.FollowRequest, r return nil } -func (p *processor) notifyFollow(follow *gtsmodel.Follow, targetAccount *gtsmodel.Account) error { +func (p *processor) notifyFollow(ctx context.Context, follow *gtsmodel.Follow, targetAccount *gtsmodel.Account) error { // return if this isn't a local account if targetAccount.Domain != "" { return nil } // first remove the follow request notification - if err := p.db.DeleteWhere([]db.Where{ + if err := p.db.DeleteWhere(ctx, []db.Where{ {Key: "notification_type", Value: gtsmodel.NotificationFollowRequest}, {Key: "target_account_id", Value: follow.TargetAccountID}, {Key: "origin_account_id", Value: follow.AccountID}, @@ -174,7 +175,7 @@ func (p *processor) notifyFollow(follow *gtsmodel.Follow, targetAccount *gtsmode OriginAccountID: follow.AccountID, OriginAccount: follow.Account, } - if err := p.db.Put(notif); err != nil { + if err := p.db.Put(ctx, notif); err != nil { return fmt.Errorf("notifyFollow: error putting notification in database: %s", err) } @@ -191,7 +192,7 @@ func (p *processor) notifyFollow(follow *gtsmodel.Follow, targetAccount *gtsmode return nil } -func (p *processor) notifyFave(fave *gtsmodel.StatusFave, targetAccount *gtsmodel.Account) error { +func (p *processor) notifyFave(ctx context.Context, fave *gtsmodel.StatusFave, targetAccount *gtsmodel.Account) error { // return if this isn't a local account if targetAccount.Domain != "" { return nil @@ -213,7 +214,7 @@ func (p *processor) notifyFave(fave *gtsmodel.StatusFave, targetAccount *gtsmode Status: fave.Status, } - if err := p.db.Put(notif); err != nil { + if err := p.db.Put(ctx, notif); err != nil { return fmt.Errorf("notifyFave: error putting notification in database: %s", err) } @@ -230,14 +231,14 @@ func (p *processor) notifyFave(fave *gtsmodel.StatusFave, targetAccount *gtsmode return nil } -func (p *processor) notifyAnnounce(status *gtsmodel.Status) error { +func (p *processor) notifyAnnounce(ctx context.Context, status *gtsmodel.Status) error { if status.BoostOfID == "" { // not a boost, nothing to do return nil } if status.BoostOf == nil { - boostedStatus, err := p.db.GetStatusByID(status.BoostOfID) + boostedStatus, err := p.db.GetStatusByID(ctx, status.BoostOfID) if err != nil { return fmt.Errorf("notifyAnnounce: error getting status with id %s: %s", status.BoostOfID, err) } @@ -245,7 +246,7 @@ func (p *processor) notifyAnnounce(status *gtsmodel.Status) error { } if status.BoostOfAccount == nil { - boostedAcct, err := p.db.GetAccountByID(status.BoostOfAccountID) + boostedAcct, err := p.db.GetAccountByID(ctx, status.BoostOfAccountID) if err != nil { return fmt.Errorf("notifyAnnounce: error getting account with id %s: %s", status.BoostOfAccountID, err) } @@ -264,7 +265,7 @@ func (p *processor) notifyAnnounce(status *gtsmodel.Status) error { } // make sure a notif doesn't already exist for this announce - err := p.db.GetWhere([]db.Where{ + err := p.db.GetWhere(ctx, []db.Where{ {Key: "notification_type", Value: gtsmodel.NotificationReblog}, {Key: "target_account_id", Value: status.BoostOfAccountID}, {Key: "origin_account_id", Value: status.AccountID}, @@ -292,7 +293,7 @@ func (p *processor) notifyAnnounce(status *gtsmodel.Status) error { Status: status, } - if err := p.db.Put(notif); err != nil { + if err := p.db.Put(ctx, notif); err != nil { return fmt.Errorf("notifyAnnounce: error putting notification in database: %s", err) } @@ -309,10 +310,10 @@ func (p *processor) notifyAnnounce(status *gtsmodel.Status) error { return nil } -func (p *processor) timelineStatus(status *gtsmodel.Status) error { +func (p *processor) timelineStatus(ctx context.Context, status *gtsmodel.Status) error { // make sure the author account is pinned onto the status if status.Account == nil { - a, err := p.db.GetAccountByID(status.AccountID) + a, err := p.db.GetAccountByID(ctx, status.AccountID) if err != nil { return fmt.Errorf("timelineStatus: error getting author account with id %s: %s", status.AccountID, err) } @@ -320,7 +321,7 @@ func (p *processor) timelineStatus(status *gtsmodel.Status) error { } // get local followers of the account that posted the status - follows, err := p.db.GetAccountFollowedBy(status.AccountID, true) + follows, err := p.db.GetAccountFollowedBy(ctx, status.AccountID, true) if err != nil { return fmt.Errorf("timelineStatus: error getting followers for account id %s: %s", status.AccountID, err) } @@ -338,7 +339,7 @@ func (p *processor) timelineStatus(status *gtsmodel.Status) error { errors := make(chan error, len(follows)) for _, f := range follows { - go p.timelineStatusForAccount(status, f.AccountID, errors, &wg) + go p.timelineStatusForAccount(ctx, status, f.AccountID, errors, &wg) } // read any errors that come in from the async functions @@ -365,11 +366,11 @@ func (p *processor) timelineStatus(status *gtsmodel.Status) error { return nil } -func (p *processor) timelineStatusForAccount(status *gtsmodel.Status, accountID string, errors chan error, wg *sync.WaitGroup) { +func (p *processor) timelineStatusForAccount(ctx context.Context, status *gtsmodel.Status, accountID string, errors chan error, wg *sync.WaitGroup) { defer wg.Done() // get the timeline owner account - timelineAccount, err := p.db.GetAccountByID(accountID) + timelineAccount, err := p.db.GetAccountByID(ctx, accountID) if err != nil { errors <- fmt.Errorf("timelineStatusForAccount: error getting account for timeline with id %s: %s", accountID, err) return diff --git a/internal/processing/fromfederator.go b/internal/processing/fromfederator.go index c95c27778..6d8de289f 100644 --- a/internal/processing/fromfederator.go +++ b/internal/processing/fromfederator.go @@ -19,6 +19,7 @@ package processing import ( + "context" "errors" "fmt" "net/url" @@ -29,7 +30,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/id" ) -func (p *processor) processFromFederator(federatorMsg gtsmodel.FromFederator) error { +func (p *processor) processFromFederator(ctx context.Context, federatorMsg gtsmodel.FromFederator) error { l := p.log.WithFields(logrus.Fields{ "func": "processFromFederator", "federatorMsg": fmt.Sprintf("%+v", federatorMsg), @@ -53,11 +54,11 @@ func (p *processor) processFromFederator(federatorMsg gtsmodel.FromFederator) er return err } - if err := p.timelineStatus(status); err != nil { + if err := p.timelineStatus(ctx, status); err != nil { return err } - if err := p.notifyStatus(status); err != nil { + if err := p.notifyStatus(ctx, status); err != nil { return err } case gtsmodel.ActivityStreamsProfile: @@ -70,7 +71,7 @@ func (p *processor) processFromFederator(federatorMsg gtsmodel.FromFederator) er return errors.New("like was not parseable as *gtsmodel.StatusFave") } - if err := p.notifyFave(incomingFave, federatorMsg.ReceivingAccount); err != nil { + if err := p.notifyFave(ctx, incomingFave, federatorMsg.ReceivingAccount); err != nil { return err } case gtsmodel.ActivityStreamsFollow: @@ -80,7 +81,7 @@ func (p *processor) processFromFederator(federatorMsg gtsmodel.FromFederator) er return errors.New("incomingFollowRequest was not parseable as *gtsmodel.FollowRequest") } - if err := p.notifyFollowRequest(incomingFollowRequest, federatorMsg.ReceivingAccount); err != nil { + if err := p.notifyFollowRequest(ctx, incomingFollowRequest, federatorMsg.ReceivingAccount); err != nil { return err } case gtsmodel.ActivityStreamsAnnounce: @@ -100,17 +101,17 @@ func (p *processor) processFromFederator(federatorMsg gtsmodel.FromFederator) er } incomingAnnounce.ID = incomingAnnounceID - if err := p.db.PutStatus(incomingAnnounce); err != nil { + if err := p.db.PutStatus(ctx, incomingAnnounce); err != nil { if err != db.ErrNoEntries { return fmt.Errorf("error adding dereferenced announce to the db: %s", err) } } - if err := p.timelineStatus(incomingAnnounce); err != nil { + if err := p.timelineStatus(ctx, incomingAnnounce); err != nil { return err } - if err := p.notifyAnnounce(incomingAnnounce); err != nil { + if err := p.notifyAnnounce(ctx, incomingAnnounce); err != nil { return err } case gtsmodel.ActivityStreamsBlock: @@ -172,13 +173,13 @@ func (p *processor) processFromFederator(federatorMsg gtsmodel.FromFederator) er // delete all mentions for this status for _, m := range statusToDelete.MentionIDs { - if err := p.db.DeleteByID(m, >smodel.Mention{}); err != nil { + if err := p.db.DeleteByID(ctx, m, >smodel.Mention{}); err != nil { return err } } // delete all notifications for this status - if err := p.db.DeleteWhere([]db.Where{{Key: "status_id", Value: statusToDelete.ID}}, &[]*gtsmodel.Notification{}); err != nil { + if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "status_id", Value: statusToDelete.ID}}, &[]*gtsmodel.Notification{}); err != nil { return err } @@ -198,7 +199,7 @@ func (p *processor) processFromFederator(federatorMsg gtsmodel.FromFederator) er return errors.New("follow was not parseable as *gtsmodel.Follow") } - if err := p.notifyFollow(follow, federatorMsg.ReceivingAccount); err != nil { + if err := p.notifyFollow(ctx, follow, federatorMsg.ReceivingAccount); err != nil { return err } } diff --git a/internal/processing/instance.go b/internal/processing/instance.go index b151744ef..b42be869d 100644 --- a/internal/processing/instance.go +++ b/internal/processing/instance.go @@ -19,6 +19,7 @@ package processing import ( + "context" "fmt" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" @@ -29,9 +30,9 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/util" ) -func (p *processor) InstanceGet(domain string) (*apimodel.Instance, gtserror.WithCode) { +func (p *processor) InstanceGet(ctx context.Context, domain string) (*apimodel.Instance, gtserror.WithCode) { i := >smodel.Instance{} - if err := p.db.GetWhere([]db.Where{{Key: "domain", Value: domain}}, i); err != nil { + if err := p.db.GetWhere(ctx, []db.Where{{Key: "domain", Value: domain}}, i); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error fetching instance %s: %s", p.config.Host, err)) } @@ -43,15 +44,15 @@ func (p *processor) InstanceGet(domain string) (*apimodel.Instance, gtserror.Wit return ai, nil } -func (p *processor) InstancePatch(form *apimodel.InstanceSettingsUpdateRequest) (*apimodel.Instance, gtserror.WithCode) { +func (p *processor) InstancePatch(ctx context.Context, form *apimodel.InstanceSettingsUpdateRequest) (*apimodel.Instance, gtserror.WithCode) { // fetch the instance entry from the db for processing i := >smodel.Instance{} - if err := p.db.GetWhere([]db.Where{{Key: "domain", Value: p.config.Host}}, i); err != nil { + if err := p.db.GetWhere(ctx, []db.Where{{Key: "domain", Value: p.config.Host}}, i); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error fetching instance %s: %s", p.config.Host, err)) } // fetch the instance account from the db for processing - ia, err := p.db.GetInstanceAccount("") + ia, err := p.db.GetInstanceAccount(ctx, "") if err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error fetching instance account %s: %s", p.config.Host, err)) } @@ -67,13 +68,13 @@ func (p *processor) InstancePatch(form *apimodel.InstanceSettingsUpdateRequest) // validate & update site contact account if it's set on the form if form.ContactUsername != nil { // make sure the account with the given username exists in the db - contactAccount, err := p.db.GetLocalAccountByUsername(*form.ContactUsername) + contactAccount, err := p.db.GetLocalAccountByUsername(ctx, *form.ContactUsername) if err != nil { return nil, gtserror.NewErrorBadRequest(err, fmt.Sprintf("account with username %s not retrievable", *form.ContactUsername)) } // make sure it has a user associated with it contactUser := >smodel.User{} - if err := p.db.GetWhere([]db.Where{{Key: "account_id", Value: contactAccount.ID}}, contactUser); err != nil { + if err := p.db.GetWhere(ctx, []db.Where{{Key: "account_id", Value: contactAccount.ID}}, contactUser); err != nil { return nil, gtserror.NewErrorBadRequest(err, fmt.Sprintf("user for account with username %s not retrievable", *form.ContactUsername)) } // suspended accounts cannot be contact accounts @@ -132,7 +133,7 @@ func (p *processor) InstancePatch(form *apimodel.InstanceSettingsUpdateRequest) // process avatar if provided if form.Avatar != nil && form.Avatar.Size != 0 { - _, err := p.accountProcessor.UpdateAvatar(form.Avatar, ia.ID) + _, err := p.accountProcessor.UpdateAvatar(ctx, form.Avatar, ia.ID) if err != nil { return nil, gtserror.NewErrorBadRequest(err, "error processing avatar") } @@ -140,13 +141,13 @@ func (p *processor) InstancePatch(form *apimodel.InstanceSettingsUpdateRequest) // process header if provided if form.Header != nil && form.Header.Size != 0 { - _, err := p.accountProcessor.UpdateHeader(form.Header, ia.ID) + _, err := p.accountProcessor.UpdateHeader(ctx, form.Header, ia.ID) if err != nil { return nil, gtserror.NewErrorBadRequest(err, "error processing header") } } - if err := p.db.UpdateByID(i.ID, i); err != nil { + if err := p.db.UpdateByID(ctx, i.ID, i); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error updating instance %s: %s", p.config.Host, err)) } diff --git a/internal/processing/media.go b/internal/processing/media.go index 6ca0eda5b..0b2443893 100644 --- a/internal/processing/media.go +++ b/internal/processing/media.go @@ -19,23 +19,25 @@ package processing import ( + "context" + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/oauth" ) -func (p *processor) MediaCreate(authed *oauth.Auth, form *apimodel.AttachmentRequest) (*apimodel.Attachment, error) { - return p.mediaProcessor.Create(authed.Account, form) +func (p *processor) MediaCreate(ctx context.Context, authed *oauth.Auth, form *apimodel.AttachmentRequest) (*apimodel.Attachment, error) { + return p.mediaProcessor.Create(ctx, authed.Account, form) } -func (p *processor) MediaGet(authed *oauth.Auth, mediaAttachmentID string) (*apimodel.Attachment, gtserror.WithCode) { - return p.mediaProcessor.GetMedia(authed.Account, mediaAttachmentID) +func (p *processor) MediaGet(ctx context.Context, authed *oauth.Auth, mediaAttachmentID string) (*apimodel.Attachment, gtserror.WithCode) { + return p.mediaProcessor.GetMedia(ctx, authed.Account, mediaAttachmentID) } -func (p *processor) MediaUpdate(authed *oauth.Auth, mediaAttachmentID string, form *apimodel.AttachmentUpdateRequest) (*apimodel.Attachment, gtserror.WithCode) { - return p.mediaProcessor.Update(authed.Account, mediaAttachmentID, form) +func (p *processor) MediaUpdate(ctx context.Context, authed *oauth.Auth, mediaAttachmentID string, form *apimodel.AttachmentUpdateRequest) (*apimodel.Attachment, gtserror.WithCode) { + return p.mediaProcessor.Update(ctx, authed.Account, mediaAttachmentID, form) } -func (p *processor) FileGet(authed *oauth.Auth, form *apimodel.GetContentRequestForm) (*apimodel.Content, error) { - return p.mediaProcessor.GetFile(authed.Account, form) +func (p *processor) FileGet(ctx context.Context, authed *oauth.Auth, form *apimodel.GetContentRequestForm) (*apimodel.Content, error) { + return p.mediaProcessor.GetFile(ctx, authed.Account, form) } diff --git a/internal/processing/processor.go b/internal/processing/processor.go index 48ed2a35f..c2ddbbed9 100644 --- a/internal/processing/processor.go +++ b/internal/processing/processor.go @@ -64,108 +64,108 @@ type Processor interface { */ // AccountCreate processes the given form for creating a new account, returning an oauth token for that account if successful. - AccountCreate(authed *oauth.Auth, form *apimodel.AccountCreateRequest) (*apimodel.Token, error) + AccountCreate(ctx context.Context, authed *oauth.Auth, form *apimodel.AccountCreateRequest) (*apimodel.Token, error) // AccountGet processes the given request for account information. - AccountGet(authed *oauth.Auth, targetAccountID string) (*apimodel.Account, error) + AccountGet(ctx context.Context, authed *oauth.Auth, targetAccountID string) (*apimodel.Account, error) // AccountUpdate processes the update of an account with the given form - AccountUpdate(authed *oauth.Auth, form *apimodel.UpdateCredentialsRequest) (*apimodel.Account, error) + AccountUpdate(ctx context.Context, authed *oauth.Auth, form *apimodel.UpdateCredentialsRequest) (*apimodel.Account, error) // AccountStatusesGet fetches a number of statuses (in time descending order) from the given account, filtered by visibility for // the account given in authed. - AccountStatusesGet(authed *oauth.Auth, targetAccountID string, limit int, excludeReplies bool, maxID string, pinned bool, mediaOnly bool) ([]apimodel.Status, gtserror.WithCode) + AccountStatusesGet(ctx context.Context, authed *oauth.Auth, targetAccountID string, limit int, excludeReplies bool, maxID string, pinned bool, mediaOnly bool) ([]apimodel.Status, gtserror.WithCode) // AccountFollowersGet fetches a list of the target account's followers. - AccountFollowersGet(authed *oauth.Auth, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) + AccountFollowersGet(ctx context.Context, authed *oauth.Auth, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) // AccountFollowingGet fetches a list of the accounts that target account is following. - AccountFollowingGet(authed *oauth.Auth, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) + AccountFollowingGet(ctx context.Context, authed *oauth.Auth, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) // AccountRelationshipGet returns a relationship model describing the relationship of the targetAccount to the Authed account. - AccountRelationshipGet(authed *oauth.Auth, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) + AccountRelationshipGet(ctx context.Context, authed *oauth.Auth, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) // AccountFollowCreate handles a follow request to an account, either remote or local. - AccountFollowCreate(authed *oauth.Auth, form *apimodel.AccountFollowRequest) (*apimodel.Relationship, gtserror.WithCode) + AccountFollowCreate(ctx context.Context, authed *oauth.Auth, form *apimodel.AccountFollowRequest) (*apimodel.Relationship, gtserror.WithCode) // AccountFollowRemove handles the removal of a follow/follow request to an account, either remote or local. - AccountFollowRemove(authed *oauth.Auth, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) + AccountFollowRemove(ctx context.Context, authed *oauth.Auth, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) // AccountBlockCreate handles the creation of a block from authed account to target account, either remote or local. - AccountBlockCreate(authed *oauth.Auth, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) + AccountBlockCreate(ctx context.Context, authed *oauth.Auth, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) // AccountBlockRemove handles the removal of a block from authed account to target account, either remote or local. - AccountBlockRemove(authed *oauth.Auth, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) + AccountBlockRemove(ctx context.Context, authed *oauth.Auth, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) // AdminEmojiCreate handles the creation of a new instance emoji by an admin, using the given form. - AdminEmojiCreate(authed *oauth.Auth, form *apimodel.EmojiCreateRequest) (*apimodel.Emoji, error) + AdminEmojiCreate(ctx context.Context, authed *oauth.Auth, form *apimodel.EmojiCreateRequest) (*apimodel.Emoji, error) // AdminDomainBlockCreate handles the creation of a new domain block by an admin, using the given form. - AdminDomainBlockCreate(authed *oauth.Auth, form *apimodel.DomainBlockCreateRequest) (*apimodel.DomainBlock, gtserror.WithCode) + AdminDomainBlockCreate(ctx context.Context, authed *oauth.Auth, form *apimodel.DomainBlockCreateRequest) (*apimodel.DomainBlock, gtserror.WithCode) // AdminDomainBlocksImport handles the import of multiple domain blocks by an admin, using the given form. - AdminDomainBlocksImport(authed *oauth.Auth, form *apimodel.DomainBlockCreateRequest) ([]*apimodel.DomainBlock, gtserror.WithCode) + AdminDomainBlocksImport(ctx context.Context, authed *oauth.Auth, form *apimodel.DomainBlockCreateRequest) ([]*apimodel.DomainBlock, gtserror.WithCode) // AdminDomainBlocksGet returns a list of currently blocked domains. - AdminDomainBlocksGet(authed *oauth.Auth, export bool) ([]*apimodel.DomainBlock, gtserror.WithCode) + AdminDomainBlocksGet(ctx context.Context, authed *oauth.Auth, export bool) ([]*apimodel.DomainBlock, gtserror.WithCode) // AdminDomainBlockGet returns one domain block, specified by ID. - AdminDomainBlockGet(authed *oauth.Auth, id string, export bool) (*apimodel.DomainBlock, gtserror.WithCode) + AdminDomainBlockGet(ctx context.Context, authed *oauth.Auth, id string, export bool) (*apimodel.DomainBlock, gtserror.WithCode) // AdminDomainBlockDelete deletes one domain block, specified by ID, returning the deleted domain block. - AdminDomainBlockDelete(authed *oauth.Auth, id string) (*apimodel.DomainBlock, gtserror.WithCode) + AdminDomainBlockDelete(ctx context.Context, authed *oauth.Auth, id string) (*apimodel.DomainBlock, gtserror.WithCode) // AppCreate processes the creation of a new API application - AppCreate(authed *oauth.Auth, form *apimodel.ApplicationCreateRequest) (*apimodel.Application, error) + AppCreate(ctx context.Context, authed *oauth.Auth, form *apimodel.ApplicationCreateRequest) (*apimodel.Application, error) // BlocksGet returns a list of accounts blocked by the requesting account. - BlocksGet(authed *oauth.Auth, maxID string, sinceID string, limit int) (*apimodel.BlocksResponse, gtserror.WithCode) + BlocksGet(ctx context.Context, authed *oauth.Auth, maxID string, sinceID string, limit int) (*apimodel.BlocksResponse, gtserror.WithCode) // FileGet handles the fetching of a media attachment file via the fileserver. - FileGet(authed *oauth.Auth, form *apimodel.GetContentRequestForm) (*apimodel.Content, error) + FileGet(ctx context.Context, authed *oauth.Auth, form *apimodel.GetContentRequestForm) (*apimodel.Content, error) // FollowRequestsGet handles the getting of the authed account's incoming follow requests - FollowRequestsGet(auth *oauth.Auth) ([]apimodel.Account, gtserror.WithCode) + FollowRequestsGet(ctx context.Context, auth *oauth.Auth) ([]apimodel.Account, gtserror.WithCode) // FollowRequestAccept handles the acceptance of a follow request from the given account ID - FollowRequestAccept(auth *oauth.Auth, accountID string) (*apimodel.Relationship, gtserror.WithCode) + FollowRequestAccept(ctx context.Context, auth *oauth.Auth, accountID string) (*apimodel.Relationship, gtserror.WithCode) // InstanceGet retrieves instance information for serving at api/v1/instance - InstanceGet(domain string) (*apimodel.Instance, gtserror.WithCode) + InstanceGet(ctx context.Context, domain string) (*apimodel.Instance, gtserror.WithCode) // InstancePatch updates this instance according to the given form. // // It should already be ascertained that the requesting account is authenticated and an admin. - InstancePatch(form *apimodel.InstanceSettingsUpdateRequest) (*apimodel.Instance, gtserror.WithCode) + InstancePatch(ctx context.Context, form *apimodel.InstanceSettingsUpdateRequest) (*apimodel.Instance, gtserror.WithCode) // MediaCreate handles the creation of a media attachment, using the given form. - MediaCreate(authed *oauth.Auth, form *apimodel.AttachmentRequest) (*apimodel.Attachment, error) + MediaCreate(ctx context.Context, authed *oauth.Auth, form *apimodel.AttachmentRequest) (*apimodel.Attachment, error) // MediaGet handles the GET of a media attachment with the given ID - MediaGet(authed *oauth.Auth, attachmentID string) (*apimodel.Attachment, gtserror.WithCode) + MediaGet(ctx context.Context, authed *oauth.Auth, attachmentID string) (*apimodel.Attachment, gtserror.WithCode) // MediaUpdate handles the PUT of a media attachment with the given ID and form - MediaUpdate(authed *oauth.Auth, attachmentID string, form *apimodel.AttachmentUpdateRequest) (*apimodel.Attachment, gtserror.WithCode) + MediaUpdate(ctx context.Context, authed *oauth.Auth, attachmentID string, form *apimodel.AttachmentUpdateRequest) (*apimodel.Attachment, gtserror.WithCode) // NotificationsGet - NotificationsGet(authed *oauth.Auth, limit int, maxID string, sinceID string) ([]*apimodel.Notification, gtserror.WithCode) + NotificationsGet(ctx context.Context, authed *oauth.Auth, limit int, maxID string, sinceID string) ([]*apimodel.Notification, gtserror.WithCode) // SearchGet performs a search with the given params, resolving/dereferencing remotely as desired - SearchGet(authed *oauth.Auth, searchQuery *apimodel.SearchQuery) (*apimodel.SearchResult, gtserror.WithCode) + SearchGet(ctx context.Context, authed *oauth.Auth, searchQuery *apimodel.SearchQuery) (*apimodel.SearchResult, gtserror.WithCode) // StatusCreate processes the given form to create a new status, returning the api model representation of that status if it's OK. - StatusCreate(authed *oauth.Auth, form *apimodel.AdvancedStatusCreateForm) (*apimodel.Status, error) + StatusCreate(ctx context.Context, authed *oauth.Auth, form *apimodel.AdvancedStatusCreateForm) (*apimodel.Status, error) // StatusDelete processes the delete of a given status, returning the deleted status if the delete goes through. - StatusDelete(authed *oauth.Auth, targetStatusID string) (*apimodel.Status, error) + StatusDelete(ctx context.Context, authed *oauth.Auth, targetStatusID string) (*apimodel.Status, error) // StatusFave processes the faving of a given status, returning the updated status if the fave goes through. - StatusFave(authed *oauth.Auth, targetStatusID string) (*apimodel.Status, error) + StatusFave(ctx context.Context, authed *oauth.Auth, targetStatusID string) (*apimodel.Status, error) // StatusBoost processes the boost/reblog of a given status, returning the newly-created boost if all is well. - StatusBoost(authed *oauth.Auth, targetStatusID string) (*apimodel.Status, gtserror.WithCode) + StatusBoost(ctx context.Context, authed *oauth.Auth, targetStatusID string) (*apimodel.Status, gtserror.WithCode) // StatusUnboost processes the unboost/unreblog of a given status, returning the status if all is well. - StatusUnboost(authed *oauth.Auth, targetStatusID string) (*apimodel.Status, gtserror.WithCode) + StatusUnboost(ctx context.Context, authed *oauth.Auth, targetStatusID string) (*apimodel.Status, gtserror.WithCode) // StatusBoostedBy returns a slice of accounts that have boosted the given status, filtered according to privacy settings. - StatusBoostedBy(authed *oauth.Auth, targetStatusID string) ([]*apimodel.Account, gtserror.WithCode) + StatusBoostedBy(ctx context.Context, authed *oauth.Auth, targetStatusID string) ([]*apimodel.Account, gtserror.WithCode) // StatusFavedBy returns a slice of accounts that have liked the given status, filtered according to privacy settings. - StatusFavedBy(authed *oauth.Auth, targetStatusID string) ([]*apimodel.Account, error) + StatusFavedBy(ctx context.Context, authed *oauth.Auth, targetStatusID string) ([]*apimodel.Account, error) // StatusGet gets the given status, taking account of privacy settings and blocks etc. - StatusGet(authed *oauth.Auth, targetStatusID string) (*apimodel.Status, error) + StatusGet(ctx context.Context, authed *oauth.Auth, targetStatusID string) (*apimodel.Status, error) // StatusUnfave processes the unfaving of a given status, returning the updated status if the fave goes through. - StatusUnfave(authed *oauth.Auth, targetStatusID string) (*apimodel.Status, error) + StatusUnfave(ctx context.Context, authed *oauth.Auth, targetStatusID string) (*apimodel.Status, error) // StatusGetContext returns the context (previous and following posts) from the given status ID - StatusGetContext(authed *oauth.Auth, targetStatusID string) (*apimodel.Context, gtserror.WithCode) + StatusGetContext(ctx context.Context, authed *oauth.Auth, targetStatusID string) (*apimodel.Context, gtserror.WithCode) // HomeTimelineGet returns statuses from the home timeline, with the given filters/parameters. - HomeTimelineGet(authed *oauth.Auth, maxID string, sinceID string, minID string, limit int, local bool) (*apimodel.StatusTimelineResponse, gtserror.WithCode) + HomeTimelineGet(ctx context.Context, authed *oauth.Auth, maxID string, sinceID string, minID string, limit int, local bool) (*apimodel.StatusTimelineResponse, gtserror.WithCode) // PublicTimelineGet returns statuses from the public/local timeline, with the given filters/parameters. - PublicTimelineGet(authed *oauth.Auth, maxID string, sinceID string, minID string, limit int, local bool) (*apimodel.StatusTimelineResponse, gtserror.WithCode) + PublicTimelineGet(ctx context.Context, authed *oauth.Auth, maxID string, sinceID string, minID string, limit int, local bool) (*apimodel.StatusTimelineResponse, gtserror.WithCode) // FavedTimelineGet returns faved statuses, with the given filters/parameters. - FavedTimelineGet(authed *oauth.Auth, maxID string, minID string, limit int) (*apimodel.StatusTimelineResponse, gtserror.WithCode) + FavedTimelineGet(ctx context.Context, authed *oauth.Auth, maxID string, minID string, limit int) (*apimodel.StatusTimelineResponse, gtserror.WithCode) // AuthorizeStreamingRequest returns a gotosocial account in exchange for an access token, or an error if the given token is not valid. - AuthorizeStreamingRequest(accessToken string) (*gtsmodel.Account, error) + AuthorizeStreamingRequest(ctx context.Context, accessToken string) (*gtsmodel.Account, error) // OpenStreamForAccount opens a new stream for the given account, with the given stream type. - OpenStreamForAccount(account *gtsmodel.Account, streamType string) (*gtsmodel.Stream, gtserror.WithCode) + OpenStreamForAccount(ctx context.Context, account *gtsmodel.Account, streamType string) (*gtsmodel.Stream, gtserror.WithCode) /* FEDERATION API-FACING PROCESSING FUNCTIONS @@ -199,10 +199,10 @@ type Processor interface { GetWebfingerAccount(ctx context.Context, requestedUsername string, requestURL *url.URL) (*apimodel.WellKnownResponse, gtserror.WithCode) // GetNodeInfoRel returns a well known response giving the path to node info. - GetNodeInfoRel(request *http.Request) (*apimodel.WellKnownResponse, gtserror.WithCode) + GetNodeInfoRel(ctx context.Context, request *http.Request) (*apimodel.WellKnownResponse, gtserror.WithCode) // GetNodeInfo returns a node info struct in response to a node info request. - GetNodeInfo(request *http.Request) (*apimodel.Nodeinfo, gtserror.WithCode) + GetNodeInfo(ctx context.Context, request *http.Request) (*apimodel.Nodeinfo, gtserror.WithCode) // InboxPost handles POST requests to a user's inbox for new activitypub messages. // @@ -280,7 +280,7 @@ func NewProcessor(config *config.Config, tc typeutils.TypeConverter, federator f } // Start starts the Processor, reading from its channels and passing messages back and forth. -func (p *processor) Start() error { +func (p *processor) Start(ctx context.Context) error { go func() { DistLoop: for { @@ -288,14 +288,14 @@ func (p *processor) Start() error { case clientMsg := <-p.fromClientAPI: p.log.Tracef("received message FROM client API: %+v", clientMsg) go func() { - if err := p.processFromClientAPI(clientMsg); err != nil { + if err := p.processFromClientAPI(ctx, clientMsg); err != nil { p.log.Error(err) } }() case federatorMsg := <-p.fromFederator: p.log.Tracef("received message FROM federator: %+v", federatorMsg) go func() { - if err := p.processFromFederator(federatorMsg); err != nil { + if err := p.processFromFederator(ctx, federatorMsg); err != nil { p.log.Error(err) } }() diff --git a/internal/transport/controller.go b/internal/transport/controller.go index 4eb6b5658..c2f5026e0 100644 --- a/internal/transport/controller.go +++ b/internal/transport/controller.go @@ -19,6 +19,7 @@ package transport import ( + "context" "crypto" "fmt" "sync" @@ -33,7 +34,7 @@ import ( // Controller generates transports for use in making federation requests to other servers. type Controller interface { NewTransport(pubKeyID string, privkey crypto.PrivateKey) (Transport, error) - NewTransportForUsername(username string) (Transport, error) + NewTransportForUsername(ctx context.Context, username string) (Transport, error) } type controller struct { @@ -90,7 +91,7 @@ func (c *controller) NewTransport(pubKeyID string, privkey crypto.PrivateKey) (T }, nil } -func (c *controller) NewTransportForUsername(username string) (Transport, error) { +func (c *controller) NewTransportForUsername(ctx context.Context, username string) (Transport, error) { // We need an account to use to create a transport for dereferecing something. // If a username has been given, we can fetch the account with that username and use it. // Otherwise, we can take the instance account and use those credentials to make the request. @@ -101,7 +102,7 @@ func (c *controller) NewTransportForUsername(username string) (Transport, error) u = username } - ourAccount, err := c.db.GetLocalAccountByUsername(u) + ourAccount, err := c.db.GetLocalAccountByUsername(ctx, u) if err != nil { return nil, fmt.Errorf("error getting account %s from db: %s", username, err) } diff --git a/internal/transport/deliver.go b/internal/transport/deliver.go index 844cb6bea..fd0fb576f 100644 --- a/internal/transport/deliver.go +++ b/internal/transport/deliver.go @@ -1,3 +1,21 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . +*/ + package transport import ( @@ -5,12 +23,12 @@ import ( "net/url" ) -func (t *transport) BatchDeliver(c context.Context, b []byte, recipients []*url.URL) error { - return t.sigTransport.BatchDeliver(c, b, recipients) +func (t *transport) BatchDeliver(ctx context.Context, b []byte, recipients []*url.URL) error { + return t.sigTransport.BatchDeliver(ctx, b, recipients) } -func (t *transport) Deliver(c context.Context, b []byte, to *url.URL) error { +func (t *transport) Deliver(ctx context.Context, b []byte, to *url.URL) error { l := t.log.WithField("func", "Deliver") l.Debugf("performing POST to %s", to.String()) - return t.sigTransport.Deliver(c, b, to) + return t.sigTransport.Deliver(ctx, b, to) } diff --git a/internal/transport/dereference.go b/internal/transport/dereference.go index d7a28fe17..85fa370ee 100644 --- a/internal/transport/dereference.go +++ b/internal/transport/dereference.go @@ -1,3 +1,21 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . +*/ + package transport import ( @@ -5,8 +23,8 @@ import ( "net/url" ) -func (t *transport) Dereference(c context.Context, iri *url.URL) ([]byte, error) { +func (t *transport) Dereference(ctx context.Context, iri *url.URL) ([]byte, error) { l := t.log.WithField("func", "Dereference") l.Debugf("performing GET to %s", iri.String()) - return t.sigTransport.Dereference(c, iri) + return t.sigTransport.Dereference(ctx, iri) } diff --git a/internal/transport/derefinstance.go b/internal/transport/derefinstance.go index a8b2ddfc7..fca611598 100644 --- a/internal/transport/derefinstance.go +++ b/internal/transport/derefinstance.go @@ -1,3 +1,21 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . +*/ + package transport import ( @@ -16,7 +34,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/util" ) -func (t *transport) DereferenceInstance(c context.Context, iri *url.URL) (*gtsmodel.Instance, error) { +func (t *transport) DereferenceInstance(ctx context.Context, iri *url.URL) (*gtsmodel.Instance, error) { l := t.log.WithField("func", "DereferenceInstance") var i *gtsmodel.Instance @@ -27,7 +45,7 @@ func (t *transport) DereferenceInstance(c context.Context, iri *url.URL) (*gtsmo // // This will only work with Mastodon-api compatible instances: Mastodon, some Pleroma instances, GoToSocial. l.Debugf("trying to dereference instance %s by /api/v1/instance", iri.Host) - i, err = dereferenceByAPIV1Instance(c, t, iri) + i, err = dereferenceByAPIV1Instance(ctx, t, iri) if err == nil { l.Debugf("successfully dereferenced instance using /api/v1/instance") return i, nil @@ -37,7 +55,7 @@ func (t *transport) DereferenceInstance(c context.Context, iri *url.URL) (*gtsmo // If that doesn't work, try to dereference using /.well-known/nodeinfo. // This will involve two API calls and return less info overall, but should be more widely compatible. l.Debugf("trying to dereference instance %s by /.well-known/nodeinfo", iri.Host) - i, err = dereferenceByNodeInfo(c, t, iri) + i, err = dereferenceByNodeInfo(ctx, t, iri) if err == nil { l.Debugf("successfully dereferenced instance using /.well-known/nodeinfo") return i, nil @@ -58,7 +76,7 @@ func (t *transport) DereferenceInstance(c context.Context, iri *url.URL) (*gtsmo }, nil } -func dereferenceByAPIV1Instance(c context.Context, t *transport, iri *url.URL) (*gtsmodel.Instance, error) { +func dereferenceByAPIV1Instance(ctx context.Context, t *transport, iri *url.URL) (*gtsmodel.Instance, error) { l := t.log.WithField("func", "dereferenceByAPIV1Instance") cleanIRI := &url.URL{ @@ -68,11 +86,10 @@ func dereferenceByAPIV1Instance(c context.Context, t *transport, iri *url.URL) ( } l.Debugf("performing GET to %s", cleanIRI.String()) - req, err := http.NewRequest("GET", cleanIRI.String(), nil) + req, err := http.NewRequestWithContext(ctx, "GET", cleanIRI.String(), nil) if err != nil { return nil, err } - req = req.WithContext(c) req.Header.Add("Accept", "application/json") req.Header.Add("Date", t.clock.Now().UTC().Format("Mon, 02 Jan 2006 15:04:05")+" GMT") req.Header.Add("User-Agent", fmt.Sprintf("%s %s", t.appAgent, t.gofedAgent)) @@ -216,7 +233,7 @@ func dereferenceByNodeInfo(c context.Context, t *transport, iri *url.URL) (*gtsm return i, nil } -func callNodeInfoWellKnown(c context.Context, t *transport, iri *url.URL) (*url.URL, error) { +func callNodeInfoWellKnown(ctx context.Context, t *transport, iri *url.URL) (*url.URL, error) { l := t.log.WithField("func", "callNodeInfoWellKnown") cleanIRI := &url.URL{ @@ -226,11 +243,11 @@ func callNodeInfoWellKnown(c context.Context, t *transport, iri *url.URL) (*url. } l.Debugf("performing GET to %s", cleanIRI.String()) - req, err := http.NewRequest("GET", cleanIRI.String(), nil) + req, err := http.NewRequestWithContext(ctx, "GET", cleanIRI.String(), nil) if err != nil { return nil, err } - req = req.WithContext(c) + req.Header.Add("Accept", "application/json") req.Header.Add("Date", t.clock.Now().UTC().Format("Mon, 02 Jan 2006 15:04:05")+" GMT") req.Header.Add("User-Agent", fmt.Sprintf("%s %s", t.appAgent, t.gofedAgent)) @@ -281,15 +298,15 @@ func callNodeInfoWellKnown(c context.Context, t *transport, iri *url.URL) (*url. return nodeinfoHref, nil } -func callNodeInfo(c context.Context, t *transport, iri *url.URL) (*apimodel.Nodeinfo, error) { +func callNodeInfo(ctx context.Context, t *transport, iri *url.URL) (*apimodel.Nodeinfo, error) { l := t.log.WithField("func", "callNodeInfo") l.Debugf("performing GET to %s", iri.String()) - req, err := http.NewRequest("GET", iri.String(), nil) + req, err := http.NewRequestWithContext(ctx, "GET", iri.String(), nil) if err != nil { return nil, err } - req = req.WithContext(c) + req.Header.Add("Accept", "application/json") req.Header.Add("Date", t.clock.Now().UTC().Format("Mon, 02 Jan 2006 15:04:05")+" GMT") req.Header.Add("User-Agent", fmt.Sprintf("%s %s", t.appAgent, t.gofedAgent)) diff --git a/internal/transport/derefmedia.go b/internal/transport/derefmedia.go index 5fa901100..e265bfdd4 100644 --- a/internal/transport/derefmedia.go +++ b/internal/transport/derefmedia.go @@ -1,3 +1,21 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . +*/ + package transport import ( @@ -8,14 +26,13 @@ import ( "net/url" ) -func (t *transport) DereferenceMedia(c context.Context, iri *url.URL, expectedContentType string) ([]byte, error) { +func (t *transport) DereferenceMedia(ctx context.Context, iri *url.URL, expectedContentType string) ([]byte, error) { l := t.log.WithField("func", "DereferenceMedia") l.Debugf("performing GET to %s", iri.String()) - req, err := http.NewRequest("GET", iri.String(), nil) + req, err := http.NewRequestWithContext(ctx, "GET", iri.String(), nil) if err != nil { return nil, err } - req = req.WithContext(c) if expectedContentType == "" { req.Header.Add("Accept", "*/*") } else { diff --git a/internal/transport/finger.go b/internal/transport/finger.go index 12cd2fb64..bf64521c4 100644 --- a/internal/transport/finger.go +++ b/internal/transport/finger.go @@ -1,3 +1,21 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . +*/ + package transport import ( @@ -8,7 +26,7 @@ import ( "net/url" ) -func (t *transport) Finger(c context.Context, targetUsername string, targetDomain string) ([]byte, error) { +func (t *transport) Finger(ctx context.Context, targetUsername string, targetDomain string) ([]byte, error) { l := t.log.WithField("func", "Finger") urlString := fmt.Sprintf("https://%s/.well-known/webfinger?resource=acct:%s@%s", targetDomain, targetUsername, targetDomain) l.Debugf("performing GET to %s", urlString) @@ -20,11 +38,11 @@ func (t *transport) Finger(c context.Context, targetUsername string, targetDomai l.Debugf("performing GET to %s", iri.String()) - req, err := http.NewRequest("GET", iri.String(), nil) + req, err := http.NewRequestWithContext(ctx, "GET", iri.String(), nil) if err != nil { return nil, err } - req = req.WithContext(c) + req.Header.Add("Accept", "application/json") req.Header.Add("Accept", "application/jrd+json") req.Header.Add("Date", t.clock.Now().UTC().Format("Mon, 02 Jan 2006 15:04:05")+" GMT") diff --git a/internal/transport/transport.go b/internal/transport/transport.go index 04c72de5c..8d8262834 100644 --- a/internal/transport/transport.go +++ b/internal/transport/transport.go @@ -1,3 +1,21 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . +*/ + package transport import ( @@ -17,11 +35,11 @@ import ( type Transport interface { pub.Transport // DereferenceMedia fetches the bytes of the given media attachment IRI, with the expectedContentType. - DereferenceMedia(c context.Context, iri *url.URL, expectedContentType string) ([]byte, error) + DereferenceMedia(ctx context.Context, iri *url.URL, expectedContentType string) ([]byte, error) // DereferenceInstance dereferences remote instance information, first by checking /api/v1/instance, and then by checking /.well-known/nodeinfo. - DereferenceInstance(c context.Context, iri *url.URL) (*gtsmodel.Instance, error) + DereferenceInstance(ctx context.Context, iri *url.URL) (*gtsmodel.Instance, error) // Finger performs a webfinger request with the given username and domain, and returns the bytes from the response body. - Finger(c context.Context, targetUsername string, targetDomains string) ([]byte, error) + Finger(ctx context.Context, targetUsername string, targetDomains string) ([]byte, error) } // transport implements the Transport interface diff --git a/testrig/db.go b/testrig/db.go index c670103f1..47ff1a78c 100644 --- a/testrig/db.go +++ b/testrig/db.go @@ -84,111 +84,114 @@ func NewTestDB() db.DB { // signatures with, otherwise this function will randomly generate new keys for accounts and signature // verification will fail. func StandardDBSetup(db db.DB, accounts map[string]*gtsmodel.Account) { + ctx := context.Background() + for _, m := range testModels { - if err := db.CreateTable(m); err != nil { + if err := db.CreateTable(ctx, m); err != nil { panic(err) } } for _, v := range NewTestTokens() { - if err := db.Put(v); err != nil { + if err := db.Put(ctx, v); err != nil { panic(err) } } for _, v := range NewTestClients() { - if err := db.Put(v); err != nil { + if err := db.Put(ctx, v); err != nil { panic(err) } } for _, v := range NewTestApplications() { - if err := db.Put(v); err != nil { + if err := db.Put(ctx, v); err != nil { panic(err) } } for _, v := range NewTestUsers() { - if err := db.Put(v); err != nil { + if err := db.Put(ctx, v); err != nil { panic(err) } } if accounts == nil { for _, v := range NewTestAccounts() { - if err := db.Put(v); err != nil { + if err := db.Put(ctx, v); err != nil { panic(err) } } } else { for _, v := range accounts { - if err := db.Put(v); err != nil { + if err := db.Put(ctx, v); err != nil { panic(err) } } } for _, v := range NewTestAttachments() { - if err := db.Put(v); err != nil { + if err := db.Put(ctx, v); err != nil { panic(err) } } for _, v := range NewTestStatuses() { - if err := db.PutStatus(v); err != nil { + if err := db.PutStatus(ctx, v); err != nil { panic(err) } } for _, v := range NewTestEmojis() { - if err := db.Put(v); err != nil { + if err := db.Put(ctx, v); err != nil { panic(err) } } for _, v := range NewTestTags() { - if err := db.Put(v); err != nil { + if err := db.Put(ctx, v); err != nil { panic(err) } } for _, v := range NewTestMentions() { - if err := db.Put(v); err != nil { + if err := db.Put(ctx, v); err != nil { panic(err) } } for _, v := range NewTestFaves() { - if err := db.Put(v); err != nil { + if err := db.Put(ctx, v); err != nil { panic(err) } } for _, v := range NewTestFollows() { - if err := db.Put(v); err != nil { + if err := db.Put(ctx, v); err != nil { panic(err) } } for _, v := range NewTestNotifications() { - if err := db.Put(v); err != nil { + if err := db.Put(ctx, v); err != nil { panic(err) } } - if err := db.CreateInstanceAccount(); err != nil { + if err := db.CreateInstanceAccount(ctx); err != nil { panic(err) } - if err := db.CreateInstanceInstance(); err != nil { + if err := db.CreateInstanceInstance(ctx); err != nil { panic(err) } } // StandardDBTeardown drops all the standard testing tables/models from the database to ensure it's clean for the next test. func StandardDBTeardown(db db.DB) { + ctx := context.Background() for _, m := range testModels { - if err := db.DropTable(m); err != nil { + if err := db.DropTable(ctx, m); err != nil { panic(err) } } diff --git a/vendor/github.com/go-pg/pg/extra/pgdebug/go.mod b/vendor/github.com/go-pg/pg/extra/pgdebug/go.mod deleted file mode 100644 index d44ba0123..000000000 --- a/vendor/github.com/go-pg/pg/extra/pgdebug/go.mod +++ /dev/null @@ -1,7 +0,0 @@ -module github.com/go-pg/pg/extra/pgdebug - -go 1.15 - -replace github.com/go-pg/pg/v10 => ../.. - -require github.com/go-pg/pg/v10 v10.6.2 diff --git a/vendor/github.com/go-pg/pg/extra/pgdebug/go.sum b/vendor/github.com/go-pg/pg/extra/pgdebug/go.sum deleted file mode 100644 index 8483a864a..000000000 --- a/vendor/github.com/go-pg/pg/extra/pgdebug/go.sum +++ /dev/null @@ -1,161 +0,0 @@ -cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= -github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= -github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= -github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= -github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= -github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= -github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= -github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= -github.com/go-pg/zerochecker v0.2.0 h1:pp7f72c3DobMWOb2ErtZsnrPaSvHd2W4o9//8HtF4mU= -github.com/go-pg/zerochecker v0.2.0/go.mod h1:NJZ4wKL0NmTtz0GKCoJ8kym6Xn/EQzXRl2OnAe7MmDo= -github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= -github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= -github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.3.4/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= -github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= -github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= -github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= -github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= -github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= -github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= -github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= -github.com/golang/protobuf v1.4.3 h1:JjCZWpVbqXDqFVmTfYWEVTMIYrL/NPdPSCHPJ0T/raM= -github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= -github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= -github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= -github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= -github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.3 h1:x95R7cp+rSeeqAMI2knLtQ0DKlaBhv2NrtrOvafPHRo= -github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= -github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= -github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= -github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= -github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= -github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= -github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs= -github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= -github.com/nxadm/tail v1.4.4 h1:DQuhQpB1tVlglWS2hLQ5OV6B5r8aGxSrPc5Qo6uTN78= -github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= -github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= -github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk= -github.com/onsi/ginkgo v1.14.2 h1:8mVmC9kjFFmA8H4pKMUhcblgifdkOIXPvbhN1T36q1M= -github.com/onsi/ginkgo v1.14.2/go.mod h1:iSB4RoI2tjJc9BBv4NKIKWKya62Rps+oPG/Lv9klQyY= -github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= -github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= -github.com/onsi/gomega v1.10.3 h1:gph6h/qe9GSUw1NhH1gp+qb+h8rXD8Cy60Z32Qw3ELA= -github.com/onsi/gomega v1.10.3/go.mod h1:V9xEwhxec5O8UDM77eCW8vLymOMltsqPVYWrpDsH8xc= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= -github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= -github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYmW5DyG0UqvY96Bu5QYsTLvCHdrgo= -github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs= -github.com/vmihailenco/bufpool v0.1.11 h1:gOq2WmBrq0i2yW5QJ16ykccQ4wH9UyEsgLm6czKAd94= -github.com/vmihailenco/bufpool v0.1.11/go.mod h1:AFf/MOy3l2CFTKbxwt0mp2MwnqjNEs5H/UxrkA5jxTQ= -github.com/vmihailenco/msgpack/v4 v4.3.11/go.mod h1:gborTTJjAo/GWTqqRjrLCn9pgNN+NXzzngzBKDPIqw4= -github.com/vmihailenco/msgpack/v5 v5.0.0 h1:nCaMMPEyfgwkGc/Y0GreJPhuvzqCqW+Ufq5lY7zLO2c= -github.com/vmihailenco/msgpack/v5 v5.0.0/go.mod h1:HVxBVPUK/+fZMonk4bi1islLa8V3cfnBug0+4dykPzo= -github.com/vmihailenco/tagparser v0.1.1/go.mod h1:OeAg3pn3UbLjkWt+rN9oFYB6u/cQgqMEUPoW2WPyhdI= -github.com/vmihailenco/tagparser v0.1.2 h1:gnjoVuB/kljJ5wICEEOpx98oXMWPLj22G67Vbd1qPqc= -github.com/vmihailenco/tagparser v0.1.2/go.mod h1:OeAg3pn3UbLjkWt+rN9oFYB6u/cQgqMEUPoW2WPyhdI= -go.opentelemetry.io/otel v0.14.0 h1:YFBEfjCk9MTjaytCNSUkp9Q8lF7QJezA06T71FbQxLQ= -go.opentelemetry.io/otel v0.14.0/go.mod h1:vH5xEuwy7Rts0GNtsCW3HYQoZDY+OmBJ6t1bFGGlxgw= -golang.org/x/crypto v0.0.0-20180910181607-0e37d006457b/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.0.0-20201117144127-c1f2f97bffc9 h1:phUcVbl53swtrUN8kQEXFhUxPlIlWyBfKmidCu7P95o= -golang.org/x/crypto v0.0.0-20201117144127-c1f2f97bffc9/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= -golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= -golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= -golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= -golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= -golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= -golang.org/x/net v0.0.0-20200301022130-244492dfa37a/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= -golang.org/x/net v0.0.0-20201006153459-a7d1128ccaa0/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= -golang.org/x/net v0.0.0-20201110031124-69a78807bb2b h1:uwuIcX0g4Yl1NC5XAz37xsr2lTtcqevgzYNVt49waME= -golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= -golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= -golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190904154756-749cb33beabd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200519105757-fe76b779f299/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20201119102817-f84b799fce68 h1:nxC68pudNYkKU6jWhgrqdreuFiOQWj1Fs7T3VrH4Pjw= -golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= -golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k= -golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= -golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= -golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= -google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= -google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= -google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= -google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= -google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= -google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= -google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= -google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= -google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= -google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= -google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= -google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= -google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= -google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= -google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= -google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= -google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= -google.golang.org/protobuf v1.25.0 h1:Ejskq+SyPohKW+1uil0JJMtmHCgJPJ/qWTxr8qp+R4c= -google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU= -gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= -gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= -gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.3.0 h1:clyUAQHOM3G0M3f5vQj7LuJrETvjVot3Z5el9nffUtU= -gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -mellium.im/sasl v0.2.1 h1:nspKSRg7/SyO0cRGY71OkfHab8tf9kCts6a6oTDut0w= -mellium.im/sasl v0.2.1/go.mod h1:ROaEDLQNuf9vjKqE1SrAfnsobm2YKXT1gnN1uDp1PjQ= diff --git a/vendor/github.com/go-pg/pg/extra/pgdebug/pgdebug.go b/vendor/github.com/go-pg/pg/extra/pgdebug/pgdebug.go deleted file mode 100644 index bbf6ada19..000000000 --- a/vendor/github.com/go-pg/pg/extra/pgdebug/pgdebug.go +++ /dev/null @@ -1,42 +0,0 @@ -package pgdebug - -import ( - "context" - "fmt" - - "github.com/go-pg/pg/v10" -) - -// DebugHook is a query hook that logs an error with a query if there are any. -// It can be installed with: -// -// db.AddQueryHook(pgext.DebugHook{}) -type DebugHook struct { - // Verbose causes hook to print all queries (even those without an error). - Verbose bool - EmptyLine bool -} - -var _ pg.QueryHook = (*DebugHook)(nil) - -func (h DebugHook) BeforeQuery(ctx context.Context, evt *pg.QueryEvent) (context.Context, error) { - q, err := evt.FormattedQuery() - if err != nil { - return nil, err - } - - if evt.Err != nil { - fmt.Printf("%s executing a query:\n%s\n", evt.Err, q) - } else if h.Verbose { - if h.EmptyLine { - fmt.Println() - } - fmt.Println(string(q)) - } - - return ctx, nil -} - -func (DebugHook) AfterQuery(context.Context, *pg.QueryEvent) error { - return nil -} diff --git a/vendor/github.com/uptrace/bun/.gitignore b/vendor/github.com/uptrace/bun/.gitignore new file mode 100644 index 000000000..6f7763c71 --- /dev/null +++ b/vendor/github.com/uptrace/bun/.gitignore @@ -0,0 +1,3 @@ +*.s3db +*.prof +*.test diff --git a/vendor/github.com/uptrace/bun/.prettierrc.yaml b/vendor/github.com/uptrace/bun/.prettierrc.yaml new file mode 100644 index 000000000..decea5634 --- /dev/null +++ b/vendor/github.com/uptrace/bun/.prettierrc.yaml @@ -0,0 +1,6 @@ +trailingComma: all +tabWidth: 2 +semi: false +singleQuote: true +proseWrap: always +printWidth: 100 diff --git a/vendor/github.com/uptrace/bun/CHANGELOG.md b/vendor/github.com/uptrace/bun/CHANGELOG.md new file mode 100644 index 000000000..01bf6ba31 --- /dev/null +++ b/vendor/github.com/uptrace/bun/CHANGELOG.md @@ -0,0 +1,99 @@ +# Changelog + +## v0.4.1 - Aug 18 2021 + +- Fixed migrate package to properly rollback migrations. +- Added `allowzero` tag option that undoes `nullzero` option. + +## v0.4.0 - Aug 11 2021 + +- Changed `WhereGroup` function to accept `*SelectQuery`. +- Fixed query hooks for count queries. + +## v0.3.4 - Jul 19 2021 + +- Renamed `migrate.CreateGo` to `CreateGoMigration`. +- Added `migrate.WithPackageName` to customize the Go package name in generated migrations. +- Renamed `migrate.CreateSQL` to `CreateSQLMigrations` and changed `CreateSQLMigrations` to create + both up and down migration files. + +## v0.3.1 - Jul 12 2021 + +- Renamed `alias` field struct tag to `alt` so it is not confused with column alias. +- Reworked migrate package API. See + [migrate](https://github.com/uptrace/bun/tree/master/example/migrate) example for details. + +## v0.3.0 - Jul 09 2021 + +- Changed migrate package to return structured data instead of logging the progress. See + [migrate](https://github.com/uptrace/bun/tree/master/example/migrate) example for details. + +## v0.2.14 - Jul 01 2021 + +- Added [sqliteshim](https://pkg.go.dev/github.com/uptrace/bun/driver/sqliteshim) by + [Ivan Trubach](https://github.com/tie). +- Added support for MySQL 5.7 in addition to MySQL 8. + +## v0.2.12 - Jun 29 2021 + +- Fixed scanners for net.IP and net.IPNet. + +## v0.2.10 - Jun 29 2021 + +- Fixed pgdriver to format passed query args. + +## v0.2.9 - Jun 27 2021 + +- Added support for prepared statements in pgdriver. + +## v0.2.7 - Jun 26 2021 + +- Added `UpdateQuery.Bulk` helper to generate bulk-update queries. + + Before: + + ```go + models := []Model{ + {42, "hello"}, + {43, "world"}, + } + return db.NewUpdate(). + With("_data", db.NewValues(&models)). + Model(&models). + Table("_data"). + Set("model.str = _data.str"). + Where("model.id = _data.id") + ``` + + Now: + + ```go + db.NewUpdate(). + Model(&models). + Bulk() + ``` + +## v0.2.5 - Jun 25 2021 + +- Changed time.Time to always append zero time as `NULL`. +- Added `db.RunInTx` helper. + +## v0.2.4 - Jun 21 2021 + +- Added SSL support to pgdriver. + +## v0.2.3 - Jun 20 2021 + +- Replaced `ForceDelete(ctx)` with `ForceDelete().Exec(ctx)` for soft deletes. + +## v0.2.1 - Jun 17 2021 + +- Renamed `DBI` to `IConn`. `IConn` is a common interface for `*sql.DB`, `*sql.Conn`, and `*sql.Tx`. +- Added `IDB`. `IDB` is a common interface for `*bun.DB`, `bun.Conn`, and `bun.Tx`. + +## v0.2.0 - Jun 16 2021 + +- Changed [model hooks](https://bun.uptrace.dev/guide/hooks.html#model-hooks). See + [model-hooks](example/model-hooks) example. +- Renamed `has-one` to `belongs-to`. Renamed `belongs-to` to `has-one`. Previously Bun used + incorrect names for these relations. diff --git a/vendor/github.com/go-pg/pg/extra/pgdebug/LICENSE b/vendor/github.com/uptrace/bun/LICENSE similarity index 94% rename from vendor/github.com/go-pg/pg/extra/pgdebug/LICENSE rename to vendor/github.com/uptrace/bun/LICENSE index 7751509b8..7ec81810c 100644 --- a/vendor/github.com/go-pg/pg/extra/pgdebug/LICENSE +++ b/vendor/github.com/uptrace/bun/LICENSE @@ -1,4 +1,4 @@ -Copyright (c) 2013 github.com/go-pg/pg Authors. All rights reserved. +Copyright (c) 2021 Vladimir Mihailenco. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/vendor/github.com/uptrace/bun/Makefile b/vendor/github.com/uptrace/bun/Makefile new file mode 100644 index 000000000..54744c617 --- /dev/null +++ b/vendor/github.com/uptrace/bun/Makefile @@ -0,0 +1,21 @@ +ALL_GO_MOD_DIRS := $(shell find . -type f -name 'go.mod' -exec dirname {} \; | sort) + +test: + set -e; for dir in $(ALL_GO_MOD_DIRS); do \ + echo "go test in $${dir}"; \ + (cd "$${dir}" && \ + go test ./... && \ + go vet); \ + done + +go_mod_tidy: + set -e; for dir in $(ALL_GO_MOD_DIRS); do \ + echo "go mod tidy in $${dir}"; \ + (cd "$${dir}" && \ + go get -d ./... && \ + go mod tidy); \ + done + +fmt: + gofmt -w -s ./ + goimports -w -local github.com/uptrace/bun ./ diff --git a/vendor/github.com/uptrace/bun/README.md b/vendor/github.com/uptrace/bun/README.md new file mode 100644 index 000000000..e7cc77a60 --- /dev/null +++ b/vendor/github.com/uptrace/bun/README.md @@ -0,0 +1,267 @@ +

+ + All-in-one tool to optimize performance and monitor errors & logs + +

+ +# Simple and performant SQL database client + +[![build workflow](https://github.com/uptrace/bun/actions/workflows/build.yml/badge.svg)](https://github.com/uptrace/bun/actions) +[![PkgGoDev](https://pkg.go.dev/badge/github.com/uptrace/bun)](https://pkg.go.dev/github.com/uptrace/bun) +[![Documentation](https://img.shields.io/badge/bun-documentation-informational)](https://bun.uptrace.dev/) +[![Chat](https://discordapp.com/api/guilds/752070105847955518/widget.png)](https://discord.gg/rWtp5Aj) + +Main features are: + +- Works with [PostgreSQL](https://bun.uptrace.dev/guide/drivers.html#postgresql), + [MySQL](https://bun.uptrace.dev/guide/drivers.html#mysql), + [SQLite](https://bun.uptrace.dev/guide/drivers.html#sqlite). +- [Selecting](/example/basic/) into a map, struct, slice of maps/structs/vars. +- [Bulk inserts](https://bun.uptrace.dev/guide/queries.html#insert). +- [Bulk updates](https://bun.uptrace.dev/guide/queries.html#update) using common table expressions. +- [Bulk deletes](https://bun.uptrace.dev/guide/queries.html#delete). +- [Fixtures](https://bun.uptrace.dev/guide/fixtures.html). +- [Migrations](https://bun.uptrace.dev/guide/migrations.html). +- [Soft deletes](https://bun.uptrace.dev/guide/soft-deletes.html). + +Resources: + +- [Examples](https://github.com/uptrace/bun/tree/master/example) +- [Documentation](https://bun.uptrace.dev/) +- [Reference](https://pkg.go.dev/github.com/uptrace/bun) +- [Starter kit](https://github.com/go-bun/bun-starter-kit) +- [RealWorld app](https://github.com/go-bun/bun-realworld-app) + +
+ github.com/frederikhors/orm-benchmark results + +``` + 4000 times - Insert + raw_stmt: 0.38s 94280 ns/op 718 B/op 14 allocs/op + raw: 0.39s 96719 ns/op 718 B/op 13 allocs/op + beego_orm: 0.48s 118994 ns/op 2411 B/op 56 allocs/op + bun: 0.57s 142285 ns/op 918 B/op 12 allocs/op + pg: 0.58s 145496 ns/op 1235 B/op 12 allocs/op + gorm: 0.70s 175294 ns/op 6665 B/op 88 allocs/op + xorm: 0.76s 189533 ns/op 3032 B/op 94 allocs/op + + 4000 times - MultiInsert 100 row + raw: 4.59s 1147385 ns/op 135155 B/op 916 allocs/op + raw_stmt: 4.59s 1148137 ns/op 131076 B/op 916 allocs/op + beego_orm: 5.50s 1375637 ns/op 179962 B/op 2747 allocs/op + bun: 6.18s 1544648 ns/op 4265 B/op 214 allocs/op + pg: 7.01s 1753495 ns/op 5039 B/op 114 allocs/op + gorm: 9.52s 2379219 ns/op 293956 B/op 3729 allocs/op + xorm: 11.66s 2915478 ns/op 286140 B/op 7422 allocs/op + + 4000 times - Update + raw_stmt: 0.26s 65781 ns/op 773 B/op 14 allocs/op + raw: 0.31s 77209 ns/op 757 B/op 13 allocs/op + beego_orm: 0.43s 107064 ns/op 1802 B/op 47 allocs/op + bun: 0.56s 139839 ns/op 589 B/op 4 allocs/op + pg: 0.60s 149608 ns/op 896 B/op 11 allocs/op + gorm: 0.74s 185970 ns/op 6604 B/op 81 allocs/op + xorm: 0.81s 203240 ns/op 2994 B/op 119 allocs/op + + 4000 times - Read + raw: 0.33s 81671 ns/op 2081 B/op 49 allocs/op + raw_stmt: 0.34s 85847 ns/op 2112 B/op 50 allocs/op + beego_orm: 0.38s 94777 ns/op 2106 B/op 75 allocs/op + pg: 0.42s 106148 ns/op 1526 B/op 22 allocs/op + bun: 0.43s 106904 ns/op 1319 B/op 18 allocs/op + gorm: 0.65s 162221 ns/op 5240 B/op 108 allocs/op + xorm: 1.13s 281738 ns/op 8326 B/op 237 allocs/op + + 4000 times - MultiRead limit 100 + raw: 1.52s 380351 ns/op 38356 B/op 1037 allocs/op + raw_stmt: 1.54s 385541 ns/op 38388 B/op 1038 allocs/op + pg: 1.86s 465468 ns/op 24045 B/op 631 allocs/op + bun: 2.58s 645354 ns/op 30009 B/op 1122 allocs/op + beego_orm: 2.93s 732028 ns/op 55280 B/op 3077 allocs/op + gorm: 4.97s 1241831 ns/op 71628 B/op 3877 allocs/op + xorm: doesn't work +``` + +
+ +## Installation + +```go +go get github.com/uptrace/bun +``` + +You also need to install a database/sql driver and the corresponding Bun +[dialect](https://bun.uptrace.dev/guide/drivers.html). + +## Quickstart + +First you need to create a `sql.DB`. Here we are using the +[sqliteshim](https://pkg.go.dev/github.com/uptrace/bun/driver/sqliteshim) driver which choses +between [modernc.org/sqlite](https://modernc.org/sqlite/) and +[mattn/go-sqlite3](https://github.com/mattn/go-sqlite3) depending on your platform. + +```go +import "github.com/uptrace/bun/driver/sqliteshim" + +sqldb, err := sql.Open(sqliteshim.ShimName, "file::memory:?cache=shared") +if err != nil { + panic(err) +} +``` + +And then create a `bun.DB` on top of it using the corresponding SQLite dialect: + +```go +import ( + "github.com/uptrace/bun" + "github.com/uptrace/bun/dialect/sqlitedialect" +) + +db := bun.NewDB(sqldb, sqlitedialect.New()) +``` + +Now you are ready to issue some queries: + +```go +type User struct { + ID int64 + Name string +} + +user := new(User) +err := db.NewSelect(). + Model(user). + Where("name != ?", ""). + OrderExpr("id ASC"). + Limit(1). + Scan(ctx) +``` + +The code above is equivalent to: + +```go +query := "SELECT id, name FROM users AS user WHERE name != '' ORDER BY id ASC LIMIT 1" + +rows, err := sqldb.QueryContext(ctx, query) +if err != nil { + panic(err) +} + +if !rows.Next() { + panic(sql.ErrNoRows) +} + +user := new(User) +if err := db.ScanRow(ctx, rows, user); err != nil { + panic(err) +} + +if err := rows.Err(); err != nil { + panic(err) +} +``` + +## Basic example + +To provide initial data for our [example](/example/basic/), we will use Bun +[fixtures](https://bun.uptrace.dev/guide/fixtures.html): + +```go +import "github.com/uptrace/bun/dbfixture" + +// Register models for the fixture. +db.RegisterModel((*User)(nil), (*Story)(nil)) + +// WithRecreateTables tells Bun to drop existing tables and create new ones. +fixture := dbfixture.New(db, dbfixture.WithRecreateTables()) + +// Load fixture.yaml which contains data for User and Story models. +if err := fixture.Load(ctx, os.DirFS("."), "fixture.yaml"); err != nil { + panic(err) +} +``` + +The `fixture.yaml` looks like this: + +```yaml +- model: User + rows: + - _id: admin + name: admin + emails: ['admin1@admin', 'admin2@admin'] + - _id: root + name: root + emails: ['root1@root', 'root2@root'] + +- model: Story + rows: + - title: Cool story + author_id: '{{ $.User.admin.ID }}' +``` + +To select all users: + +```go +users := make([]User, 0) +if err := db.NewSelect().Model(&users).OrderExpr("id ASC").Scan(ctx); err != nil { + panic(err) +} +``` + +To select a single user by id: + +```go +user1 := new(User) +if err := db.NewSelect().Model(user1).Where("id = ?", 1).Scan(ctx); err != nil { + panic(err) +} +``` + +To select a story and the associated author in a single query: + +```go +story := new(Story) +if err := db.NewSelect(). + Model(story). + Relation("Author"). + Limit(1). + Scan(ctx); err != nil { + panic(err) +} +``` + +To select a user into a map: + +```go +m := make(map[string]interface{}) +if err := db.NewSelect(). + Model((*User)(nil)). + Limit(1). + Scan(ctx, &m); err != nil { + panic(err) +} +``` + +To select all users scanning each column into a separate slice: + +```go +var ids []int64 +var names []string +if err := db.NewSelect(). + ColumnExpr("id, name"). + Model((*User)(nil)). + OrderExpr("id ASC"). + Scan(ctx, &ids, &names); err != nil { + panic(err) +} +``` + +For more details, please consult [docs](https://bun.uptrace.dev/) and check [examples](example). + +## Contributors + +Thanks to all the people who already contributed! + + + + diff --git a/vendor/github.com/uptrace/bun/RELEASING.md b/vendor/github.com/uptrace/bun/RELEASING.md new file mode 100644 index 000000000..9e50c1063 --- /dev/null +++ b/vendor/github.com/uptrace/bun/RELEASING.md @@ -0,0 +1,21 @@ +# Releasing + +1. Run `release.sh` script which updates versions in go.mod files and pushes a new branch to GitHub: + +```shell +./scripts/release.sh -t v1.0.0 +``` + +2. Open a pull request and wait for the build to finish. + +3. Merge the pull request and run `tag.sh` to create tags for packages: + +```shell +./scripts/tag.sh -t v1.0.0 +``` + +4. Push the tags: + +```shell +git push origin --tags +``` diff --git a/vendor/github.com/uptrace/bun/bun.go b/vendor/github.com/uptrace/bun/bun.go new file mode 100644 index 000000000..92ebe691a --- /dev/null +++ b/vendor/github.com/uptrace/bun/bun.go @@ -0,0 +1,122 @@ +package bun + +import ( + "context" + "fmt" + "reflect" + + "github.com/uptrace/bun/schema" +) + +type ( + Safe = schema.Safe + Ident = schema.Ident +) + +type NullTime = schema.NullTime + +type BaseModel = schema.BaseModel + +type ( + BeforeScanHook = schema.BeforeScanHook + AfterScanHook = schema.AfterScanHook +) + +type BeforeSelectHook interface { + BeforeSelect(ctx context.Context, query *SelectQuery) error +} + +type AfterSelectHook interface { + AfterSelect(ctx context.Context, query *SelectQuery) error +} + +type BeforeInsertHook interface { + BeforeInsert(ctx context.Context, query *InsertQuery) error +} + +type AfterInsertHook interface { + AfterInsert(ctx context.Context, query *InsertQuery) error +} + +type BeforeUpdateHook interface { + BeforeUpdate(ctx context.Context, query *UpdateQuery) error +} + +type AfterUpdateHook interface { + AfterUpdate(ctx context.Context, query *UpdateQuery) error +} + +type BeforeDeleteHook interface { + BeforeDelete(ctx context.Context, query *DeleteQuery) error +} + +type AfterDeleteHook interface { + AfterDelete(ctx context.Context, query *DeleteQuery) error +} + +type BeforeCreateTableHook interface { + BeforeCreateTable(ctx context.Context, query *CreateTableQuery) error +} + +type AfterCreateTableHook interface { + AfterCreateTable(ctx context.Context, query *CreateTableQuery) error +} + +type BeforeDropTableHook interface { + BeforeDropTable(ctx context.Context, query *DropTableQuery) error +} + +type AfterDropTableHook interface { + AfterDropTable(ctx context.Context, query *DropTableQuery) error +} + +//------------------------------------------------------------------------------ + +type InValues struct { + slice reflect.Value + err error +} + +var _ schema.QueryAppender = InValues{} + +func In(slice interface{}) InValues { + v := reflect.ValueOf(slice) + if v.Kind() != reflect.Slice { + return InValues{ + err: fmt.Errorf("bun: In(non-slice %T)", slice), + } + } + return InValues{ + slice: v, + } +} + +func (in InValues) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) { + if in.err != nil { + return nil, in.err + } + return appendIn(fmter, b, in.slice), nil +} + +func appendIn(fmter schema.Formatter, b []byte, slice reflect.Value) []byte { + sliceLen := slice.Len() + for i := 0; i < sliceLen; i++ { + if i > 0 { + b = append(b, ", "...) + } + + elem := slice.Index(i) + if elem.Kind() == reflect.Interface { + elem = elem.Elem() + } + + if elem.Kind() == reflect.Slice { + b = append(b, '(') + b = appendIn(fmter, b, elem) + b = append(b, ')') + } else { + b = fmter.AppendValue(b, elem) + } + } + return b +} diff --git a/vendor/github.com/uptrace/bun/db.go b/vendor/github.com/uptrace/bun/db.go new file mode 100644 index 000000000..d08adefb5 --- /dev/null +++ b/vendor/github.com/uptrace/bun/db.go @@ -0,0 +1,502 @@ +package bun + +import ( + "context" + "database/sql" + "errors" + "fmt" + "reflect" + "strings" + "sync/atomic" + + "github.com/uptrace/bun/dialect/feature" + "github.com/uptrace/bun/internal" + "github.com/uptrace/bun/schema" +) + +const ( + discardUnknownColumns internal.Flag = 1 << iota +) + +type DBStats struct { + Queries uint64 + Errors uint64 +} + +type DBOption func(db *DB) + +func WithDiscardUnknownColumns() DBOption { + return func(db *DB) { + db.flags = db.flags.Set(discardUnknownColumns) + } +} + +type DB struct { + *sql.DB + dialect schema.Dialect + features feature.Feature + + queryHooks []QueryHook + + fmter schema.Formatter + flags internal.Flag + + stats DBStats +} + +func NewDB(sqldb *sql.DB, dialect schema.Dialect, opts ...DBOption) *DB { + dialect.Init(sqldb) + + db := &DB{ + DB: sqldb, + dialect: dialect, + features: dialect.Features(), + fmter: schema.NewFormatter(dialect), + } + + for _, opt := range opts { + opt(db) + } + + return db +} + +func (db *DB) String() string { + var b strings.Builder + b.WriteString("DB") + return b.String() +} + +func (db *DB) DBStats() DBStats { + return DBStats{ + Queries: atomic.LoadUint64(&db.stats.Queries), + Errors: atomic.LoadUint64(&db.stats.Errors), + } +} + +func (db *DB) NewValues(model interface{}) *ValuesQuery { + return NewValuesQuery(db, model) +} + +func (db *DB) NewSelect() *SelectQuery { + return NewSelectQuery(db) +} + +func (db *DB) NewInsert() *InsertQuery { + return NewInsertQuery(db) +} + +func (db *DB) NewUpdate() *UpdateQuery { + return NewUpdateQuery(db) +} + +func (db *DB) NewDelete() *DeleteQuery { + return NewDeleteQuery(db) +} + +func (db *DB) NewCreateTable() *CreateTableQuery { + return NewCreateTableQuery(db) +} + +func (db *DB) NewDropTable() *DropTableQuery { + return NewDropTableQuery(db) +} + +func (db *DB) NewCreateIndex() *CreateIndexQuery { + return NewCreateIndexQuery(db) +} + +func (db *DB) NewDropIndex() *DropIndexQuery { + return NewDropIndexQuery(db) +} + +func (db *DB) NewTruncateTable() *TruncateTableQuery { + return NewTruncateTableQuery(db) +} + +func (db *DB) NewAddColumn() *AddColumnQuery { + return NewAddColumnQuery(db) +} + +func (db *DB) NewDropColumn() *DropColumnQuery { + return NewDropColumnQuery(db) +} + +func (db *DB) ResetModel(ctx context.Context, models ...interface{}) error { + for _, model := range models { + if _, err := db.NewDropTable().Model(model).IfExists().Exec(ctx); err != nil { + return err + } + if _, err := db.NewCreateTable().Model(model).Exec(ctx); err != nil { + return err + } + } + return nil +} + +func (db *DB) Dialect() schema.Dialect { + return db.dialect +} + +func (db *DB) ScanRows(ctx context.Context, rows *sql.Rows, dest ...interface{}) error { + model, err := newModel(db, dest) + if err != nil { + return err + } + + _, err = model.ScanRows(ctx, rows) + return err +} + +func (db *DB) ScanRow(ctx context.Context, rows *sql.Rows, dest ...interface{}) error { + model, err := newModel(db, dest) + if err != nil { + return err + } + + rs, ok := model.(rowScanner) + if !ok { + return fmt.Errorf("bun: %T does not support ScanRow", model) + } + + return rs.ScanRow(ctx, rows) +} + +func (db *DB) AddQueryHook(hook QueryHook) { + db.queryHooks = append(db.queryHooks, hook) +} + +func (db *DB) Table(typ reflect.Type) *schema.Table { + return db.dialect.Tables().Get(typ) +} + +func (db *DB) RegisterModel(models ...interface{}) { + db.dialect.Tables().Register(models...) +} + +func (db *DB) clone() *DB { + clone := *db + + l := len(clone.queryHooks) + clone.queryHooks = clone.queryHooks[:l:l] + + return &clone +} + +func (db *DB) WithNamedArg(name string, value interface{}) *DB { + clone := db.clone() + clone.fmter = clone.fmter.WithNamedArg(name, value) + return clone +} + +func (db *DB) Formatter() schema.Formatter { + return db.fmter +} + +//------------------------------------------------------------------------------ + +func (db *DB) Exec(query string, args ...interface{}) (sql.Result, error) { + return db.ExecContext(context.Background(), query, args...) +} + +func (db *DB) ExecContext( + ctx context.Context, query string, args ...interface{}, +) (sql.Result, error) { + ctx, event := db.beforeQuery(ctx, nil, query, args) + res, err := db.DB.ExecContext(ctx, db.format(query, args)) + db.afterQuery(ctx, event, res, err) + return res, err +} + +func (db *DB) Query(query string, args ...interface{}) (*sql.Rows, error) { + return db.QueryContext(context.Background(), query, args...) +} + +func (db *DB) QueryContext( + ctx context.Context, query string, args ...interface{}, +) (*sql.Rows, error) { + ctx, event := db.beforeQuery(ctx, nil, query, args) + rows, err := db.DB.QueryContext(ctx, db.format(query, args)) + db.afterQuery(ctx, event, nil, err) + return rows, err +} + +func (db *DB) QueryRow(query string, args ...interface{}) *sql.Row { + return db.QueryRowContext(context.Background(), query, args...) +} + +func (db *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { + ctx, event := db.beforeQuery(ctx, nil, query, args) + row := db.DB.QueryRowContext(ctx, db.format(query, args)) + db.afterQuery(ctx, event, nil, row.Err()) + return row +} + +func (db *DB) format(query string, args []interface{}) string { + return db.fmter.FormatQuery(query, args...) +} + +//------------------------------------------------------------------------------ + +type Conn struct { + db *DB + *sql.Conn +} + +func (db *DB) Conn(ctx context.Context) (Conn, error) { + conn, err := db.DB.Conn(ctx) + if err != nil { + return Conn{}, err + } + return Conn{ + db: db, + Conn: conn, + }, nil +} + +func (c Conn) ExecContext( + ctx context.Context, query string, args ...interface{}, +) (sql.Result, error) { + ctx, event := c.db.beforeQuery(ctx, nil, query, args) + res, err := c.Conn.ExecContext(ctx, c.db.format(query, args)) + c.db.afterQuery(ctx, event, res, err) + return res, err +} + +func (c Conn) QueryContext( + ctx context.Context, query string, args ...interface{}, +) (*sql.Rows, error) { + ctx, event := c.db.beforeQuery(ctx, nil, query, args) + rows, err := c.Conn.QueryContext(ctx, c.db.format(query, args)) + c.db.afterQuery(ctx, event, nil, err) + return rows, err +} + +func (c Conn) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { + ctx, event := c.db.beforeQuery(ctx, nil, query, args) + row := c.Conn.QueryRowContext(ctx, c.db.format(query, args)) + c.db.afterQuery(ctx, event, nil, row.Err()) + return row +} + +func (c Conn) NewValues(model interface{}) *ValuesQuery { + return NewValuesQuery(c.db, model).Conn(c) +} + +func (c Conn) NewSelect() *SelectQuery { + return NewSelectQuery(c.db).Conn(c) +} + +func (c Conn) NewInsert() *InsertQuery { + return NewInsertQuery(c.db).Conn(c) +} + +func (c Conn) NewUpdate() *UpdateQuery { + return NewUpdateQuery(c.db).Conn(c) +} + +func (c Conn) NewDelete() *DeleteQuery { + return NewDeleteQuery(c.db).Conn(c) +} + +func (c Conn) NewCreateTable() *CreateTableQuery { + return NewCreateTableQuery(c.db).Conn(c) +} + +func (c Conn) NewDropTable() *DropTableQuery { + return NewDropTableQuery(c.db).Conn(c) +} + +func (c Conn) NewCreateIndex() *CreateIndexQuery { + return NewCreateIndexQuery(c.db).Conn(c) +} + +func (c Conn) NewDropIndex() *DropIndexQuery { + return NewDropIndexQuery(c.db).Conn(c) +} + +func (c Conn) NewTruncateTable() *TruncateTableQuery { + return NewTruncateTableQuery(c.db).Conn(c) +} + +func (c Conn) NewAddColumn() *AddColumnQuery { + return NewAddColumnQuery(c.db).Conn(c) +} + +func (c Conn) NewDropColumn() *DropColumnQuery { + return NewDropColumnQuery(c.db).Conn(c) +} + +//------------------------------------------------------------------------------ + +type Stmt struct { + *sql.Stmt +} + +func (db *DB) Prepare(query string) (Stmt, error) { + return db.PrepareContext(context.Background(), query) +} + +func (db *DB) PrepareContext(ctx context.Context, query string) (Stmt, error) { + stmt, err := db.DB.PrepareContext(ctx, query) + if err != nil { + return Stmt{}, err + } + return Stmt{Stmt: stmt}, nil +} + +//------------------------------------------------------------------------------ + +type Tx struct { + db *DB + *sql.Tx +} + +// RunInTx runs the function in a transaction. If the function returns an error, +// the transaction is rolled back. Otherwise, the transaction is committed. +func (db *DB) RunInTx( + ctx context.Context, opts *sql.TxOptions, fn func(ctx context.Context, tx Tx) error, +) error { + tx, err := db.BeginTx(ctx, opts) + if err != nil { + return err + } + defer tx.Rollback() //nolint:errcheck + + if err := fn(ctx, tx); err != nil { + return err + } + return tx.Commit() +} + +func (db *DB) Begin() (Tx, error) { + return db.BeginTx(context.Background(), nil) +} + +func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) { + tx, err := db.DB.BeginTx(ctx, opts) + if err != nil { + return Tx{}, err + } + return Tx{ + db: db, + Tx: tx, + }, nil +} + +func (tx Tx) Exec(query string, args ...interface{}) (sql.Result, error) { + return tx.ExecContext(context.TODO(), query, args...) +} + +func (tx Tx) ExecContext( + ctx context.Context, query string, args ...interface{}, +) (sql.Result, error) { + ctx, event := tx.db.beforeQuery(ctx, nil, query, args) + res, err := tx.Tx.ExecContext(ctx, tx.db.format(query, args)) + tx.db.afterQuery(ctx, event, res, err) + return res, err +} + +func (tx Tx) Query(query string, args ...interface{}) (*sql.Rows, error) { + return tx.QueryContext(context.TODO(), query, args...) +} + +func (tx Tx) QueryContext( + ctx context.Context, query string, args ...interface{}, +) (*sql.Rows, error) { + ctx, event := tx.db.beforeQuery(ctx, nil, query, args) + rows, err := tx.Tx.QueryContext(ctx, tx.db.format(query, args)) + tx.db.afterQuery(ctx, event, nil, err) + return rows, err +} + +func (tx Tx) QueryRow(query string, args ...interface{}) *sql.Row { + return tx.QueryRowContext(context.TODO(), query, args...) +} + +func (tx Tx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { + ctx, event := tx.db.beforeQuery(ctx, nil, query, args) + row := tx.Tx.QueryRowContext(ctx, tx.db.format(query, args)) + tx.db.afterQuery(ctx, event, nil, row.Err()) + return row +} + +//------------------------------------------------------------------------------ + +func (tx Tx) NewValues(model interface{}) *ValuesQuery { + return NewValuesQuery(tx.db, model).Conn(tx) +} + +func (tx Tx) NewSelect() *SelectQuery { + return NewSelectQuery(tx.db).Conn(tx) +} + +func (tx Tx) NewInsert() *InsertQuery { + return NewInsertQuery(tx.db).Conn(tx) +} + +func (tx Tx) NewUpdate() *UpdateQuery { + return NewUpdateQuery(tx.db).Conn(tx) +} + +func (tx Tx) NewDelete() *DeleteQuery { + return NewDeleteQuery(tx.db).Conn(tx) +} + +func (tx Tx) NewCreateTable() *CreateTableQuery { + return NewCreateTableQuery(tx.db).Conn(tx) +} + +func (tx Tx) NewDropTable() *DropTableQuery { + return NewDropTableQuery(tx.db).Conn(tx) +} + +func (tx Tx) NewCreateIndex() *CreateIndexQuery { + return NewCreateIndexQuery(tx.db).Conn(tx) +} + +func (tx Tx) NewDropIndex() *DropIndexQuery { + return NewDropIndexQuery(tx.db).Conn(tx) +} + +func (tx Tx) NewTruncateTable() *TruncateTableQuery { + return NewTruncateTableQuery(tx.db).Conn(tx) +} + +func (tx Tx) NewAddColumn() *AddColumnQuery { + return NewAddColumnQuery(tx.db).Conn(tx) +} + +func (tx Tx) NewDropColumn() *DropColumnQuery { + return NewDropColumnQuery(tx.db).Conn(tx) +} + +//------------------------------------------------------------------------------0 + +func (db *DB) makeQueryBytes() []byte { + // TODO: make this configurable? + return make([]byte, 0, 4096) +} + +//------------------------------------------------------------------------------ + +type result struct { + r sql.Result + n int +} + +func (r result) RowsAffected() (int64, error) { + if r.r != nil { + return r.r.RowsAffected() + } + return int64(r.n), nil +} + +func (r result) LastInsertId() (int64, error) { + if r.r != nil { + return r.r.LastInsertId() + } + return 0, errors.New("LastInsertId is not available") +} diff --git a/vendor/github.com/uptrace/bun/dialect/append.go b/vendor/github.com/uptrace/bun/dialect/append.go new file mode 100644 index 000000000..7040c5155 --- /dev/null +++ b/vendor/github.com/uptrace/bun/dialect/append.go @@ -0,0 +1,178 @@ +package dialect + +import ( + "encoding/hex" + "math" + "strconv" + "time" + "unicode/utf8" + + "github.com/uptrace/bun/internal" + "github.com/uptrace/bun/internal/parser" +) + +func AppendError(b []byte, err error) []byte { + b = append(b, "?!("...) + b = append(b, err.Error()...) + b = append(b, ')') + return b +} + +func AppendNull(b []byte) []byte { + return append(b, "NULL"...) +} + +func AppendBool(b []byte, v bool) []byte { + if v { + return append(b, "TRUE"...) + } + return append(b, "FALSE"...) +} + +func AppendFloat32(b []byte, v float32) []byte { + return appendFloat(b, float64(v), 32) +} + +func AppendFloat64(b []byte, v float64) []byte { + return appendFloat(b, v, 64) +} + +func appendFloat(b []byte, v float64, bitSize int) []byte { + switch { + case math.IsNaN(v): + return append(b, "'NaN'"...) + case math.IsInf(v, 1): + return append(b, "'Infinity'"...) + case math.IsInf(v, -1): + return append(b, "'-Infinity'"...) + default: + return strconv.AppendFloat(b, v, 'f', -1, bitSize) + } +} + +func AppendString(b []byte, s string) []byte { + b = append(b, '\'') + for _, r := range s { + if r == '\000' { + continue + } + + if r == '\'' { + b = append(b, '\'', '\'') + continue + } + + if r < utf8.RuneSelf { + b = append(b, byte(r)) + continue + } + + l := len(b) + if cap(b)-l < utf8.UTFMax { + b = append(b, make([]byte, utf8.UTFMax)...) + } + n := utf8.EncodeRune(b[l:l+utf8.UTFMax], r) + b = b[:l+n] + } + b = append(b, '\'') + return b +} + +func AppendBytes(b []byte, bytes []byte) []byte { + if bytes == nil { + return AppendNull(b) + } + + b = append(b, `'\x`...) + + s := len(b) + b = append(b, make([]byte, hex.EncodedLen(len(bytes)))...) + hex.Encode(b[s:], bytes) + + b = append(b, '\'') + + return b +} + +func AppendTime(b []byte, tm time.Time) []byte { + if tm.IsZero() { + return AppendNull(b) + } + b = append(b, '\'') + b = tm.UTC().AppendFormat(b, "2006-01-02 15:04:05.999999-07:00") + b = append(b, '\'') + return b +} + +func AppendJSON(b, jsonb []byte) []byte { + b = append(b, '\'') + + p := parser.New(jsonb) + for p.Valid() { + c := p.Read() + switch c { + case '"': + b = append(b, '"') + case '\'': + b = append(b, "''"...) + case '\000': + continue + case '\\': + if p.SkipBytes([]byte("u0000")) { + b = append(b, `\\u0000`...) + } else { + b = append(b, '\\') + if p.Valid() { + b = append(b, p.Read()) + } + } + default: + b = append(b, c) + } + } + + b = append(b, '\'') + + return b +} + +//------------------------------------------------------------------------------ + +func AppendIdent(b []byte, field string, quote byte) []byte { + return appendIdent(b, internal.Bytes(field), quote) +} + +func appendIdent(b, src []byte, quote byte) []byte { + var quoted bool +loop: + for _, c := range src { + switch c { + case '*': + if !quoted { + b = append(b, '*') + continue loop + } + case '.': + if quoted { + b = append(b, quote) + quoted = false + } + b = append(b, '.') + continue loop + } + + if !quoted { + b = append(b, quote) + quoted = true + } + if c == quote { + b = append(b, quote, quote) + } else { + b = append(b, c) + } + } + if quoted { + b = append(b, quote) + } + return b +} diff --git a/vendor/github.com/uptrace/bun/dialect/dialect.go b/vendor/github.com/uptrace/bun/dialect/dialect.go new file mode 100644 index 000000000..9ff8b2461 --- /dev/null +++ b/vendor/github.com/uptrace/bun/dialect/dialect.go @@ -0,0 +1,26 @@ +package dialect + +type Name int + +func (n Name) String() string { + switch n { + case PG: + return "pg" + case SQLite: + return "sqlite" + case MySQL5: + return "mysql5" + case MySQL8: + return "mysql8" + default: + return "invalid" + } +} + +const ( + Invalid Name = iota + PG + SQLite + MySQL5 + MySQL8 +) diff --git a/vendor/github.com/uptrace/bun/dialect/feature/feature.go b/vendor/github.com/uptrace/bun/dialect/feature/feature.go new file mode 100644 index 000000000..ff8f1d625 --- /dev/null +++ b/vendor/github.com/uptrace/bun/dialect/feature/feature.go @@ -0,0 +1,22 @@ +package feature + +import "github.com/uptrace/bun/internal" + +type Feature = internal.Flag + +const DefaultFeatures = Returning | TableCascade + +const ( + Returning Feature = 1 << iota + DefaultPlaceholder + DoubleColonCast + ValuesRow + UpdateMultiTable + InsertTableAlias + DeleteTableAlias + AutoIncrement + TableCascade + TableIdentity + TableTruncate + OnDuplicateKey +) diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/LICENSE b/vendor/github.com/uptrace/bun/dialect/pgdialect/LICENSE new file mode 100644 index 000000000..7ec81810c --- /dev/null +++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/LICENSE @@ -0,0 +1,24 @@ +Copyright (c) 2021 Vladimir Mihailenco. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/append.go b/vendor/github.com/uptrace/bun/dialect/pgdialect/append.go new file mode 100644 index 000000000..475621197 --- /dev/null +++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/append.go @@ -0,0 +1,303 @@ +package pgdialect + +import ( + "database/sql/driver" + "fmt" + "reflect" + "strconv" + "time" + "unicode/utf8" + + "github.com/uptrace/bun/dialect" + "github.com/uptrace/bun/schema" +) + +var ( + driverValuerType = reflect.TypeOf((*driver.Valuer)(nil)).Elem() + + stringType = reflect.TypeOf((*string)(nil)).Elem() + sliceStringType = reflect.TypeOf([]string(nil)) + + intType = reflect.TypeOf((*int)(nil)).Elem() + sliceIntType = reflect.TypeOf([]int(nil)) + + int64Type = reflect.TypeOf((*int64)(nil)).Elem() + sliceInt64Type = reflect.TypeOf([]int64(nil)) + + float64Type = reflect.TypeOf((*float64)(nil)).Elem() + sliceFloat64Type = reflect.TypeOf([]float64(nil)) +) + +func customAppender(typ reflect.Type) schema.AppenderFunc { + switch typ.Kind() { + case reflect.Uint32: + return appendUint32ValueAsInt + case reflect.Uint, reflect.Uint64: + return appendUint64ValueAsInt + } + return nil +} + +func appendUint32ValueAsInt(fmter schema.Formatter, b []byte, v reflect.Value) []byte { + return strconv.AppendInt(b, int64(int32(v.Uint())), 10) +} + +func appendUint64ValueAsInt(fmter schema.Formatter, b []byte, v reflect.Value) []byte { + return strconv.AppendInt(b, int64(v.Uint()), 10) +} + +//------------------------------------------------------------------------------ + +func arrayAppend(fmter schema.Formatter, b []byte, v interface{}) []byte { + switch v := v.(type) { + case int64: + return strconv.AppendInt(b, v, 10) + case float64: + return dialect.AppendFloat64(b, v) + case bool: + return dialect.AppendBool(b, v) + case []byte: + return dialect.AppendBytes(b, v) + case string: + return arrayAppendString(b, v) + case time.Time: + return dialect.AppendTime(b, v) + default: + err := fmt.Errorf("pgdialect: can't append %T", v) + return dialect.AppendError(b, err) + } +} + +func arrayElemAppender(typ reflect.Type) schema.AppenderFunc { + if typ.Kind() == reflect.String { + return arrayAppendStringValue + } + if typ.Implements(driverValuerType) { + return arrayAppendDriverValue + } + return schema.Appender(typ, customAppender) +} + +func arrayAppendStringValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte { + return arrayAppendString(b, v.String()) +} + +func arrayAppendDriverValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte { + iface, err := v.Interface().(driver.Valuer).Value() + if err != nil { + return dialect.AppendError(b, err) + } + return arrayAppend(fmter, b, iface) +} + +//------------------------------------------------------------------------------ + +func arrayAppender(typ reflect.Type) schema.AppenderFunc { + kind := typ.Kind() + if kind == reflect.Ptr { + typ = typ.Elem() + kind = typ.Kind() + } + + switch kind { + case reflect.Slice, reflect.Array: + // ok: + default: + return nil + } + + elemType := typ.Elem() + + if kind == reflect.Slice { + switch elemType { + case stringType: + return appendStringSliceValue + case intType: + return appendIntSliceValue + case int64Type: + return appendInt64SliceValue + case float64Type: + return appendFloat64SliceValue + } + } + + appendElem := arrayElemAppender(elemType) + if appendElem == nil { + panic(fmt.Errorf("pgdialect: %s is not supported", typ)) + } + + return func(fmter schema.Formatter, b []byte, v reflect.Value) []byte { + kind := v.Kind() + switch kind { + case reflect.Ptr, reflect.Slice: + if v.IsNil() { + return dialect.AppendNull(b) + } + } + + if kind == reflect.Ptr { + v = v.Elem() + } + + b = append(b, '\'') + + b = append(b, '{') + for i := 0; i < v.Len(); i++ { + elem := v.Index(i) + b = appendElem(fmter, b, elem) + b = append(b, ',') + } + if v.Len() > 0 { + b[len(b)-1] = '}' // Replace trailing comma. + } else { + b = append(b, '}') + } + + b = append(b, '\'') + + return b + } +} + +func appendStringSliceValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte { + ss := v.Convert(sliceStringType).Interface().([]string) + return appendStringSlice(b, ss) +} + +func appendStringSlice(b []byte, ss []string) []byte { + if ss == nil { + return dialect.AppendNull(b) + } + + b = append(b, '\'') + + b = append(b, '{') + for _, s := range ss { + b = arrayAppendString(b, s) + b = append(b, ',') + } + if len(ss) > 0 { + b[len(b)-1] = '}' // Replace trailing comma. + } else { + b = append(b, '}') + } + + b = append(b, '\'') + + return b +} + +func appendIntSliceValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte { + ints := v.Convert(sliceIntType).Interface().([]int) + return appendIntSlice(b, ints) +} + +func appendIntSlice(b []byte, ints []int) []byte { + if ints == nil { + return dialect.AppendNull(b) + } + + b = append(b, '\'') + + b = append(b, '{') + for _, n := range ints { + b = strconv.AppendInt(b, int64(n), 10) + b = append(b, ',') + } + if len(ints) > 0 { + b[len(b)-1] = '}' // Replace trailing comma. + } else { + b = append(b, '}') + } + + b = append(b, '\'') + + return b +} + +func appendInt64SliceValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte { + ints := v.Convert(sliceInt64Type).Interface().([]int64) + return appendInt64Slice(b, ints) +} + +func appendInt64Slice(b []byte, ints []int64) []byte { + if ints == nil { + return dialect.AppendNull(b) + } + + b = append(b, '\'') + + b = append(b, '{') + for _, n := range ints { + b = strconv.AppendInt(b, n, 10) + b = append(b, ',') + } + if len(ints) > 0 { + b[len(b)-1] = '}' // Replace trailing comma. + } else { + b = append(b, '}') + } + + b = append(b, '\'') + + return b +} + +func appendFloat64SliceValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte { + floats := v.Convert(sliceFloat64Type).Interface().([]float64) + return appendFloat64Slice(b, floats) +} + +func appendFloat64Slice(b []byte, floats []float64) []byte { + if floats == nil { + return dialect.AppendNull(b) + } + + b = append(b, '\'') + + b = append(b, '{') + for _, n := range floats { + b = dialect.AppendFloat64(b, n) + b = append(b, ',') + } + if len(floats) > 0 { + b[len(b)-1] = '}' // Replace trailing comma. + } else { + b = append(b, '}') + } + + b = append(b, '\'') + + return b +} + +//------------------------------------------------------------------------------ + +func arrayAppendString(b []byte, s string) []byte { + b = append(b, '"') + for _, r := range s { + switch r { + case 0: + // ignore + case '\'': + b = append(b, "'''"...) + case '"': + b = append(b, '\\', '"') + case '\\': + b = append(b, '\\', '\\') + default: + if r < utf8.RuneSelf { + b = append(b, byte(r)) + break + } + l := len(b) + if cap(b)-l < utf8.UTFMax { + b = append(b, make([]byte, utf8.UTFMax)...) + } + n := utf8.EncodeRune(b[l:l+utf8.UTFMax], r) + b = b[:l+n] + } + } + b = append(b, '"') + return b +} diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/array.go b/vendor/github.com/uptrace/bun/dialect/pgdialect/array.go new file mode 100644 index 000000000..57f5a4384 --- /dev/null +++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/array.go @@ -0,0 +1,65 @@ +package pgdialect + +import ( + "database/sql" + "fmt" + "reflect" + + "github.com/uptrace/bun/schema" +) + +type ArrayValue struct { + v reflect.Value + + append schema.AppenderFunc + scan schema.ScannerFunc +} + +// Array accepts a slice and returns a wrapper for working with PostgreSQL +// array data type. +// +// For struct fields you can use array tag: +// +// Emails []string `bun:",array"` +func Array(vi interface{}) *ArrayValue { + v := reflect.ValueOf(vi) + if !v.IsValid() { + panic(fmt.Errorf("bun: Array(nil)")) + } + + return &ArrayValue{ + v: v, + + append: arrayAppender(v.Type()), + scan: arrayScanner(v.Type()), + } +} + +var ( + _ schema.QueryAppender = (*ArrayValue)(nil) + _ sql.Scanner = (*ArrayValue)(nil) +) + +func (a *ArrayValue) AppendQuery(fmter schema.Formatter, b []byte) ([]byte, error) { + if a.append == nil { + panic(fmt.Errorf("bun: Array(unsupported %s)", a.v.Type())) + } + return a.append(fmter, b, a.v), nil +} + +func (a *ArrayValue) Scan(src interface{}) error { + if a.scan == nil { + return fmt.Errorf("bun: Array(unsupported %s)", a.v.Type()) + } + if a.v.Kind() != reflect.Ptr { + return fmt.Errorf("bun: Array(non-pointer %s)", a.v.Type()) + } + return a.scan(a.v, src) +} + +func (a *ArrayValue) Value() interface{} { + if a.v.IsValid() { + return a.v.Interface() + } + return nil +} diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/array_parser.go b/vendor/github.com/uptrace/bun/dialect/pgdialect/array_parser.go new file mode 100644 index 000000000..1c927fca0 --- /dev/null +++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/array_parser.go @@ -0,0 +1,146 @@ +package pgdialect + +import ( + "bytes" + "fmt" + "io" +) + +type arrayParser struct { + b []byte + i int + + buf []byte + err error +} + +func newArrayParser(b []byte) *arrayParser { + p := &arrayParser{ + b: b, + i: 1, + } + if len(b) < 2 || b[0] != '{' || b[len(b)-1] != '}' { + p.err = fmt.Errorf("bun: can't parse array: %q", b) + } + return p +} + +func (p *arrayParser) NextElem() ([]byte, error) { + if p.err != nil { + return nil, p.err + } + + c, err := p.readByte() + if err != nil { + return nil, err + } + + switch c { + case '}': + return nil, io.EOF + case '"': + b, err := p.readSubstring() + if err != nil { + return nil, err + } + + if p.peek() == ',' { + p.skipNext() + } + + return b, nil + default: + b := p.readSimple() + if bytes.Equal(b, []byte("NULL")) { + b = nil + } + + if p.peek() == ',' { + p.skipNext() + } + + return b, nil + } +} + +func (p *arrayParser) readSimple() []byte { + p.unreadByte() + + if i := bytes.IndexByte(p.b[p.i:], ','); i >= 0 { + b := p.b[p.i : p.i+i] + p.i += i + return b + } + + b := p.b[p.i : len(p.b)-1] + p.i = len(p.b) - 1 + return b +} + +func (p *arrayParser) readSubstring() ([]byte, error) { + c, err := p.readByte() + if err != nil { + return nil, err + } + + p.buf = p.buf[:0] + for { + if c == '"' { + break + } + + next, err := p.readByte() + if err != nil { + return nil, err + } + + if c == '\\' { + switch next { + case '\\', '"': + p.buf = append(p.buf, next) + + c, err = p.readByte() + if err != nil { + return nil, err + } + default: + p.buf = append(p.buf, '\\') + c = next + } + continue + } + + p.buf = append(p.buf, c) + c = next + } + + return p.buf, nil +} + +func (p *arrayParser) valid() bool { + return p.i < len(p.b) +} + +func (p *arrayParser) readByte() (byte, error) { + if p.valid() { + c := p.b[p.i] + p.i++ + return c, nil + } + return 0, io.EOF +} + +func (p *arrayParser) unreadByte() { + p.i-- +} + +func (p *arrayParser) peek() byte { + if p.valid() { + return p.b[p.i] + } + return 0 +} + +func (p *arrayParser) skipNext() { + p.i++ +} diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/array_scan.go b/vendor/github.com/uptrace/bun/dialect/pgdialect/array_scan.go new file mode 100644 index 000000000..33d31f325 --- /dev/null +++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/array_scan.go @@ -0,0 +1,302 @@ +package pgdialect + +import ( + "fmt" + "io" + "reflect" + "strconv" + + "github.com/uptrace/bun/internal" + "github.com/uptrace/bun/schema" +) + +func arrayScanner(typ reflect.Type) schema.ScannerFunc { + kind := typ.Kind() + if kind == reflect.Ptr { + typ = typ.Elem() + kind = typ.Kind() + } + + switch kind { + case reflect.Slice, reflect.Array: + // ok: + default: + return nil + } + + elemType := typ.Elem() + + if kind == reflect.Slice { + switch elemType { + case stringType: + return scanStringSliceValue + case intType: + return scanIntSliceValue + case int64Type: + return scanInt64SliceValue + case float64Type: + return scanFloat64SliceValue + } + } + + scanElem := schema.Scanner(elemType) + return func(dest reflect.Value, src interface{}) error { + dest = reflect.Indirect(dest) + if !dest.CanSet() { + return fmt.Errorf("bun: Scan(non-settable %s)", dest.Type()) + } + + kind := dest.Kind() + + if src == nil { + if kind != reflect.Slice || !dest.IsNil() { + dest.Set(reflect.Zero(dest.Type())) + } + return nil + } + + if kind == reflect.Slice { + if dest.IsNil() { + dest.Set(reflect.MakeSlice(dest.Type(), 0, 0)) + } else if dest.Len() > 0 { + dest.Set(dest.Slice(0, 0)) + } + } + + b, err := toBytes(src) + if err != nil { + return err + } + + p := newArrayParser(b) + nextValue := internal.MakeSliceNextElemFunc(dest) + for { + elem, err := p.NextElem() + if err != nil { + if err == io.EOF { + break + } + return err + } + + elemValue := nextValue() + if err := scanElem(elemValue, elem); err != nil { + return err + } + } + + return nil + } +} + +func scanStringSliceValue(dest reflect.Value, src interface{}) error { + dest = reflect.Indirect(dest) + if !dest.CanSet() { + return fmt.Errorf("bun: Scan(non-settable %s)", dest.Type()) + } + + slice, err := decodeStringSlice(src) + if err != nil { + return err + } + + dest.Set(reflect.ValueOf(slice)) + return nil +} + +func decodeStringSlice(src interface{}) ([]string, error) { + if src == nil { + return nil, nil + } + + b, err := toBytes(src) + if err != nil { + return nil, err + } + + slice := make([]string, 0) + + p := newArrayParser(b) + for { + elem, err := p.NextElem() + if err != nil { + if err == io.EOF { + break + } + return nil, err + } + slice = append(slice, string(elem)) + } + + return slice, nil +} + +func scanIntSliceValue(dest reflect.Value, src interface{}) error { + dest = reflect.Indirect(dest) + if !dest.CanSet() { + return fmt.Errorf("bun: Scan(non-settable %s)", dest.Type()) + } + + slice, err := decodeIntSlice(src) + if err != nil { + return err + } + + dest.Set(reflect.ValueOf(slice)) + return nil +} + +func decodeIntSlice(src interface{}) ([]int, error) { + if src == nil { + return nil, nil + } + + b, err := toBytes(src) + if err != nil { + return nil, err + } + + slice := make([]int, 0) + + p := newArrayParser(b) + for { + elem, err := p.NextElem() + if err != nil { + if err == io.EOF { + break + } + return nil, err + } + + if elem == nil { + slice = append(slice, 0) + continue + } + + n, err := strconv.Atoi(bytesToString(elem)) + if err != nil { + return nil, err + } + + slice = append(slice, n) + } + + return slice, nil +} + +func scanInt64SliceValue(dest reflect.Value, src interface{}) error { + dest = reflect.Indirect(dest) + if !dest.CanSet() { + return fmt.Errorf("bun: Scan(non-settable %s)", dest.Type()) + } + + slice, err := decodeInt64Slice(src) + if err != nil { + return err + } + + dest.Set(reflect.ValueOf(slice)) + return nil +} + +func decodeInt64Slice(src interface{}) ([]int64, error) { + if src == nil { + return nil, nil + } + + b, err := toBytes(src) + if err != nil { + return nil, err + } + + slice := make([]int64, 0) + + p := newArrayParser(b) + for { + elem, err := p.NextElem() + if err != nil { + if err == io.EOF { + break + } + return nil, err + } + + if elem == nil { + slice = append(slice, 0) + continue + } + + n, err := strconv.ParseInt(bytesToString(elem), 10, 64) + if err != nil { + return nil, err + } + + slice = append(slice, n) + } + + return slice, nil +} + +func scanFloat64SliceValue(dest reflect.Value, src interface{}) error { + dest = reflect.Indirect(dest) + if !dest.CanSet() { + return fmt.Errorf("bun: Scan(non-settable %s)", dest.Type()) + } + + slice, err := scanFloat64Slice(src) + if err != nil { + return err + } + + dest.Set(reflect.ValueOf(slice)) + return nil +} + +func scanFloat64Slice(src interface{}) ([]float64, error) { + if src == -1 { + return nil, nil + } + + b, err := toBytes(src) + if err != nil { + return nil, err + } + + slice := make([]float64, 0) + + p := newArrayParser(b) + for { + elem, err := p.NextElem() + if err != nil { + if err == io.EOF { + break + } + return nil, err + } + + if elem == nil { + slice = append(slice, 0) + continue + } + + n, err := strconv.ParseFloat(bytesToString(elem), 64) + if err != nil { + return nil, err + } + + slice = append(slice, n) + } + + return slice, nil +} + +func toBytes(src interface{}) ([]byte, error) { + switch src := src.(type) { + case string: + return stringToBytes(src), nil + case []byte: + return src, nil + default: + return nil, fmt.Errorf("bun: got %T, wanted []byte or string", src) + } +} diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/dialect.go b/vendor/github.com/uptrace/bun/dialect/pgdialect/dialect.go new file mode 100644 index 000000000..fb210751b --- /dev/null +++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/dialect.go @@ -0,0 +1,150 @@ +package pgdialect + +import ( + "database/sql" + "reflect" + "strconv" + "sync" + "time" + + "github.com/uptrace/bun/dialect" + "github.com/uptrace/bun/dialect/feature" + "github.com/uptrace/bun/dialect/sqltype" + "github.com/uptrace/bun/schema" +) + +type Dialect struct { + tables *schema.Tables + features feature.Feature + + appenderMap sync.Map + scannerMap sync.Map +} + +func New() *Dialect { + d := new(Dialect) + d.tables = schema.NewTables(d) + d.features = feature.Returning | + feature.DefaultPlaceholder | + feature.DoubleColonCast | + feature.InsertTableAlias | + feature.DeleteTableAlias | + feature.TableCascade | + feature.TableIdentity | + feature.TableTruncate + return d +} + +func (d *Dialect) Init(*sql.DB) {} + +func (d *Dialect) Name() dialect.Name { + return dialect.PG +} + +func (d *Dialect) Features() feature.Feature { + return d.features +} + +func (d *Dialect) Tables() *schema.Tables { + return d.tables +} + +func (d *Dialect) OnTable(table *schema.Table) { + for _, field := range table.FieldMap { + d.onField(field) + } +} + +func (d *Dialect) onField(field *schema.Field) { + field.DiscoveredSQLType = fieldSQLType(field) + + if field.AutoIncrement { + switch field.DiscoveredSQLType { + case sqltype.SmallInt: + field.CreateTableSQLType = pgTypeSmallSerial + case sqltype.Integer: + field.CreateTableSQLType = pgTypeSerial + case sqltype.BigInt: + field.CreateTableSQLType = pgTypeBigSerial + } + } + + if field.Tag.HasOption("array") { + field.Append = arrayAppender(field.IndirectType) + field.Scan = arrayScanner(field.IndirectType) + } +} + +func (d *Dialect) IdentQuote() byte { + return '"' +} + +func (d *Dialect) Append(fmter schema.Formatter, b []byte, v interface{}) []byte { + switch v := v.(type) { + case nil: + return dialect.AppendNull(b) + case bool: + return dialect.AppendBool(b, v) + case int: + return strconv.AppendInt(b, int64(v), 10) + case int32: + return strconv.AppendInt(b, int64(v), 10) + case int64: + return strconv.AppendInt(b, v, 10) + case uint: + return strconv.AppendInt(b, int64(v), 10) + case uint32: + return strconv.AppendInt(b, int64(v), 10) + case uint64: + return strconv.AppendInt(b, int64(v), 10) + case float32: + return dialect.AppendFloat32(b, v) + case float64: + return dialect.AppendFloat64(b, v) + case string: + return dialect.AppendString(b, v) + case time.Time: + return dialect.AppendTime(b, v) + case []byte: + return dialect.AppendBytes(b, v) + case schema.QueryAppender: + return schema.AppendQueryAppender(fmter, b, v) + default: + vv := reflect.ValueOf(v) + if vv.Kind() == reflect.Ptr && vv.IsNil() { + return dialect.AppendNull(b) + } + appender := d.Appender(vv.Type()) + return appender(fmter, b, vv) + } +} + +func (d *Dialect) Appender(typ reflect.Type) schema.AppenderFunc { + if v, ok := d.appenderMap.Load(typ); ok { + return v.(schema.AppenderFunc) + } + + fn := schema.Appender(typ, customAppender) + + if v, ok := d.appenderMap.LoadOrStore(typ, fn); ok { + return v.(schema.AppenderFunc) + } + return fn +} + +func (d *Dialect) FieldAppender(field *schema.Field) schema.AppenderFunc { + return schema.FieldAppender(d, field) +} + +func (d *Dialect) Scanner(typ reflect.Type) schema.ScannerFunc { + if v, ok := d.scannerMap.Load(typ); ok { + return v.(schema.ScannerFunc) + } + + fn := scanner(typ) + + if v, ok := d.scannerMap.LoadOrStore(typ, fn); ok { + return v.(schema.ScannerFunc) + } + return fn +} diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/go.mod b/vendor/github.com/uptrace/bun/dialect/pgdialect/go.mod new file mode 100644 index 000000000..0cad1ce5b --- /dev/null +++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/go.mod @@ -0,0 +1,7 @@ +module github.com/uptrace/bun/dialect/pgdialect + +go 1.16 + +replace github.com/uptrace/bun => ../.. + +require github.com/uptrace/bun v0.4.3 diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/go.sum b/vendor/github.com/uptrace/bun/dialect/pgdialect/go.sum new file mode 100644 index 000000000..4d0f1c1bb --- /dev/null +++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/go.sum @@ -0,0 +1,22 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYmW5DyG0UqvY96Bu5QYsTLvCHdrgo= +github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs= +github.com/vmihailenco/msgpack/v5 v5.3.4 h1:qMKAwOV+meBw2Y8k9cVwAy7qErtYCwBzZ2ellBfvnqc= +github.com/vmihailenco/msgpack/v5 v5.3.4/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc= +github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= +github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= +golang.org/x/sys v0.0.0-20210616094352-59db8d763f22 h1:RqytpXGR1iVNX7psjB3ff8y7sNFinVFvkx1c8SjBkio= +golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/safe.go b/vendor/github.com/uptrace/bun/dialect/pgdialect/safe.go new file mode 100644 index 000000000..dff30b9c5 --- /dev/null +++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/safe.go @@ -0,0 +1,11 @@ +// +build appengine + +package pgdialect + +func bytesToString(b []byte) string { + return string(b) +} + +func stringToBytes(s string) []byte { + return []byte(s) +} diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/scan.go b/vendor/github.com/uptrace/bun/dialect/pgdialect/scan.go new file mode 100644 index 000000000..9e22282f5 --- /dev/null +++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/scan.go @@ -0,0 +1,28 @@ +package pgdialect + +import ( + "fmt" + "reflect" + + "github.com/uptrace/bun/schema" +) + +func scanner(typ reflect.Type) schema.ScannerFunc { + if typ.Kind() == reflect.Interface { + return scanInterface + } + return schema.Scanner(typ) +} + +func scanInterface(dest reflect.Value, src interface{}) error { + if dest.IsNil() { + dest.Set(reflect.ValueOf(src)) + return nil + } + + dest = dest.Elem() + if fn := scanner(dest.Type()); fn != nil { + return fn(dest, src) + } + return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type()) +} diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/sqltype.go b/vendor/github.com/uptrace/bun/dialect/pgdialect/sqltype.go new file mode 100644 index 000000000..4c2d8075d --- /dev/null +++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/sqltype.go @@ -0,0 +1,104 @@ +package pgdialect + +import ( + "encoding/json" + "net" + "reflect" + "time" + + "github.com/uptrace/bun/dialect/sqltype" + "github.com/uptrace/bun/schema" +) + +const ( + // Date / Time + pgTypeTimestampTz = "TIMESTAMPTZ" // Timestamp with a time zone + pgTypeDate = "DATE" // Date + pgTypeTime = "TIME" // Time without a time zone + pgTypeTimeTz = "TIME WITH TIME ZONE" // Time with a time zone + pgTypeInterval = "INTERVAL" // Time Interval + + // Network Addresses + pgTypeInet = "INET" // IPv4 or IPv6 hosts and networks + pgTypeCidr = "CIDR" // IPv4 or IPv6 networks + pgTypeMacaddr = "MACADDR" // MAC addresses + + // Serial Types + pgTypeSmallSerial = "SMALLSERIAL" // 2 byte autoincrementing integer + pgTypeSerial = "SERIAL" // 4 byte autoincrementing integer + pgTypeBigSerial = "BIGSERIAL" // 8 byte autoincrementing integer + + // Character Types + pgTypeChar = "CHAR" // fixed length string (blank padded) + pgTypeText = "TEXT" // variable length string without limit + + // JSON Types + pgTypeJSON = "JSON" // text representation of json data + pgTypeJSONB = "JSONB" // binary representation of json data + + // Binary Data Types + pgTypeBytea = "BYTEA" // binary string +) + +var ( + timeType = reflect.TypeOf((*time.Time)(nil)).Elem() + ipType = reflect.TypeOf((*net.IP)(nil)).Elem() + ipNetType = reflect.TypeOf((*net.IPNet)(nil)).Elem() + jsonRawMessageType = reflect.TypeOf((*json.RawMessage)(nil)).Elem() +) + +func fieldSQLType(field *schema.Field) string { + if field.UserSQLType != "" { + return field.UserSQLType + } + + if v, ok := field.Tag.Options["composite"]; ok { + return v + } + + if _, ok := field.Tag.Options["hstore"]; ok { + return "hstore" + } + + if _, ok := field.Tag.Options["array"]; ok { + switch field.IndirectType.Kind() { + case reflect.Slice, reflect.Array: + sqlType := sqlType(field.IndirectType.Elem()) + return sqlType + "[]" + } + } + + return sqlType(field.IndirectType) +} + +func sqlType(typ reflect.Type) string { + switch typ { + case ipType: + return pgTypeInet + case ipNetType: + return pgTypeCidr + case jsonRawMessageType: + return pgTypeJSONB + } + + sqlType := schema.DiscoverSQLType(typ) + switch sqlType { + case sqltype.Timestamp: + sqlType = pgTypeTimestampTz + } + + switch typ.Kind() { + case reflect.Map, reflect.Struct: + if sqlType == sqltype.VarChar { + return pgTypeJSONB + } + return sqlType + case reflect.Array, reflect.Slice: + if typ.Elem().Kind() == reflect.Uint8 { + return pgTypeBytea + } + return pgTypeJSONB + } + + return sqlType +} diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/unsafe.go b/vendor/github.com/uptrace/bun/dialect/pgdialect/unsafe.go new file mode 100644 index 000000000..2a02a20b1 --- /dev/null +++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/unsafe.go @@ -0,0 +1,18 @@ +// +build !appengine + +package pgdialect + +import "unsafe" + +func bytesToString(b []byte) string { + return *(*string)(unsafe.Pointer(&b)) +} + +func stringToBytes(s string) []byte { + return *(*[]byte)(unsafe.Pointer( + &struct { + string + Cap int + }{s, len(s)}, + )) +} diff --git a/vendor/github.com/uptrace/bun/dialect/sqltype/sqltype.go b/vendor/github.com/uptrace/bun/dialect/sqltype/sqltype.go new file mode 100644 index 000000000..84a51d26d --- /dev/null +++ b/vendor/github.com/uptrace/bun/dialect/sqltype/sqltype.go @@ -0,0 +1,14 @@ +package sqltype + +const ( + Boolean = "BOOLEAN" + SmallInt = "SMALLINT" + Integer = "INTEGER" + BigInt = "BIGINT" + Real = "REAL" + DoublePrecision = "DOUBLE PRECISION" + VarChar = "VARCHAR" + Timestamp = "TIMESTAMP" + JSON = "JSON" + JSONB = "JSONB" +) diff --git a/vendor/github.com/uptrace/bun/driver/pgdriver/LICENSE b/vendor/github.com/uptrace/bun/driver/pgdriver/LICENSE new file mode 100644 index 000000000..7ec81810c --- /dev/null +++ b/vendor/github.com/uptrace/bun/driver/pgdriver/LICENSE @@ -0,0 +1,24 @@ +Copyright (c) 2021 Vladimir Mihailenco. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/github.com/uptrace/bun/driver/pgdriver/README.md b/vendor/github.com/uptrace/bun/driver/pgdriver/README.md new file mode 100644 index 000000000..8bb641a5c --- /dev/null +++ b/vendor/github.com/uptrace/bun/driver/pgdriver/README.md @@ -0,0 +1,36 @@ +# pgdriver + +[![PkgGoDev](https://pkg.go.dev/badge/github.com/uptrace/bun/driver/pgdriver)](https://pkg.go.dev/github.com/uptrace/bun/driver/pgdriver) + +pgdriver is a database/sql driver for PostgreSQL based on [go-pg](https://github.com/go-pg/pg) code. + +You can install it with: + +```shell +github.com/uptrace/bun/driver/pgdriver +``` + +And then create a `sql.DB` using it: + +```go +import _ "github.com/uptrace/bun/driver/pgdriver" + +dsn := "postgres://postgres:@localhost:5432/test" +db, err := sql.Open("pg", dsn) +``` + +Alternatively: + +```go +dsn := "postgres://postgres:@localhost:5432/test" +db := sql.OpenDB(pgdriver.NewConnector(pgdriver.WithDSN(dsn))) +``` + +[Benchmark](https://github.com/go-bun/bun-benchmark): + +``` +BenchmarkInsert/pg-12 7254 148380 ns/op 900 B/op 13 allocs/op +BenchmarkInsert/pgx-12 6494 166391 ns/op 2076 B/op 26 allocs/op +BenchmarkSelect/pg-12 9100 132952 ns/op 1417 B/op 18 allocs/op +BenchmarkSelect/pgx-12 8199 154920 ns/op 3679 B/op 60 allocs/op +``` diff --git a/vendor/github.com/uptrace/bun/driver/pgdriver/column.go b/vendor/github.com/uptrace/bun/driver/pgdriver/column.go new file mode 100644 index 000000000..5c3626943 --- /dev/null +++ b/vendor/github.com/uptrace/bun/driver/pgdriver/column.go @@ -0,0 +1,192 @@ +package pgdriver + +import ( + "encoding/hex" + "fmt" + "io" + "strconv" + "strings" + "time" +) + +const ( + pgBool = 16 + + pgInt2 = 21 + pgInt4 = 23 + pgInt8 = 20 + + pgFloat4 = 700 + pgFloat8 = 701 + + pgText = 25 + pgVarchar = 1043 + pgBytea = 17 + + pgDate = 1082 + pgTimestamp = 1114 + pgTimestamptz = 1184 +) + +func readColumnValue(rd *reader, dataType int32, dataLen int) (interface{}, error) { + if dataLen == -1 { + return nil, nil + } + + switch dataType { + case pgBool: + return readBoolCol(rd, dataLen) + case pgInt2: + return readIntCol(rd, dataLen, 16) + case pgInt4: + return readIntCol(rd, dataLen, 32) + case pgInt8: + return readIntCol(rd, dataLen, 64) + case pgFloat4: + return readFloatCol(rd, dataLen, 32) + case pgFloat8: + return readFloatCol(rd, dataLen, 64) + case pgTimestamp: + return readTimeCol(rd, dataLen) + case pgTimestamptz: + return readTimeCol(rd, dataLen) + case pgDate: + return readTimeCol(rd, dataLen) + case pgText, pgVarchar: + return readStringCol(rd, dataLen) + case pgBytea: + return readBytesCol(rd, dataLen) + } + + b := make([]byte, dataLen) + if _, err := io.ReadFull(rd, b); err != nil { + return nil, err + } + return b, nil +} + +func readBoolCol(rd *reader, n int) (interface{}, error) { + tmp, err := rd.ReadTemp(n) + if err != nil { + return nil, err + } + return len(tmp) == 1 && (tmp[0] == 't' || tmp[0] == '1'), nil +} + +func readIntCol(rd *reader, n int, bitSize int) (interface{}, error) { + if n <= 0 { + return 0, nil + } + + tmp, err := rd.ReadTemp(n) + if err != nil { + return 0, err + } + + return strconv.ParseInt(bytesToString(tmp), 10, bitSize) +} + +func readFloatCol(rd *reader, n int, bitSize int) (interface{}, error) { + if n <= 0 { + return 0, nil + } + + tmp, err := rd.ReadTemp(n) + if err != nil { + return 0, err + } + + return strconv.ParseFloat(bytesToString(tmp), bitSize) +} + +func readStringCol(rd *reader, n int) (interface{}, error) { + if n <= 0 { + return "", nil + } + + b := make([]byte, n) + + if _, err := io.ReadFull(rd, b); err != nil { + return nil, err + } + + return bytesToString(b), nil +} + +func readBytesCol(rd *reader, n int) (interface{}, error) { + if n <= 0 { + return []byte{}, nil + } + + tmp, err := rd.ReadTemp(n) + if err != nil { + return nil, err + } + + if len(tmp) < 2 || tmp[0] != '\\' || tmp[1] != 'x' { + return nil, fmt.Errorf("pgdriver: can't parse bytea: %q", tmp) + } + tmp = tmp[2:] // Cut off "\x". + + b := make([]byte, hex.DecodedLen(len(tmp))) + if _, err := hex.Decode(b, tmp); err != nil { + return nil, err + } + return b, nil +} + +func readTimeCol(rd *reader, n int) (interface{}, error) { + if n <= 0 { + return time.Time{}, nil + } + + tmp, err := rd.ReadTemp(n) + if err != nil { + return time.Time{}, err + } + + tm, err := parseTime(bytesToString(tmp)) + if err != nil { + return time.Time{}, err + } + return tm, nil +} + +const ( + dateFormat = "2006-01-02" + timeFormat = "15:04:05.999999999" + timestampFormat = "2006-01-02 15:04:05.999999999" + timestamptzFormat = "2006-01-02 15:04:05.999999999-07:00:00" + timestamptzFormat2 = "2006-01-02 15:04:05.999999999-07:00" + timestamptzFormat3 = "2006-01-02 15:04:05.999999999-07" +) + +func parseTime(s string) (time.Time, error) { + switch l := len(s); { + case l < len("15:04:05"): + return time.Time{}, fmt.Errorf("pgdriver: can't parse time=%q", s) + case l <= len(timeFormat): + if s[2] == ':' { + return time.ParseInLocation(timeFormat, s, time.UTC) + } + return time.ParseInLocation(dateFormat, s, time.UTC) + default: + if s[10] == 'T' { + return time.Parse(time.RFC3339Nano, s) + } + if c := s[l-9]; c == '+' || c == '-' { + return time.Parse(timestamptzFormat, s) + } + if c := s[l-6]; c == '+' || c == '-' { + return time.Parse(timestamptzFormat2, s) + } + if c := s[l-3]; c == '+' || c == '-' { + if strings.HasSuffix(s, "+00") { + s = s[:len(s)-3] + return time.ParseInLocation(timestampFormat, s, time.UTC) + } + return time.Parse(timestamptzFormat3, s) + } + return time.ParseInLocation(timestampFormat, s, time.UTC) + } +} diff --git a/vendor/github.com/uptrace/bun/driver/pgdriver/config.go b/vendor/github.com/uptrace/bun/driver/pgdriver/config.go new file mode 100644 index 000000000..8e8abfe59 --- /dev/null +++ b/vendor/github.com/uptrace/bun/driver/pgdriver/config.go @@ -0,0 +1,233 @@ +package pgdriver + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "net" + "net/url" + "os" + "strings" + "time" +) + +type Config struct { + // Network type, either tcp or unix. + // Default is tcp. + Network string + // TCP host:port or Unix socket depending on Network. + Addr string + // Dial timeout for establishing new connections. + // Default is 5 seconds. + DialTimeout time.Duration + // Dialer creates new network connection and has priority over + // Network and Addr options. + Dialer func(ctx context.Context, network, addr string) (net.Conn, error) + + // TLS config for secure connections. + TLSConfig *tls.Config + + User string + Password string + Database string + AppName string + + // Timeout for socket reads. If reached, commands will fail + // with a timeout instead of blocking. + ReadTimeout time.Duration + // Timeout for socket writes. If reached, commands will fail + // with a timeout instead of blocking. + WriteTimeout time.Duration +} + +func newDefaultConfig() *Config { + host := env("PGHOST", "localhost") + port := env("PGPORT", "5432") + + cfg := &Config{ + Network: "tcp", + Addr: net.JoinHostPort(host, port), + DialTimeout: 5 * time.Second, + + User: env("PGUSER", "postgres"), + Database: env("PGDATABASE", "postgres"), + + ReadTimeout: 10 * time.Second, + WriteTimeout: 5 * time.Second, + } + + cfg.Dialer = func(ctx context.Context, network, addr string) (net.Conn, error) { + netDialer := &net.Dialer{ + Timeout: cfg.DialTimeout, + KeepAlive: 5 * time.Minute, + } + return netDialer.DialContext(ctx, network, addr) + } + + return cfg +} + +type DriverOption func(*Connector) + +func WithAddr(addr string) DriverOption { + if addr == "" { + panic("addr is empty") + } + return func(d *Connector) { + d.cfg.Addr = addr + } +} + +func WithTLSConfig(cfg *tls.Config) DriverOption { + return func(d *Connector) { + d.cfg.TLSConfig = cfg + } +} + +func WithUser(user string) DriverOption { + if user == "" { + panic("user is empty") + } + return func(d *Connector) { + d.cfg.User = user + } +} + +func WithPassword(password string) DriverOption { + return func(d *Connector) { + d.cfg.Password = password + } +} + +func WithDatabase(database string) DriverOption { + if database == "" { + panic("database is empty") + } + return func(d *Connector) { + d.cfg.Database = database + } +} + +func WithApplicationName(appName string) DriverOption { + return func(d *Connector) { + d.cfg.AppName = appName + } +} + +func WithTimeout(timeout time.Duration) DriverOption { + return func(d *Connector) { + d.cfg.DialTimeout = timeout + d.cfg.ReadTimeout = timeout + d.cfg.WriteTimeout = timeout + } +} + +func WithDialTimeout(dialTimeout time.Duration) DriverOption { + return func(d *Connector) { + d.cfg.DialTimeout = dialTimeout + } +} + +func WithReadTimeout(readTimeout time.Duration) DriverOption { + return func(d *Connector) { + d.cfg.ReadTimeout = readTimeout + } +} + +func WithWriteTimeout(writeTimeout time.Duration) DriverOption { + return func(d *Connector) { + d.cfg.WriteTimeout = writeTimeout + } +} + +func WithDSN(dsn string) DriverOption { + return func(d *Connector) { + opts, err := parseDSN(dsn) + if err != nil { + panic(err) + } + for _, opt := range opts { + opt(d) + } + } +} + +func parseDSN(dsn string) ([]DriverOption, error) { + u, err := url.Parse(dsn) + if err != nil { + return nil, err + } + + if u.Scheme != "postgres" && u.Scheme != "postgresql" { + return nil, errors.New("pgdriver: invalid scheme: " + u.Scheme) + } + + query, err := url.ParseQuery(u.RawQuery) + if err != nil { + return nil, err + } + + var opts []DriverOption + + if u.Host != "" { + addr := u.Host + if !strings.Contains(addr, ":") { + addr += ":5432" + } + opts = append(opts, WithAddr(addr)) + } + if u.User != nil { + opts = append(opts, WithUser(u.User.Username())) + if password, ok := u.User.Password(); ok { + opts = append(opts, WithPassword(password)) + } + } + if len(u.Path) > 1 { + opts = append(opts, WithDatabase(u.Path[1:])) + } + + if appName := query.Get("application_name"); appName != "" { + opts = append(opts, WithApplicationName(appName)) + } + delete(query, "application_name") + + if sslMode := query.Get("sslmode"); sslMode != "" { + switch sslMode { + case "verify-ca", "verify-full": + opts = append(opts, WithTLSConfig(new(tls.Config))) + case "allow", "prefer", "require": + opts = append(opts, WithTLSConfig(&tls.Config{InsecureSkipVerify: true})) + case "disable": + // no TLS config + default: + return nil, fmt.Errorf("pgdriver: sslmode '%s' is not supported", sslMode) + } + } else { + opts = append(opts, WithTLSConfig(&tls.Config{InsecureSkipVerify: true})) + } + delete(query, "sslmode") + + for key := range query { + return nil, fmt.Errorf("pgdriver: unsupported option=%q", key) + } + + return opts, nil +} + +func env(key, defValue string) string { + if s := os.Getenv(key); s != "" { + return s + } + return defValue +} + +// verify is a method to make sure if the config is legitimate +// in the case it detects any errors, it returns with a non-nil error +// it can be extended to check other parameters +func (c *Config) verify() error { + if c.User == "" { + return errors.New("pgdriver: User option is empty (to configure, use WithUser).") + } + return nil +} diff --git a/vendor/github.com/uptrace/bun/driver/pgdriver/driver.go b/vendor/github.com/uptrace/bun/driver/pgdriver/driver.go new file mode 100644 index 000000000..d25c3adbc --- /dev/null +++ b/vendor/github.com/uptrace/bun/driver/pgdriver/driver.go @@ -0,0 +1,606 @@ +package pgdriver + +import ( + "bufio" + "bytes" + "context" + "database/sql" + "database/sql/driver" + "errors" + "fmt" + "io" + "log" + "net" + "os" + "strconv" + "sync" + "sync/atomic" + "time" +) + +func init() { + sql.Register("pg", NewDriver()) +} + +type logging interface { + Printf(ctx context.Context, format string, v ...interface{}) +} + +type logger struct { + log *log.Logger +} + +func (l *logger) Printf(ctx context.Context, format string, v ...interface{}) { + _ = l.log.Output(2, fmt.Sprintf(format, v...)) +} + +var Logger logging = &logger{ + log: log.New(os.Stderr, "pgdriver: ", log.LstdFlags|log.Lshortfile), +} + +//------------------------------------------------------------------------------ + +type Driver struct { + connector *Connector +} + +var _ driver.DriverContext = (*Driver)(nil) + +func NewDriver() Driver { + return Driver{} +} + +func (d Driver) OpenConnector(name string) (driver.Connector, error) { + opts, err := parseDSN(name) + if err != nil { + return nil, err + } + return NewConnector(opts...), nil +} + +func (d Driver) Open(name string) (driver.Conn, error) { + connector, err := d.OpenConnector(name) + if err != nil { + return nil, err + } + return connector.Connect(context.TODO()) +} + +//------------------------------------------------------------------------------ + +type DriverStats struct { + Queries uint64 + Errors uint64 +} + +type Connector struct { + cfg *Config + + stats DriverStats +} + +func NewConnector(opts ...DriverOption) *Connector { + d := &Connector{cfg: newDefaultConfig()} + for _, opt := range opts { + opt(d) + } + return d +} + +var _ driver.Connector = (*Connector)(nil) + +func (d *Connector) Connect(ctx context.Context) (driver.Conn, error) { + if err := d.cfg.verify(); err != nil { + return nil, err + } + + return newConn(ctx, d) +} + +func (d *Connector) Driver() driver.Driver { + return Driver{connector: d} +} + +func (d *Connector) Config() *Config { + return d.cfg +} + +func (d *Connector) Stats() DriverStats { + return DriverStats{ + Queries: atomic.LoadUint64(&d.stats.Queries), + Errors: atomic.LoadUint64(&d.stats.Errors), + } +} + +//------------------------------------------------------------------------------ + +type Conn struct { + driver *Connector + + netConn net.Conn + rd *reader + + processID int32 + secretKey int32 + + stmtCount int + + closed int32 +} + +func newConn(ctx context.Context, driver *Connector) (*Conn, error) { + netConn, err := driver.cfg.Dialer(ctx, driver.cfg.Network, driver.cfg.Addr) + if err != nil { + return nil, err + } + + cn := &Conn{ + driver: driver, + netConn: netConn, + rd: newReader(netConn), + } + + if cn.driver.cfg.TLSConfig != nil { + if err := enableSSL(ctx, cn, cn.driver.cfg.TLSConfig); err != nil { + return nil, err + } + } + + if err := startup(ctx, cn); err != nil { + return nil, err + } + + return cn, nil +} + +func (cn *Conn) reader(ctx context.Context, timeout time.Duration) *reader { + cn.setReadDeadline(ctx, timeout) + return cn.rd +} + +func (cn *Conn) withWriter( + ctx context.Context, + timeout time.Duration, + fn func(wr *bufio.Writer) error, +) error { + wr := getBufioWriter() + + cn.setWriteDeadline(ctx, timeout) + wr.Reset(cn.netConn) + + err := fn(wr) + if err == nil { + err = wr.Flush() + } + + putBufioWriter(wr) + + return err +} + +var _ driver.Conn = (*Conn)(nil) + +func (cn *Conn) Prepare(query string) (driver.Stmt, error) { + if cn.isClosed() { + return nil, driver.ErrBadConn + } + + ctx := context.TODO() + + name := fmt.Sprintf("pgdriver-%d", cn.stmtCount) + cn.stmtCount++ + + if err := writeParseDescribeSync(ctx, cn, name, query); err != nil { + return nil, err + } + + rowDesc, err := readParseDescribeSync(ctx, cn) + if err != nil { + return nil, err + } + + return newStmt(cn, name, rowDesc), nil +} + +func (cn *Conn) Close() error { + if !atomic.CompareAndSwapInt32(&cn.closed, 0, 1) { + return nil + } + return cn.netConn.Close() +} + +func (cn *Conn) isClosed() bool { + return atomic.LoadInt32(&cn.closed) == 1 +} + +func (cn *Conn) Begin() (driver.Tx, error) { + return cn.BeginTx(context.Background(), driver.TxOptions{}) +} + +var _ driver.ConnBeginTx = (*Conn)(nil) + +func (cn *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + // No need to check if the conn is closed. ExecContext below handles that. + + if sql.IsolationLevel(opts.Isolation) != sql.LevelDefault { + return nil, errors.New("pgdriver: custom IsolationLevel is not supported") + } + if opts.ReadOnly { + return nil, errors.New("pgdriver: ReadOnly transactions are not supported") + } + + if _, err := cn.ExecContext(ctx, "BEGIN", nil); err != nil { + return nil, err + } + return tx{cn: cn}, nil +} + +var _ driver.ExecerContext = (*Conn)(nil) + +func (cn *Conn) ExecContext( + ctx context.Context, query string, args []driver.NamedValue, +) (driver.Result, error) { + if cn.isClosed() { + return nil, driver.ErrBadConn + } + res, err := cn.exec(ctx, query, args) + if err != nil { + return nil, cn.checkBadConn(err) + } + return res, nil +} + +func (cn *Conn) exec( + ctx context.Context, query string, args []driver.NamedValue, +) (driver.Result, error) { + query, err := formatQuery(query, args) + if err != nil { + return nil, err + } + if err := writeQuery(ctx, cn, query); err != nil { + return nil, err + } + return readQuery(ctx, cn) +} + +var _ driver.QueryerContext = (*Conn)(nil) + +func (cn *Conn) QueryContext( + ctx context.Context, query string, args []driver.NamedValue, +) (driver.Rows, error) { + if cn.isClosed() { + return nil, driver.ErrBadConn + } + rows, err := cn.query(ctx, query, args) + if err != nil { + return nil, cn.checkBadConn(err) + } + return rows, nil +} + +func (cn *Conn) query( + ctx context.Context, query string, args []driver.NamedValue, +) (driver.Rows, error) { + query, err := formatQuery(query, args) + if err != nil { + return nil, err + } + if err := writeQuery(ctx, cn, query); err != nil { + return nil, err + } + return readQueryData(ctx, cn) +} + +var _ driver.Pinger = (*Conn)(nil) + +func (cn *Conn) Ping(ctx context.Context) error { + _, err := cn.ExecContext(ctx, "SELECT 1", nil) + return err +} + +func (cn *Conn) setReadDeadline(ctx context.Context, timeout time.Duration) { + if timeout == -1 { + timeout = cn.driver.cfg.ReadTimeout + } + _ = cn.netConn.SetReadDeadline(cn.deadline(ctx, timeout)) +} + +func (cn *Conn) setWriteDeadline(ctx context.Context, timeout time.Duration) { + if timeout == -1 { + timeout = cn.driver.cfg.WriteTimeout + } + _ = cn.netConn.SetWriteDeadline(cn.deadline(ctx, timeout)) +} + +func (cn *Conn) deadline(ctx context.Context, timeout time.Duration) time.Time { + deadline, ok := ctx.Deadline() + if !ok { + if timeout == 0 { + return time.Time{} + } + return time.Now().Add(timeout) + } + + if timeout == 0 { + return deadline + } + if tm := time.Now().Add(timeout); tm.Before(deadline) { + return tm + } + return deadline +} + +var _ driver.Validator = (*Conn)(nil) + +func (cn *Conn) IsValid() bool { + return !cn.isClosed() +} + +func (cn *Conn) checkBadConn(err error) error { + if isBadConn(err, false) { + // Close and return driver.ErrBadConn next time the conn is used. + _ = cn.Close() + } + // Always return the original error. + return err +} + +//------------------------------------------------------------------------------ + +type rows struct { + cn *Conn + rowDesc *rowDescription + reusable bool + closed bool +} + +var _ driver.Rows = (*rows)(nil) + +func newRows(cn *Conn, rowDesc *rowDescription, reusable bool) *rows { + return &rows{ + cn: cn, + rowDesc: rowDesc, + reusable: reusable, + } +} + +func (r *rows) Columns() []string { + if r.closed || r.rowDesc == nil { + return nil + } + return r.rowDesc.names +} + +func (r *rows) Close() error { + if r.closed { + return nil + } + defer r.close() + + for { + switch err := r.Next(nil); err { + case nil, io.EOF: + return nil + default: // unexpected error + _ = r.cn.Close() + return err + } + } +} + +func (r *rows) close() { + r.closed = true + + if r.rowDesc != nil { + if r.reusable { + rowDescPool.Put(r.rowDesc) + } + r.rowDesc = nil + } +} + +func (r *rows) Next(dest []driver.Value) error { + if r.closed { + return io.EOF + } + + eof, err := r.next(dest) + if err == io.EOF { + return io.ErrUnexpectedEOF + } else if err != nil { + return err + } + if eof { + return io.EOF + } + return nil +} + +func (r *rows) next(dest []driver.Value) (eof bool, _ error) { + rd := r.cn.reader(context.TODO(), -1) + var firstErr error + for { + c, msgLen, err := readMessageType(rd) + if err != nil { + return false, err + } + + switch c { + case dataRowMsg: + return false, r.readDataRow(rd, dest) + case commandCompleteMsg: + if err := rd.Discard(msgLen); err != nil { + return false, err + } + case readyForQueryMsg: + r.close() + + if err := rd.Discard(msgLen); err != nil { + return false, err + } + + if firstErr != nil { + return false, firstErr + } + return true, nil + case errorResponseMsg: + e, err := readError(rd) + if err != nil { + return false, err + } + if firstErr == nil { + firstErr = e + } + default: + return false, fmt.Errorf("pgdriver: Next: unexpected message %q", c) + } + } +} + +func (r *rows) readDataRow(rd *reader, dest []driver.Value) error { + numCol, err := readInt16(rd) + if err != nil { + return err + } + + if len(dest) != int(numCol) { + return fmt.Errorf("pgdriver: query returned %d columns, but Scan dest has %d items", + numCol, len(dest)) + } + + for colIdx := int16(0); colIdx < numCol; colIdx++ { + dataLen, err := readInt32(rd) + if err != nil { + return err + } + + value, err := readColumnValue(rd, r.rowDesc.types[colIdx], int(dataLen)) + if err != nil { + return err + } + + if dest != nil { + dest[colIdx] = value + } + } + + return nil +} + +//------------------------------------------------------------------------------ + +func parseResult(b []byte) (driver.RowsAffected, error) { + i := bytes.LastIndexByte(b, ' ') + if i == -1 { + return 0, nil + } + + b = b[i+1 : len(b)-1] + affected, err := strconv.ParseUint(bytesToString(b), 10, 64) + if err != nil { + return 0, nil + } + + return driver.RowsAffected(affected), nil +} + +//------------------------------------------------------------------------------ + +type tx struct { + cn *Conn +} + +var _ driver.Tx = (*tx)(nil) + +func (tx tx) Commit() error { + _, err := tx.cn.ExecContext(context.Background(), "COMMIT", nil) + return err +} + +func (tx tx) Rollback() error { + _, err := tx.cn.ExecContext(context.Background(), "ROLLBACK", nil) + return err +} + +//------------------------------------------------------------------------------ + +type stmt struct { + cn *Conn + name string + rowDesc *rowDescription +} + +var ( + _ driver.Stmt = (*stmt)(nil) + _ driver.StmtExecContext = (*stmt)(nil) + _ driver.StmtQueryContext = (*stmt)(nil) +) + +func newStmt(cn *Conn, name string, rowDesc *rowDescription) *stmt { + return &stmt{ + cn: cn, + name: name, + rowDesc: rowDesc, + } +} + +func (stmt *stmt) Close() error { + if stmt.rowDesc != nil { + rowDescPool.Put(stmt.rowDesc) + stmt.rowDesc = nil + } + + ctx := context.TODO() + if err := writeCloseStmt(ctx, stmt.cn, stmt.name); err != nil { + return err + } + if err := readCloseStmtComplete(ctx, stmt.cn); err != nil { + return err + } + return nil +} + +func (stmt *stmt) NumInput() int { + if stmt.rowDesc == nil { + return -1 + } + return int(stmt.rowDesc.numInput) +} + +func (stmt *stmt) Exec(args []driver.Value) (driver.Result, error) { + panic("not implemented") +} + +func (stmt *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { + if err := writeBindExecute(ctx, stmt.cn, stmt.name, args); err != nil { + return nil, err + } + return readExtQuery(ctx, stmt.cn) +} + +func (stmt *stmt) Query(args []driver.Value) (driver.Rows, error) { + panic("not implemented") +} + +func (stmt *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { + if err := writeBindExecute(ctx, stmt.cn, stmt.name, args); err != nil { + return nil, err + } + return readExtQueryData(ctx, stmt.cn, stmt.rowDesc) +} + +//------------------------------------------------------------------------------ + +var bufioWriterPool = sync.Pool{ + New: func() interface{} { + return bufio.NewWriter(nil) + }, +} + +func getBufioWriter() *bufio.Writer { + return bufioWriterPool.Get().(*bufio.Writer) +} + +func putBufioWriter(wr *bufio.Writer) { + bufioWriterPool.Put(wr) +} diff --git a/vendor/github.com/uptrace/bun/driver/pgdriver/error.go b/vendor/github.com/uptrace/bun/driver/pgdriver/error.go new file mode 100644 index 000000000..5f1f9fec6 --- /dev/null +++ b/vendor/github.com/uptrace/bun/driver/pgdriver/error.go @@ -0,0 +1,66 @@ +package pgdriver + +import ( + "fmt" + "net" +) + +// Error represents an error returned by PostgreSQL server +// using PostgreSQL ErrorResponse protocol. +// +// https://www.postgresql.org/docs/current/static/protocol-message-formats.html +type Error struct { + m map[byte]string +} + +// Field returns a string value associated with an error field. +// +// https://www.postgresql.org/docs/current/static/protocol-error-fields.html +func (err Error) Field(k byte) string { + return err.m[k] +} + +// IntegrityViolation reports whether an error is a part of +// Integrity Constraint Violation class of errors. +// +// https://www.postgresql.org/docs/current/static/errcodes-appendix.html +func (err Error) IntegrityViolation() bool { + switch err.Field('C') { + case "23000", "23001", "23502", "23503", "23505", "23514", "23P01": + return true + default: + return false + } +} + +func (err Error) Error() string { + return fmt.Sprintf("%s #%s %s", + err.Field('S'), err.Field('C'), err.Field('M')) +} + +func isBadConn(err error, allowTimeout bool) bool { + if err == nil { + return false + } + + if err, ok := err.(Error); ok { + switch err.Field('V') { + case "FATAL", "PANIC": + return true + } + switch err.Field('C') { + case "25P02", // current transaction is aborted + "57014": // canceling statement due to user request + return true + } + return false + } + + if allowTimeout { + if err, ok := err.(net.Error); ok && err.Timeout() { + return !err.Temporary() + } + } + + return true +} diff --git a/vendor/github.com/uptrace/bun/driver/pgdriver/format.go b/vendor/github.com/uptrace/bun/driver/pgdriver/format.go new file mode 100644 index 000000000..c85967da5 --- /dev/null +++ b/vendor/github.com/uptrace/bun/driver/pgdriver/format.go @@ -0,0 +1,188 @@ +package pgdriver + +import ( + "database/sql/driver" + "encoding/hex" + "fmt" + "math" + "strconv" + "time" + "unicode/utf8" +) + +func formatQuery(query string, args []driver.NamedValue) (string, error) { + if len(args) == 0 { + return query, nil + } + + dst := make([]byte, 0, 2*len(query)) + + p := newParser(query) + for p.Valid() { + switch c := p.Next(); c { + case '$': + if i, ok := p.Number(); ok { + if i > len(args) { + return "", fmt.Errorf("pgdriver: got %d args, wanted %d", len(args), i) + } + + var err error + dst, err = appendArg(dst, args[i-1].Value) + if err != nil { + return "", err + } + } else { + dst = append(dst, '$') + } + case '\'': + if b, ok := p.QuotedString(); ok { + dst = append(dst, b...) + } else { + dst = append(dst, '\'') + } + default: + dst = append(dst, c) + } + } + + return bytesToString(dst), nil +} + +func appendArg(b []byte, v interface{}) ([]byte, error) { + switch v := v.(type) { + case nil: + return append(b, "NULL"...), nil + case int64: + return strconv.AppendInt(b, v, 10), nil + case float64: + switch { + case math.IsNaN(v): + return append(b, "'NaN'"...), nil + case math.IsInf(v, 1): + return append(b, "'Infinity'"...), nil + case math.IsInf(v, -1): + return append(b, "'-Infinity'"...), nil + default: + return strconv.AppendFloat(b, v, 'f', -1, 64), nil + } + case bool: + if v { + return append(b, "TRUE"...), nil + } + return append(b, "FALSE"...), nil + case []byte: + if v == nil { + return append(b, "NULL"...), nil + } + + b = append(b, `'\x`...) + + s := len(b) + b = append(b, make([]byte, hex.EncodedLen(len(v)))...) + hex.Encode(b[s:], v) + + b = append(b, "'"...) + + return b, nil + case string: + b = append(b, '\'') + for _, r := range v { + if r == '\000' { + continue + } + + if r == '\'' { + b = append(b, '\'', '\'') + continue + } + + if r < utf8.RuneSelf { + b = append(b, byte(r)) + continue + } + l := len(b) + if cap(b)-l < utf8.UTFMax { + b = append(b, make([]byte, utf8.UTFMax)...) + } + n := utf8.EncodeRune(b[l:l+utf8.UTFMax], r) + b = b[:l+n] + } + b = append(b, '\'') + return b, nil + case time.Time: + if v.IsZero() { + return append(b, "NULL"...), nil + } + return v.UTC().AppendFormat(b, "'2006-01-02 15:04:05.999999-07:00'"), nil + default: + return nil, fmt.Errorf("pgdriver: unexpected arg: %T", v) + } +} + +type parser struct { + b []byte + i int +} + +func newParser(s string) *parser { + return &parser{ + b: stringToBytes(s), + } +} + +func (p *parser) Valid() bool { + return p.i < len(p.b) +} + +func (p *parser) Next() byte { + c := p.b[p.i] + p.i++ + return c +} + +func (p *parser) Number() (int, bool) { + start := p.i + end := len(p.b) + + for i := p.i; i < len(p.b); i++ { + c := p.b[i] + if !isNum(c) { + end = i + break + } + } + + p.i = end + b := p.b[start:end] + + n, err := strconv.Atoi(bytesToString(b)) + if err != nil { + return 0, false + } + + return n, true +} + +func (p *parser) QuotedString() ([]byte, bool) { + start := p.i - 1 + end := len(p.b) + + var c byte + for i := p.i; i < len(p.b); i++ { + next := p.b[i] + if c == '\'' && next != '\'' { + end = i + break + } + c = next + } + + p.i = end + b := p.b[start:end] + + return b, true +} + +func isNum(c byte) bool { + return c >= '0' && c <= '9' +} diff --git a/vendor/github.com/uptrace/bun/driver/pgdriver/go.mod b/vendor/github.com/uptrace/bun/driver/pgdriver/go.mod new file mode 100644 index 000000000..0ebe475f9 --- /dev/null +++ b/vendor/github.com/uptrace/bun/driver/pgdriver/go.mod @@ -0,0 +1,11 @@ +module github.com/uptrace/bun/driver/pgdriver + +go 1.16 + +replace github.com/uptrace/bun => ../.. + +require ( + github.com/stretchr/testify v1.7.0 + github.com/uptrace/bun v0.4.3 + mellium.im/sasl v0.2.1 +) diff --git a/vendor/github.com/uptrace/bun/driver/pgdriver/go.sum b/vendor/github.com/uptrace/bun/driver/pgdriver/go.sum new file mode 100644 index 000000000..db9059b35 --- /dev/null +++ b/vendor/github.com/uptrace/bun/driver/pgdriver/go.sum @@ -0,0 +1,27 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYmW5DyG0UqvY96Bu5QYsTLvCHdrgo= +github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs= +github.com/vmihailenco/msgpack/v5 v5.3.4 h1:qMKAwOV+meBw2Y8k9cVwAy7qErtYCwBzZ2ellBfvnqc= +github.com/vmihailenco/msgpack/v5 v5.3.4/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc= +github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= +github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= +golang.org/x/crypto v0.0.0-20180910181607-0e37d006457b h1:2b9XGzhjiYsYPnKXoEfL7klWZQIt8IfyRCz62gCqqlQ= +golang.org/x/crypto v0.0.0-20180910181607-0e37d006457b/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= +golang.org/x/sys v0.0.0-20210616094352-59db8d763f22 h1:RqytpXGR1iVNX7psjB3ff8y7sNFinVFvkx1c8SjBkio= +golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +mellium.im/sasl v0.2.1 h1:nspKSRg7/SyO0cRGY71OkfHab8tf9kCts6a6oTDut0w= +mellium.im/sasl v0.2.1/go.mod h1:ROaEDLQNuf9vjKqE1SrAfnsobm2YKXT1gnN1uDp1PjQ= diff --git a/vendor/github.com/uptrace/bun/driver/pgdriver/listener.go b/vendor/github.com/uptrace/bun/driver/pgdriver/listener.go new file mode 100644 index 000000000..937c80fa7 --- /dev/null +++ b/vendor/github.com/uptrace/bun/driver/pgdriver/listener.go @@ -0,0 +1,392 @@ +package pgdriver + +import ( + "context" + "errors" + "strconv" + "sync" + "sync/atomic" + "time" + + "github.com/uptrace/bun" +) + +const pingChannel = "bun:ping" + +var ( + errListenerClosed = errors.New("bun: listener is closed") + errPingTimeout = errors.New("bun: ping timeout") +) + +type Listener struct { + db *bun.DB + driver *Connector + + channels []string + + mu sync.Mutex + cn *Conn + closed bool + exit chan struct{} +} + +func NewListener(db *bun.DB) *Listener { + return &Listener{ + db: db, + driver: db.Driver().(Driver).connector, + exit: make(chan struct{}), + } +} + +// Close closes the listener, releasing any open resources. +func (ln *Listener) Close() error { + return ln.withLock(func() error { + if ln.closed { + return errListenerClosed + } + + ln.closed = true + close(ln.exit) + + return ln.closeConn(errListenerClosed) + }) +} + +func (ln *Listener) withLock(fn func() error) error { + ln.mu.Lock() + defer ln.mu.Unlock() + return fn() +} + +func (ln *Listener) conn(ctx context.Context) (*Conn, error) { + if ln.closed { + return nil, errListenerClosed + } + if ln.cn != nil { + return ln.cn, nil + } + + atomic.AddUint64(&ln.driver.stats.Queries, 1) + + cn, err := ln._conn(ctx) + if err != nil { + atomic.AddUint64(&ln.driver.stats.Errors, 1) + return nil, err + } + + ln.cn = cn + return cn, nil +} + +func (ln *Listener) _conn(ctx context.Context) (*Conn, error) { + driverConn, err := ln.driver.Connect(ctx) + if err != nil { + return nil, err + } + cn := driverConn.(*Conn) + + if len(ln.channels) > 0 { + err := ln.listen(ctx, cn, ln.channels...) + if err != nil { + _ = cn.Close() + return nil, err + } + } + + return cn, nil +} + +func (ln *Listener) checkConn(ctx context.Context, cn *Conn, err error, allowTimeout bool) { + _ = ln.withLock(func() error { + if ln.closed || ln.cn != cn { + return nil + } + if isBadConn(err, allowTimeout) { + ln.reconnect(ctx, err) + } + return nil + }) +} + +func (ln *Listener) reconnect(ctx context.Context, reason error) { + if ln.cn != nil { + Logger.Printf(ctx, "bun: discarding bad listener connection: %s", reason) + _ = ln.closeConn(reason) + } + _, _ = ln.conn(ctx) +} + +func (ln *Listener) closeConn(reason error) error { + if ln.cn == nil { + return nil + } + err := ln.cn.Close() + ln.cn = nil + return err +} + +// Listen starts listening for notifications on channels. +func (ln *Listener) Listen(ctx context.Context, channels ...string) error { + var cn *Conn + + if err := ln.withLock(func() error { + ln.channels = appendIfNotExists(ln.channels, channels...) + + var err error + cn, err = ln.conn(ctx) + return err + }); err != nil { + return err + } + + if err := ln.listen(ctx, cn, channels...); err != nil { + ln.checkConn(ctx, cn, err, false) + return err + } + return nil +} + +func (ln *Listener) listen(ctx context.Context, cn *Conn, channels ...string) error { + for _, channel := range channels { + if err := writeQuery(ctx, cn, "LISTEN "+strconv.Quote(channel)); err != nil { + return err + } + } + return nil +} + +// Unlisten stops listening for notifications on channels. +func (ln *Listener) Unlisten(ctx context.Context, channels ...string) error { + var cn *Conn + + if err := ln.withLock(func() error { + ln.channels = removeIfExists(ln.channels, channels...) + + var err error + cn, err = ln.conn(ctx) + return err + }); err != nil { + return err + } + + if err := ln.unlisten(ctx, cn, channels...); err != nil { + ln.checkConn(ctx, cn, err, false) + return err + } + return nil +} + +func (ln *Listener) unlisten(ctx context.Context, cn *Conn, channels ...string) error { + for _, channel := range channels { + if err := writeQuery(ctx, cn, "UNLISTEN "+strconv.Quote(channel)); err != nil { + return err + } + } + return nil +} + +// Receive indefinitely waits for a notification. This is low-level API +// and in most cases Channel should be used instead. +func (ln *Listener) Receive(ctx context.Context) (channel string, payload string, err error) { + return ln.ReceiveTimeout(ctx, 0) +} + +// ReceiveTimeout waits for a notification until timeout is reached. +// This is low-level API and in most cases Channel should be used instead. +func (ln *Listener) ReceiveTimeout( + ctx context.Context, timeout time.Duration, +) (channel, payload string, err error) { + var cn *Conn + + if err := ln.withLock(func() error { + var err error + cn, err = ln.conn(ctx) + return err + }); err != nil { + return "", "", err + } + + rd := cn.reader(ctx, timeout) + channel, payload, err = readNotification(ctx, rd) + if err != nil { + ln.checkConn(ctx, cn, err, timeout > 0) + return "", "", err + } + + return channel, payload, nil +} + +// Channel returns a channel for concurrently receiving notifications. +// It periodically sends Ping notification to test connection health. +// +// The channel is closed with Listener. Receive* APIs can not be used +// after channel is created. +func (ln *Listener) Channel(opts ...ChannelOption) <-chan Notification { + return newChannel(ln, opts).ch +} + +//------------------------------------------------------------------------------ + +// Notification received with LISTEN command. +type Notification struct { + Channel string + Payload string +} + +type ChannelOption func(c *channel) + +func WithChannelSize(size int) ChannelOption { + return func(c *channel) { + c.size = size + } +} + +type channel struct { + ctx context.Context + ln *Listener + + size int + pingTimeout time.Duration + chanSendTimeout time.Duration + + ch chan Notification + pingCh chan struct{} +} + +func newChannel(ln *Listener, opts []ChannelOption) *channel { + c := &channel{ + ctx: context.TODO(), + ln: ln, + + size: 100, + pingTimeout: 5 * time.Second, + chanSendTimeout: time.Minute, + } + + for _, opt := range opts { + opt(c) + } + + c.ch = make(chan Notification, c.size) + c.pingCh = make(chan struct{}, 1) + _ = c.ln.Listen(c.ctx, pingChannel) + go c.startReceive() + go c.startPing() + + return c +} + +func (c *channel) startReceive() { + timer := time.NewTimer(time.Minute) + timer.Stop() + + var errCount int + for { + channel, payload, err := c.ln.Receive(c.ctx) + if err != nil { + if err == errListenerClosed { + close(c.ch) + return + } + + if errCount > 0 { + time.Sleep(500 * time.Millisecond) + } + errCount++ + + continue + } + + errCount = 0 + + // Any notification is as good as a ping. + select { + case c.pingCh <- struct{}{}: + default: + } + + switch channel { + case pingChannel: + // ignore + default: + timer.Reset(c.chanSendTimeout) + select { + case c.ch <- Notification{channel, payload}: + if !timer.Stop() { + <-timer.C + } + case <-timer.C: + Logger.Printf( + c.ctx, + "pgdriver: %s channel is full for %s (notification is dropped)", + c, + c.chanSendTimeout, + ) + } + } + } +} + +func (c *channel) startPing() { + timer := time.NewTimer(time.Minute) + timer.Stop() + + healthy := true + for { + timer.Reset(c.pingTimeout) + select { + case <-c.pingCh: + healthy = true + if !timer.Stop() { + <-timer.C + } + case <-timer.C: + pingErr := c.ping(c.ctx) + if healthy { + healthy = false + } else { + if pingErr == nil { + pingErr = errPingTimeout + } + _ = c.ln.withLock(func() error { + c.ln.reconnect(c.ctx, pingErr) + return nil + }) + } + case <-c.ln.exit: + return + } + } +} + +func (c *channel) ping(ctx context.Context) error { + _, err := c.ln.db.ExecContext(ctx, "NOTIFY "+strconv.Quote(pingChannel)) + return err +} + +func appendIfNotExists(ss []string, es ...string) []string { +loop: + for _, e := range es { + for _, s := range ss { + if s == e { + continue loop + } + } + ss = append(ss, e) + } + return ss +} + +func removeIfExists(ss []string, es ...string) []string { + for _, e := range es { + for i, s := range ss { + if s == e { + last := len(ss) - 1 + ss[i] = ss[last] + ss = ss[:last] + break + } + } + } + return ss +} diff --git a/vendor/github.com/uptrace/bun/driver/pgdriver/proto.go b/vendor/github.com/uptrace/bun/driver/pgdriver/proto.go new file mode 100644 index 000000000..327fc29fc --- /dev/null +++ b/vendor/github.com/uptrace/bun/driver/pgdriver/proto.go @@ -0,0 +1,1127 @@ +package pgdriver + +import ( + "bufio" + "context" + "crypto/md5" + "crypto/tls" + "database/sql" + "database/sql/driver" + "encoding/binary" + "encoding/hex" + "errors" + "fmt" + "io" + "math" + "strconv" + "sync" + "time" + "unicode/utf8" + + "mellium.im/sasl" +) + +// https://www.postgresql.org/docs/current/protocol-message-formats.html +//nolint:deadcode,varcheck,unused +const ( + commandCompleteMsg = 'C' + errorResponseMsg = 'E' + noticeResponseMsg = 'N' + parameterStatusMsg = 'S' + authenticationOKMsg = 'R' + backendKeyDataMsg = 'K' + noDataMsg = 'n' + passwordMessageMsg = 'p' + terminateMsg = 'X' + + saslInitialResponseMsg = 'p' + authenticationSASLContinueMsg = 'R' + saslResponseMsg = 'p' + authenticationSASLFinalMsg = 'R' + + authenticationOK = 0 + authenticationCleartextPassword = 3 + authenticationMD5Password = 5 + authenticationSASL = 10 + + notificationResponseMsg = 'A' + + describeMsg = 'D' + parameterDescriptionMsg = 't' + + queryMsg = 'Q' + readyForQueryMsg = 'Z' + emptyQueryResponseMsg = 'I' + rowDescriptionMsg = 'T' + dataRowMsg = 'D' + + parseMsg = 'P' + parseCompleteMsg = '1' + + bindMsg = 'B' + bindCompleteMsg = '2' + + executeMsg = 'E' + + syncMsg = 'S' + flushMsg = 'H' + + closeMsg = 'C' + closeCompleteMsg = '3' + + copyInResponseMsg = 'G' + copyOutResponseMsg = 'H' + copyDataMsg = 'd' + copyDoneMsg = 'c' +) + +var errEmptyQuery = errors.New("pgdriver: query is empty") + +type reader struct { + *bufio.Reader + buf []byte +} + +func newReader(r io.Reader) *reader { + return &reader{ + Reader: bufio.NewReader(r), + buf: make([]byte, 128), + } +} + +func (r *reader) ReadTemp(n int) ([]byte, error) { + if n <= len(r.buf) { + b := r.buf[:n] + _, err := io.ReadFull(r.Reader, b) + return b, err + } + + b := make([]byte, n) + _, err := io.ReadFull(r.Reader, b) + return b, err +} + +func (r *reader) Discard(n int) error { + _, err := r.ReadTemp(n) + return err +} + +func enableSSL(ctx context.Context, cn *Conn, tlsConf *tls.Config) error { + if err := writeSSLMsg(ctx, cn); err != nil { + return err + } + + rd := cn.reader(ctx, -1) + + c, err := rd.ReadByte() + if err != nil { + return err + } + if c != 'S' { + return errors.New("pgdriver: SSL is not enabled on the server") + } + + cn.netConn = tls.Client(cn.netConn, tlsConf) + rd.Reset(cn.netConn) + + return nil +} + +func writeSSLMsg(ctx context.Context, cn *Conn) error { + wb := getWriteBuffer() + defer putWriteBuffer(wb) + + wb.StartMessage(0) + wb.WriteInt32(80877103) + wb.FinishMessage() + + return cn.withWriter(ctx, -1, func(wr *bufio.Writer) error { + _, err := wr.Write(wb.Bytes) + return err + }) +} + +//------------------------------------------------------------------------------ + +func startup(ctx context.Context, cn *Conn) error { + if err := writeStartup(ctx, cn); err != nil { + return err + } + + rd := cn.reader(ctx, -1) + + for { + c, msgLen, err := readMessageType(rd) + if err != nil { + return err + } + + switch c { + case backendKeyDataMsg: + processID, err := readInt32(rd) + if err != nil { + return err + } + secretKey, err := readInt32(rd) + if err != nil { + return err + } + cn.processID = processID + cn.secretKey = secretKey + case authenticationOKMsg: + if err := auth(ctx, cn, rd); err != nil { + return err + } + case readyForQueryMsg: + return rd.Discard(msgLen) + case parameterStatusMsg, noticeResponseMsg: + if err := rd.Discard(msgLen); err != nil { + return err + } + case errorResponseMsg: + e, err := readError(rd) + if err != nil { + return err + } + return e + default: + return fmt.Errorf("pgdriver: unexpected startup message: %q", c) + } + } +} + +func writeStartup(ctx context.Context, cn *Conn) error { + wb := getWriteBuffer() + defer putWriteBuffer(wb) + + wb.StartMessage(0) + wb.WriteInt32(196608) + wb.WriteString("user") + wb.WriteString(cn.driver.cfg.User) + wb.WriteString("database") + wb.WriteString(cn.driver.cfg.Database) + if cn.driver.cfg.AppName != "" { + wb.WriteString("application_name") + wb.WriteString(cn.driver.cfg.AppName) + } + wb.WriteString("") + wb.FinishMessage() + + return cn.withWriter(ctx, -1, func(wr *bufio.Writer) error { + _, err := wr.Write(wb.Bytes) + return err + }) +} + +//------------------------------------------------------------------------------ + +func auth(ctx context.Context, cn *Conn, rd *reader) error { + num, err := readInt32(rd) + if err != nil { + return err + } + + switch num { + case authenticationOK: + return nil + case authenticationCleartextPassword: + return authCleartext(ctx, cn, rd) + case authenticationMD5Password: + return authMD5(ctx, cn, rd) + case authenticationSASL: + if err := authSASL(ctx, cn, rd); err != nil { + return fmt.Errorf("pgdriver: SASL: %w", err) + } + return nil + default: + return fmt.Errorf("pgdriver: unknown authentication message: %q", num) + } +} + +func authCleartext(ctx context.Context, cn *Conn, rd *reader) error { + if err := writePassword(ctx, cn, cn.driver.cfg.Password); err != nil { + return err + } + return readAuthOK(cn, rd) +} + +func readAuthOK(cn *Conn, rd *reader) error { + c, _, err := readMessageType(rd) + if err != nil { + return err + } + + switch c { + case authenticationOKMsg: + num, err := readInt32(rd) + if err != nil { + return err + } + if num != 0 { + return fmt.Errorf("pgdriver: unexpected authentication code: %q", num) + } + return nil + case errorResponseMsg: + e, err := readError(rd) + if err != nil { + return err + } + return e + default: + return fmt.Errorf("pgdriver: unknown password message: %q", c) + } +} + +//------------------------------------------------------------------------------ + +func authMD5(ctx context.Context, cn *Conn, rd *reader) error { + b, err := rd.ReadTemp(4) + if err != nil { + return err + } + + secret := "md5" + md5s(md5s(cn.driver.cfg.Password+cn.driver.cfg.User)+string(b)) + if err := writePassword(ctx, cn, secret); err != nil { + return err + } + + return readAuthOK(cn, rd) +} + +func writePassword(ctx context.Context, cn *Conn, password string) error { + wb := getWriteBuffer() + defer putWriteBuffer(wb) + + wb.StartMessage(passwordMessageMsg) + wb.WriteString(password) + wb.FinishMessage() + + return cn.withWriter(ctx, -1, func(wr *bufio.Writer) error { + _, err := wr.Write(wb.Bytes) + return err + }) +} + +func md5s(s string) string { + h := md5.Sum([]byte(s)) + return hex.EncodeToString(h[:]) +} + +//------------------------------------------------------------------------------ + +func authSASL(ctx context.Context, cn *Conn, rd *reader) error { + s, err := readString(rd) + if err != nil { + return err + } + + var saslMech sasl.Mechanism + + switch s { + case sasl.ScramSha256.Name: + saslMech = sasl.ScramSha256 + case sasl.ScramSha256Plus.Name: + saslMech = sasl.ScramSha256Plus + default: + return fmt.Errorf("got %q, wanted %q", s, sasl.ScramSha256.Name) + } + + c0, err := rd.ReadByte() + if err != nil { + return err + } + if c0 != 0 { + return fmt.Errorf("got %q, wanted %q", c0, 0) + } + + creds := sasl.Credentials(func() (Username, Password, Identity []byte) { + return []byte(cn.driver.cfg.User), []byte(cn.driver.cfg.Password), nil + }) + client := sasl.NewClient(saslMech, creds) + + _, resp, err := client.Step(nil) + if err != nil { + return fmt.Errorf("client.Step 1 failed: %w", err) + } + + if err := saslWriteInitialResponse(ctx, cn, saslMech, resp); err != nil { + return err + } + + c, msgLen, err := readMessageType(rd) + if err != nil { + return err + } + + switch c { + case authenticationSASLContinueMsg: + c11, err := readInt32(rd) + if err != nil { + return err + } + if c11 != 11 { + return fmt.Errorf("got %q, wanted %q", c, 11) + } + + b, err := rd.ReadTemp(msgLen - 4) + if err != nil { + return err + } + + _, resp, err = client.Step(b) + if err != nil { + return fmt.Errorf("client.Step 2 failed: %w", err) + } + + if err := saslWriteResponse(ctx, cn, resp); err != nil { + return err + } + + resp, err = saslReadAuthFinal(cn, rd) + if err != nil { + return err + } + + if _, _, err := client.Step(resp); err != nil { + return fmt.Errorf("client.Step 3 failed: %w", err) + } + + if client.State() != sasl.ValidServerResponse { + return fmt.Errorf("got state=%q, wanted %q", client.State(), sasl.ValidServerResponse) + } + + return nil + case errorResponseMsg: + e, err := readError(rd) + if err != nil { + return err + } + return e + default: + return fmt.Errorf("got %q, wanted %q", c, authenticationSASLContinueMsg) + } +} + +func saslWriteInitialResponse( + ctx context.Context, cn *Conn, saslMech sasl.Mechanism, resp []byte, +) error { + wb := getWriteBuffer() + defer putWriteBuffer(wb) + + wb.StartMessage(saslInitialResponseMsg) + wb.WriteString(saslMech.Name) + wb.WriteInt32(int32(len(resp))) + if _, err := wb.Write(resp); err != nil { + return err + } + wb.FinishMessage() + + return cn.withWriter(ctx, -1, func(wr *bufio.Writer) error { + _, err := wr.Write(wb.Bytes) + return err + }) +} + +func saslWriteResponse(ctx context.Context, cn *Conn, resp []byte) error { + wb := getWriteBuffer() + defer putWriteBuffer(wb) + + wb.StartMessage(saslResponseMsg) + if _, err := wb.Write(resp); err != nil { + return err + } + wb.FinishMessage() + + return cn.withWriter(ctx, -1, func(wr *bufio.Writer) error { + _, err := wr.Write(wb.Bytes) + return err + }) +} + +func saslReadAuthFinal(cn *Conn, rd *reader) ([]byte, error) { + c, msgLen, err := readMessageType(rd) + if err != nil { + return nil, err + } + + switch c { + case authenticationSASLFinalMsg: + c12, err := readInt32(rd) + if err != nil { + return nil, err + } + if c12 != 12 { + return nil, fmt.Errorf("got %q, wanted %q", c, 12) + } + + resp := make([]byte, msgLen-4) + if _, err := io.ReadFull(rd, resp); err != nil { + return nil, err + } + + if err := readAuthOK(cn, rd); err != nil { + return nil, err + } + + return resp, nil + case errorResponseMsg: + e, err := readError(rd) + if err != nil { + return nil, err + } + return nil, e + default: + return nil, fmt.Errorf("got %q, wanted %q", c, authenticationSASLFinalMsg) + } +} + +//------------------------------------------------------------------------------ + +func writeQuery(ctx context.Context, cn *Conn, query string) error { + return cn.withWriter(ctx, -1, func(wr *bufio.Writer) error { + if err := wr.WriteByte(queryMsg); err != nil { + return err + } + + binary.BigEndian.PutUint32(cn.rd.buf, uint32(len(query)+5)) + if _, err := wr.Write(cn.rd.buf[:4]); err != nil { + return err + } + + if _, err := wr.WriteString(query); err != nil { + return err + } + if err := wr.WriteByte(0x0); err != nil { + return err + } + + return nil + }) +} + +func readQuery(ctx context.Context, cn *Conn) (sql.Result, error) { + rd := cn.reader(ctx, -1) + + var res driver.Result + var firstErr error + for { + c, msgLen, err := readMessageType(rd) + if err != nil { + return nil, err + } + + switch c { + case errorResponseMsg: + e, err := readError(rd) + if err != nil { + return nil, err + } + if firstErr == nil { + firstErr = e + } + case emptyQueryResponseMsg: + if firstErr == nil { + firstErr = errEmptyQuery + } + case commandCompleteMsg: + tmp, err := rd.ReadTemp(msgLen) + if err != nil { + firstErr = err + break + } + + r, err := parseResult(tmp) + if err != nil { + firstErr = err + } else { + res = r + } + case describeMsg, + rowDescriptionMsg, + noticeResponseMsg, + parameterStatusMsg: + if err := rd.Discard(msgLen); err != nil { + return nil, err + } + case readyForQueryMsg: + if err := rd.Discard(msgLen); err != nil { + return nil, err + } + return res, firstErr + default: + return nil, fmt.Errorf("pgdriver: Exec: unexpected message %q", c) + } + } +} + +func readQueryData(ctx context.Context, cn *Conn) (*rows, error) { + rd := cn.reader(ctx, -1) + var firstErr error + for { + c, msgLen, err := readMessageType(rd) + if err != nil { + return nil, err + } + + switch c { + case rowDescriptionMsg: + rowDesc, err := readRowDescription(rd) + if err != nil { + return nil, err + } + return newRows(cn, rowDesc, true), nil + case commandCompleteMsg: + if err := rd.Discard(msgLen); err != nil { + return nil, err + } + case readyForQueryMsg: + if err := rd.Discard(msgLen); err != nil { + return nil, err + } + if firstErr != nil { + return nil, firstErr + } + return &rows{closed: true}, nil + case errorResponseMsg: + e, err := readError(rd) + if err != nil { + return nil, err + } + if firstErr == nil { + firstErr = e + } + case emptyQueryResponseMsg: + if firstErr == nil { + firstErr = errEmptyQuery + } + case noticeResponseMsg, parameterStatusMsg: + if err := rd.Discard(msgLen); err != nil { + return nil, err + } + default: + return nil, fmt.Errorf("pgdriver: newRows: unexpected message %q", c) + } + } +} + +//------------------------------------------------------------------------------ + +var rowDescPool sync.Pool + +type rowDescription struct { + buf []byte + names []string + types []int32 + numInput int16 +} + +func newRowDescription(numCol int) *rowDescription { + if numCol < 16 { + numCol = 16 + } + return &rowDescription{ + buf: make([]byte, 0, 16*numCol), + names: make([]string, 0, numCol), + types: make([]int32, 0, numCol), + numInput: -1, + } +} + +func (d *rowDescription) reset(numCol int) { + d.buf = make([]byte, 0, 16*numCol) + d.names = d.names[:0] + d.types = d.types[:0] + d.numInput = -1 +} + +func (d *rowDescription) addName(name []byte) { + if len(d.buf)+len(name) > cap(d.buf) { + d.buf = make([]byte, 0, cap(d.buf)) + } + + i := len(d.buf) + d.buf = append(d.buf, name...) + d.names = append(d.names, bytesToString(d.buf[i:])) +} + +func (d *rowDescription) addType(dataType int32) { + d.types = append(d.types, dataType) +} + +func readRowDescription(rd *reader) (*rowDescription, error) { + numCol, err := readInt16(rd) + if err != nil { + return nil, err + } + + rowDesc, ok := rowDescPool.Get().(*rowDescription) + if !ok { + rowDesc = newRowDescription(int(numCol)) + } else { + rowDesc.reset(int(numCol)) + } + + for i := 0; i < int(numCol); i++ { + name, err := rd.ReadSlice(0) + if err != nil { + return nil, err + } + rowDesc.addName(name[:len(name)-1]) + + if _, err := rd.ReadTemp(6); err != nil { + return nil, err + } + + dataType, err := readInt32(rd) + if err != nil { + return nil, err + } + rowDesc.addType(dataType) + + if _, err := rd.ReadTemp(8); err != nil { + return nil, err + } + } + + return rowDesc, nil +} + +//------------------------------------------------------------------------------ + +func readNotification(ctx context.Context, rd *reader) (channel, payload string, err error) { + for { + c, msgLen, err := readMessageType(rd) + if err != nil { + return "", "", err + } + + switch c { + case commandCompleteMsg, readyForQueryMsg, noticeResponseMsg: + if err := rd.Discard(msgLen); err != nil { + return "", "", err + } + case errorResponseMsg: + e, err := readError(rd) + if err != nil { + return "", "", err + } + return "", "", e + case notificationResponseMsg: + if err := rd.Discard(4); err != nil { + return "", "", err + } + channel, err = readString(rd) + if err != nil { + return "", "", err + } + payload, err = readString(rd) + if err != nil { + return "", "", err + } + return channel, payload, nil + default: + return "", "", fmt.Errorf("pgdriver: readNotification: unexpected message %q", c) + } + } +} + +//------------------------------------------------------------------------------ + +func writeParseDescribeSync(ctx context.Context, cn *Conn, name, query string) error { + wb := getWriteBuffer() + defer putWriteBuffer(wb) + + wb.StartMessage(parseMsg) + wb.WriteString(name) + wb.WriteString(query) + wb.WriteInt16(0) + wb.FinishMessage() + + wb.StartMessage(describeMsg) + wb.WriteByte('S') + wb.WriteString(name) + wb.FinishMessage() + + wb.StartMessage(syncMsg) + wb.FinishMessage() + + return cn.withWriter(ctx, -1, func(wr *bufio.Writer) error { + _, err := wr.Write(wb.Bytes) + return err + }) +} + +func readParseDescribeSync(ctx context.Context, cn *Conn) (*rowDescription, error) { + rd := cn.reader(ctx, -1) + var numParam int16 + var rowDesc *rowDescription + var firstErr error + for { + c, msgLen, err := readMessageType(rd) + if err != nil { + return nil, err + } + + switch c { + case parseCompleteMsg: + if err := rd.Discard(msgLen); err != nil { + return nil, err + } + case rowDescriptionMsg: // response to DESCRIBE message. + rowDesc, err = readRowDescription(rd) + if err != nil { + return nil, err + } + rowDesc.numInput = numParam + case parameterDescriptionMsg: // response to DESCRIBE message. + numParam, err = readInt16(rd) + if err != nil { + return nil, err + } + + for i := 0; i < int(numParam); i++ { + if _, err := readInt32(rd); err != nil { + return nil, err + } + } + case noDataMsg: // response to DESCRIBE message. + if err := rd.Discard(msgLen); err != nil { + return nil, err + } + case readyForQueryMsg: + if err := rd.Discard(msgLen); err != nil { + return nil, err + } + if firstErr != nil { + return nil, firstErr + } + return rowDesc, err + case errorResponseMsg: + e, err := readError(rd) + if err != nil { + return nil, err + } + if firstErr == nil { + firstErr = e + } + case noticeResponseMsg, parameterStatusMsg: + if err := rd.Discard(msgLen); err != nil { + return nil, err + } + default: + return nil, fmt.Errorf("pgdriver: readParseDescribeSync: unexpected message %q", c) + } + } +} + +func writeBindExecute(ctx context.Context, cn *Conn, name string, args []driver.NamedValue) error { + wb := getWriteBuffer() + defer putWriteBuffer(wb) + + wb.StartMessage(bindMsg) + wb.WriteString("") + wb.WriteString(name) + wb.WriteInt16(0) + wb.WriteInt16(int16(len(args))) + for i := range args { + wb.StartParam() + bytes, err := appendStmtArg(wb.Bytes, args[i].Value) + if err != nil { + return err + } + if bytes != nil { + wb.Bytes = bytes + wb.FinishParam() + } else { + wb.FinishNullParam() + } + } + wb.WriteInt16(0) + wb.FinishMessage() + + wb.StartMessage(executeMsg) + wb.WriteString("") + wb.WriteInt32(0) + wb.FinishMessage() + + wb.StartMessage(syncMsg) + wb.FinishMessage() + + return cn.withWriter(ctx, -1, func(wr *bufio.Writer) error { + _, err := wr.Write(wb.Bytes) + return err + }) +} + +func readExtQuery(ctx context.Context, cn *Conn) (driver.Result, error) { + rd := cn.reader(ctx, -1) + var res driver.Result + var firstErr error + for { + c, msgLen, err := readMessageType(rd) + if err != nil { + return nil, err + } + + switch c { + case bindCompleteMsg, dataRowMsg: + if err := rd.Discard(msgLen); err != nil { + return nil, err + } + case commandCompleteMsg: // response to EXECUTE message. + tmp, err := rd.ReadTemp(msgLen) + if err != nil { + return nil, err + } + + r, err := parseResult(tmp) + if err != nil { + if firstErr == nil { + firstErr = err + } + } else { + res = r + } + case readyForQueryMsg: // Response to SYNC message. + if err := rd.Discard(msgLen); err != nil { + return nil, err + } + if firstErr != nil { + return nil, firstErr + } + return res, nil + case errorResponseMsg: + e, err := readError(rd) + if err != nil { + return nil, err + } + if firstErr == nil { + firstErr = e + } + case emptyQueryResponseMsg: + if firstErr == nil { + firstErr = errEmptyQuery + } + case noticeResponseMsg, parameterStatusMsg: + if err := rd.Discard(msgLen); err != nil { + return nil, err + } + default: + return nil, fmt.Errorf("pgdriver: readExtQuery: unexpected message %q", c) + } + } +} + +func readExtQueryData(ctx context.Context, cn *Conn, rowDesc *rowDescription) (*rows, error) { + rd := cn.reader(ctx, -1) + var firstErr error + for { + c, msgLen, err := readMessageType(rd) + if err != nil { + return nil, err + } + + switch c { + case bindCompleteMsg: + if err := rd.Discard(msgLen); err != nil { + return nil, err + } + return newRows(cn, rowDesc, false), nil + case commandCompleteMsg: // response to EXECUTE message. + if err := rd.Discard(msgLen); err != nil { + return nil, err + } + case readyForQueryMsg: // Response to SYNC message. + if err := rd.Discard(msgLen); err != nil { + return nil, err + } + if firstErr != nil { + return nil, firstErr + } + return &rows{closed: true}, nil + case errorResponseMsg: + e, err := readError(rd) + if err != nil { + return nil, err + } + if firstErr == nil { + firstErr = e + } + case emptyQueryResponseMsg: + if firstErr == nil { + firstErr = errEmptyQuery + } + case noticeResponseMsg, parameterStatusMsg: + if err := rd.Discard(msgLen); err != nil { + return nil, err + } + default: + return nil, fmt.Errorf("pgdriver: readExtQueryData: unexpected message %q", c) + } + } +} + +func writeCloseStmt(ctx context.Context, cn *Conn, name string) error { + wb := getWriteBuffer() + defer putWriteBuffer(wb) + + wb.StartMessage(closeMsg) + wb.WriteByte('S') //nolint + wb.WriteString(name) + wb.FinishMessage() + + wb.StartMessage(flushMsg) + wb.FinishMessage() + + return cn.withWriter(ctx, -1, func(wr *bufio.Writer) error { + _, err := wr.Write(wb.Bytes) + return err + }) +} + +func readCloseStmtComplete(ctx context.Context, cn *Conn) error { + rd := cn.reader(ctx, -1) + for { + c, msgLen, err := readMessageType(rd) + if err != nil { + return err + } + + switch c { + case closeCompleteMsg: + return rd.Discard(msgLen) + case errorResponseMsg: + e, err := readError(rd) + if err != nil { + return err + } + return e + case noticeResponseMsg, parameterStatusMsg: + if err := rd.Discard(msgLen); err != nil { + return err + } + default: + return fmt.Errorf("pgdriver: readCloseCompleteMsg: unexpected message %q", c) + } + } +} + +//------------------------------------------------------------------------------ + +func readMessageType(rd *reader) (byte, int, error) { + c, err := rd.ReadByte() + if err != nil { + return 0, 0, err + } + l, err := readInt32(rd) + if err != nil { + return 0, 0, err + } + return c, int(l) - 4, nil +} + +func readInt16(rd *reader) (int16, error) { + b, err := rd.ReadTemp(2) + if err != nil { + return 0, err + } + return int16(binary.BigEndian.Uint16(b)), nil +} + +func readInt32(rd *reader) (int32, error) { + b, err := rd.ReadTemp(4) + if err != nil { + return 0, err + } + return int32(binary.BigEndian.Uint32(b)), nil +} + +func readString(rd *reader) (string, error) { + b, err := rd.ReadSlice(0) + if err != nil { + return "", err + } + return string(b[:len(b)-1]), nil +} + +func readError(rd *reader) (error, error) { + m := make(map[byte]string) + for { + c, err := rd.ReadByte() + if err != nil { + return nil, err + } + if c == 0 { + break + } + s, err := readString(rd) + if err != nil { + return nil, err + } + m[c] = s + } + return Error{m: m}, nil +} + +//------------------------------------------------------------------------------ + +func appendStmtArg(b []byte, v driver.Value) ([]byte, error) { + switch v := v.(type) { + case nil: + return nil, nil + case int64: + return strconv.AppendInt(b, v, 10), nil + case float64: + switch { + case math.IsNaN(v): + return append(b, "NaN"...), nil + case math.IsInf(v, 1): + return append(b, "Infinity"...), nil + case math.IsInf(v, -1): + return append(b, "-Infinity"...), nil + default: + return strconv.AppendFloat(b, v, 'f', -1, 64), nil + } + case bool: + if v { + return append(b, "TRUE"...), nil + } + return append(b, "FALSE"...), nil + case []byte: + if v == nil { + return nil, nil + } + + b = append(b, `\x`...) + + s := len(b) + b = append(b, make([]byte, hex.EncodedLen(len(v)))...) + hex.Encode(b[s:], v) + + return b, nil + case string: + for _, r := range v { + if r == 0 { + continue + } + if r < utf8.RuneSelf { + b = append(b, byte(r)) + continue + } + l := len(b) + if cap(b)-l < utf8.UTFMax { + b = append(b, make([]byte, utf8.UTFMax)...) + } + n := utf8.EncodeRune(b[l:l+utf8.UTFMax], r) + b = b[:l+n] + } + return b, nil + case time.Time: + if v.IsZero() { + return nil, nil + } + return v.UTC().AppendFormat(b, "2006-01-02 15:04:05.999999-07:00"), nil + default: + return nil, fmt.Errorf("pgdriver: unexpected arg: %T", v) + } +} diff --git a/vendor/github.com/uptrace/bun/driver/pgdriver/safe.go b/vendor/github.com/uptrace/bun/driver/pgdriver/safe.go new file mode 100644 index 000000000..fab151a78 --- /dev/null +++ b/vendor/github.com/uptrace/bun/driver/pgdriver/safe.go @@ -0,0 +1,11 @@ +// +build appengine + +package internal + +func bytesToString(b []byte) string { + return string(b) +} + +func stringToBytes(s string) []byte { + return []byte(s) +} diff --git a/vendor/github.com/uptrace/bun/driver/pgdriver/unsafe.go b/vendor/github.com/uptrace/bun/driver/pgdriver/unsafe.go new file mode 100644 index 000000000..6ba868105 --- /dev/null +++ b/vendor/github.com/uptrace/bun/driver/pgdriver/unsafe.go @@ -0,0 +1,19 @@ +// +build !appengine + +package pgdriver + +import "unsafe" + +func bytesToString(b []byte) string { + return *(*string)(unsafe.Pointer(&b)) +} + +//nolint:deadcode,unused +func stringToBytes(s string) []byte { + return *(*[]byte)(unsafe.Pointer( + &struct { + string + Cap int + }{s, len(s)}, + )) +} diff --git a/vendor/github.com/uptrace/bun/driver/pgdriver/write_buffer.go b/vendor/github.com/uptrace/bun/driver/pgdriver/write_buffer.go new file mode 100644 index 000000000..cb683563d --- /dev/null +++ b/vendor/github.com/uptrace/bun/driver/pgdriver/write_buffer.go @@ -0,0 +1,112 @@ +package pgdriver + +import ( + "encoding/binary" + "io" + "sync" +) + +var wbPool = sync.Pool{ + New: func() interface{} { + return newWriteBuffer() + }, +} + +func getWriteBuffer() *writeBuffer { + wb := wbPool.Get().(*writeBuffer) + return wb +} + +func putWriteBuffer(wb *writeBuffer) { + wb.Reset() + wbPool.Put(wb) +} + +type writeBuffer struct { + Bytes []byte + + msgStart int + paramStart int +} + +func newWriteBuffer() *writeBuffer { + return &writeBuffer{ + Bytes: make([]byte, 0, 1024), + } +} + +func (b *writeBuffer) Reset() { + b.Bytes = b.Bytes[:0] +} + +func (b *writeBuffer) StartMessage(c byte) { + if c == 0 { + b.msgStart = len(b.Bytes) + b.Bytes = append(b.Bytes, 0, 0, 0, 0) + } else { + b.msgStart = len(b.Bytes) + 1 + b.Bytes = append(b.Bytes, c, 0, 0, 0, 0) + } +} + +func (b *writeBuffer) FinishMessage() { + binary.BigEndian.PutUint32( + b.Bytes[b.msgStart:], uint32(len(b.Bytes)-b.msgStart)) +} + +func (b *writeBuffer) Query() []byte { + return b.Bytes[b.msgStart+4 : len(b.Bytes)-1] +} + +func (b *writeBuffer) StartParam() { + b.paramStart = len(b.Bytes) + b.Bytes = append(b.Bytes, 0, 0, 0, 0) +} + +func (b *writeBuffer) FinishParam() { + binary.BigEndian.PutUint32( + b.Bytes[b.paramStart:], uint32(len(b.Bytes)-b.paramStart-4)) +} + +var nullParamLength = int32(-1) + +func (b *writeBuffer) FinishNullParam() { + binary.BigEndian.PutUint32( + b.Bytes[b.paramStart:], uint32(nullParamLength)) +} + +func (b *writeBuffer) Write(data []byte) (int, error) { + b.Bytes = append(b.Bytes, data...) + return len(data), nil +} + +func (b *writeBuffer) WriteInt16(num int16) { + b.Bytes = append(b.Bytes, 0, 0) + binary.BigEndian.PutUint16(b.Bytes[len(b.Bytes)-2:], uint16(num)) +} + +func (b *writeBuffer) WriteInt32(num int32) { + b.Bytes = append(b.Bytes, 0, 0, 0, 0) + binary.BigEndian.PutUint32(b.Bytes[len(b.Bytes)-4:], uint32(num)) +} + +func (b *writeBuffer) WriteString(s string) { + b.Bytes = append(b.Bytes, s...) + b.Bytes = append(b.Bytes, 0) +} + +func (b *writeBuffer) WriteBytes(data []byte) { + b.Bytes = append(b.Bytes, data...) + b.Bytes = append(b.Bytes, 0) +} + +func (b *writeBuffer) WriteByte(c byte) error { + b.Bytes = append(b.Bytes, c) + return nil +} + +func (b *writeBuffer) ReadFrom(r io.Reader) (int64, error) { + n, err := r.Read(b.Bytes[len(b.Bytes):cap(b.Bytes)]) + b.Bytes = b.Bytes[:len(b.Bytes)+n] + return int64(n), err +} diff --git a/vendor/github.com/uptrace/bun/extra/bunjson/json.go b/vendor/github.com/uptrace/bun/extra/bunjson/json.go new file mode 100644 index 000000000..eff9d3f0e --- /dev/null +++ b/vendor/github.com/uptrace/bun/extra/bunjson/json.go @@ -0,0 +1,26 @@ +package bunjson + +import ( + "encoding/json" + "io" +) + +var _ Provider = (*StdProvider)(nil) + +type StdProvider struct{} + +func (StdProvider) Marshal(v interface{}) ([]byte, error) { + return json.Marshal(v) +} + +func (StdProvider) Unmarshal(data []byte, v interface{}) error { + return json.Unmarshal(data, v) +} + +func (StdProvider) NewEncoder(w io.Writer) Encoder { + return json.NewEncoder(w) +} + +func (StdProvider) NewDecoder(r io.Reader) Decoder { + return json.NewDecoder(r) +} diff --git a/vendor/github.com/uptrace/bun/extra/bunjson/provider.go b/vendor/github.com/uptrace/bun/extra/bunjson/provider.go new file mode 100644 index 000000000..7f810e122 --- /dev/null +++ b/vendor/github.com/uptrace/bun/extra/bunjson/provider.go @@ -0,0 +1,43 @@ +package bunjson + +import ( + "io" +) + +var provider Provider = StdProvider{} + +func SetProvider(p Provider) { + provider = p +} + +type Provider interface { + Marshal(v interface{}) ([]byte, error) + Unmarshal(data []byte, v interface{}) error + NewEncoder(w io.Writer) Encoder + NewDecoder(r io.Reader) Decoder +} + +type Decoder interface { + Decode(v interface{}) error + UseNumber() +} + +type Encoder interface { + Encode(v interface{}) error +} + +func Marshal(v interface{}) ([]byte, error) { + return provider.Marshal(v) +} + +func Unmarshal(data []byte, v interface{}) error { + return provider.Unmarshal(data, v) +} + +func NewEncoder(w io.Writer) Encoder { + return provider.NewEncoder(w) +} + +func NewDecoder(r io.Reader) Decoder { + return provider.NewDecoder(r) +} diff --git a/vendor/github.com/uptrace/bun/go.mod b/vendor/github.com/uptrace/bun/go.mod new file mode 100644 index 000000000..92def2a3d --- /dev/null +++ b/vendor/github.com/uptrace/bun/go.mod @@ -0,0 +1,12 @@ +module github.com/uptrace/bun + +go 1.16 + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/jinzhu/inflection v1.0.0 + github.com/stretchr/testify v1.7.0 + github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc + github.com/vmihailenco/msgpack/v5 v5.3.4 + golang.org/x/sys v0.0.0-20210616094352-59db8d763f22 // indirect +) diff --git a/vendor/github.com/uptrace/bun/go.sum b/vendor/github.com/uptrace/bun/go.sum new file mode 100644 index 000000000..3bf0a4a3f --- /dev/null +++ b/vendor/github.com/uptrace/bun/go.sum @@ -0,0 +1,23 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYmW5DyG0UqvY96Bu5QYsTLvCHdrgo= +github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs= +github.com/vmihailenco/msgpack/v5 v5.3.4 h1:qMKAwOV+meBw2Y8k9cVwAy7qErtYCwBzZ2ellBfvnqc= +github.com/vmihailenco/msgpack/v5 v5.3.4/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc= +github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= +github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= +golang.org/x/sys v0.0.0-20210616094352-59db8d763f22 h1:RqytpXGR1iVNX7psjB3ff8y7sNFinVFvkx1c8SjBkio= +golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/vendor/github.com/uptrace/bun/hook.go b/vendor/github.com/uptrace/bun/hook.go new file mode 100644 index 000000000..4cfa68fa6 --- /dev/null +++ b/vendor/github.com/uptrace/bun/hook.go @@ -0,0 +1,98 @@ +package bun + +import ( + "context" + "database/sql" + "reflect" + "sync/atomic" + "time" + + "github.com/uptrace/bun/schema" +) + +type QueryEvent struct { + DB *DB + + QueryAppender schema.QueryAppender + Query string + QueryArgs []interface{} + + StartTime time.Time + Result sql.Result + Err error + + Stash map[interface{}]interface{} +} + +type QueryHook interface { + BeforeQuery(context.Context, *QueryEvent) context.Context + AfterQuery(context.Context, *QueryEvent) +} + +func (db *DB) beforeQuery( + ctx context.Context, + queryApp schema.QueryAppender, + query string, + queryArgs []interface{}, +) (context.Context, *QueryEvent) { + atomic.AddUint64(&db.stats.Queries, 1) + + if len(db.queryHooks) == 0 { + return ctx, nil + } + + event := &QueryEvent{ + DB: db, + + QueryAppender: queryApp, + Query: query, + QueryArgs: queryArgs, + + StartTime: time.Now(), + } + + for _, hook := range db.queryHooks { + ctx = hook.BeforeQuery(ctx, event) + } + + return ctx, event +} + +func (db *DB) afterQuery( + ctx context.Context, + event *QueryEvent, + res sql.Result, + err error, +) { + switch err { + case nil, sql.ErrNoRows: + // nothing + default: + atomic.AddUint64(&db.stats.Errors, 1) + } + + if event == nil { + return + } + + event.Result = res + event.Err = err + + db.afterQueryFromIndex(ctx, event, len(db.queryHooks)-1) +} + +func (db *DB) afterQueryFromIndex(ctx context.Context, event *QueryEvent, hookIndex int) { + for ; hookIndex >= 0; hookIndex-- { + db.queryHooks[hookIndex].AfterQuery(ctx, event) + } +} + +//------------------------------------------------------------------------------ + +func callBeforeScanHook(ctx context.Context, v reflect.Value) error { + return v.Interface().(schema.BeforeScanHook).BeforeScan(ctx) +} + +func callAfterScanHook(ctx context.Context, v reflect.Value) error { + return v.Interface().(schema.AfterScanHook).AfterScan(ctx) +} diff --git a/vendor/github.com/uptrace/bun/internal/flag.go b/vendor/github.com/uptrace/bun/internal/flag.go new file mode 100644 index 000000000..b42f59df7 --- /dev/null +++ b/vendor/github.com/uptrace/bun/internal/flag.go @@ -0,0 +1,16 @@ +package internal + +type Flag uint64 + +func (flag Flag) Has(other Flag) bool { + return flag&other == other +} + +func (flag Flag) Set(other Flag) Flag { + return flag | other +} + +func (flag Flag) Remove(other Flag) Flag { + flag &= ^other + return flag +} diff --git a/vendor/github.com/uptrace/bun/internal/hex.go b/vendor/github.com/uptrace/bun/internal/hex.go new file mode 100644 index 000000000..6fae2bb78 --- /dev/null +++ b/vendor/github.com/uptrace/bun/internal/hex.go @@ -0,0 +1,43 @@ +package internal + +import ( + fasthex "github.com/tmthrgd/go-hex" +) + +type HexEncoder struct { + b []byte + written bool +} + +func NewHexEncoder(b []byte) *HexEncoder { + return &HexEncoder{ + b: b, + } +} + +func (enc *HexEncoder) Bytes() []byte { + return enc.b +} + +func (enc *HexEncoder) Write(b []byte) (int, error) { + if !enc.written { + enc.b = append(enc.b, '\'') + enc.b = append(enc.b, `\x`...) + enc.written = true + } + + i := len(enc.b) + enc.b = append(enc.b, make([]byte, fasthex.EncodedLen(len(b)))...) + fasthex.Encode(enc.b[i:], b) + + return len(b), nil +} + +func (enc *HexEncoder) Close() error { + if enc.written { + enc.b = append(enc.b, '\'') + } else { + enc.b = append(enc.b, "NULL"...) + } + return nil +} diff --git a/vendor/github.com/uptrace/bun/internal/logger.go b/vendor/github.com/uptrace/bun/internal/logger.go new file mode 100644 index 000000000..2e22a0893 --- /dev/null +++ b/vendor/github.com/uptrace/bun/internal/logger.go @@ -0,0 +1,27 @@ +package internal + +import ( + "fmt" + "log" + "os" +) + +var Warn = log.New(os.Stderr, "WARN: bun: ", log.LstdFlags) + +var Deprecated = log.New(os.Stderr, "DEPRECATED: bun: ", log.LstdFlags) + +type Logging interface { + Printf(format string, v ...interface{}) +} + +type logger struct { + log *log.Logger +} + +func (l *logger) Printf(format string, v ...interface{}) { + _ = l.log.Output(2, fmt.Sprintf(format, v...)) +} + +var Logger Logging = &logger{ + log: log.New(os.Stderr, "bun: ", log.LstdFlags|log.Lshortfile), +} diff --git a/vendor/github.com/uptrace/bun/internal/map_key.go b/vendor/github.com/uptrace/bun/internal/map_key.go new file mode 100644 index 000000000..bb5fcca8c --- /dev/null +++ b/vendor/github.com/uptrace/bun/internal/map_key.go @@ -0,0 +1,67 @@ +package internal + +import "reflect" + +var ifaceType = reflect.TypeOf((*interface{})(nil)).Elem() + +type MapKey struct { + iface interface{} +} + +func NewMapKey(is []interface{}) MapKey { + return MapKey{ + iface: newMapKey(is), + } +} + +func newMapKey(is []interface{}) interface{} { + switch len(is) { + case 1: + ptr := new([1]interface{}) + copy((*ptr)[:], is) + return *ptr + case 2: + ptr := new([2]interface{}) + copy((*ptr)[:], is) + return *ptr + case 3: + ptr := new([3]interface{}) + copy((*ptr)[:], is) + return *ptr + case 4: + ptr := new([4]interface{}) + copy((*ptr)[:], is) + return *ptr + case 5: + ptr := new([5]interface{}) + copy((*ptr)[:], is) + return *ptr + case 6: + ptr := new([6]interface{}) + copy((*ptr)[:], is) + return *ptr + case 7: + ptr := new([7]interface{}) + copy((*ptr)[:], is) + return *ptr + case 8: + ptr := new([8]interface{}) + copy((*ptr)[:], is) + return *ptr + case 9: + ptr := new([9]interface{}) + copy((*ptr)[:], is) + return *ptr + case 10: + ptr := new([10]interface{}) + copy((*ptr)[:], is) + return *ptr + default: + } + + at := reflect.New(reflect.ArrayOf(len(is), ifaceType)).Elem() + for i, v := range is { + *(at.Index(i).Addr().Interface().(*interface{})) = v + } + return at.Interface() +} diff --git a/vendor/github.com/uptrace/bun/internal/parser/parser.go b/vendor/github.com/uptrace/bun/internal/parser/parser.go new file mode 100644 index 000000000..cdfc0be16 --- /dev/null +++ b/vendor/github.com/uptrace/bun/internal/parser/parser.go @@ -0,0 +1,141 @@ +package parser + +import ( + "bytes" + "strconv" + + "github.com/uptrace/bun/internal" +) + +type Parser struct { + b []byte + i int +} + +func New(b []byte) *Parser { + return &Parser{ + b: b, + } +} + +func NewString(s string) *Parser { + return New(internal.Bytes(s)) +} + +func (p *Parser) Valid() bool { + return p.i < len(p.b) +} + +func (p *Parser) Bytes() []byte { + return p.b[p.i:] +} + +func (p *Parser) Read() byte { + if p.Valid() { + c := p.b[p.i] + p.Advance() + return c + } + return 0 +} + +func (p *Parser) Peek() byte { + if p.Valid() { + return p.b[p.i] + } + return 0 +} + +func (p *Parser) Advance() { + p.i++ +} + +func (p *Parser) Skip(skip byte) bool { + if p.Peek() == skip { + p.Advance() + return true + } + return false +} + +func (p *Parser) SkipBytes(skip []byte) bool { + if len(skip) > len(p.b[p.i:]) { + return false + } + if !bytes.Equal(p.b[p.i:p.i+len(skip)], skip) { + return false + } + p.i += len(skip) + return true +} + +func (p *Parser) ReadSep(sep byte) ([]byte, bool) { + ind := bytes.IndexByte(p.b[p.i:], sep) + if ind == -1 { + b := p.b[p.i:] + p.i = len(p.b) + return b, false + } + + b := p.b[p.i : p.i+ind] + p.i += ind + 1 + return b, true +} + +func (p *Parser) ReadIdentifier() (string, bool) { + if p.i < len(p.b) && p.b[p.i] == '(' { + s := p.i + 1 + if ind := bytes.IndexByte(p.b[s:], ')'); ind != -1 { + b := p.b[s : s+ind] + p.i = s + ind + 1 + return internal.String(b), false + } + } + + ind := len(p.b) - p.i + var alpha bool + for i, c := range p.b[p.i:] { + if isNum(c) { + continue + } + if isAlpha(c) || (i > 0 && alpha && c == '_') { + alpha = true + continue + } + ind = i + break + } + if ind == 0 { + return "", false + } + b := p.b[p.i : p.i+ind] + p.i += ind + return internal.String(b), !alpha +} + +func (p *Parser) ReadNumber() int { + ind := len(p.b) - p.i + for i, c := range p.b[p.i:] { + if !isNum(c) { + ind = i + break + } + } + if ind == 0 { + return 0 + } + n, err := strconv.Atoi(string(p.b[p.i : p.i+ind])) + if err != nil { + panic(err) + } + p.i += ind + return n +} + +func isNum(c byte) bool { + return c >= '0' && c <= '9' +} + +func isAlpha(c byte) bool { + return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') +} diff --git a/vendor/github.com/uptrace/bun/internal/safe.go b/vendor/github.com/uptrace/bun/internal/safe.go new file mode 100644 index 000000000..862ff0eb3 --- /dev/null +++ b/vendor/github.com/uptrace/bun/internal/safe.go @@ -0,0 +1,11 @@ +// +build appengine + +package internal + +func String(b []byte) string { + return string(b) +} + +func Bytes(s string) []byte { + return []byte(s) +} diff --git a/vendor/github.com/uptrace/bun/internal/tagparser/parser.go b/vendor/github.com/uptrace/bun/internal/tagparser/parser.go new file mode 100644 index 000000000..8ef89248c --- /dev/null +++ b/vendor/github.com/uptrace/bun/internal/tagparser/parser.go @@ -0,0 +1,147 @@ +package tagparser + +import ( + "strings" +) + +type Tag struct { + Name string + Options map[string]string +} + +func (t Tag) HasOption(name string) bool { + _, ok := t.Options[name] + return ok +} + +func Parse(s string) Tag { + p := parser{ + s: s, + } + p.parse() + return p.tag +} + +type parser struct { + s string + i int + + tag Tag + seenName bool // for empty names +} + +func (p *parser) setName(name string) { + if p.seenName { + p.addOption(name, "") + } else { + p.seenName = true + p.tag.Name = name + } +} + +func (p *parser) addOption(key, value string) { + p.seenName = true + if key == "" { + return + } + if p.tag.Options == nil { + p.tag.Options = make(map[string]string) + } + p.tag.Options[key] = value +} + +func (p *parser) parse() { + for p.valid() { + p.parseKeyValue() + if p.peek() == ',' { + p.i++ + } + } +} + +func (p *parser) parseKeyValue() { + start := p.i + + for p.valid() { + switch c := p.read(); c { + case ',': + key := p.s[start : p.i-1] + p.setName(key) + return + case ':': + key := p.s[start : p.i-1] + value := p.parseValue() + p.addOption(key, value) + return + case '"': + key := p.parseQuotedValue() + p.setName(key) + return + } + } + + key := p.s[start:p.i] + p.setName(key) +} + +func (p *parser) parseValue() string { + start := p.i + + for p.valid() { + switch c := p.read(); c { + case '"': + return p.parseQuotedValue() + case ',': + return p.s[start : p.i-1] + } + } + + if p.i == start { + return "" + } + return p.s[start:p.i] +} + +func (p *parser) parseQuotedValue() string { + if i := strings.IndexByte(p.s[p.i:], '"'); i >= 0 && p.s[p.i+i-1] != '\\' { + s := p.s[p.i : p.i+i] + p.i += i + 1 + return s + } + + b := make([]byte, 0, 16) + + for p.valid() { + switch c := p.read(); c { + case '\\': + b = append(b, p.read()) + case '"': + return string(b) + default: + b = append(b, c) + } + } + + return "" +} + +func (p *parser) valid() bool { + return p.i < len(p.s) +} + +func (p *parser) read() byte { + if !p.valid() { + return 0 + } + c := p.s[p.i] + p.i++ + return c +} + +func (p *parser) peek() byte { + if !p.valid() { + return 0 + } + c := p.s[p.i] + return c +} diff --git a/vendor/github.com/uptrace/bun/internal/time.go b/vendor/github.com/uptrace/bun/internal/time.go new file mode 100644 index 000000000..e4e0804b0 --- /dev/null +++ b/vendor/github.com/uptrace/bun/internal/time.go @@ -0,0 +1,41 @@ +package internal + +import ( + "fmt" + "time" +) + +const ( + dateFormat = "2006-01-02" + timeFormat = "15:04:05.999999999" + timestampFormat = "2006-01-02 15:04:05.999999999" + timestamptzFormat = "2006-01-02 15:04:05.999999999-07:00:00" + timestamptzFormat2 = "2006-01-02 15:04:05.999999999-07:00" + timestamptzFormat3 = "2006-01-02 15:04:05.999999999-07" +) + +func ParseTime(s string) (time.Time, error) { + switch l := len(s); { + case l < len("15:04:05"): + return time.Time{}, fmt.Errorf("bun: can't parse time=%q", s) + case l <= len(timeFormat): + if s[2] == ':' { + return time.ParseInLocation(timeFormat, s, time.UTC) + } + return time.ParseInLocation(dateFormat, s, time.UTC) + default: + if s[10] == 'T' { + return time.Parse(time.RFC3339Nano, s) + } + if c := s[l-9]; c == '+' || c == '-' { + return time.Parse(timestamptzFormat, s) + } + if c := s[l-6]; c == '+' || c == '-' { + return time.Parse(timestamptzFormat2, s) + } + if c := s[l-3]; c == '+' || c == '-' { + return time.Parse(timestamptzFormat3, s) + } + return time.ParseInLocation(timestampFormat, s, time.UTC) + } +} diff --git a/vendor/github.com/uptrace/bun/internal/underscore.go b/vendor/github.com/uptrace/bun/internal/underscore.go new file mode 100644 index 000000000..9de52fb7b --- /dev/null +++ b/vendor/github.com/uptrace/bun/internal/underscore.go @@ -0,0 +1,67 @@ +package internal + +func IsUpper(c byte) bool { + return c >= 'A' && c <= 'Z' +} + +func IsLower(c byte) bool { + return c >= 'a' && c <= 'z' +} + +func ToUpper(c byte) byte { + return c - 32 +} + +func ToLower(c byte) byte { + return c + 32 +} + +// Underscore converts "CamelCasedString" to "camel_cased_string". +func Underscore(s string) string { + r := make([]byte, 0, len(s)+5) + for i := 0; i < len(s); i++ { + c := s[i] + if IsUpper(c) { + if i > 0 && i+1 < len(s) && (IsLower(s[i-1]) || IsLower(s[i+1])) { + r = append(r, '_', ToLower(c)) + } else { + r = append(r, ToLower(c)) + } + } else { + r = append(r, c) + } + } + return string(r) +} + +func CamelCased(s string) string { + r := make([]byte, 0, len(s)) + upperNext := true + for i := 0; i < len(s); i++ { + c := s[i] + if c == '_' { + upperNext = true + continue + } + if upperNext { + if IsLower(c) { + c = ToUpper(c) + } + upperNext = false + } + r = append(r, c) + } + return string(r) +} + +func ToExported(s string) string { + if len(s) == 0 { + return s + } + if c := s[0]; IsLower(c) { + b := []byte(s) + b[0] = ToUpper(c) + return string(b) + } + return s +} diff --git a/vendor/github.com/uptrace/bun/internal/unsafe.go b/vendor/github.com/uptrace/bun/internal/unsafe.go new file mode 100644 index 000000000..4bc79701f --- /dev/null +++ b/vendor/github.com/uptrace/bun/internal/unsafe.go @@ -0,0 +1,20 @@ +// +build !appengine + +package internal + +import "unsafe" + +// String converts byte slice to string. +func String(b []byte) string { + return *(*string)(unsafe.Pointer(&b)) +} + +// Bytes converts string to byte slice. +func Bytes(s string) []byte { + return *(*[]byte)(unsafe.Pointer( + &struct { + string + Cap int + }{s, len(s)}, + )) +} diff --git a/vendor/github.com/uptrace/bun/internal/util.go b/vendor/github.com/uptrace/bun/internal/util.go new file mode 100644 index 000000000..c831dc659 --- /dev/null +++ b/vendor/github.com/uptrace/bun/internal/util.go @@ -0,0 +1,57 @@ +package internal + +import ( + "reflect" +) + +func MakeSliceNextElemFunc(v reflect.Value) func() reflect.Value { + if v.Kind() == reflect.Array { + var pos int + return func() reflect.Value { + v := v.Index(pos) + pos++ + return v + } + } + + elemType := v.Type().Elem() + + if elemType.Kind() == reflect.Ptr { + elemType = elemType.Elem() + return func() reflect.Value { + if v.Len() < v.Cap() { + v.Set(v.Slice(0, v.Len()+1)) + elem := v.Index(v.Len() - 1) + if elem.IsNil() { + elem.Set(reflect.New(elemType)) + } + return elem.Elem() + } + + elem := reflect.New(elemType) + v.Set(reflect.Append(v, elem)) + return elem.Elem() + } + } + + zero := reflect.Zero(elemType) + return func() reflect.Value { + if v.Len() < v.Cap() { + v.Set(v.Slice(0, v.Len()+1)) + return v.Index(v.Len() - 1) + } + + v.Set(reflect.Append(v, zero)) + return v.Index(v.Len() - 1) + } +} + +func Unwrap(err error) error { + u, ok := err.(interface { + Unwrap() error + }) + if !ok { + return nil + } + return u.Unwrap() +} diff --git a/vendor/github.com/uptrace/bun/join.go b/vendor/github.com/uptrace/bun/join.go new file mode 100644 index 000000000..4557f5bc0 --- /dev/null +++ b/vendor/github.com/uptrace/bun/join.go @@ -0,0 +1,308 @@ +package bun + +import ( + "context" + "reflect" + + "github.com/uptrace/bun/internal" + "github.com/uptrace/bun/schema" +) + +type join struct { + Parent *join + BaseModel tableModel + JoinModel tableModel + Relation *schema.Relation + + ApplyQueryFunc func(*SelectQuery) *SelectQuery + columns []schema.QueryWithArgs +} + +func (j *join) applyQuery(q *SelectQuery) { + if j.ApplyQueryFunc == nil { + return + } + + var table *schema.Table + var columns []schema.QueryWithArgs + + // Save state. + table, q.table = q.table, j.JoinModel.Table() + columns, q.columns = q.columns, nil + + q = j.ApplyQueryFunc(q) + + // Restore state. + q.table = table + j.columns, q.columns = q.columns, columns +} + +func (j *join) Select(ctx context.Context, q *SelectQuery) error { + switch j.Relation.Type { + case schema.HasManyRelation: + return j.selectMany(ctx, q) + case schema.ManyToManyRelation: + return j.selectM2M(ctx, q) + } + panic("not reached") +} + +func (j *join) selectMany(ctx context.Context, q *SelectQuery) error { + q = j.manyQuery(q) + if q == nil { + return nil + } + return q.Scan(ctx) +} + +func (j *join) manyQuery(q *SelectQuery) *SelectQuery { + hasManyModel := newHasManyModel(j) + if hasManyModel == nil { + return nil + } + + q = q.Model(hasManyModel) + + var where []byte + if len(j.Relation.JoinFields) > 1 { + where = append(where, '(') + } + where = appendColumns(where, j.JoinModel.Table().SQLAlias, j.Relation.JoinFields) + if len(j.Relation.JoinFields) > 1 { + where = append(where, ')') + } + where = append(where, " IN ("...) + where = appendChildValues( + q.db.Formatter(), + where, + j.JoinModel.Root(), + j.JoinModel.ParentIndex(), + j.Relation.BaseFields, + ) + where = append(where, ")"...) + q = q.Where(internal.String(where)) + + if j.Relation.PolymorphicField != nil { + q = q.Where("? = ?", j.Relation.PolymorphicField.SQLName, j.Relation.PolymorphicValue) + } + + j.applyQuery(q) + q = q.Apply(j.hasManyColumns) + + return q +} + +func (j *join) hasManyColumns(q *SelectQuery) *SelectQuery { + if j.Relation.M2MTable != nil { + q = q.ColumnExpr(string(j.Relation.M2MTable.SQLAlias) + ".*") + } + + b := make([]byte, 0, 32) + + if len(j.columns) > 0 { + for i, col := range j.columns { + if i > 0 { + b = append(b, ", "...) + } + + var err error + b, err = col.AppendQuery(q.db.fmter, b) + if err != nil { + q.err = err + return q + } + } + } else { + joinTable := j.JoinModel.Table() + b = appendColumns(b, joinTable.SQLAlias, joinTable.Fields) + } + + q = q.ColumnExpr(internal.String(b)) + + return q +} + +func (j *join) selectM2M(ctx context.Context, q *SelectQuery) error { + q = j.m2mQuery(q) + if q == nil { + return nil + } + return q.Scan(ctx) +} + +func (j *join) m2mQuery(q *SelectQuery) *SelectQuery { + fmter := q.db.fmter + + m2mModel := newM2MModel(j) + if m2mModel == nil { + return nil + } + q = q.Model(m2mModel) + + index := j.JoinModel.ParentIndex() + baseTable := j.BaseModel.Table() + + //nolint + var join []byte + join = append(join, "JOIN "...) + join = fmter.AppendQuery(join, string(j.Relation.M2MTable.Name)) + join = append(join, " AS "...) + join = append(join, j.Relation.M2MTable.SQLAlias...) + join = append(join, " ON ("...) + for i, col := range j.Relation.M2MBaseFields { + if i > 0 { + join = append(join, ", "...) + } + join = append(join, j.Relation.M2MTable.SQLAlias...) + join = append(join, '.') + join = append(join, col.SQLName...) + } + join = append(join, ") IN ("...) + join = appendChildValues(fmter, join, j.BaseModel.Root(), index, baseTable.PKs) + join = append(join, ")"...) + q = q.Join(internal.String(join)) + + joinTable := j.JoinModel.Table() + for i, m2mJoinField := range j.Relation.M2MJoinFields { + joinField := j.Relation.JoinFields[i] + q = q.Where("?.? = ?.?", + joinTable.SQLAlias, joinField.SQLName, + j.Relation.M2MTable.SQLAlias, m2mJoinField.SQLName) + } + + j.applyQuery(q) + q = q.Apply(j.hasManyColumns) + + return q +} + +func (j *join) hasParent() bool { + if j.Parent != nil { + switch j.Parent.Relation.Type { + case schema.HasOneRelation, schema.BelongsToRelation: + return true + } + } + return false +} + +func (j *join) appendAlias(fmter schema.Formatter, b []byte) []byte { + quote := fmter.IdentQuote() + + b = append(b, quote) + b = appendAlias(b, j) + b = append(b, quote) + return b +} + +func (j *join) appendAliasColumn(fmter schema.Formatter, b []byte, column string) []byte { + quote := fmter.IdentQuote() + + b = append(b, quote) + b = appendAlias(b, j) + b = append(b, "__"...) + b = append(b, column...) + b = append(b, quote) + return b +} + +func (j *join) appendBaseAlias(fmter schema.Formatter, b []byte) []byte { + quote := fmter.IdentQuote() + + if j.hasParent() { + b = append(b, quote) + b = appendAlias(b, j.Parent) + b = append(b, quote) + return b + } + return append(b, j.BaseModel.Table().SQLAlias...) +} + +func (j *join) appendSoftDelete(b []byte, flags internal.Flag) []byte { + b = append(b, '.') + b = append(b, j.JoinModel.Table().SoftDeleteField.SQLName...) + if flags.Has(deletedFlag) { + b = append(b, " IS NOT NULL"...) + } else { + b = append(b, " IS NULL"...) + } + return b +} + +func appendAlias(b []byte, j *join) []byte { + if j.hasParent() { + b = appendAlias(b, j.Parent) + b = append(b, "__"...) + } + b = append(b, j.Relation.Field.Name...) + return b +} + +func (j *join) appendHasOneJoin( + fmter schema.Formatter, b []byte, q *SelectQuery, +) (_ []byte, err error) { + isSoftDelete := j.JoinModel.Table().SoftDeleteField != nil && !q.flags.Has(allWithDeletedFlag) + + b = append(b, "LEFT JOIN "...) + b = fmter.AppendQuery(b, string(j.JoinModel.Table().SQLNameForSelects)) + b = append(b, " AS "...) + b = j.appendAlias(fmter, b) + + b = append(b, " ON "...) + + b = append(b, '(') + for i, baseField := range j.Relation.BaseFields { + if i > 0 { + b = append(b, " AND "...) + } + b = j.appendAlias(fmter, b) + b = append(b, '.') + b = append(b, j.Relation.JoinFields[i].SQLName...) + b = append(b, " = "...) + b = j.appendBaseAlias(fmter, b) + b = append(b, '.') + b = append(b, baseField.SQLName...) + } + b = append(b, ')') + + if isSoftDelete { + b = append(b, " AND "...) + b = j.appendAlias(fmter, b) + b = j.appendSoftDelete(b, q.flags) + } + + return b, nil +} + +func appendChildValues( + fmter schema.Formatter, b []byte, v reflect.Value, index []int, fields []*schema.Field, +) []byte { + seen := make(map[string]struct{}) + walk(v, index, func(v reflect.Value) { + start := len(b) + + if len(fields) > 1 { + b = append(b, '(') + } + for i, f := range fields { + if i > 0 { + b = append(b, ", "...) + } + b = f.AppendValue(fmter, b, v) + } + if len(fields) > 1 { + b = append(b, ')') + } + b = append(b, ", "...) + + if _, ok := seen[string(b[start:])]; ok { + b = b[:start] + } else { + seen[string(b[start:])] = struct{}{} + } + }) + if len(seen) > 0 { + b = b[:len(b)-2] // trim ", " + } + return b +} diff --git a/vendor/github.com/uptrace/bun/model.go b/vendor/github.com/uptrace/bun/model.go new file mode 100644 index 000000000..c9f0f3583 --- /dev/null +++ b/vendor/github.com/uptrace/bun/model.go @@ -0,0 +1,207 @@ +package bun + +import ( + "context" + "database/sql" + "errors" + "fmt" + "reflect" + "time" + + "github.com/uptrace/bun/schema" +) + +var errNilModel = errors.New("bun: Model(nil)") + +var timeType = reflect.TypeOf((*time.Time)(nil)).Elem() + +type Model interface { + ScanRows(ctx context.Context, rows *sql.Rows) (int, error) + Value() interface{} +} + +type rowScanner interface { + ScanRow(ctx context.Context, rows *sql.Rows) error +} + +type model interface { + Model +} + +type tableModel interface { + model + + schema.BeforeScanHook + schema.AfterScanHook + ScanColumn(column string, src interface{}) error + + Table() *schema.Table + Relation() *schema.Relation + + Join(string, func(*SelectQuery) *SelectQuery) *join + GetJoin(string) *join + GetJoins() []join + AddJoin(join) *join + + Root() reflect.Value + ParentIndex() []int + Mount(reflect.Value) + + updateSoftDeleteField() error +} + +func newModel(db *DB, dest []interface{}) (model, error) { + if len(dest) == 1 { + return _newModel(db, dest[0], true) + } + + values := make([]reflect.Value, len(dest)) + + for i, el := range dest { + v := reflect.ValueOf(el) + if v.Kind() != reflect.Ptr { + return nil, fmt.Errorf("bun: Scan(non-pointer %T)", dest) + } + + v = v.Elem() + if v.Kind() != reflect.Slice { + return newScanModel(db, dest), nil + } + + values[i] = v + } + + return newSliceModel(db, dest, values), nil +} + +func newSingleModel(db *DB, dest interface{}) (model, error) { + return _newModel(db, dest, false) +} + +func _newModel(db *DB, dest interface{}, scan bool) (model, error) { + switch dest := dest.(type) { + case nil: + return nil, errNilModel + case Model: + return dest, nil + case sql.Scanner: + if !scan { + return nil, fmt.Errorf("bun: Model(unsupported %T)", dest) + } + return newScanModel(db, []interface{}{dest}), nil + } + + v := reflect.ValueOf(dest) + if !v.IsValid() { + return nil, errNilModel + } + if v.Kind() != reflect.Ptr { + return nil, fmt.Errorf("bun: Model(non-pointer %T)", dest) + } + + if v.IsNil() { + typ := v.Type().Elem() + if typ.Kind() == reflect.Struct { + return newStructTableModel(db, dest, db.Table(typ)), nil + } + return nil, fmt.Errorf("bun: Model(nil %T)", dest) + } + + v = v.Elem() + + switch v.Kind() { + case reflect.Map: + typ := v.Type() + if err := validMap(typ); err != nil { + return nil, err + } + mapPtr := v.Addr().Interface().(*map[string]interface{}) + return newMapModel(db, mapPtr), nil + case reflect.Struct: + if v.Type() != timeType { + return newStructTableModelValue(db, dest, v), nil + } + case reflect.Slice: + switch elemType := sliceElemType(v); elemType.Kind() { + case reflect.Struct: + if elemType != timeType { + return newSliceTableModel(db, dest, v, elemType), nil + } + case reflect.Map: + if err := validMap(elemType); err != nil { + return nil, err + } + slicePtr := v.Addr().Interface().(*[]map[string]interface{}) + return newMapSliceModel(db, slicePtr), nil + } + return newSliceModel(db, []interface{}{dest}, []reflect.Value{v}), nil + } + + if scan { + return newScanModel(db, []interface{}{dest}), nil + } + + return nil, fmt.Errorf("bun: Model(unsupported %T)", dest) +} + +func newTableModelIndex( + db *DB, + table *schema.Table, + root reflect.Value, + index []int, + rel *schema.Relation, +) (tableModel, error) { + typ := typeByIndex(table.Type, index) + + if typ.Kind() == reflect.Struct { + return &structTableModel{ + db: db, + table: table.Dialect().Tables().Get(typ), + rel: rel, + + root: root, + index: index, + }, nil + } + + if typ.Kind() == reflect.Slice { + structType := indirectType(typ.Elem()) + if structType.Kind() == reflect.Struct { + m := sliceTableModel{ + structTableModel: structTableModel{ + db: db, + table: table.Dialect().Tables().Get(structType), + rel: rel, + + root: root, + index: index, + }, + } + m.init(typ) + return &m, nil + } + } + + return nil, fmt.Errorf("bun: NewModel(%s)", typ) +} + +func validMap(typ reflect.Type) error { + if typ.Key().Kind() != reflect.String || typ.Elem().Kind() != reflect.Interface { + return fmt.Errorf("bun: Model(unsupported %s) (expected *map[string]interface{})", + typ) + } + return nil +} + +//------------------------------------------------------------------------------ + +func isSingleRowModel(m model) bool { + switch m.(type) { + case *mapModel, + *structTableModel, + *scanModel: + return true + default: + return false + } +} diff --git a/vendor/github.com/uptrace/bun/model_map.go b/vendor/github.com/uptrace/bun/model_map.go new file mode 100644 index 000000000..81c1a4a3b --- /dev/null +++ b/vendor/github.com/uptrace/bun/model_map.go @@ -0,0 +1,183 @@ +package bun + +import ( + "context" + "database/sql" + "reflect" + "sort" + + "github.com/uptrace/bun/schema" +) + +type mapModel struct { + db *DB + + dest *map[string]interface{} + m map[string]interface{} + + rows *sql.Rows + columns []string + _columnTypes []*sql.ColumnType + scanIndex int +} + +var _ model = (*mapModel)(nil) + +func newMapModel(db *DB, dest *map[string]interface{}) *mapModel { + m := &mapModel{ + db: db, + dest: dest, + } + if dest != nil { + m.m = *dest + } + return m +} + +func (m *mapModel) Value() interface{} { + return m.dest +} + +func (m *mapModel) ScanRows(ctx context.Context, rows *sql.Rows) (int, error) { + if !rows.Next() { + return 0, rows.Err() + } + + columns, err := rows.Columns() + if err != nil { + return 0, err + } + + m.rows = rows + m.columns = columns + dest := makeDest(m, len(columns)) + + if m.m == nil { + m.m = make(map[string]interface{}, len(m.columns)) + } + + m.scanIndex = 0 + if err := rows.Scan(dest...); err != nil { + return 0, err + } + + *m.dest = m.m + + return 1, nil +} + +func (m *mapModel) Scan(src interface{}) error { + if _, ok := src.([]byte); !ok { + return m.scanRaw(src) + } + + columnTypes, err := m.columnTypes() + if err != nil { + return err + } + + scanType := columnTypes[m.scanIndex].ScanType() + switch scanType.Kind() { + case reflect.Interface: + return m.scanRaw(src) + case reflect.Slice: + if scanType.Elem().Kind() == reflect.Uint8 { + return m.scanRaw(src) + } + } + + dest := reflect.New(scanType).Elem() + if err := schema.Scanner(scanType)(dest, src); err != nil { + return err + } + + return m.scanRaw(dest.Interface()) +} + +func (m *mapModel) columnTypes() ([]*sql.ColumnType, error) { + if m._columnTypes == nil { + columnTypes, err := m.rows.ColumnTypes() + if err != nil { + return nil, err + } + m._columnTypes = columnTypes + } + return m._columnTypes, nil +} + +func (m *mapModel) scanRaw(src interface{}) error { + columnName := m.columns[m.scanIndex] + m.scanIndex++ + m.m[columnName] = src + return nil +} + +func (m *mapModel) appendColumnsValues(fmter schema.Formatter, b []byte) []byte { + keys := make([]string, 0, len(m.m)) + + for k := range m.m { + keys = append(keys, k) + } + sort.Strings(keys) + + b = append(b, " ("...) + + for i, k := range keys { + if i > 0 { + b = append(b, ", "...) + } + b = fmter.AppendIdent(b, k) + } + + b = append(b, ") VALUES ("...) + + isTemplate := fmter.IsNop() + for i, k := range keys { + if i > 0 { + b = append(b, ", "...) + } + if isTemplate { + b = append(b, '?') + } else { + b = fmter.Dialect().Append(fmter, b, m.m[k]) + } + } + + b = append(b, ")"...) + + return b +} + +func (m *mapModel) appendSet(fmter schema.Formatter, b []byte) []byte { + keys := make([]string, 0, len(m.m)) + + for k := range m.m { + keys = append(keys, k) + } + sort.Strings(keys) + + isTemplate := fmter.IsNop() + for i, k := range keys { + if i > 0 { + b = append(b, ", "...) + } + + b = fmter.AppendIdent(b, k) + b = append(b, " = "...) + if isTemplate { + b = append(b, '?') + } else { + b = fmter.Dialect().Append(fmter, b, m.m[k]) + } + } + + return b +} + +func makeDest(v interface{}, n int) []interface{} { + dest := make([]interface{}, n) + for i := range dest { + dest[i] = v + } + return dest +} diff --git a/vendor/github.com/uptrace/bun/model_map_slice.go b/vendor/github.com/uptrace/bun/model_map_slice.go new file mode 100644 index 000000000..5c6f48e44 --- /dev/null +++ b/vendor/github.com/uptrace/bun/model_map_slice.go @@ -0,0 +1,162 @@ +package bun + +import ( + "context" + "database/sql" + "errors" + "sort" + + "github.com/uptrace/bun/dialect/feature" + "github.com/uptrace/bun/schema" +) + +type mapSliceModel struct { + mapModel + dest *[]map[string]interface{} + + keys []string +} + +var _ model = (*mapSliceModel)(nil) + +func newMapSliceModel(db *DB, dest *[]map[string]interface{}) *mapSliceModel { + return &mapSliceModel{ + mapModel: mapModel{ + db: db, + }, + dest: dest, + } +} + +func (m *mapSliceModel) Value() interface{} { + return m.dest +} + +func (m *mapSliceModel) SetCap(cap int) { + if cap > 100 { + cap = 100 + } + if slice := *m.dest; len(slice) < cap { + *m.dest = make([]map[string]interface{}, 0, cap) + } +} + +func (m *mapSliceModel) ScanRows(ctx context.Context, rows *sql.Rows) (int, error) { + columns, err := rows.Columns() + if err != nil { + return 0, err + } + + m.rows = rows + m.columns = columns + dest := makeDest(m, len(columns)) + + slice := *m.dest + if len(slice) > 0 { + slice = slice[:0] + } + + var n int + + for rows.Next() { + m.m = make(map[string]interface{}, len(m.columns)) + + m.scanIndex = 0 + if err := rows.Scan(dest...); err != nil { + return 0, err + } + + slice = append(slice, m.m) + n++ + } + if err := rows.Err(); err != nil { + return 0, err + } + + *m.dest = slice + return n, nil +} + +func (m *mapSliceModel) appendColumns(fmter schema.Formatter, b []byte) (_ []byte, err error) { + if err := m.initKeys(); err != nil { + return nil, err + } + + for i, k := range m.keys { + if i > 0 { + b = append(b, ", "...) + } + b = fmter.AppendIdent(b, k) + } + + return b, nil +} + +func (m *mapSliceModel) appendValues(fmter schema.Formatter, b []byte) (_ []byte, err error) { + if err := m.initKeys(); err != nil { + return nil, err + } + slice := *m.dest + + b = append(b, "VALUES "...) + if m.db.features.Has(feature.ValuesRow) { + b = append(b, "ROW("...) + } else { + b = append(b, '(') + } + + if fmter.IsNop() { + for i := range m.keys { + if i > 0 { + b = append(b, ", "...) + } + b = append(b, '?') + } + return b, nil + } + + for i, el := range slice { + if i > 0 { + b = append(b, "), "...) + if m.db.features.Has(feature.ValuesRow) { + b = append(b, "ROW("...) + } else { + b = append(b, '(') + } + } + + for j, key := range m.keys { + if j > 0 { + b = append(b, ", "...) + } + b = fmter.Dialect().Append(fmter, b, el[key]) + } + } + + b = append(b, ')') + + return b, nil +} + +func (m *mapSliceModel) initKeys() error { + if m.keys != nil { + return nil + } + + slice := *m.dest + if len(slice) == 0 { + return errors.New("bun: map slice is empty") + } + + first := slice[0] + keys := make([]string, 0, len(first)) + + for k := range first { + keys = append(keys, k) + } + + sort.Strings(keys) + m.keys = keys + + return nil +} diff --git a/vendor/github.com/uptrace/bun/model_scan.go b/vendor/github.com/uptrace/bun/model_scan.go new file mode 100644 index 000000000..6dd061fb2 --- /dev/null +++ b/vendor/github.com/uptrace/bun/model_scan.go @@ -0,0 +1,54 @@ +package bun + +import ( + "context" + "database/sql" + "reflect" +) + +type scanModel struct { + db *DB + + dest []interface{} + scanIndex int +} + +var _ model = (*scanModel)(nil) + +func newScanModel(db *DB, dest []interface{}) *scanModel { + return &scanModel{ + db: db, + dest: dest, + } +} + +func (m *scanModel) Value() interface{} { + return m.dest +} + +func (m *scanModel) ScanRows(ctx context.Context, rows *sql.Rows) (int, error) { + if !rows.Next() { + return 0, rows.Err() + } + + dest := makeDest(m, len(m.dest)) + + m.scanIndex = 0 + if err := rows.Scan(dest...); err != nil { + return 0, err + } + + return 1, nil +} + +func (m *scanModel) ScanRow(ctx context.Context, rows *sql.Rows) error { + return rows.Scan(m.dest...) +} + +func (m *scanModel) Scan(src interface{}) error { + dest := reflect.ValueOf(m.dest[m.scanIndex]) + m.scanIndex++ + + scanner := m.db.dialect.Scanner(dest.Type()) + return scanner(dest, src) +} diff --git a/vendor/github.com/uptrace/bun/model_slice.go b/vendor/github.com/uptrace/bun/model_slice.go new file mode 100644 index 000000000..afe804382 --- /dev/null +++ b/vendor/github.com/uptrace/bun/model_slice.go @@ -0,0 +1,82 @@ +package bun + +import ( + "context" + "database/sql" + "reflect" + + "github.com/uptrace/bun/internal" + "github.com/uptrace/bun/schema" +) + +type sliceInfo struct { + nextElem func() reflect.Value + scan schema.ScannerFunc +} + +type sliceModel struct { + dest []interface{} + values []reflect.Value + scanIndex int + info []sliceInfo +} + +var _ model = (*sliceModel)(nil) + +func newSliceModel(db *DB, dest []interface{}, values []reflect.Value) *sliceModel { + return &sliceModel{ + dest: dest, + values: values, + } +} + +func (m *sliceModel) Value() interface{} { + return m.dest +} + +func (m *sliceModel) ScanRows(ctx context.Context, rows *sql.Rows) (int, error) { + columns, err := rows.Columns() + if err != nil { + return 0, err + } + + m.info = make([]sliceInfo, len(m.values)) + for i, v := range m.values { + if v.IsValid() && v.Len() > 0 { + v.Set(v.Slice(0, 0)) + } + + m.info[i] = sliceInfo{ + nextElem: internal.MakeSliceNextElemFunc(v), + scan: schema.Scanner(v.Type().Elem()), + } + } + + if len(columns) == 0 { + return 0, nil + } + dest := makeDest(m, len(columns)) + + var n int + + for rows.Next() { + m.scanIndex = 0 + if err := rows.Scan(dest...); err != nil { + return 0, err + } + n++ + } + if err := rows.Err(); err != nil { + return 0, err + } + + return n, nil +} + +func (m *sliceModel) Scan(src interface{}) error { + info := m.info[m.scanIndex] + m.scanIndex++ + + dest := info.nextElem() + return info.scan(dest, src) +} diff --git a/vendor/github.com/uptrace/bun/model_table_has_many.go b/vendor/github.com/uptrace/bun/model_table_has_many.go new file mode 100644 index 000000000..e64b7088d --- /dev/null +++ b/vendor/github.com/uptrace/bun/model_table_has_many.go @@ -0,0 +1,149 @@ +package bun + +import ( + "context" + "database/sql" + "fmt" + "reflect" + + "github.com/uptrace/bun/internal" + "github.com/uptrace/bun/schema" +) + +type hasManyModel struct { + *sliceTableModel + baseTable *schema.Table + rel *schema.Relation + + baseValues map[internal.MapKey][]reflect.Value + structKey []interface{} +} + +var _ tableModel = (*hasManyModel)(nil) + +func newHasManyModel(j *join) *hasManyModel { + baseTable := j.BaseModel.Table() + joinModel := j.JoinModel.(*sliceTableModel) + baseValues := baseValues(joinModel, j.Relation.BaseFields) + if len(baseValues) == 0 { + return nil + } + m := hasManyModel{ + sliceTableModel: joinModel, + baseTable: baseTable, + rel: j.Relation, + + baseValues: baseValues, + } + if !m.sliceOfPtr { + m.strct = reflect.New(m.table.Type).Elem() + } + return &m +} + +func (m *hasManyModel) ScanRows(ctx context.Context, rows *sql.Rows) (int, error) { + columns, err := rows.Columns() + if err != nil { + return 0, err + } + + m.columns = columns + dest := makeDest(m, len(columns)) + + var n int + + for rows.Next() { + if m.sliceOfPtr { + m.strct = reflect.New(m.table.Type).Elem() + } else { + m.strct.Set(m.table.ZeroValue) + } + m.structInited = false + + m.scanIndex = 0 + m.structKey = m.structKey[:0] + if err := rows.Scan(dest...); err != nil { + return 0, err + } + + if err := m.parkStruct(); err != nil { + return 0, err + } + + n++ + } + if err := rows.Err(); err != nil { + return 0, err + } + + return n, nil +} + +func (m *hasManyModel) Scan(src interface{}) error { + column := m.columns[m.scanIndex] + m.scanIndex++ + + field, err := m.table.Field(column) + if err != nil { + return err + } + + if err := field.ScanValue(m.strct, src); err != nil { + return err + } + + for _, f := range m.rel.JoinFields { + if f.Name == field.Name { + m.structKey = append(m.structKey, field.Value(m.strct).Interface()) + break + } + } + + return nil +} + +func (m *hasManyModel) parkStruct() error { + baseValues, ok := m.baseValues[internal.NewMapKey(m.structKey)] + if !ok { + return fmt.Errorf( + "bun: has-many relation=%s does not have base %s with id=%q (check join conditions)", + m.rel.Field.GoName, m.baseTable, m.structKey) + } + + for i, v := range baseValues { + if !m.sliceOfPtr { + v.Set(reflect.Append(v, m.strct)) + continue + } + + if i == 0 { + v.Set(reflect.Append(v, m.strct.Addr())) + continue + } + + clone := reflect.New(m.strct.Type()).Elem() + clone.Set(m.strct) + v.Set(reflect.Append(v, clone.Addr())) + } + + return nil +} + +func baseValues(model tableModel, fields []*schema.Field) map[internal.MapKey][]reflect.Value { + fieldIndex := model.Relation().Field.Index + m := make(map[internal.MapKey][]reflect.Value) + key := make([]interface{}, 0, len(fields)) + walk(model.Root(), model.ParentIndex(), func(v reflect.Value) { + key = modelKey(key[:0], v, fields) + mapKey := internal.NewMapKey(key) + m[mapKey] = append(m[mapKey], v.FieldByIndex(fieldIndex)) + }) + return m +} + +func modelKey(key []interface{}, strct reflect.Value, fields []*schema.Field) []interface{} { + for _, f := range fields { + key = append(key, f.Value(strct).Interface()) + } + return key +} diff --git a/vendor/github.com/uptrace/bun/model_table_m2m.go b/vendor/github.com/uptrace/bun/model_table_m2m.go new file mode 100644 index 000000000..4357e3a8e --- /dev/null +++ b/vendor/github.com/uptrace/bun/model_table_m2m.go @@ -0,0 +1,138 @@ +package bun + +import ( + "context" + "database/sql" + "fmt" + "reflect" + + "github.com/uptrace/bun/internal" + "github.com/uptrace/bun/schema" +) + +type m2mModel struct { + *sliceTableModel + baseTable *schema.Table + rel *schema.Relation + + baseValues map[internal.MapKey][]reflect.Value + structKey []interface{} +} + +var _ tableModel = (*m2mModel)(nil) + +func newM2MModel(j *join) *m2mModel { + baseTable := j.BaseModel.Table() + joinModel := j.JoinModel.(*sliceTableModel) + baseValues := baseValues(joinModel, baseTable.PKs) + if len(baseValues) == 0 { + return nil + } + m := &m2mModel{ + sliceTableModel: joinModel, + baseTable: baseTable, + rel: j.Relation, + + baseValues: baseValues, + } + if !m.sliceOfPtr { + m.strct = reflect.New(m.table.Type).Elem() + } + return m +} + +func (m *m2mModel) ScanRows(ctx context.Context, rows *sql.Rows) (int, error) { + columns, err := rows.Columns() + if err != nil { + return 0, err + } + + m.columns = columns + dest := makeDest(m, len(columns)) + + var n int + + for rows.Next() { + if m.sliceOfPtr { + m.strct = reflect.New(m.table.Type).Elem() + } else { + m.strct.Set(m.table.ZeroValue) + } + m.structInited = false + + m.scanIndex = 0 + m.structKey = m.structKey[:0] + if err := rows.Scan(dest...); err != nil { + return 0, err + } + + if err := m.parkStruct(); err != nil { + return 0, err + } + + n++ + } + if err := rows.Err(); err != nil { + return 0, err + } + + return n, nil +} + +func (m *m2mModel) Scan(src interface{}) error { + column := m.columns[m.scanIndex] + m.scanIndex++ + + field, ok := m.table.FieldMap[column] + if !ok { + return m.scanM2MColumn(column, src) + } + + if err := field.ScanValue(m.strct, src); err != nil { + return err + } + + for _, fk := range m.rel.M2MBaseFields { + if fk.Name == field.Name { + m.structKey = append(m.structKey, field.Value(m.strct).Interface()) + break + } + } + + return nil +} + +func (m *m2mModel) scanM2MColumn(column string, src interface{}) error { + for _, field := range m.rel.M2MBaseFields { + if field.Name == column { + dest := reflect.New(field.IndirectType).Elem() + if err := field.Scan(dest, src); err != nil { + return err + } + m.structKey = append(m.structKey, dest.Interface()) + break + } + } + + _, err := m.scanColumn(column, src) + return err +} + +func (m *m2mModel) parkStruct() error { + baseValues, ok := m.baseValues[internal.NewMapKey(m.structKey)] + if !ok { + return fmt.Errorf( + "bun: m2m relation=%s does not have base %s with key=%q (check join conditions)", + m.rel.Field.GoName, m.baseTable, m.structKey) + } + + for _, v := range baseValues { + if m.sliceOfPtr { + v.Set(reflect.Append(v, m.strct.Addr())) + } else { + v.Set(reflect.Append(v, m.strct)) + } + } + + return nil +} diff --git a/vendor/github.com/uptrace/bun/model_table_slice.go b/vendor/github.com/uptrace/bun/model_table_slice.go new file mode 100644 index 000000000..67e7c71e7 --- /dev/null +++ b/vendor/github.com/uptrace/bun/model_table_slice.go @@ -0,0 +1,113 @@ +package bun + +import ( + "context" + "database/sql" + "reflect" + + "github.com/uptrace/bun/schema" +) + +type sliceTableModel struct { + structTableModel + + slice reflect.Value + sliceLen int + sliceOfPtr bool + nextElem func() reflect.Value +} + +var _ tableModel = (*sliceTableModel)(nil) + +func newSliceTableModel( + db *DB, dest interface{}, slice reflect.Value, elemType reflect.Type, +) *sliceTableModel { + m := &sliceTableModel{ + structTableModel: structTableModel{ + db: db, + table: db.Table(elemType), + dest: dest, + root: slice, + }, + + slice: slice, + sliceLen: slice.Len(), + nextElem: makeSliceNextElemFunc(slice), + } + m.init(slice.Type()) + return m +} + +func (m *sliceTableModel) init(sliceType reflect.Type) { + switch sliceType.Elem().Kind() { + case reflect.Ptr, reflect.Interface: + m.sliceOfPtr = true + } +} + +func (m *sliceTableModel) Join(name string, apply func(*SelectQuery) *SelectQuery) *join { + return m.join(m.slice, name, apply) +} + +func (m *sliceTableModel) Bind(bind reflect.Value) { + m.slice = bind.Field(m.index[len(m.index)-1]) +} + +func (m *sliceTableModel) SetCap(cap int) { + if cap > 100 { + cap = 100 + } + if m.slice.Cap() < cap { + m.slice.Set(reflect.MakeSlice(m.slice.Type(), 0, cap)) + } +} + +func (m *sliceTableModel) ScanRows(ctx context.Context, rows *sql.Rows) (int, error) { + columns, err := rows.Columns() + if err != nil { + return 0, err + } + + m.columns = columns + dest := makeDest(m, len(columns)) + + if m.slice.IsValid() && m.slice.Len() > 0 { + m.slice.Set(m.slice.Slice(0, 0)) + } + + var n int + + for rows.Next() { + m.strct = m.nextElem() + m.structInited = false + + if err := m.scanRow(ctx, rows, dest); err != nil { + return 0, err + } + + n++ + } + if err := rows.Err(); err != nil { + return 0, err + } + + return n, nil +} + +// Inherit these hooks from structTableModel. +var ( + _ schema.BeforeScanHook = (*sliceTableModel)(nil) + _ schema.AfterScanHook = (*sliceTableModel)(nil) +) + +func (m *sliceTableModel) updateSoftDeleteField() error { + sliceLen := m.slice.Len() + for i := 0; i < sliceLen; i++ { + strct := indirect(m.slice.Index(i)) + fv := m.table.SoftDeleteField.Value(strct) + if err := m.table.UpdateSoftDeleteField(fv); err != nil { + return err + } + } + return nil +} diff --git a/vendor/github.com/uptrace/bun/model_table_struct.go b/vendor/github.com/uptrace/bun/model_table_struct.go new file mode 100644 index 000000000..3bb0c14dd --- /dev/null +++ b/vendor/github.com/uptrace/bun/model_table_struct.go @@ -0,0 +1,345 @@ +package bun + +import ( + "context" + "database/sql" + "fmt" + "reflect" + "strings" + + "github.com/uptrace/bun/dialect" + "github.com/uptrace/bun/schema" +) + +type structTableModel struct { + db *DB + table *schema.Table + + rel *schema.Relation + joins []join + + dest interface{} + root reflect.Value + index []int + + strct reflect.Value + structInited bool + structInitErr error + + columns []string + scanIndex int +} + +var _ tableModel = (*structTableModel)(nil) + +func newStructTableModel(db *DB, dest interface{}, table *schema.Table) *structTableModel { + return &structTableModel{ + db: db, + table: table, + dest: dest, + } +} + +func newStructTableModelValue(db *DB, dest interface{}, v reflect.Value) *structTableModel { + return &structTableModel{ + db: db, + table: db.Table(v.Type()), + dest: dest, + root: v, + strct: v, + } +} + +func (m *structTableModel) Value() interface{} { + return m.dest +} + +func (m *structTableModel) Table() *schema.Table { + return m.table +} + +func (m *structTableModel) Relation() *schema.Relation { + return m.rel +} + +func (m *structTableModel) Root() reflect.Value { + return m.root +} + +func (m *structTableModel) Index() []int { + return m.index +} + +func (m *structTableModel) ParentIndex() []int { + return m.index[:len(m.index)-len(m.rel.Field.Index)] +} + +func (m *structTableModel) Mount(host reflect.Value) { + m.strct = host.FieldByIndex(m.rel.Field.Index) + m.structInited = false +} + +func (m *structTableModel) initStruct() error { + if m.structInited { + return m.structInitErr + } + m.structInited = true + + switch m.strct.Kind() { + case reflect.Invalid: + m.structInitErr = errNilModel + return m.structInitErr + case reflect.Interface: + m.strct = m.strct.Elem() + } + + if m.strct.Kind() == reflect.Ptr { + if m.strct.IsNil() { + m.strct.Set(reflect.New(m.strct.Type().Elem())) + m.strct = m.strct.Elem() + } else { + m.strct = m.strct.Elem() + } + } + + m.mountJoins() + + return nil +} + +func (m *structTableModel) mountJoins() { + for i := range m.joins { + j := &m.joins[i] + switch j.Relation.Type { + case schema.HasOneRelation, schema.BelongsToRelation: + j.JoinModel.Mount(m.strct) + } + } +} + +var _ schema.BeforeScanHook = (*structTableModel)(nil) + +func (m *structTableModel) BeforeScan(ctx context.Context) error { + if !m.table.HasBeforeScanHook() { + return nil + } + return callBeforeScanHook(ctx, m.strct.Addr()) +} + +var _ schema.AfterScanHook = (*structTableModel)(nil) + +func (m *structTableModel) AfterScan(ctx context.Context) error { + if !m.table.HasAfterScanHook() || !m.structInited { + return nil + } + + var firstErr error + + if err := callAfterScanHook(ctx, m.strct.Addr()); err != nil && firstErr == nil { + firstErr = err + } + + for _, j := range m.joins { + switch j.Relation.Type { + case schema.HasOneRelation, schema.BelongsToRelation: + if err := j.JoinModel.AfterScan(ctx); err != nil && firstErr == nil { + firstErr = err + } + } + } + + return firstErr +} + +func (m *structTableModel) GetJoin(name string) *join { + for i := range m.joins { + j := &m.joins[i] + if j.Relation.Field.Name == name || j.Relation.Field.GoName == name { + return j + } + } + return nil +} + +func (m *structTableModel) GetJoins() []join { + return m.joins +} + +func (m *structTableModel) AddJoin(j join) *join { + m.joins = append(m.joins, j) + return &m.joins[len(m.joins)-1] +} + +func (m *structTableModel) Join(name string, apply func(*SelectQuery) *SelectQuery) *join { + return m.join(m.strct, name, apply) +} + +func (m *structTableModel) join( + bind reflect.Value, name string, apply func(*SelectQuery) *SelectQuery, +) *join { + path := strings.Split(name, ".") + index := make([]int, 0, len(path)) + + currJoin := join{ + BaseModel: m, + JoinModel: m, + } + var lastJoin *join + + for _, name := range path { + relation, ok := currJoin.JoinModel.Table().Relations[name] + if !ok { + return nil + } + + currJoin.Relation = relation + index = append(index, relation.Field.Index...) + + if j := currJoin.JoinModel.GetJoin(name); j != nil { + currJoin.BaseModel = j.BaseModel + currJoin.JoinModel = j.JoinModel + + lastJoin = j + } else { + model, err := newTableModelIndex(m.db, m.table, bind, index, relation) + if err != nil { + return nil + } + + currJoin.Parent = lastJoin + currJoin.BaseModel = currJoin.JoinModel + currJoin.JoinModel = model + + lastJoin = currJoin.BaseModel.AddJoin(currJoin) + } + } + + // No joins with such name. + if lastJoin == nil { + return nil + } + if apply != nil { + lastJoin.ApplyQueryFunc = apply + } + + return lastJoin +} + +func (m *structTableModel) updateSoftDeleteField() error { + fv := m.table.SoftDeleteField.Value(m.strct) + return m.table.UpdateSoftDeleteField(fv) +} + +func (m *structTableModel) ScanRows(ctx context.Context, rows *sql.Rows) (int, error) { + if !rows.Next() { + return 0, rows.Err() + } + + if err := m.ScanRow(ctx, rows); err != nil { + return 0, err + } + + // For inserts, SQLite3 can return a row like it was inserted sucessfully and then + // an actual error for the next row. See issues/100. + if m.db.dialect.Name() == dialect.SQLite { + _ = rows.Next() + if err := rows.Err(); err != nil { + return 0, err + } + } + + return 1, nil +} + +func (m *structTableModel) ScanRow(ctx context.Context, rows *sql.Rows) error { + columns, err := rows.Columns() + if err != nil { + return err + } + + m.columns = columns + dest := makeDest(m, len(columns)) + + return m.scanRow(ctx, rows, dest) +} + +func (m *structTableModel) scanRow(ctx context.Context, rows *sql.Rows, dest []interface{}) error { + if err := m.BeforeScan(ctx); err != nil { + return err + } + + m.scanIndex = 0 + if err := rows.Scan(dest...); err != nil { + return err + } + + if err := m.AfterScan(ctx); err != nil { + return err + } + + return nil +} + +func (m *structTableModel) Scan(src interface{}) error { + column := m.columns[m.scanIndex] + m.scanIndex++ + + return m.ScanColumn(unquote(column), src) +} + +func (m *structTableModel) ScanColumn(column string, src interface{}) error { + if ok, err := m.scanColumn(column, src); ok { + return err + } + if column == "" || column[0] == '_' || m.db.flags.Has(discardUnknownColumns) { + return nil + } + return fmt.Errorf("bun: %s does not have column %q", m.table.TypeName, column) +} + +func (m *structTableModel) scanColumn(column string, src interface{}) (bool, error) { + if src != nil { + if err := m.initStruct(); err != nil { + return true, err + } + } + + if field, ok := m.table.FieldMap[column]; ok { + return true, field.ScanValue(m.strct, src) + } + + if joinName, column := splitColumn(column); joinName != "" { + if join := m.GetJoin(joinName); join != nil { + return true, join.JoinModel.ScanColumn(column, src) + } + if m.table.ModelName == joinName { + return true, m.ScanColumn(column, src) + } + } + + return false, nil +} + +func (m *structTableModel) AppendNamedArg( + fmter schema.Formatter, b []byte, name string, +) ([]byte, bool) { + return m.table.AppendNamedArg(fmter, b, name, m.strct) +} + +// sqlite3 sometimes does not unquote columns. +func unquote(s string) string { + if s == "" { + return s + } + if s[0] == '"' && s[len(s)-1] == '"' { + return s[1 : len(s)-1] + } + return s +} + +func splitColumn(s string) (string, string) { + if i := strings.Index(s, "__"); i >= 0 { + return s[:i], s[i+2:] + } + return "", s +} diff --git a/vendor/github.com/uptrace/bun/query_base.go b/vendor/github.com/uptrace/bun/query_base.go new file mode 100644 index 000000000..1a7c32720 --- /dev/null +++ b/vendor/github.com/uptrace/bun/query_base.go @@ -0,0 +1,874 @@ +package bun + +import ( + "context" + "database/sql" + "errors" + "fmt" + + "github.com/uptrace/bun/dialect/feature" + "github.com/uptrace/bun/internal" + "github.com/uptrace/bun/schema" +) + +const ( + wherePKFlag internal.Flag = 1 << iota + forceDeleteFlag + deletedFlag + allWithDeletedFlag +) + +type withQuery struct { + name string + query schema.QueryAppender +} + +// IConn is a common interface for *sql.DB, *sql.Conn, and *sql.Tx. +type IConn interface { + QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) + QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row +} + +var ( + _ IConn = (*sql.DB)(nil) + _ IConn = (*sql.Conn)(nil) + _ IConn = (*sql.Tx)(nil) + _ IConn = (*DB)(nil) + _ IConn = (*Conn)(nil) + _ IConn = (*Tx)(nil) +) + +// IDB is a common interface for *bun.DB, bun.Conn, and bun.Tx. +type IDB interface { + IConn + + NewValues(model interface{}) *ValuesQuery + NewSelect() *SelectQuery + NewInsert() *InsertQuery + NewUpdate() *UpdateQuery + NewDelete() *DeleteQuery + NewCreateTable() *CreateTableQuery + NewDropTable() *DropTableQuery + NewCreateIndex() *CreateIndexQuery + NewDropIndex() *DropIndexQuery + NewTruncateTable() *TruncateTableQuery + NewAddColumn() *AddColumnQuery + NewDropColumn() *DropColumnQuery +} + +var ( + _ IConn = (*DB)(nil) + _ IConn = (*Conn)(nil) + _ IConn = (*Tx)(nil) +) + +type baseQuery struct { + db *DB + conn IConn + + model model + err error + + tableModel tableModel + table *schema.Table + + with []withQuery + modelTable schema.QueryWithArgs + tables []schema.QueryWithArgs + columns []schema.QueryWithArgs + + flags internal.Flag +} + +func (q *baseQuery) DB() *DB { + return q.db +} + +func (q *baseQuery) GetModel() Model { + return q.model +} + +func (q *baseQuery) setConn(db IConn) { + // Unwrap Bun wrappers to not call query hooks twice. + switch db := db.(type) { + case *DB: + q.conn = db.DB + case Conn: + q.conn = db.Conn + case Tx: + q.conn = db.Tx + default: + q.conn = db + } +} + +// TODO: rename to setModel +func (q *baseQuery) setTableModel(modeli interface{}) { + model, err := newSingleModel(q.db, modeli) + if err != nil { + q.setErr(err) + return + } + + q.model = model + if tm, ok := model.(tableModel); ok { + q.tableModel = tm + q.table = tm.Table() + } +} + +func (q *baseQuery) setErr(err error) { + if q.err == nil { + q.err = err + } +} + +func (q *baseQuery) getModel(dest []interface{}) (model, error) { + if len(dest) == 0 { + if q.model != nil { + return q.model, nil + } + return nil, errNilModel + } + return newModel(q.db, dest) +} + +//------------------------------------------------------------------------------ + +func (q *baseQuery) checkSoftDelete() error { + if q.table == nil { + return errors.New("bun: can't use soft deletes without a table") + } + if q.table.SoftDeleteField == nil { + return fmt.Errorf("%s does not have a soft delete field", q.table) + } + if q.tableModel == nil { + return errors.New("bun: can't use soft deletes without a table model") + } + return nil +} + +// Deleted adds `WHERE deleted_at IS NOT NULL` clause for soft deleted models. +func (q *baseQuery) whereDeleted() { + if err := q.checkSoftDelete(); err != nil { + q.setErr(err) + return + } + q.flags = q.flags.Set(deletedFlag) + q.flags = q.flags.Remove(allWithDeletedFlag) +} + +// AllWithDeleted changes query to return all rows including soft deleted ones. +func (q *baseQuery) whereAllWithDeleted() { + if err := q.checkSoftDelete(); err != nil { + q.setErr(err) + return + } + q.flags = q.flags.Set(allWithDeletedFlag) + q.flags = q.flags.Remove(deletedFlag) +} + +func (q *baseQuery) isSoftDelete() bool { + if q.table != nil { + return q.table.SoftDeleteField != nil && !q.flags.Has(allWithDeletedFlag) + } + return false +} + +//------------------------------------------------------------------------------ + +func (q *baseQuery) addWith(name string, query schema.QueryAppender) { + q.with = append(q.with, withQuery{ + name: name, + query: query, + }) +} + +func (q *baseQuery) appendWith(fmter schema.Formatter, b []byte) (_ []byte, err error) { + if len(q.with) == 0 { + return b, nil + } + + b = append(b, "WITH "...) + for i, with := range q.with { + if i > 0 { + b = append(b, ", "...) + } + + b = fmter.AppendIdent(b, with.name) + if q, ok := with.query.(schema.ColumnsAppender); ok { + b = append(b, " ("...) + b, err = q.AppendColumns(fmter, b) + if err != nil { + return nil, err + } + b = append(b, ")"...) + } + + b = append(b, " AS ("...) + + b, err = with.query.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + + b = append(b, ')') + } + b = append(b, ' ') + return b, nil +} + +//------------------------------------------------------------------------------ + +func (q *baseQuery) addTable(table schema.QueryWithArgs) { + q.tables = append(q.tables, table) +} + +func (q *baseQuery) addColumn(column schema.QueryWithArgs) { + q.columns = append(q.columns, column) +} + +func (q *baseQuery) excludeColumn(columns []string) { + if q.columns == nil { + for _, f := range q.table.Fields { + q.columns = append(q.columns, schema.UnsafeIdent(f.Name)) + } + } + + if len(columns) == 1 && columns[0] == "*" { + q.columns = make([]schema.QueryWithArgs, 0) + return + } + + for _, column := range columns { + if !q._excludeColumn(column) { + q.setErr(fmt.Errorf("bun: can't find column=%q", column)) + return + } + } +} + +func (q *baseQuery) _excludeColumn(column string) bool { + for i, col := range q.columns { + if col.Args == nil && col.Query == column { + q.columns = append(q.columns[:i], q.columns[i+1:]...) + return true + } + } + return false +} + +//------------------------------------------------------------------------------ + +func (q *baseQuery) modelHasTableName() bool { + return !q.modelTable.IsZero() || q.table != nil +} + +func (q *baseQuery) hasTables() bool { + return q.modelHasTableName() || len(q.tables) > 0 +} + +func (q *baseQuery) appendTables( + fmter schema.Formatter, b []byte, +) (_ []byte, err error) { + return q._appendTables(fmter, b, false) +} + +func (q *baseQuery) appendTablesWithAlias( + fmter schema.Formatter, b []byte, +) (_ []byte, err error) { + return q._appendTables(fmter, b, true) +} + +func (q *baseQuery) _appendTables( + fmter schema.Formatter, b []byte, withAlias bool, +) (_ []byte, err error) { + startLen := len(b) + + if q.modelHasTableName() { + if !q.modelTable.IsZero() { + b, err = q.modelTable.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } else { + b = fmter.AppendQuery(b, string(q.table.SQLNameForSelects)) + if withAlias && q.table.SQLAlias != q.table.SQLNameForSelects { + b = append(b, " AS "...) + b = append(b, q.table.SQLAlias...) + } + } + } + + for _, table := range q.tables { + if len(b) > startLen { + b = append(b, ", "...) + } + b, err = table.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + + return b, nil +} + +func (q *baseQuery) appendFirstTable(fmter schema.Formatter, b []byte) ([]byte, error) { + return q._appendFirstTable(fmter, b, false) +} + +func (q *baseQuery) appendFirstTableWithAlias( + fmter schema.Formatter, b []byte, +) ([]byte, error) { + return q._appendFirstTable(fmter, b, true) +} + +func (q *baseQuery) _appendFirstTable( + fmter schema.Formatter, b []byte, withAlias bool, +) ([]byte, error) { + if !q.modelTable.IsZero() { + return q.modelTable.AppendQuery(fmter, b) + } + + if q.table != nil { + b = fmter.AppendQuery(b, string(q.table.SQLName)) + if withAlias { + b = append(b, " AS "...) + b = append(b, q.table.SQLAlias...) + } + return b, nil + } + + if len(q.tables) > 0 { + return q.tables[0].AppendQuery(fmter, b) + } + + return nil, errors.New("bun: query does not have a table") +} + +func (q *baseQuery) hasMultiTables() bool { + if q.modelHasTableName() { + return len(q.tables) >= 1 + } + return len(q.tables) >= 2 +} + +func (q *baseQuery) appendOtherTables(fmter schema.Formatter, b []byte) (_ []byte, err error) { + tables := q.tables + if !q.modelHasTableName() { + tables = tables[1:] + } + for i, table := range tables { + if i > 0 { + b = append(b, ", "...) + } + b, err = table.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + return b, nil +} + +//------------------------------------------------------------------------------ + +func (q *baseQuery) appendColumns(fmter schema.Formatter, b []byte) (_ []byte, err error) { + for i, f := range q.columns { + if i > 0 { + b = append(b, ", "...) + } + b, err = f.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + return b, nil +} + +func (q *baseQuery) getFields() ([]*schema.Field, error) { + table := q.tableModel.Table() + + if len(q.columns) == 0 { + return table.Fields, nil + } + + fields, err := q._getFields(false) + if err != nil { + return nil, err + } + + return fields, nil +} + +func (q *baseQuery) getDataFields() ([]*schema.Field, error) { + if len(q.columns) == 0 { + return q.table.DataFields, nil + } + return q._getFields(true) +} + +func (q *baseQuery) _getFields(omitPK bool) ([]*schema.Field, error) { + fields := make([]*schema.Field, 0, len(q.columns)) + for _, col := range q.columns { + if col.Args != nil { + continue + } + + field, err := q.table.Field(col.Query) + if err != nil { + return nil, err + } + + if omitPK && field.IsPK { + continue + } + + fields = append(fields, field) + } + return fields, nil +} + +func (q *baseQuery) scan( + ctx context.Context, + queryApp schema.QueryAppender, + query string, + model model, + hasDest bool, +) (res result, _ error) { + ctx, event := q.db.beforeQuery(ctx, queryApp, query, nil) + + rows, err := q.conn.QueryContext(ctx, query) + if err != nil { + q.db.afterQuery(ctx, event, nil, err) + return res, err + } + defer rows.Close() + + n, err := model.ScanRows(ctx, rows) + if err != nil { + q.db.afterQuery(ctx, event, nil, err) + return res, err + } + + res.n = n + if n == 0 && hasDest && isSingleRowModel(model) { + err = sql.ErrNoRows + } + + q.db.afterQuery(ctx, event, nil, err) + + return res, err +} + +func (q *baseQuery) exec( + ctx context.Context, + queryApp schema.QueryAppender, + query string, +) (res result, _ error) { + ctx, event := q.db.beforeQuery(ctx, queryApp, query, nil) + + r, err := q.conn.ExecContext(ctx, query) + if err != nil { + q.db.afterQuery(ctx, event, nil, err) + return res, err + } + + res.r = r + + q.db.afterQuery(ctx, event, nil, err) + return res, nil +} + +//------------------------------------------------------------------------------ + +func (q *baseQuery) AppendNamedArg(fmter schema.Formatter, b []byte, name string) ([]byte, bool) { + if q.table == nil { + return b, false + } + + if m, ok := q.tableModel.(*structTableModel); ok { + if b, ok := m.AppendNamedArg(fmter, b, name); ok { + return b, ok + } + } + + switch name { + case "TableName": + b = fmter.AppendQuery(b, string(q.table.SQLName)) + return b, true + case "TableAlias": + b = fmter.AppendQuery(b, string(q.table.SQLAlias)) + return b, true + case "PKs": + b = appendColumns(b, "", q.table.PKs) + return b, true + case "TablePKs": + b = appendColumns(b, q.table.SQLAlias, q.table.PKs) + return b, true + case "Columns": + b = appendColumns(b, "", q.table.Fields) + return b, true + case "TableColumns": + b = appendColumns(b, q.table.SQLAlias, q.table.Fields) + return b, true + } + + return b, false +} + +func appendColumns(b []byte, table schema.Safe, fields []*schema.Field) []byte { + for i, f := range fields { + if i > 0 { + b = append(b, ", "...) + } + + if len(table) > 0 { + b = append(b, table...) + b = append(b, '.') + } + b = append(b, f.SQLName...) + } + return b +} + +func formatterWithModel(fmter schema.Formatter, model schema.NamedArgAppender) schema.Formatter { + if fmter.IsNop() { + return fmter + } + return fmter.WithArg(model) +} + +//------------------------------------------------------------------------------ + +type whereBaseQuery struct { + baseQuery + + where []schema.QueryWithSep +} + +func (q *whereBaseQuery) addWhere(where schema.QueryWithSep) { + q.where = append(q.where, where) +} + +func (q *whereBaseQuery) addWhereGroup(sep string, where []schema.QueryWithSep) { + if len(where) == 0 { + return + } + + where[0].Sep = "" + + q.addWhere(schema.SafeQueryWithSep("", nil, sep+"(")) + q.where = append(q.where, where...) + q.addWhere(schema.SafeQueryWithSep("", nil, ")")) +} + +func (q *whereBaseQuery) mustAppendWhere( + fmter schema.Formatter, b []byte, withAlias bool, +) ([]byte, error) { + if len(q.where) == 0 && !q.flags.Has(wherePKFlag) { + err := errors.New("bun: Update and Delete queries require at least one Where") + return nil, err + } + return q.appendWhere(fmter, b, withAlias) +} + +func (q *whereBaseQuery) appendWhere( + fmter schema.Formatter, b []byte, withAlias bool, +) (_ []byte, err error) { + if len(q.where) == 0 && !q.isSoftDelete() && !q.flags.Has(wherePKFlag) { + return b, nil + } + + b = append(b, " WHERE "...) + startLen := len(b) + + if len(q.where) > 0 { + b, err = appendWhere(fmter, b, q.where) + if err != nil { + return nil, err + } + } + + if q.isSoftDelete() { + if len(b) > startLen { + b = append(b, " AND "...) + } + if withAlias { + b = append(b, q.tableModel.Table().SQLAlias...) + b = append(b, '.') + } + b = append(b, q.tableModel.Table().SoftDeleteField.SQLName...) + if q.flags.Has(deletedFlag) { + b = append(b, " IS NOT NULL"...) + } else { + b = append(b, " IS NULL"...) + } + } + + if q.flags.Has(wherePKFlag) { + if len(b) > startLen { + b = append(b, " AND "...) + } + b, err = q.appendWherePK(fmter, b, withAlias) + if err != nil { + return nil, err + } + } + + return b, nil +} + +func appendWhere( + fmter schema.Formatter, b []byte, where []schema.QueryWithSep, +) (_ []byte, err error) { + for i, where := range where { + if i > 0 || where.Sep == "(" { + b = append(b, where.Sep...) + } + + if where.Query == "" && where.Args == nil { + continue + } + + b = append(b, '(') + b, err = where.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + b = append(b, ')') + } + return b, nil +} + +func (q *whereBaseQuery) appendWherePK( + fmter schema.Formatter, b []byte, withAlias bool, +) (_ []byte, err error) { + if q.table == nil { + err := fmt.Errorf("bun: got %T, but WherePK requires a struct or slice-based model", q.model) + return nil, err + } + if err := q.table.CheckPKs(); err != nil { + return nil, err + } + + switch model := q.tableModel.(type) { + case *structTableModel: + return q.appendWherePKStruct(fmter, b, model, withAlias) + case *sliceTableModel: + return q.appendWherePKSlice(fmter, b, model, withAlias) + } + + return nil, fmt.Errorf("bun: WherePK does not support %T", q.tableModel) +} + +func (q *whereBaseQuery) appendWherePKStruct( + fmter schema.Formatter, b []byte, model *structTableModel, withAlias bool, +) (_ []byte, err error) { + if !model.strct.IsValid() { + return nil, errNilModel + } + + isTemplate := fmter.IsNop() + b = append(b, '(') + for i, f := range q.table.PKs { + if i > 0 { + b = append(b, " AND "...) + } + if withAlias { + b = append(b, q.table.SQLAlias...) + b = append(b, '.') + } + b = append(b, f.SQLName...) + b = append(b, " = "...) + if isTemplate { + b = append(b, '?') + } else { + b = f.AppendValue(fmter, b, model.strct) + } + } + b = append(b, ')') + return b, nil +} + +func (q *whereBaseQuery) appendWherePKSlice( + fmter schema.Formatter, b []byte, model *sliceTableModel, withAlias bool, +) (_ []byte, err error) { + if len(q.table.PKs) > 1 { + b = append(b, '(') + } + if withAlias { + b = appendColumns(b, q.table.SQLAlias, q.table.PKs) + } else { + b = appendColumns(b, "", q.table.PKs) + } + if len(q.table.PKs) > 1 { + b = append(b, ')') + } + + b = append(b, " IN ("...) + + isTemplate := fmter.IsNop() + slice := model.slice + sliceLen := slice.Len() + for i := 0; i < sliceLen; i++ { + if i > 0 { + if isTemplate { + break + } + b = append(b, ", "...) + } + + el := indirect(slice.Index(i)) + + if len(q.table.PKs) > 1 { + b = append(b, '(') + } + for i, f := range q.table.PKs { + if i > 0 { + b = append(b, ", "...) + } + if isTemplate { + b = append(b, '?') + } else { + b = f.AppendValue(fmter, b, el) + } + } + if len(q.table.PKs) > 1 { + b = append(b, ')') + } + } + + b = append(b, ')') + + return b, nil +} + +//------------------------------------------------------------------------------ + +type returningQuery struct { + returning []schema.QueryWithArgs + returningFields []*schema.Field +} + +func (q *returningQuery) addReturning(ret schema.QueryWithArgs) { + q.returning = append(q.returning, ret) +} + +func (q *returningQuery) addReturningField(field *schema.Field) { + if len(q.returning) > 0 { + return + } + for _, f := range q.returningFields { + if f == field { + return + } + } + q.returningFields = append(q.returningFields, field) +} + +func (q *returningQuery) hasReturning() bool { + if len(q.returning) == 1 { + switch q.returning[0].Query { + case "null", "NULL": + return false + } + } + return len(q.returning) > 0 || len(q.returningFields) > 0 +} + +func (q *returningQuery) appendReturning( + fmter schema.Formatter, b []byte, +) (_ []byte, err error) { + if !q.hasReturning() { + return b, nil + } + + b = append(b, " RETURNING "...) + + for i, f := range q.returning { + if i > 0 { + b = append(b, ", "...) + } + b, err = f.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + + if len(q.returning) > 0 { + return b, nil + } + + b = appendColumns(b, "", q.returningFields) + return b, nil +} + +//------------------------------------------------------------------------------ + +type columnValue struct { + column string + value schema.QueryWithArgs +} + +type customValueQuery struct { + modelValues map[string]schema.QueryWithArgs + extraValues []columnValue +} + +func (q *customValueQuery) addValue( + table *schema.Table, column string, value string, args []interface{}, +) { + if _, ok := table.FieldMap[column]; ok { + if q.modelValues == nil { + q.modelValues = make(map[string]schema.QueryWithArgs) + } + q.modelValues[column] = schema.SafeQuery(value, args) + } else { + q.extraValues = append(q.extraValues, columnValue{ + column: column, + value: schema.SafeQuery(value, args), + }) + } +} + +//------------------------------------------------------------------------------ + +type setQuery struct { + set []schema.QueryWithArgs +} + +func (q *setQuery) addSet(set schema.QueryWithArgs) { + q.set = append(q.set, set) +} + +func (q setQuery) appendSet(fmter schema.Formatter, b []byte) (_ []byte, err error) { + for i, f := range q.set { + if i > 0 { + b = append(b, ", "...) + } + b, err = f.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + return b, nil +} + +//------------------------------------------------------------------------------ + +type cascadeQuery struct { + restrict bool +} + +func (q cascadeQuery) appendCascade(fmter schema.Formatter, b []byte) []byte { + if !fmter.HasFeature(feature.TableCascade) { + return b + } + if q.restrict { + b = append(b, " RESTRICT"...) + } else { + b = append(b, " CASCADE"...) + } + return b +} diff --git a/vendor/github.com/uptrace/bun/query_column_add.go b/vendor/github.com/uptrace/bun/query_column_add.go new file mode 100644 index 000000000..ce2f60bf0 --- /dev/null +++ b/vendor/github.com/uptrace/bun/query_column_add.go @@ -0,0 +1,105 @@ +package bun + +import ( + "context" + "database/sql" + "fmt" + + "github.com/uptrace/bun/internal" + "github.com/uptrace/bun/schema" +) + +type AddColumnQuery struct { + baseQuery +} + +func NewAddColumnQuery(db *DB) *AddColumnQuery { + q := &AddColumnQuery{ + baseQuery: baseQuery{ + db: db, + conn: db.DB, + }, + } + return q +} + +func (q *AddColumnQuery) Conn(db IConn) *AddColumnQuery { + q.setConn(db) + return q +} + +func (q *AddColumnQuery) Model(model interface{}) *AddColumnQuery { + q.setTableModel(model) + return q +} + +//------------------------------------------------------------------------------ + +func (q *AddColumnQuery) Table(tables ...string) *AddColumnQuery { + for _, table := range tables { + q.addTable(schema.UnsafeIdent(table)) + } + return q +} + +func (q *AddColumnQuery) TableExpr(query string, args ...interface{}) *AddColumnQuery { + q.addTable(schema.SafeQuery(query, args)) + return q +} + +func (q *AddColumnQuery) ModelTableExpr(query string, args ...interface{}) *AddColumnQuery { + q.modelTable = schema.SafeQuery(query, args) + return q +} + +//------------------------------------------------------------------------------ + +func (q *AddColumnQuery) ColumnExpr(query string, args ...interface{}) *AddColumnQuery { + q.addColumn(schema.SafeQuery(query, args)) + return q +} + +//------------------------------------------------------------------------------ + +func (q *AddColumnQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) { + if q.err != nil { + return nil, q.err + } + if len(q.columns) != 1 { + return nil, fmt.Errorf("bun: AddColumnQuery requires exactly one column") + } + + b = append(b, "ALTER TABLE "...) + + b, err = q.appendFirstTable(fmter, b) + if err != nil { + return nil, err + } + + b = append(b, " ADD "...) + + b, err = q.columns[0].AppendQuery(fmter, b) + if err != nil { + return nil, err + } + + return b, nil +} + +//------------------------------------------------------------------------------ + +func (q *AddColumnQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) { + queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes()) + if err != nil { + return nil, err + } + + query := internal.String(queryBytes) + + res, err := q.exec(ctx, q, query) + if err != nil { + return nil, err + } + + return res, nil +} diff --git a/vendor/github.com/uptrace/bun/query_column_drop.go b/vendor/github.com/uptrace/bun/query_column_drop.go new file mode 100644 index 000000000..5684beeb3 --- /dev/null +++ b/vendor/github.com/uptrace/bun/query_column_drop.go @@ -0,0 +1,112 @@ +package bun + +import ( + "context" + "database/sql" + "fmt" + + "github.com/uptrace/bun/internal" + "github.com/uptrace/bun/schema" +) + +type DropColumnQuery struct { + baseQuery +} + +func NewDropColumnQuery(db *DB) *DropColumnQuery { + q := &DropColumnQuery{ + baseQuery: baseQuery{ + db: db, + conn: db.DB, + }, + } + return q +} + +func (q *DropColumnQuery) Conn(db IConn) *DropColumnQuery { + q.setConn(db) + return q +} + +func (q *DropColumnQuery) Model(model interface{}) *DropColumnQuery { + q.setTableModel(model) + return q +} + +//------------------------------------------------------------------------------ + +func (q *DropColumnQuery) Table(tables ...string) *DropColumnQuery { + for _, table := range tables { + q.addTable(schema.UnsafeIdent(table)) + } + return q +} + +func (q *DropColumnQuery) TableExpr(query string, args ...interface{}) *DropColumnQuery { + q.addTable(schema.SafeQuery(query, args)) + return q +} + +func (q *DropColumnQuery) ModelTableExpr(query string, args ...interface{}) *DropColumnQuery { + q.modelTable = schema.SafeQuery(query, args) + return q +} + +//------------------------------------------------------------------------------ + +func (q *DropColumnQuery) Column(columns ...string) *DropColumnQuery { + for _, column := range columns { + q.addColumn(schema.UnsafeIdent(column)) + } + return q +} + +func (q *DropColumnQuery) ColumnExpr(query string, args ...interface{}) *DropColumnQuery { + q.addColumn(schema.SafeQuery(query, args)) + return q +} + +//------------------------------------------------------------------------------ + +func (q *DropColumnQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) { + if q.err != nil { + return nil, q.err + } + if len(q.columns) != 1 { + return nil, fmt.Errorf("bun: DropColumnQuery requires exactly one column") + } + + b = append(b, "ALTER TABLE "...) + + b, err = q.appendFirstTable(fmter, b) + if err != nil { + return nil, err + } + + b = append(b, " DROP COLUMN "...) + + b, err = q.columns[0].AppendQuery(fmter, b) + if err != nil { + return nil, err + } + + return b, nil +} + +//------------------------------------------------------------------------------ + +func (q *DropColumnQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) { + queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes()) + if err != nil { + return nil, err + } + + query := internal.String(queryBytes) + + res, err := q.exec(ctx, q, query) + if err != nil { + return nil, err + } + + return res, nil +} diff --git a/vendor/github.com/uptrace/bun/query_delete.go b/vendor/github.com/uptrace/bun/query_delete.go new file mode 100644 index 000000000..c0c5039c7 --- /dev/null +++ b/vendor/github.com/uptrace/bun/query_delete.go @@ -0,0 +1,256 @@ +package bun + +import ( + "context" + "database/sql" + + "github.com/uptrace/bun/dialect/feature" + "github.com/uptrace/bun/internal" + "github.com/uptrace/bun/schema" +) + +type DeleteQuery struct { + whereBaseQuery + returningQuery +} + +func NewDeleteQuery(db *DB) *DeleteQuery { + q := &DeleteQuery{ + whereBaseQuery: whereBaseQuery{ + baseQuery: baseQuery{ + db: db, + conn: db.DB, + }, + }, + } + return q +} + +func (q *DeleteQuery) Conn(db IConn) *DeleteQuery { + q.setConn(db) + return q +} + +func (q *DeleteQuery) Model(model interface{}) *DeleteQuery { + q.setTableModel(model) + return q +} + +// Apply calls the fn passing the DeleteQuery as an argument. +func (q *DeleteQuery) Apply(fn func(*DeleteQuery) *DeleteQuery) *DeleteQuery { + return fn(q) +} + +func (q *DeleteQuery) With(name string, query schema.QueryAppender) *DeleteQuery { + q.addWith(name, query) + return q +} + +func (q *DeleteQuery) Table(tables ...string) *DeleteQuery { + for _, table := range tables { + q.addTable(schema.UnsafeIdent(table)) + } + return q +} + +func (q *DeleteQuery) TableExpr(query string, args ...interface{}) *DeleteQuery { + q.addTable(schema.SafeQuery(query, args)) + return q +} + +func (q *DeleteQuery) ModelTableExpr(query string, args ...interface{}) *DeleteQuery { + q.modelTable = schema.SafeQuery(query, args) + return q +} + +//------------------------------------------------------------------------------ + +func (q *DeleteQuery) WherePK() *DeleteQuery { + q.flags = q.flags.Set(wherePKFlag) + return q +} + +func (q *DeleteQuery) Where(query string, args ...interface{}) *DeleteQuery { + q.addWhere(schema.SafeQueryWithSep(query, args, " AND ")) + return q +} + +func (q *DeleteQuery) WhereOr(query string, args ...interface{}) *DeleteQuery { + q.addWhere(schema.SafeQueryWithSep(query, args, " OR ")) + return q +} + +func (q *DeleteQuery) WhereGroup(sep string, fn func(*DeleteQuery) *DeleteQuery) *DeleteQuery { + saved := q.where + q.where = nil + + q = fn(q) + + where := q.where + q.where = saved + + q.addWhereGroup(sep, where) + + return q +} + +func (q *DeleteQuery) WhereDeleted() *DeleteQuery { + q.whereDeleted() + return q +} + +func (q *DeleteQuery) WhereAllWithDeleted() *DeleteQuery { + q.whereAllWithDeleted() + return q +} + +func (q *DeleteQuery) ForceDelete() *DeleteQuery { + q.flags = q.flags.Set(forceDeleteFlag) + return q +} + +//------------------------------------------------------------------------------ + +// Returning adds a RETURNING clause to the query. +// +// To suppress the auto-generated RETURNING clause, use `Returning("NULL")`. +func (q *DeleteQuery) Returning(query string, args ...interface{}) *DeleteQuery { + q.addReturning(schema.SafeQuery(query, args)) + return q +} + +func (q *DeleteQuery) hasReturning() bool { + if !q.db.features.Has(feature.Returning) { + return false + } + return q.returningQuery.hasReturning() +} + +//------------------------------------------------------------------------------ + +func (q *DeleteQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) { + if q.err != nil { + return nil, q.err + } + fmter = formatterWithModel(fmter, q) + + if q.isSoftDelete() { + if err := q.tableModel.updateSoftDeleteField(); err != nil { + return nil, err + } + + upd := UpdateQuery{ + whereBaseQuery: q.whereBaseQuery, + returningQuery: q.returningQuery, + } + upd.Column(q.table.SoftDeleteField.Name) + return upd.AppendQuery(fmter, b) + } + + q = q.WhereAllWithDeleted() + withAlias := q.db.features.Has(feature.DeleteTableAlias) + + b, err = q.appendWith(fmter, b) + if err != nil { + return nil, err + } + + b = append(b, "DELETE FROM "...) + + if withAlias { + b, err = q.appendFirstTableWithAlias(fmter, b) + } else { + b, err = q.appendFirstTable(fmter, b) + } + if err != nil { + return nil, err + } + + if q.hasMultiTables() { + b = append(b, " USING "...) + b, err = q.appendOtherTables(fmter, b) + if err != nil { + return nil, err + } + } + + b, err = q.mustAppendWhere(fmter, b, withAlias) + if err != nil { + return nil, err + } + + if len(q.returning) > 0 { + b, err = q.appendReturning(fmter, b) + if err != nil { + return nil, err + } + } + + return b, nil +} + +func (q *DeleteQuery) isSoftDelete() bool { + return q.tableModel != nil && q.table.SoftDeleteField != nil && !q.flags.Has(forceDeleteFlag) +} + +//------------------------------------------------------------------------------ + +func (q *DeleteQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) { + if q.table != nil { + if err := q.beforeDeleteHook(ctx); err != nil { + return nil, err + } + } + + queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes()) + if err != nil { + return nil, err + } + + query := internal.String(queryBytes) + + var res sql.Result + + if hasDest := len(dest) > 0; hasDest || q.hasReturning() { + model, err := q.getModel(dest) + if err != nil { + return nil, err + } + + res, err = q.scan(ctx, q, query, model, hasDest) + if err != nil { + return nil, err + } + } else { + res, err = q.exec(ctx, q, query) + if err != nil { + return nil, err + } + } + + if q.table != nil { + if err := q.afterDeleteHook(ctx); err != nil { + return nil, err + } + } + + return res, nil +} + +func (q *DeleteQuery) beforeDeleteHook(ctx context.Context) error { + if hook, ok := q.table.ZeroIface.(BeforeDeleteHook); ok { + if err := hook.BeforeDelete(ctx, q); err != nil { + return err + } + } + return nil +} + +func (q *DeleteQuery) afterDeleteHook(ctx context.Context) error { + if hook, ok := q.table.ZeroIface.(AfterDeleteHook); ok { + if err := hook.AfterDelete(ctx, q); err != nil { + return err + } + } + return nil +} diff --git a/vendor/github.com/uptrace/bun/query_index_create.go b/vendor/github.com/uptrace/bun/query_index_create.go new file mode 100644 index 000000000..de7eb7aa0 --- /dev/null +++ b/vendor/github.com/uptrace/bun/query_index_create.go @@ -0,0 +1,242 @@ +package bun + +import ( + "context" + "database/sql" + + "github.com/uptrace/bun/internal" + "github.com/uptrace/bun/schema" +) + +type CreateIndexQuery struct { + whereBaseQuery + + unique bool + fulltext bool + spatial bool + concurrently bool + ifNotExists bool + + index schema.QueryWithArgs + using schema.QueryWithArgs + include []schema.QueryWithArgs +} + +func NewCreateIndexQuery(db *DB) *CreateIndexQuery { + q := &CreateIndexQuery{ + whereBaseQuery: whereBaseQuery{ + baseQuery: baseQuery{ + db: db, + conn: db.DB, + }, + }, + } + return q +} + +func (q *CreateIndexQuery) Conn(db IConn) *CreateIndexQuery { + q.setConn(db) + return q +} + +func (q *CreateIndexQuery) Model(model interface{}) *CreateIndexQuery { + q.setTableModel(model) + return q +} + +func (q *CreateIndexQuery) Unique() *CreateIndexQuery { + q.unique = true + return q +} + +func (q *CreateIndexQuery) Concurrently() *CreateIndexQuery { + q.concurrently = true + return q +} + +func (q *CreateIndexQuery) IfNotExists() *CreateIndexQuery { + q.ifNotExists = true + return q +} + +//------------------------------------------------------------------------------ + +func (q *CreateIndexQuery) Index(query string) *CreateIndexQuery { + q.index = schema.UnsafeIdent(query) + return q +} + +func (q *CreateIndexQuery) IndexExpr(query string, args ...interface{}) *CreateIndexQuery { + q.index = schema.SafeQuery(query, args) + return q +} + +//------------------------------------------------------------------------------ + +func (q *CreateIndexQuery) Table(tables ...string) *CreateIndexQuery { + for _, table := range tables { + q.addTable(schema.UnsafeIdent(table)) + } + return q +} + +func (q *CreateIndexQuery) TableExpr(query string, args ...interface{}) *CreateIndexQuery { + q.addTable(schema.SafeQuery(query, args)) + return q +} + +func (q *CreateIndexQuery) ModelTableExpr(query string, args ...interface{}) *CreateIndexQuery { + q.modelTable = schema.SafeQuery(query, args) + return q +} + +func (q *CreateIndexQuery) Using(query string, args ...interface{}) *CreateIndexQuery { + q.using = schema.SafeQuery(query, args) + return q +} + +//------------------------------------------------------------------------------ + +func (q *CreateIndexQuery) Column(columns ...string) *CreateIndexQuery { + for _, column := range columns { + q.addColumn(schema.UnsafeIdent(column)) + } + return q +} + +func (q *CreateIndexQuery) ColumnExpr(query string, args ...interface{}) *CreateIndexQuery { + q.addColumn(schema.SafeQuery(query, args)) + return q +} + +func (q *CreateIndexQuery) ExcludeColumn(columns ...string) *CreateIndexQuery { + q.excludeColumn(columns) + return q +} + +//------------------------------------------------------------------------------ + +func (q *CreateIndexQuery) Include(columns ...string) *CreateIndexQuery { + for _, column := range columns { + q.include = append(q.include, schema.UnsafeIdent(column)) + } + return q +} + +func (q *CreateIndexQuery) IncludeExpr(query string, args ...interface{}) *CreateIndexQuery { + q.include = append(q.include, schema.SafeQuery(query, args)) + return q +} + +//------------------------------------------------------------------------------ + +func (q *CreateIndexQuery) Where(query string, args ...interface{}) *CreateIndexQuery { + q.addWhere(schema.SafeQueryWithSep(query, args, " AND ")) + return q +} + +func (q *CreateIndexQuery) WhereOr(query string, args ...interface{}) *CreateIndexQuery { + q.addWhere(schema.SafeQueryWithSep(query, args, " OR ")) + return q +} + +//------------------------------------------------------------------------------ + +func (q *CreateIndexQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) { + if q.err != nil { + return nil, q.err + } + + b = append(b, "CREATE "...) + + if q.unique { + b = append(b, "UNIQUE "...) + } + if q.fulltext { + b = append(b, "FULLTEXT "...) + } + if q.spatial { + b = append(b, "SPATIAL "...) + } + + b = append(b, "INDEX "...) + + if q.concurrently { + b = append(b, "CONCURRENTLY "...) + } + if q.ifNotExists { + b = append(b, "IF NOT EXISTS "...) + } + + b, err = q.index.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + + b = append(b, " ON "...) + b, err = q.appendFirstTable(fmter, b) + if err != nil { + return nil, err + } + + if !q.using.IsZero() { + b = append(b, " USING "...) + b, err = q.using.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + + b = append(b, " ("...) + for i, col := range q.columns { + if i > 0 { + b = append(b, ", "...) + } + b, err = col.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + b = append(b, ')') + + if len(q.include) > 0 { + b = append(b, " INCLUDE ("...) + for i, col := range q.include { + if i > 0 { + b = append(b, ", "...) + } + b, err = col.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + b = append(b, ')') + } + + if len(q.where) > 0 { + b, err = appendWhere(fmter, b, q.where) + if err != nil { + return nil, err + } + } + + return b, nil +} + +//------------------------------------------------------------------------------ + +func (q *CreateIndexQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) { + queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes()) + if err != nil { + return nil, err + } + + query := internal.String(queryBytes) + + res, err := q.exec(ctx, q, query) + if err != nil { + return nil, err + } + + return res, nil +} diff --git a/vendor/github.com/uptrace/bun/query_index_drop.go b/vendor/github.com/uptrace/bun/query_index_drop.go new file mode 100644 index 000000000..c922ff04f --- /dev/null +++ b/vendor/github.com/uptrace/bun/query_index_drop.go @@ -0,0 +1,105 @@ +package bun + +import ( + "context" + "database/sql" + + "github.com/uptrace/bun/internal" + "github.com/uptrace/bun/schema" +) + +type DropIndexQuery struct { + baseQuery + cascadeQuery + + concurrently bool + ifExists bool + + index schema.QueryWithArgs +} + +func NewDropIndexQuery(db *DB) *DropIndexQuery { + q := &DropIndexQuery{ + baseQuery: baseQuery{ + db: db, + conn: db.DB, + }, + } + return q +} + +func (q *DropIndexQuery) Conn(db IConn) *DropIndexQuery { + q.setConn(db) + return q +} + +func (q *DropIndexQuery) Model(model interface{}) *DropIndexQuery { + q.setTableModel(model) + return q +} + +//------------------------------------------------------------------------------ + +func (q *DropIndexQuery) Concurrently() *DropIndexQuery { + q.concurrently = true + return q +} + +func (q *DropIndexQuery) IfExists() *DropIndexQuery { + q.ifExists = true + return q +} + +func (q *DropIndexQuery) Restrict() *DropIndexQuery { + q.restrict = true + return q +} + +func (q *DropIndexQuery) Index(query string, args ...interface{}) *DropIndexQuery { + q.index = schema.SafeQuery(query, args) + return q +} + +//------------------------------------------------------------------------------ + +func (q *DropIndexQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) { + if q.err != nil { + return nil, q.err + } + + b = append(b, "DROP INDEX "...) + + if q.concurrently { + b = append(b, "CONCURRENTLY "...) + } + if q.ifExists { + b = append(b, "IF EXISTS "...) + } + + b, err = q.index.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + + b = q.appendCascade(fmter, b) + + return b, nil +} + +//------------------------------------------------------------------------------ + +func (q *DropIndexQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) { + queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes()) + if err != nil { + return nil, err + } + + query := internal.String(queryBytes) + + res, err := q.exec(ctx, q, query) + if err != nil { + return nil, err + } + + return res, nil +} diff --git a/vendor/github.com/uptrace/bun/query_insert.go b/vendor/github.com/uptrace/bun/query_insert.go new file mode 100644 index 000000000..efddee407 --- /dev/null +++ b/vendor/github.com/uptrace/bun/query_insert.go @@ -0,0 +1,551 @@ +package bun + +import ( + "context" + "database/sql" + "fmt" + "reflect" + + "github.com/uptrace/bun/dialect/feature" + "github.com/uptrace/bun/internal" + "github.com/uptrace/bun/schema" +) + +type InsertQuery struct { + whereBaseQuery + returningQuery + customValueQuery + + onConflict schema.QueryWithArgs + setQuery + + ignore bool + replace bool +} + +func NewInsertQuery(db *DB) *InsertQuery { + q := &InsertQuery{ + whereBaseQuery: whereBaseQuery{ + baseQuery: baseQuery{ + db: db, + conn: db.DB, + }, + }, + } + return q +} + +func (q *InsertQuery) Conn(db IConn) *InsertQuery { + q.setConn(db) + return q +} + +func (q *InsertQuery) Model(model interface{}) *InsertQuery { + q.setTableModel(model) + return q +} + +// Apply calls the fn passing the SelectQuery as an argument. +func (q *InsertQuery) Apply(fn func(*InsertQuery) *InsertQuery) *InsertQuery { + return fn(q) +} + +func (q *InsertQuery) With(name string, query schema.QueryAppender) *InsertQuery { + q.addWith(name, query) + return q +} + +//------------------------------------------------------------------------------ + +func (q *InsertQuery) Table(tables ...string) *InsertQuery { + for _, table := range tables { + q.addTable(schema.UnsafeIdent(table)) + } + return q +} + +func (q *InsertQuery) TableExpr(query string, args ...interface{}) *InsertQuery { + q.addTable(schema.SafeQuery(query, args)) + return q +} + +func (q *InsertQuery) ModelTableExpr(query string, args ...interface{}) *InsertQuery { + q.modelTable = schema.SafeQuery(query, args) + return q +} + +//------------------------------------------------------------------------------ + +func (q *InsertQuery) Column(columns ...string) *InsertQuery { + for _, column := range columns { + q.addColumn(schema.UnsafeIdent(column)) + } + return q +} + +func (q *InsertQuery) ExcludeColumn(columns ...string) *InsertQuery { + q.excludeColumn(columns) + return q +} + +// Value overwrites model value for the column in INSERT and UPDATE queries. +func (q *InsertQuery) Value(column string, value string, args ...interface{}) *InsertQuery { + if q.table == nil { + q.err = errNilModel + return q + } + q.addValue(q.table, column, value, args) + return q +} + +func (q *InsertQuery) Where(query string, args ...interface{}) *InsertQuery { + q.addWhere(schema.SafeQueryWithSep(query, args, " AND ")) + return q +} + +func (q *InsertQuery) WhereOr(query string, args ...interface{}) *InsertQuery { + q.addWhere(schema.SafeQueryWithSep(query, args, " OR ")) + return q +} + +//------------------------------------------------------------------------------ + +// Returning adds a RETURNING clause to the query. +// +// To suppress the auto-generated RETURNING clause, use `Returning("NULL")`. +func (q *InsertQuery) Returning(query string, args ...interface{}) *InsertQuery { + q.addReturning(schema.SafeQuery(query, args)) + return q +} + +func (q *InsertQuery) hasReturning() bool { + if !q.db.features.Has(feature.Returning) { + return false + } + return q.returningQuery.hasReturning() +} + +//------------------------------------------------------------------------------ + +// Ignore generates an `INSERT IGNORE INTO` query (MySQL). +func (q *InsertQuery) Ignore() *InsertQuery { + q.ignore = true + return q +} + +// Replaces generates a `REPLACE INTO` query (MySQL). +func (q *InsertQuery) Replace() *InsertQuery { + q.replace = true + return q +} + +//------------------------------------------------------------------------------ + +func (q *InsertQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) { + if q.err != nil { + return nil, q.err + } + fmter = formatterWithModel(fmter, q) + + b, err = q.appendWith(fmter, b) + if err != nil { + return nil, err + } + + if q.replace { + b = append(b, "REPLACE "...) + } else { + b = append(b, "INSERT "...) + if q.ignore { + b = append(b, "IGNORE "...) + } + } + b = append(b, "INTO "...) + + if q.db.features.Has(feature.InsertTableAlias) && !q.onConflict.IsZero() { + b, err = q.appendFirstTableWithAlias(fmter, b) + } else { + b, err = q.appendFirstTable(fmter, b) + } + if err != nil { + return nil, err + } + + b, err = q.appendColumnsValues(fmter, b) + if err != nil { + return nil, err + } + + b, err = q.appendOn(fmter, b) + if err != nil { + return nil, err + } + + if q.hasReturning() { + b, err = q.appendReturning(fmter, b) + if err != nil { + return nil, err + } + } + + return b, nil +} + +func (q *InsertQuery) appendColumnsValues( + fmter schema.Formatter, b []byte, +) (_ []byte, err error) { + if q.hasMultiTables() { + if q.columns != nil { + b = append(b, " ("...) + b, err = q.appendColumns(fmter, b) + if err != nil { + return nil, err + } + b = append(b, ")"...) + } + + b = append(b, " SELECT * FROM "...) + b, err = q.appendOtherTables(fmter, b) + if err != nil { + return nil, err + } + + return b, nil + } + + if m, ok := q.model.(*mapModel); ok { + return m.appendColumnsValues(fmter, b), nil + } + if _, ok := q.model.(*mapSliceModel); ok { + return nil, fmt.Errorf("Insert(*[]map[string]interface{}) is not supported") + } + + if q.model == nil { + return nil, errNilModel + } + + fields, err := q.getFields() + if err != nil { + return nil, err + } + + b = append(b, " ("...) + b = q.appendFields(fmter, b, fields) + b = append(b, ") VALUES ("...) + + switch model := q.tableModel.(type) { + case *structTableModel: + b, err = q.appendStructValues(fmter, b, fields, model.strct) + if err != nil { + return nil, err + } + case *sliceTableModel: + b, err = q.appendSliceValues(fmter, b, fields, model.slice) + if err != nil { + return nil, err + } + default: + return nil, fmt.Errorf("bun: Insert does not support %T", q.tableModel) + } + + b = append(b, ')') + + return b, nil +} + +func (q *InsertQuery) appendStructValues( + fmter schema.Formatter, b []byte, fields []*schema.Field, strct reflect.Value, +) (_ []byte, err error) { + isTemplate := fmter.IsNop() + for i, f := range fields { + if i > 0 { + b = append(b, ", "...) + } + + app, ok := q.modelValues[f.Name] + if ok { + b, err = app.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + q.addReturningField(f) + continue + } + + switch { + case isTemplate: + b = append(b, '?') + case f.NullZero && f.HasZeroValue(strct): + if q.db.features.Has(feature.DefaultPlaceholder) { + b = append(b, "DEFAULT"...) + } else if f.SQLDefault != "" { + b = append(b, f.SQLDefault...) + } else { + b = append(b, "NULL"...) + } + q.addReturningField(f) + default: + b = f.AppendValue(fmter, b, strct) + } + } + + for i, v := range q.extraValues { + if i > 0 || len(fields) > 0 { + b = append(b, ", "...) + } + + b, err = v.value.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + + return b, nil +} + +func (q *InsertQuery) appendSliceValues( + fmter schema.Formatter, b []byte, fields []*schema.Field, slice reflect.Value, +) (_ []byte, err error) { + if fmter.IsNop() { + return q.appendStructValues(fmter, b, fields, reflect.Value{}) + } + + sliceLen := slice.Len() + for i := 0; i < sliceLen; i++ { + if i > 0 { + b = append(b, "), ("...) + } + el := indirect(slice.Index(i)) + b, err = q.appendStructValues(fmter, b, fields, el) + if err != nil { + return nil, err + } + } + + for i, v := range q.extraValues { + if i > 0 || len(fields) > 0 { + b = append(b, ", "...) + } + + b, err = v.value.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + + return b, nil +} + +func (q *InsertQuery) getFields() ([]*schema.Field, error) { + if q.db.features.Has(feature.DefaultPlaceholder) || len(q.columns) > 0 { + return q.baseQuery.getFields() + } + + var strct reflect.Value + + switch model := q.tableModel.(type) { + case *structTableModel: + strct = model.strct + case *sliceTableModel: + if model.sliceLen == 0 { + return nil, fmt.Errorf("bun: Insert(empty %T)", model.slice.Type()) + } + strct = indirect(model.slice.Index(0)) + } + + fields := make([]*schema.Field, 0, len(q.table.Fields)) + + for _, f := range q.table.Fields { + if f.NotNull && f.NullZero && f.SQLDefault == "" && f.HasZeroValue(strct) { + q.addReturningField(f) + continue + } + fields = append(fields, f) + } + + return fields, nil +} + +func (q *InsertQuery) appendFields( + fmter schema.Formatter, b []byte, fields []*schema.Field, +) []byte { + b = appendColumns(b, "", fields) + for i, v := range q.extraValues { + if i > 0 || len(fields) > 0 { + b = append(b, ", "...) + } + b = fmter.AppendIdent(b, v.column) + } + return b +} + +//------------------------------------------------------------------------------ + +func (q *InsertQuery) On(s string, args ...interface{}) *InsertQuery { + q.onConflict = schema.SafeQuery(s, args) + return q +} + +func (q *InsertQuery) Set(query string, args ...interface{}) *InsertQuery { + q.addSet(schema.SafeQuery(query, args)) + return q +} + +func (q *InsertQuery) appendOn(fmter schema.Formatter, b []byte) (_ []byte, err error) { + if q.onConflict.IsZero() { + return b, nil + } + + b = append(b, " ON "...) + b, err = q.onConflict.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + + if len(q.set) > 0 { + if fmter.HasFeature(feature.OnDuplicateKey) { + b = append(b, ' ') + } else { + b = append(b, " SET "...) + } + + b, err = q.appendSet(fmter, b) + if err != nil { + return nil, err + } + } else if len(q.columns) > 0 { + fields, err := q.getDataFields() + if err != nil { + return nil, err + } + + if len(fields) == 0 { + fields = q.tableModel.Table().DataFields + } + + b = q.appendSetExcluded(b, fields) + } + + b, err = q.appendWhere(fmter, b, true) + if err != nil { + return nil, err + } + + return b, nil +} + +func (q *InsertQuery) appendSetExcluded(b []byte, fields []*schema.Field) []byte { + b = append(b, " SET "...) + for i, f := range fields { + if i > 0 { + b = append(b, ", "...) + } + b = append(b, f.SQLName...) + b = append(b, " = EXCLUDED."...) + b = append(b, f.SQLName...) + } + return b +} + +//------------------------------------------------------------------------------ + +func (q *InsertQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) { + if q.table != nil { + if err := q.beforeInsertHook(ctx); err != nil { + return nil, err + } + } + + queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes()) + if err != nil { + return nil, err + } + + query := internal.String(queryBytes) + var res sql.Result + + if hasDest := len(dest) > 0; hasDest || q.hasReturning() { + model, err := q.getModel(dest) + if err != nil { + return nil, err + } + + res, err = q.scan(ctx, q, query, model, hasDest) + if err != nil { + return nil, err + } + } else { + res, err = q.exec(ctx, q, query) + if err != nil { + return nil, err + } + + if err := q.tryLastInsertID(res, dest); err != nil { + return nil, err + } + } + + if q.table != nil { + if err := q.afterInsertHook(ctx); err != nil { + return nil, err + } + } + + return res, nil +} + +func (q *InsertQuery) beforeInsertHook(ctx context.Context) error { + if hook, ok := q.table.ZeroIface.(BeforeInsertHook); ok { + if err := hook.BeforeInsert(ctx, q); err != nil { + return err + } + } + return nil +} + +func (q *InsertQuery) afterInsertHook(ctx context.Context) error { + if hook, ok := q.table.ZeroIface.(AfterInsertHook); ok { + if err := hook.AfterInsert(ctx, q); err != nil { + return err + } + } + return nil +} + +func (q *InsertQuery) tryLastInsertID(res sql.Result, dest []interface{}) error { + if q.db.features.Has(feature.Returning) || q.table == nil || len(q.table.PKs) != 1 { + return nil + } + + id, err := res.LastInsertId() + if err != nil { + return err + } + if id == 0 { + return nil + } + + model, err := q.getModel(dest) + if err != nil { + return err + } + + pk := q.table.PKs[0] + switch model := model.(type) { + case *structTableModel: + if err := pk.ScanValue(model.strct, id); err != nil { + return err + } + case *sliceTableModel: + sliceLen := model.slice.Len() + for i := 0; i < sliceLen; i++ { + strct := indirect(model.slice.Index(i)) + if err := pk.ScanValue(strct, id); err != nil { + return err + } + id++ + } + } + + return nil +} diff --git a/vendor/github.com/uptrace/bun/query_select.go b/vendor/github.com/uptrace/bun/query_select.go new file mode 100644 index 000000000..1f63686ad --- /dev/null +++ b/vendor/github.com/uptrace/bun/query_select.go @@ -0,0 +1,830 @@ +package bun + +import ( + "bytes" + "context" + "database/sql" + "errors" + "fmt" + "strconv" + "strings" + "sync" + + "github.com/uptrace/bun/dialect" + "github.com/uptrace/bun/internal" + "github.com/uptrace/bun/schema" +) + +type union struct { + expr string + query *SelectQuery +} + +type SelectQuery struct { + whereBaseQuery + + distinctOn []schema.QueryWithArgs + joins []joinQuery + group []schema.QueryWithArgs + having []schema.QueryWithArgs + order []schema.QueryWithArgs + limit int32 + offset int32 + selFor schema.QueryWithArgs + + union []union +} + +func NewSelectQuery(db *DB) *SelectQuery { + return &SelectQuery{ + whereBaseQuery: whereBaseQuery{ + baseQuery: baseQuery{ + db: db, + conn: db.DB, + }, + }, + } +} + +func (q *SelectQuery) Conn(db IConn) *SelectQuery { + q.setConn(db) + return q +} + +func (q *SelectQuery) Model(model interface{}) *SelectQuery { + q.setTableModel(model) + return q +} + +// Apply calls the fn passing the SelectQuery as an argument. +func (q *SelectQuery) Apply(fn func(*SelectQuery) *SelectQuery) *SelectQuery { + return fn(q) +} + +func (q *SelectQuery) With(name string, query schema.QueryAppender) *SelectQuery { + q.addWith(name, query) + return q +} + +func (q *SelectQuery) Distinct() *SelectQuery { + q.distinctOn = make([]schema.QueryWithArgs, 0) + return q +} + +func (q *SelectQuery) DistinctOn(query string, args ...interface{}) *SelectQuery { + q.distinctOn = append(q.distinctOn, schema.SafeQuery(query, args)) + return q +} + +//------------------------------------------------------------------------------ + +func (q *SelectQuery) Table(tables ...string) *SelectQuery { + for _, table := range tables { + q.addTable(schema.UnsafeIdent(table)) + } + return q +} + +func (q *SelectQuery) TableExpr(query string, args ...interface{}) *SelectQuery { + q.addTable(schema.SafeQuery(query, args)) + return q +} + +func (q *SelectQuery) ModelTableExpr(query string, args ...interface{}) *SelectQuery { + q.modelTable = schema.SafeQuery(query, args) + return q +} + +//------------------------------------------------------------------------------ + +func (q *SelectQuery) Column(columns ...string) *SelectQuery { + for _, column := range columns { + q.addColumn(schema.UnsafeIdent(column)) + } + return q +} + +func (q *SelectQuery) ColumnExpr(query string, args ...interface{}) *SelectQuery { + q.addColumn(schema.SafeQuery(query, args)) + return q +} + +func (q *SelectQuery) ExcludeColumn(columns ...string) *SelectQuery { + q.excludeColumn(columns) + return q +} + +//------------------------------------------------------------------------------ + +func (q *SelectQuery) WherePK() *SelectQuery { + q.flags = q.flags.Set(wherePKFlag) + return q +} + +func (q *SelectQuery) Where(query string, args ...interface{}) *SelectQuery { + q.addWhere(schema.SafeQueryWithSep(query, args, " AND ")) + return q +} + +func (q *SelectQuery) WhereOr(query string, args ...interface{}) *SelectQuery { + q.addWhere(schema.SafeQueryWithSep(query, args, " OR ")) + return q +} + +func (q *SelectQuery) WhereGroup(sep string, fn func(*SelectQuery) *SelectQuery) *SelectQuery { + saved := q.where + q.where = nil + + q = fn(q) + + where := q.where + q.where = saved + + q.addWhereGroup(sep, where) + + return q +} + +func (q *SelectQuery) WhereDeleted() *SelectQuery { + q.whereDeleted() + return q +} + +func (q *SelectQuery) WhereAllWithDeleted() *SelectQuery { + q.whereAllWithDeleted() + return q +} + +//------------------------------------------------------------------------------ + +func (q *SelectQuery) Group(columns ...string) *SelectQuery { + for _, column := range columns { + q.group = append(q.group, schema.UnsafeIdent(column)) + } + return q +} + +func (q *SelectQuery) GroupExpr(group string, args ...interface{}) *SelectQuery { + q.group = append(q.group, schema.SafeQuery(group, args)) + return q +} + +func (q *SelectQuery) Having(having string, args ...interface{}) *SelectQuery { + q.having = append(q.having, schema.SafeQuery(having, args)) + return q +} + +func (q *SelectQuery) Order(orders ...string) *SelectQuery { + for _, order := range orders { + if order == "" { + continue + } + + index := strings.IndexByte(order, ' ') + if index == -1 { + q.order = append(q.order, schema.UnsafeIdent(order)) + continue + } + + field := order[:index] + sort := order[index+1:] + + switch strings.ToUpper(sort) { + case "ASC", "DESC", "ASC NULLS FIRST", "DESC NULLS FIRST", + "ASC NULLS LAST", "DESC NULLS LAST": + q.order = append(q.order, schema.SafeQuery("? ?", []interface{}{ + Ident(field), + Safe(sort), + })) + default: + q.order = append(q.order, schema.UnsafeIdent(order)) + } + } + return q +} + +func (q *SelectQuery) OrderExpr(query string, args ...interface{}) *SelectQuery { + q.order = append(q.order, schema.SafeQuery(query, args)) + return q +} + +func (q *SelectQuery) Limit(n int) *SelectQuery { + q.limit = int32(n) + return q +} + +func (q *SelectQuery) Offset(n int) *SelectQuery { + q.offset = int32(n) + return q +} + +func (q *SelectQuery) For(s string, args ...interface{}) *SelectQuery { + q.selFor = schema.SafeQuery(s, args) + return q +} + +//------------------------------------------------------------------------------ + +func (q *SelectQuery) Union(other *SelectQuery) *SelectQuery { + return q.addUnion(" UNION ", other) +} + +func (q *SelectQuery) UnionAll(other *SelectQuery) *SelectQuery { + return q.addUnion(" UNION ALL ", other) +} + +func (q *SelectQuery) Intersect(other *SelectQuery) *SelectQuery { + return q.addUnion(" INTERSECT ", other) +} + +func (q *SelectQuery) IntersectAll(other *SelectQuery) *SelectQuery { + return q.addUnion(" INTERSECT ALL ", other) +} + +func (q *SelectQuery) Except(other *SelectQuery) *SelectQuery { + return q.addUnion(" EXCEPT ", other) +} + +func (q *SelectQuery) ExceptAll(other *SelectQuery) *SelectQuery { + return q.addUnion(" EXCEPT ALL ", other) +} + +func (q *SelectQuery) addUnion(expr string, other *SelectQuery) *SelectQuery { + q.union = append(q.union, union{ + expr: expr, + query: other, + }) + return q +} + +//------------------------------------------------------------------------------ + +func (q *SelectQuery) Join(join string, args ...interface{}) *SelectQuery { + q.joins = append(q.joins, joinQuery{ + join: schema.SafeQuery(join, args), + }) + return q +} + +func (q *SelectQuery) JoinOn(cond string, args ...interface{}) *SelectQuery { + return q.joinOn(cond, args, " AND ") +} + +func (q *SelectQuery) JoinOnOr(cond string, args ...interface{}) *SelectQuery { + return q.joinOn(cond, args, " OR ") +} + +func (q *SelectQuery) joinOn(cond string, args []interface{}, sep string) *SelectQuery { + if len(q.joins) == 0 { + q.err = errors.New("bun: query has no joins") + return q + } + j := &q.joins[len(q.joins)-1] + j.on = append(j.on, schema.SafeQueryWithSep(cond, args, sep)) + return q +} + +//------------------------------------------------------------------------------ + +// Relation adds a relation to the query. Relation name can be: +// - RelationName to select all columns, +// - RelationName.column_name, +// - RelationName._ to join relation without selecting relation columns. +func (q *SelectQuery) Relation(name string, apply ...func(*SelectQuery) *SelectQuery) *SelectQuery { + if q.tableModel == nil { + q.setErr(errNilModel) + return q + } + + var fn func(*SelectQuery) *SelectQuery + + if len(apply) == 1 { + fn = apply[0] + } else if len(apply) > 1 { + panic("only one apply function is supported") + } + + join := q.tableModel.Join(name, fn) + if join == nil { + q.setErr(fmt.Errorf("%s does not have relation=%q", q.table, name)) + return q + } + + return q +} + +func (q *SelectQuery) forEachHasOneJoin(fn func(*join) error) error { + if q.tableModel == nil { + return nil + } + return q._forEachHasOneJoin(fn, q.tableModel.GetJoins()) +} + +func (q *SelectQuery) _forEachHasOneJoin(fn func(*join) error, joins []join) error { + for i := range joins { + j := &joins[i] + switch j.Relation.Type { + case schema.HasOneRelation, schema.BelongsToRelation: + if err := fn(j); err != nil { + return err + } + if err := q._forEachHasOneJoin(fn, j.JoinModel.GetJoins()); err != nil { + return err + } + } + } + return nil +} + +func (q *SelectQuery) selectJoins(ctx context.Context, joins []join) error { + var err error + for i := range joins { + j := &joins[i] + switch j.Relation.Type { + case schema.HasOneRelation, schema.BelongsToRelation: + err = q.selectJoins(ctx, j.JoinModel.GetJoins()) + default: + err = j.Select(ctx, q.db.NewSelect()) + } + if err != nil { + return err + } + } + return nil +} + +//------------------------------------------------------------------------------ + +func (q *SelectQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) { + return q.appendQuery(fmter, b, false) +} + +func (q *SelectQuery) appendQuery( + fmter schema.Formatter, b []byte, count bool, +) (_ []byte, err error) { + if q.err != nil { + return nil, q.err + } + fmter = formatterWithModel(fmter, q) + + cteCount := count && (len(q.group) > 0 || q.distinctOn != nil) + if cteCount { + b = append(b, "WITH _count_wrapper AS ("...) + } + + if len(q.union) > 0 { + b = append(b, '(') + } + + b, err = q.appendWith(fmter, b) + if err != nil { + return nil, err + } + + b = append(b, "SELECT "...) + + if len(q.distinctOn) > 0 { + b = append(b, "DISTINCT ON ("...) + for i, app := range q.distinctOn { + if i > 0 { + b = append(b, ", "...) + } + b, err = app.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + b = append(b, ") "...) + } else if q.distinctOn != nil { + b = append(b, "DISTINCT "...) + } + + if count && !cteCount { + b = append(b, "count(*)"...) + } else { + b, err = q.appendColumns(fmter, b) + if err != nil { + return nil, err + } + } + + if q.hasTables() { + b, err = q.appendTables(fmter, b) + if err != nil { + return nil, err + } + } + + if err := q.forEachHasOneJoin(func(j *join) error { + b = append(b, ' ') + b, err = j.appendHasOneJoin(fmter, b, q) + return err + }); err != nil { + return nil, err + } + + for _, j := range q.joins { + b, err = j.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + + b, err = q.appendWhere(fmter, b, true) + if err != nil { + return nil, err + } + + if len(q.group) > 0 { + b = append(b, " GROUP BY "...) + for i, f := range q.group { + if i > 0 { + b = append(b, ", "...) + } + b, err = f.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + } + + if len(q.having) > 0 { + b = append(b, " HAVING "...) + for i, f := range q.having { + if i > 0 { + b = append(b, " AND "...) + } + b = append(b, '(') + b, err = f.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + b = append(b, ')') + } + } + + if !count { + b, err = q.appendOrder(fmter, b) + if err != nil { + return nil, err + } + + if q.limit != 0 { + b = append(b, " LIMIT "...) + b = strconv.AppendInt(b, int64(q.limit), 10) + } + + if q.offset != 0 { + b = append(b, " OFFSET "...) + b = strconv.AppendInt(b, int64(q.offset), 10) + } + + if !q.selFor.IsZero() { + b = append(b, " FOR "...) + b, err = q.selFor.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + } + + if len(q.union) > 0 { + b = append(b, ')') + + for _, u := range q.union { + b = append(b, u.expr...) + b = append(b, '(') + b, err = u.query.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + b = append(b, ')') + } + } + + if cteCount { + b = append(b, ") SELECT count(*) FROM _count_wrapper"...) + } + + return b, nil +} + +func (q *SelectQuery) appendColumns(fmter schema.Formatter, b []byte) (_ []byte, err error) { + start := len(b) + + switch { + case q.columns != nil: + for i, col := range q.columns { + if i > 0 { + b = append(b, ", "...) + } + + if col.Args == nil { + if field, ok := q.table.FieldMap[col.Query]; ok { + b = append(b, q.table.SQLAlias...) + b = append(b, '.') + b = append(b, field.SQLName...) + continue + } + } + + b, err = col.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + case q.table != nil: + if len(q.table.Fields) > 10 && fmter.IsNop() { + b = append(b, q.table.SQLAlias...) + b = append(b, '.') + b = dialect.AppendString(b, fmt.Sprintf("%d columns", len(q.table.Fields))) + } else { + b = appendColumns(b, q.table.SQLAlias, q.table.Fields) + } + default: + b = append(b, '*') + } + + if err := q.forEachHasOneJoin(func(j *join) error { + if len(b) != start { + b = append(b, ", "...) + start = len(b) + } + + b, err = q.appendHasOneColumns(fmter, b, j) + if err != nil { + return err + } + + return nil + }); err != nil { + return nil, err + } + + b = bytes.TrimSuffix(b, []byte(", ")) + + return b, nil +} + +func (q *SelectQuery) appendHasOneColumns( + fmter schema.Formatter, b []byte, join *join, +) (_ []byte, err error) { + join.applyQuery(q) + + if join.columns != nil { + for i, col := range join.columns { + if i > 0 { + b = append(b, ", "...) + } + + if col.Args == nil { + if field, ok := q.table.FieldMap[col.Query]; ok { + b = join.appendAlias(fmter, b) + b = append(b, '.') + b = append(b, field.SQLName...) + b = append(b, " AS "...) + b = join.appendAliasColumn(fmter, b, field.Name) + continue + } + } + + b, err = col.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + return b, nil + } + + for i, field := range join.JoinModel.Table().Fields { + if i > 0 { + b = append(b, ", "...) + } + b = join.appendAlias(fmter, b) + b = append(b, '.') + b = append(b, field.SQLName...) + b = append(b, " AS "...) + b = join.appendAliasColumn(fmter, b, field.Name) + } + return b, nil +} + +func (q *SelectQuery) appendTables(fmter schema.Formatter, b []byte) (_ []byte, err error) { + b = append(b, " FROM "...) + return q.appendTablesWithAlias(fmter, b) +} + +func (q *SelectQuery) appendOrder(fmter schema.Formatter, b []byte) (_ []byte, err error) { + if len(q.order) > 0 { + b = append(b, " ORDER BY "...) + + for i, f := range q.order { + if i > 0 { + b = append(b, ", "...) + } + b, err = f.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + + return b, nil + } + return b, nil +} + +//------------------------------------------------------------------------------ + +func (q *SelectQuery) Rows(ctx context.Context) (*sql.Rows, error) { + queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes()) + if err != nil { + return nil, err + } + + query := internal.String(queryBytes) + return q.conn.QueryContext(ctx, query) +} + +func (q *SelectQuery) Exec(ctx context.Context) (res sql.Result, err error) { + queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes()) + if err != nil { + return nil, err + } + + query := internal.String(queryBytes) + + res, err = q.exec(ctx, q, query) + if err != nil { + return nil, err + } + + return res, nil +} + +func (q *SelectQuery) Scan(ctx context.Context, dest ...interface{}) error { + model, err := q.getModel(dest) + if err != nil { + return err + } + + if q.limit > 1 { + if model, ok := model.(interface{ SetCap(int) }); ok { + model.SetCap(int(q.limit)) + } + } + + if q.table != nil { + if err := q.beforeSelectHook(ctx); err != nil { + return err + } + } + + queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes()) + if err != nil { + return err + } + + query := internal.String(queryBytes) + + res, err := q.scan(ctx, q, query, model, true) + if err != nil { + return err + } + + if res.n > 0 { + if tableModel, ok := model.(tableModel); ok { + if err := q.selectJoins(ctx, tableModel.GetJoins()); err != nil { + return err + } + } + } + + if q.table != nil { + if err := q.afterSelectHook(ctx); err != nil { + return err + } + } + + return nil +} + +func (q *SelectQuery) beforeSelectHook(ctx context.Context) error { + if hook, ok := q.table.ZeroIface.(BeforeSelectHook); ok { + if err := hook.BeforeSelect(ctx, q); err != nil { + return err + } + } + return nil +} + +func (q *SelectQuery) afterSelectHook(ctx context.Context) error { + if hook, ok := q.table.ZeroIface.(AfterSelectHook); ok { + if err := hook.AfterSelect(ctx, q); err != nil { + return err + } + } + return nil +} + +func (q *SelectQuery) Count(ctx context.Context) (int, error) { + qq := countQuery{q} + + queryBytes, err := qq.appendQuery(q.db.fmter, nil, true) + if err != nil { + return 0, err + } + + query := internal.String(queryBytes) + ctx, event := q.db.beforeQuery(ctx, qq, query, nil) + + var num int + err = q.conn.QueryRowContext(ctx, query).Scan(&num) + + q.db.afterQuery(ctx, event, nil, err) + + return num, err +} + +func (q *SelectQuery) ScanAndCount(ctx context.Context, dest ...interface{}) (int, error) { + var count int + var wg sync.WaitGroup + var mu sync.Mutex + var firstErr error + + if q.limit >= 0 { + wg.Add(1) + go func() { + defer wg.Done() + + if err := q.Scan(ctx, dest...); err != nil { + mu.Lock() + if firstErr == nil { + firstErr = err + } + mu.Unlock() + } + }() + } + + wg.Add(1) + go func() { + defer wg.Done() + + var err error + count, err = q.Count(ctx) + if err != nil { + mu.Lock() + if firstErr == nil { + firstErr = err + } + mu.Unlock() + } + }() + + wg.Wait() + return count, firstErr +} + +//------------------------------------------------------------------------------ + +type joinQuery struct { + join schema.QueryWithArgs + on []schema.QueryWithSep +} + +func (j *joinQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) { + b = append(b, ' ') + + b, err = j.join.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + + if len(j.on) > 0 { + b = append(b, " ON "...) + for i, on := range j.on { + if i > 0 { + b = append(b, on.Sep...) + } + + b = append(b, '(') + b, err = on.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + b = append(b, ')') + } + } + + return b, nil +} + +//------------------------------------------------------------------------------ + +type countQuery struct { + *SelectQuery +} + +func (q countQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) { + return q.appendQuery(fmter, b, true) +} diff --git a/vendor/github.com/uptrace/bun/query_table_create.go b/vendor/github.com/uptrace/bun/query_table_create.go new file mode 100644 index 000000000..0a4b3567c --- /dev/null +++ b/vendor/github.com/uptrace/bun/query_table_create.go @@ -0,0 +1,275 @@ +package bun + +import ( + "context" + "database/sql" + "sort" + "strconv" + + "github.com/uptrace/bun/dialect/feature" + "github.com/uptrace/bun/dialect/sqltype" + "github.com/uptrace/bun/internal" + "github.com/uptrace/bun/schema" +) + +type CreateTableQuery struct { + baseQuery + + temp bool + ifNotExists bool + varchar int + + fks []schema.QueryWithArgs + partitionBy schema.QueryWithArgs + tablespace schema.QueryWithArgs +} + +func NewCreateTableQuery(db *DB) *CreateTableQuery { + q := &CreateTableQuery{ + baseQuery: baseQuery{ + db: db, + conn: db.DB, + }, + } + return q +} + +func (q *CreateTableQuery) Conn(db IConn) *CreateTableQuery { + q.setConn(db) + return q +} + +func (q *CreateTableQuery) Model(model interface{}) *CreateTableQuery { + q.setTableModel(model) + return q +} + +//------------------------------------------------------------------------------ + +func (q *CreateTableQuery) Table(tables ...string) *CreateTableQuery { + for _, table := range tables { + q.addTable(schema.UnsafeIdent(table)) + } + return q +} + +func (q *CreateTableQuery) TableExpr(query string, args ...interface{}) *CreateTableQuery { + q.addTable(schema.SafeQuery(query, args)) + return q +} + +func (q *CreateTableQuery) ModelTableExpr(query string, args ...interface{}) *CreateTableQuery { + q.modelTable = schema.SafeQuery(query, args) + return q +} + +//------------------------------------------------------------------------------ + +func (q *CreateTableQuery) Temp() *CreateTableQuery { + q.temp = true + return q +} + +func (q *CreateTableQuery) IfNotExists() *CreateTableQuery { + q.ifNotExists = true + return q +} + +func (q *CreateTableQuery) Varchar(n int) *CreateTableQuery { + q.varchar = n + return q +} + +func (q *CreateTableQuery) ForeignKey(query string, args ...interface{}) *CreateTableQuery { + q.fks = append(q.fks, schema.SafeQuery(query, args)) + return q +} + +func (q *CreateTableQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) { + if q.err != nil { + return nil, q.err + } + if q.table == nil { + return nil, errNilModel + } + + b = append(b, "CREATE "...) + if q.temp { + b = append(b, "TEMP "...) + } + b = append(b, "TABLE "...) + if q.ifNotExists { + b = append(b, "IF NOT EXISTS "...) + } + b, err = q.appendFirstTable(fmter, b) + if err != nil { + return nil, err + } + + b = append(b, " ("...) + + for i, field := range q.table.Fields { + if i > 0 { + b = append(b, ", "...) + } + + b = append(b, field.SQLName...) + b = append(b, " "...) + b = q.appendSQLType(b, field) + if field.NotNull { + b = append(b, " NOT NULL"...) + } + if q.db.features.Has(feature.AutoIncrement) && field.AutoIncrement { + b = append(b, " AUTO_INCREMENT"...) + } + if field.SQLDefault != "" { + b = append(b, " DEFAULT "...) + b = append(b, field.SQLDefault...) + } + } + + b = q.appendPKConstraint(b, q.table.PKs) + b = q.appendUniqueConstraints(fmter, b) + b, err = q.appenFKConstraints(fmter, b) + if err != nil { + return nil, err + } + + b = append(b, ")"...) + + if !q.partitionBy.IsZero() { + b = append(b, " PARTITION BY "...) + b, err = q.partitionBy.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + + if !q.tablespace.IsZero() { + b = append(b, " TABLESPACE "...) + b, err = q.tablespace.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + + return b, nil +} + +func (q *CreateTableQuery) appendSQLType(b []byte, field *schema.Field) []byte { + if field.CreateTableSQLType != field.DiscoveredSQLType { + return append(b, field.CreateTableSQLType...) + } + + if q.varchar > 0 && + field.CreateTableSQLType == sqltype.VarChar { + b = append(b, "varchar("...) + b = strconv.AppendInt(b, int64(q.varchar), 10) + b = append(b, ")"...) + return b + } + + return append(b, field.CreateTableSQLType...) +} + +func (q *CreateTableQuery) appendUniqueConstraints(fmter schema.Formatter, b []byte) []byte { + unique := q.table.Unique + + keys := make([]string, 0, len(unique)) + for key := range unique { + keys = append(keys, key) + } + sort.Strings(keys) + + for _, key := range keys { + b = q.appendUniqueConstraint(fmter, b, key, unique[key]) + } + + return b +} + +func (q *CreateTableQuery) appendUniqueConstraint( + fmter schema.Formatter, b []byte, name string, fields []*schema.Field, +) []byte { + if name != "" { + b = append(b, ", CONSTRAINT "...) + b = fmter.AppendIdent(b, name) + } else { + b = append(b, ","...) + } + b = append(b, " UNIQUE ("...) + b = appendColumns(b, "", fields) + b = append(b, ")"...) + + return b +} + +func (q *CreateTableQuery) appenFKConstraints( + fmter schema.Formatter, b []byte, +) (_ []byte, err error) { + for _, fk := range q.fks { + b = append(b, ", FOREIGN KEY "...) + b, err = fk.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + return b, nil +} + +func (q *CreateTableQuery) appendPKConstraint(b []byte, pks []*schema.Field) []byte { + if len(pks) == 0 { + return b + } + + b = append(b, ", PRIMARY KEY ("...) + b = appendColumns(b, "", pks) + b = append(b, ")"...) + return b +} + +//------------------------------------------------------------------------------ + +func (q *CreateTableQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) { + if err := q.beforeCreateTableHook(ctx); err != nil { + return nil, err + } + + queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes()) + if err != nil { + return nil, err + } + + query := internal.String(queryBytes) + + res, err := q.exec(ctx, q, query) + if err != nil { + return nil, err + } + + if q.table != nil { + if err := q.afterCreateTableHook(ctx); err != nil { + return nil, err + } + } + + return res, nil +} + +func (q *CreateTableQuery) beforeCreateTableHook(ctx context.Context) error { + if hook, ok := q.table.ZeroIface.(BeforeCreateTableHook); ok { + if err := hook.BeforeCreateTable(ctx, q); err != nil { + return err + } + } + return nil +} + +func (q *CreateTableQuery) afterCreateTableHook(ctx context.Context) error { + if hook, ok := q.table.ZeroIface.(AfterCreateTableHook); ok { + if err := hook.AfterCreateTable(ctx, q); err != nil { + return err + } + } + return nil +} diff --git a/vendor/github.com/uptrace/bun/query_table_drop.go b/vendor/github.com/uptrace/bun/query_table_drop.go new file mode 100644 index 000000000..2c30171c1 --- /dev/null +++ b/vendor/github.com/uptrace/bun/query_table_drop.go @@ -0,0 +1,137 @@ +package bun + +import ( + "context" + "database/sql" + + "github.com/uptrace/bun/internal" + "github.com/uptrace/bun/schema" +) + +type DropTableQuery struct { + baseQuery + cascadeQuery + + ifExists bool +} + +func NewDropTableQuery(db *DB) *DropTableQuery { + q := &DropTableQuery{ + baseQuery: baseQuery{ + db: db, + conn: db.DB, + }, + } + return q +} + +func (q *DropTableQuery) Conn(db IConn) *DropTableQuery { + q.setConn(db) + return q +} + +func (q *DropTableQuery) Model(model interface{}) *DropTableQuery { + q.setTableModel(model) + return q +} + +//------------------------------------------------------------------------------ + +func (q *DropTableQuery) Table(tables ...string) *DropTableQuery { + for _, table := range tables { + q.addTable(schema.UnsafeIdent(table)) + } + return q +} + +func (q *DropTableQuery) TableExpr(query string, args ...interface{}) *DropTableQuery { + q.addTable(schema.SafeQuery(query, args)) + return q +} + +func (q *DropTableQuery) ModelTableExpr(query string, args ...interface{}) *DropTableQuery { + q.modelTable = schema.SafeQuery(query, args) + return q +} + +//------------------------------------------------------------------------------ + +func (q *DropTableQuery) IfExists() *DropTableQuery { + q.ifExists = true + return q +} + +func (q *DropTableQuery) Restrict() *DropTableQuery { + q.restrict = true + return q +} + +//------------------------------------------------------------------------------ + +func (q *DropTableQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) { + if q.err != nil { + return nil, q.err + } + + b = append(b, "DROP TABLE "...) + if q.ifExists { + b = append(b, "IF EXISTS "...) + } + + b, err = q.appendTables(fmter, b) + if err != nil { + return nil, err + } + + b = q.appendCascade(fmter, b) + + return b, nil +} + +//------------------------------------------------------------------------------ + +func (q *DropTableQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) { + if q.table != nil { + if err := q.beforeDropTableHook(ctx); err != nil { + return nil, err + } + } + + queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes()) + if err != nil { + return nil, err + } + + query := internal.String(queryBytes) + + res, err := q.exec(ctx, q, query) + if err != nil { + return nil, err + } + + if q.table != nil { + if err := q.afterDropTableHook(ctx); err != nil { + return nil, err + } + } + + return res, nil +} + +func (q *DropTableQuery) beforeDropTableHook(ctx context.Context) error { + if hook, ok := q.table.ZeroIface.(BeforeDropTableHook); ok { + if err := hook.BeforeDropTable(ctx, q); err != nil { + return err + } + } + return nil +} + +func (q *DropTableQuery) afterDropTableHook(ctx context.Context) error { + if hook, ok := q.table.ZeroIface.(AfterDropTableHook); ok { + if err := hook.AfterDropTable(ctx, q); err != nil { + return err + } + } + return nil +} diff --git a/vendor/github.com/uptrace/bun/query_table_truncate.go b/vendor/github.com/uptrace/bun/query_table_truncate.go new file mode 100644 index 000000000..1e4bef7f6 --- /dev/null +++ b/vendor/github.com/uptrace/bun/query_table_truncate.go @@ -0,0 +1,121 @@ +package bun + +import ( + "context" + "database/sql" + + "github.com/uptrace/bun/dialect/feature" + "github.com/uptrace/bun/internal" + "github.com/uptrace/bun/schema" +) + +type TruncateTableQuery struct { + baseQuery + cascadeQuery + + continueIdentity bool +} + +func NewTruncateTableQuery(db *DB) *TruncateTableQuery { + q := &TruncateTableQuery{ + baseQuery: baseQuery{ + db: db, + conn: db.DB, + }, + } + return q +} + +func (q *TruncateTableQuery) Conn(db IConn) *TruncateTableQuery { + q.setConn(db) + return q +} + +func (q *TruncateTableQuery) Model(model interface{}) *TruncateTableQuery { + q.setTableModel(model) + return q +} + +//------------------------------------------------------------------------------ + +func (q *TruncateTableQuery) Table(tables ...string) *TruncateTableQuery { + for _, table := range tables { + q.addTable(schema.UnsafeIdent(table)) + } + return q +} + +func (q *TruncateTableQuery) TableExpr(query string, args ...interface{}) *TruncateTableQuery { + q.addTable(schema.SafeQuery(query, args)) + return q +} + +//------------------------------------------------------------------------------ + +func (q *TruncateTableQuery) ContinueIdentity() *TruncateTableQuery { + q.continueIdentity = true + return q +} + +func (q *TruncateTableQuery) Restrict() *TruncateTableQuery { + q.restrict = true + return q +} + +//------------------------------------------------------------------------------ + +func (q *TruncateTableQuery) AppendQuery( + fmter schema.Formatter, b []byte, +) (_ []byte, err error) { + if q.err != nil { + return nil, q.err + } + + if !fmter.HasFeature(feature.TableTruncate) { + b = append(b, "DELETE FROM "...) + + b, err = q.appendTables(fmter, b) + if err != nil { + return nil, err + } + + return b, nil + } + + b = append(b, "TRUNCATE TABLE "...) + + b, err = q.appendTables(fmter, b) + if err != nil { + return nil, err + } + + if q.db.features.Has(feature.TableIdentity) { + if q.continueIdentity { + b = append(b, " CONTINUE IDENTITY"...) + } else { + b = append(b, " RESTART IDENTITY"...) + } + } + + b = q.appendCascade(fmter, b) + + return b, nil +} + +//------------------------------------------------------------------------------ + +func (q *TruncateTableQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) { + queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes()) + if err != nil { + return nil, err + } + + query := internal.String(queryBytes) + + res, err := q.exec(ctx, q, query) + if err != nil { + return nil, err + } + + return res, nil +} diff --git a/vendor/github.com/uptrace/bun/query_update.go b/vendor/github.com/uptrace/bun/query_update.go new file mode 100644 index 000000000..ea74e1419 --- /dev/null +++ b/vendor/github.com/uptrace/bun/query_update.go @@ -0,0 +1,432 @@ +package bun + +import ( + "context" + "database/sql" + "errors" + "fmt" + + "github.com/uptrace/bun/dialect/feature" + "github.com/uptrace/bun/internal" + "github.com/uptrace/bun/schema" +) + +type UpdateQuery struct { + whereBaseQuery + returningQuery + customValueQuery + setQuery + + omitZero bool +} + +func NewUpdateQuery(db *DB) *UpdateQuery { + q := &UpdateQuery{ + whereBaseQuery: whereBaseQuery{ + baseQuery: baseQuery{ + db: db, + conn: db.DB, + }, + }, + } + return q +} + +func (q *UpdateQuery) Conn(db IConn) *UpdateQuery { + q.setConn(db) + return q +} + +func (q *UpdateQuery) Model(model interface{}) *UpdateQuery { + q.setTableModel(model) + return q +} + +// Apply calls the fn passing the SelectQuery as an argument. +func (q *UpdateQuery) Apply(fn func(*UpdateQuery) *UpdateQuery) *UpdateQuery { + return fn(q) +} + +func (q *UpdateQuery) With(name string, query schema.QueryAppender) *UpdateQuery { + q.addWith(name, query) + return q +} + +//------------------------------------------------------------------------------ + +func (q *UpdateQuery) Table(tables ...string) *UpdateQuery { + for _, table := range tables { + q.addTable(schema.UnsafeIdent(table)) + } + return q +} + +func (q *UpdateQuery) TableExpr(query string, args ...interface{}) *UpdateQuery { + q.addTable(schema.SafeQuery(query, args)) + return q +} + +func (q *UpdateQuery) ModelTableExpr(query string, args ...interface{}) *UpdateQuery { + q.modelTable = schema.SafeQuery(query, args) + return q +} + +//------------------------------------------------------------------------------ + +func (q *UpdateQuery) Column(columns ...string) *UpdateQuery { + for _, column := range columns { + q.addColumn(schema.UnsafeIdent(column)) + } + return q +} + +func (q *UpdateQuery) ExcludeColumn(columns ...string) *UpdateQuery { + q.excludeColumn(columns) + return q +} + +func (q *UpdateQuery) Set(query string, args ...interface{}) *UpdateQuery { + q.addSet(schema.SafeQuery(query, args)) + return q +} + +// Value overwrites model value for the column in INSERT and UPDATE queries. +func (q *UpdateQuery) Value(column string, value string, args ...interface{}) *UpdateQuery { + if q.table == nil { + q.err = errNilModel + return q + } + q.addValue(q.table, column, value, args) + return q +} + +//------------------------------------------------------------------------------ + +func (q *UpdateQuery) WherePK() *UpdateQuery { + q.flags = q.flags.Set(wherePKFlag) + return q +} + +func (q *UpdateQuery) Where(query string, args ...interface{}) *UpdateQuery { + q.addWhere(schema.SafeQueryWithSep(query, args, " AND ")) + return q +} + +func (q *UpdateQuery) WhereOr(query string, args ...interface{}) *UpdateQuery { + q.addWhere(schema.SafeQueryWithSep(query, args, " OR ")) + return q +} + +func (q *UpdateQuery) WhereGroup(sep string, fn func(*UpdateQuery) *UpdateQuery) *UpdateQuery { + saved := q.where + q.where = nil + + q = fn(q) + + where := q.where + q.where = saved + + q.addWhereGroup(sep, where) + + return q +} + +func (q *UpdateQuery) WhereDeleted() *UpdateQuery { + q.whereDeleted() + return q +} + +func (q *UpdateQuery) WhereAllWithDeleted() *UpdateQuery { + q.whereAllWithDeleted() + return q +} + +//------------------------------------------------------------------------------ + +// Returning adds a RETURNING clause to the query. +// +// To suppress the auto-generated RETURNING clause, use `Returning("NULL")`. +func (q *UpdateQuery) Returning(query string, args ...interface{}) *UpdateQuery { + q.addReturning(schema.SafeQuery(query, args)) + return q +} + +func (q *UpdateQuery) hasReturning() bool { + if !q.db.features.Has(feature.Returning) { + return false + } + return q.returningQuery.hasReturning() +} + +//------------------------------------------------------------------------------ + +func (q *UpdateQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) { + if q.err != nil { + return nil, q.err + } + fmter = formatterWithModel(fmter, q) + + withAlias := fmter.HasFeature(feature.UpdateMultiTable) + + b, err = q.appendWith(fmter, b) + if err != nil { + return nil, err + } + + b = append(b, "UPDATE "...) + + if withAlias { + b, err = q.appendTablesWithAlias(fmter, b) + } else { + b, err = q.appendFirstTableWithAlias(fmter, b) + } + if err != nil { + return nil, err + } + + b, err = q.mustAppendSet(fmter, b) + if err != nil { + return nil, err + } + + if !fmter.HasFeature(feature.UpdateMultiTable) { + b, err = q.appendOtherTables(fmter, b) + if err != nil { + return nil, err + } + } + + b, err = q.mustAppendWhere(fmter, b, withAlias) + if err != nil { + return nil, err + } + + if len(q.returning) > 0 { + b, err = q.appendReturning(fmter, b) + if err != nil { + return nil, err + } + } + + return b, nil +} + +func (q *UpdateQuery) mustAppendSet(fmter schema.Formatter, b []byte) (_ []byte, err error) { + b = append(b, " SET "...) + + if len(q.set) > 0 { + return q.appendSet(fmter, b) + } + + if m, ok := q.model.(*mapModel); ok { + return m.appendSet(fmter, b), nil + } + + if q.tableModel == nil { + return nil, errNilModel + } + + switch model := q.tableModel.(type) { + case *structTableModel: + b, err = q.appendSetStruct(fmter, b, model) + if err != nil { + return nil, err + } + case *sliceTableModel: + return nil, errors.New("bun: to bulk Update, use CTE and VALUES") + default: + return nil, fmt.Errorf("bun: Update does not support %T", q.tableModel) + } + + return b, nil +} + +func (q *UpdateQuery) appendSetStruct( + fmter schema.Formatter, b []byte, model *structTableModel, +) ([]byte, error) { + fields, err := q.getDataFields() + if err != nil { + return nil, err + } + + isTemplate := fmter.IsNop() + pos := len(b) + for _, f := range fields { + if q.omitZero && f.NullZero && f.HasZeroValue(model.strct) { + continue + } + + if len(b) != pos { + b = append(b, ", "...) + pos = len(b) + } + + b = append(b, f.SQLName...) + b = append(b, " = "...) + + if isTemplate { + b = append(b, '?') + continue + } + + app, ok := q.modelValues[f.Name] + if ok { + b, err = app.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } else { + b = f.AppendValue(fmter, b, model.strct) + } + } + + for i, v := range q.extraValues { + if i > 0 || len(fields) > 0 { + b = append(b, ", "...) + } + + b = append(b, v.column...) + b = append(b, " = "...) + + b, err = v.value.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + + return b, nil +} + +func (q *UpdateQuery) appendOtherTables(fmter schema.Formatter, b []byte) (_ []byte, err error) { + if !q.hasMultiTables() { + return b, nil + } + + b = append(b, " FROM "...) + + b, err = q.whereBaseQuery.appendOtherTables(fmter, b) + if err != nil { + return nil, err + } + + return b, nil +} + +//------------------------------------------------------------------------------ + +func (q *UpdateQuery) Bulk() *UpdateQuery { + model, ok := q.model.(*sliceTableModel) + if !ok { + q.setErr(fmt.Errorf("bun: Bulk requires a slice, got %T", q.model)) + return q + } + + return q.With("_data", q.db.NewValues(model)). + Model(model). + TableExpr("_data"). + Set(q.updateSliceSet(model)). + Where(q.updateSliceWhere(model)) +} + +func (q *UpdateQuery) updateSliceSet(model *sliceTableModel) string { + var b []byte + for i, field := range model.table.DataFields { + if i > 0 { + b = append(b, ", "...) + } + if q.db.fmter.HasFeature(feature.UpdateMultiTable) { + b = append(b, model.table.SQLAlias...) + b = append(b, '.') + } + b = append(b, field.SQLName...) + b = append(b, " = _data."...) + b = append(b, field.SQLName...) + } + return internal.String(b) +} + +func (db *UpdateQuery) updateSliceWhere(model *sliceTableModel) string { + var b []byte + for i, pk := range model.table.PKs { + if i > 0 { + b = append(b, " AND "...) + } + b = append(b, model.table.SQLAlias...) + b = append(b, '.') + b = append(b, pk.SQLName...) + b = append(b, " = _data."...) + b = append(b, pk.SQLName...) + } + return internal.String(b) +} + +//------------------------------------------------------------------------------ + +func (q *UpdateQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) { + if q.table != nil { + if err := q.beforeUpdateHook(ctx); err != nil { + return nil, err + } + } + + queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes()) + if err != nil { + return nil, err + } + + query := internal.String(queryBytes) + + var res sql.Result + + if hasDest := len(dest) > 0; hasDest || q.hasReturning() { + model, err := q.getModel(dest) + if err != nil { + return nil, err + } + + res, err = q.scan(ctx, q, query, model, hasDest) + if err != nil { + return nil, err + } + } else { + res, err = q.exec(ctx, q, query) + if err != nil { + return nil, err + } + } + + if q.table != nil { + if err := q.afterUpdateHook(ctx); err != nil { + return nil, err + } + } + + return res, nil +} + +func (q *UpdateQuery) beforeUpdateHook(ctx context.Context) error { + if hook, ok := q.table.ZeroIface.(BeforeUpdateHook); ok { + if err := hook.BeforeUpdate(ctx, q); err != nil { + return err + } + } + return nil +} + +func (q *UpdateQuery) afterUpdateHook(ctx context.Context) error { + if hook, ok := q.table.ZeroIface.(AfterUpdateHook); ok { + if err := hook.AfterUpdate(ctx, q); err != nil { + return err + } + } + return nil +} + +// FQN returns a fully qualified column name. For MySQL, it returns the column name with +// the table alias. For other RDBMS, it returns just the column name. +func (q *UpdateQuery) FQN(name string) Ident { + if q.db.fmter.HasFeature(feature.UpdateMultiTable) { + return Ident(q.table.Alias + "." + name) + } + return Ident(name) +} diff --git a/vendor/github.com/uptrace/bun/query_values.go b/vendor/github.com/uptrace/bun/query_values.go new file mode 100644 index 000000000..323ac68ef --- /dev/null +++ b/vendor/github.com/uptrace/bun/query_values.go @@ -0,0 +1,198 @@ +package bun + +import ( + "fmt" + "reflect" + "strconv" + + "github.com/uptrace/bun/dialect/feature" + "github.com/uptrace/bun/schema" +) + +type ValuesQuery struct { + baseQuery + customValueQuery + + withOrder bool +} + +var _ schema.NamedArgAppender = (*ValuesQuery)(nil) + +func NewValuesQuery(db *DB, model interface{}) *ValuesQuery { + q := &ValuesQuery{ + baseQuery: baseQuery{ + db: db, + conn: db.DB, + }, + } + q.setTableModel(model) + return q +} + +func (q *ValuesQuery) Conn(db IConn) *ValuesQuery { + q.setConn(db) + return q +} + +func (q *ValuesQuery) WithOrder() *ValuesQuery { + q.withOrder = true + return q +} + +func (q *ValuesQuery) AppendNamedArg(fmter schema.Formatter, b []byte, name string) ([]byte, bool) { + switch name { + case "Columns": + bb, err := q.AppendColumns(fmter, b) + if err != nil { + q.setErr(err) + return b, true + } + return bb, true + } + return b, false +} + +// AppendColumns appends the table columns. It is used by CTE. +func (q *ValuesQuery) AppendColumns(fmter schema.Formatter, b []byte) (_ []byte, err error) { + if q.err != nil { + return nil, q.err + } + if q.model == nil { + return nil, errNilModel + } + + if q.tableModel != nil { + fields, err := q.getFields() + if err != nil { + return nil, err + } + + b = appendColumns(b, "", fields) + + if q.withOrder { + b = append(b, ", _order"...) + } + + return b, nil + } + + switch model := q.model.(type) { + case *mapSliceModel: + return model.appendColumns(fmter, b) + } + + return nil, fmt.Errorf("bun: Values does not support %T", q.model) +} + +func (q *ValuesQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) { + if q.err != nil { + return nil, q.err + } + if q.model == nil { + return nil, errNilModel + } + + fmter = formatterWithModel(fmter, q) + + if q.tableModel != nil { + fields, err := q.getFields() + if err != nil { + return nil, err + } + return q.appendQuery(fmter, b, fields) + } + + switch model := q.model.(type) { + case *mapSliceModel: + return model.appendValues(fmter, b) + } + + return nil, fmt.Errorf("bun: Values does not support %T", q.model) +} + +func (q *ValuesQuery) appendQuery( + fmter schema.Formatter, + b []byte, + fields []*schema.Field, +) (_ []byte, err error) { + b = append(b, "VALUES "...) + if q.db.features.Has(feature.ValuesRow) { + b = append(b, "ROW("...) + } else { + b = append(b, '(') + } + + switch model := q.tableModel.(type) { + case *structTableModel: + b, err = q.appendValues(fmter, b, fields, model.strct) + if err != nil { + return nil, err + } + + if q.withOrder { + b = append(b, ", "...) + b = strconv.AppendInt(b, 0, 10) + } + case *sliceTableModel: + slice := model.slice + sliceLen := slice.Len() + for i := 0; i < sliceLen; i++ { + if i > 0 { + b = append(b, "), "...) + if q.db.features.Has(feature.ValuesRow) { + b = append(b, "ROW("...) + } else { + b = append(b, '(') + } + } + + b, err = q.appendValues(fmter, b, fields, slice.Index(i)) + if err != nil { + return nil, err + } + + if q.withOrder { + b = append(b, ", "...) + b = strconv.AppendInt(b, int64(i), 10) + } + } + default: + return nil, fmt.Errorf("bun: Values does not support %T", q.model) + } + + b = append(b, ')') + + return b, nil +} + +func (q *ValuesQuery) appendValues( + fmter schema.Formatter, b []byte, fields []*schema.Field, strct reflect.Value, +) (_ []byte, err error) { + isTemplate := fmter.IsNop() + for i, f := range fields { + if i > 0 { + b = append(b, ", "...) + } + + app, ok := q.modelValues[f.Name] + if ok { + b, err = app.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + continue + } + + if isTemplate { + b = append(b, '?') + } else { + b = f.AppendValue(fmter, b, indirect(strct)) + } + + if fmter.HasFeature(feature.DoubleColonCast) { + b = append(b, "::"...) + b = append(b, f.UserSQLType...) + } + } + return b, nil +} diff --git a/vendor/github.com/uptrace/bun/schema/append.go b/vendor/github.com/uptrace/bun/schema/append.go new file mode 100644 index 000000000..68f7071c8 --- /dev/null +++ b/vendor/github.com/uptrace/bun/schema/append.go @@ -0,0 +1,93 @@ +package schema + +import ( + "reflect" + "strconv" + "strings" + "time" + + "github.com/vmihailenco/msgpack/v5" + + "github.com/uptrace/bun/dialect" + "github.com/uptrace/bun/dialect/sqltype" + "github.com/uptrace/bun/internal" +) + +func FieldAppender(dialect Dialect, field *Field) AppenderFunc { + if field.Tag.HasOption("msgpack") { + return appendMsgpack + } + + switch strings.ToUpper(field.UserSQLType) { + case sqltype.JSON, sqltype.JSONB: + return AppendJSONValue + } + + return dialect.Appender(field.StructField.Type) +} + +func Append(fmter Formatter, b []byte, v interface{}, custom CustomAppender) []byte { + switch v := v.(type) { + case nil: + return dialect.AppendNull(b) + case bool: + return dialect.AppendBool(b, v) + case int: + return strconv.AppendInt(b, int64(v), 10) + case int32: + return strconv.AppendInt(b, int64(v), 10) + case int64: + return strconv.AppendInt(b, v, 10) + case uint: + return strconv.AppendUint(b, uint64(v), 10) + case uint32: + return strconv.AppendUint(b, uint64(v), 10) + case uint64: + return strconv.AppendUint(b, v, 10) + case float32: + return dialect.AppendFloat32(b, v) + case float64: + return dialect.AppendFloat64(b, v) + case string: + return dialect.AppendString(b, v) + case time.Time: + return dialect.AppendTime(b, v) + case []byte: + return dialect.AppendBytes(b, v) + case QueryAppender: + return AppendQueryAppender(fmter, b, v) + default: + vv := reflect.ValueOf(v) + if vv.Kind() == reflect.Ptr && vv.IsNil() { + return dialect.AppendNull(b) + } + appender := Appender(vv.Type(), custom) + return appender(fmter, b, vv) + } +} + +func appendMsgpack(fmter Formatter, b []byte, v reflect.Value) []byte { + hexEnc := internal.NewHexEncoder(b) + + enc := msgpack.GetEncoder() + defer msgpack.PutEncoder(enc) + + enc.Reset(hexEnc) + if err := enc.EncodeValue(v); err != nil { + return dialect.AppendError(b, err) + } + + if err := hexEnc.Close(); err != nil { + return dialect.AppendError(b, err) + } + + return hexEnc.Bytes() +} + +func AppendQueryAppender(fmter Formatter, b []byte, app QueryAppender) []byte { + bb, err := app.AppendQuery(fmter, b) + if err != nil { + return dialect.AppendError(b, err) + } + return bb +} diff --git a/vendor/github.com/uptrace/bun/schema/append_value.go b/vendor/github.com/uptrace/bun/schema/append_value.go new file mode 100644 index 000000000..0c4677069 --- /dev/null +++ b/vendor/github.com/uptrace/bun/schema/append_value.go @@ -0,0 +1,237 @@ +package schema + +import ( + "database/sql/driver" + "encoding/json" + "fmt" + "net" + "reflect" + "strconv" + "time" + + "github.com/uptrace/bun/dialect" + "github.com/uptrace/bun/extra/bunjson" + "github.com/uptrace/bun/internal" +) + +var ( + timeType = reflect.TypeOf((*time.Time)(nil)).Elem() + ipType = reflect.TypeOf((*net.IP)(nil)).Elem() + ipNetType = reflect.TypeOf((*net.IPNet)(nil)).Elem() + jsonRawMessageType = reflect.TypeOf((*json.RawMessage)(nil)).Elem() + + driverValuerType = reflect.TypeOf((*driver.Valuer)(nil)).Elem() + queryAppenderType = reflect.TypeOf((*QueryAppender)(nil)).Elem() +) + +type ( + AppenderFunc func(fmter Formatter, b []byte, v reflect.Value) []byte + CustomAppender func(typ reflect.Type) AppenderFunc +) + +var appenders = []AppenderFunc{ + reflect.Bool: AppendBoolValue, + reflect.Int: AppendIntValue, + reflect.Int8: AppendIntValue, + reflect.Int16: AppendIntValue, + reflect.Int32: AppendIntValue, + reflect.Int64: AppendIntValue, + reflect.Uint: AppendUintValue, + reflect.Uint8: AppendUintValue, + reflect.Uint16: AppendUintValue, + reflect.Uint32: AppendUintValue, + reflect.Uint64: AppendUintValue, + reflect.Uintptr: nil, + reflect.Float32: AppendFloat32Value, + reflect.Float64: AppendFloat64Value, + reflect.Complex64: nil, + reflect.Complex128: nil, + reflect.Array: AppendJSONValue, + reflect.Chan: nil, + reflect.Func: nil, + reflect.Interface: nil, + reflect.Map: AppendJSONValue, + reflect.Ptr: nil, + reflect.Slice: AppendJSONValue, + reflect.String: AppendStringValue, + reflect.Struct: AppendJSONValue, + reflect.UnsafePointer: nil, +} + +func Appender(typ reflect.Type, custom CustomAppender) AppenderFunc { + switch typ { + case timeType: + return appendTimeValue + case ipType: + return appendIPValue + case ipNetType: + return appendIPNetValue + case jsonRawMessageType: + return appendJSONRawMessageValue + } + + if typ.Implements(queryAppenderType) { + return appendQueryAppenderValue + } + if typ.Implements(driverValuerType) { + return driverValueAppender(custom) + } + + kind := typ.Kind() + + if kind != reflect.Ptr { + ptr := reflect.PtrTo(typ) + if ptr.Implements(queryAppenderType) { + return addrAppender(appendQueryAppenderValue, custom) + } + if ptr.Implements(driverValuerType) { + return addrAppender(driverValueAppender(custom), custom) + } + } + + switch kind { + case reflect.Interface: + return ifaceAppenderFunc(typ, custom) + case reflect.Ptr: + return ptrAppenderFunc(typ, custom) + case reflect.Slice: + if typ.Elem().Kind() == reflect.Uint8 { + return appendBytesValue + } + case reflect.Array: + if typ.Elem().Kind() == reflect.Uint8 { + return appendArrayBytesValue + } + } + + if custom != nil { + if fn := custom(typ); fn != nil { + return fn + } + } + return appenders[typ.Kind()] +} + +func ifaceAppenderFunc(typ reflect.Type, custom func(reflect.Type) AppenderFunc) AppenderFunc { + return func(fmter Formatter, b []byte, v reflect.Value) []byte { + if v.IsNil() { + return dialect.AppendNull(b) + } + elem := v.Elem() + appender := Appender(elem.Type(), custom) + return appender(fmter, b, elem) + } +} + +func ptrAppenderFunc(typ reflect.Type, custom func(reflect.Type) AppenderFunc) AppenderFunc { + appender := Appender(typ.Elem(), custom) + return func(fmter Formatter, b []byte, v reflect.Value) []byte { + if v.IsNil() { + return dialect.AppendNull(b) + } + return appender(fmter, b, v.Elem()) + } +} + +func AppendBoolValue(fmter Formatter, b []byte, v reflect.Value) []byte { + return dialect.AppendBool(b, v.Bool()) +} + +func AppendIntValue(fmter Formatter, b []byte, v reflect.Value) []byte { + return strconv.AppendInt(b, v.Int(), 10) +} + +func AppendUintValue(fmter Formatter, b []byte, v reflect.Value) []byte { + return strconv.AppendUint(b, v.Uint(), 10) +} + +func AppendFloat32Value(fmter Formatter, b []byte, v reflect.Value) []byte { + return dialect.AppendFloat32(b, float32(v.Float())) +} + +func AppendFloat64Value(fmter Formatter, b []byte, v reflect.Value) []byte { + return dialect.AppendFloat64(b, float64(v.Float())) +} + +func appendBytesValue(fmter Formatter, b []byte, v reflect.Value) []byte { + return dialect.AppendBytes(b, v.Bytes()) +} + +func appendArrayBytesValue(fmter Formatter, b []byte, v reflect.Value) []byte { + if v.CanAddr() { + return dialect.AppendBytes(b, v.Slice(0, v.Len()).Bytes()) + } + + tmp := make([]byte, v.Len()) + reflect.Copy(reflect.ValueOf(tmp), v) + b = dialect.AppendBytes(b, tmp) + return b +} + +func AppendStringValue(fmter Formatter, b []byte, v reflect.Value) []byte { + return dialect.AppendString(b, v.String()) +} + +func AppendJSONValue(fmter Formatter, b []byte, v reflect.Value) []byte { + bb, err := bunjson.Marshal(v.Interface()) + if err != nil { + return dialect.AppendError(b, err) + } + + if len(bb) > 0 && bb[len(bb)-1] == '\n' { + bb = bb[:len(bb)-1] + } + + return dialect.AppendJSON(b, bb) +} + +func appendTimeValue(fmter Formatter, b []byte, v reflect.Value) []byte { + tm := v.Interface().(time.Time) + return dialect.AppendTime(b, tm) +} + +func appendIPValue(fmter Formatter, b []byte, v reflect.Value) []byte { + ip := v.Interface().(net.IP) + return dialect.AppendString(b, ip.String()) +} + +func appendIPNetValue(fmter Formatter, b []byte, v reflect.Value) []byte { + ipnet := v.Interface().(net.IPNet) + return dialect.AppendString(b, ipnet.String()) +} + +func appendJSONRawMessageValue(fmter Formatter, b []byte, v reflect.Value) []byte { + bytes := v.Bytes() + if bytes == nil { + return dialect.AppendNull(b) + } + return dialect.AppendString(b, internal.String(bytes)) +} + +func appendQueryAppenderValue(fmter Formatter, b []byte, v reflect.Value) []byte { + return AppendQueryAppender(fmter, b, v.Interface().(QueryAppender)) +} + +func driverValueAppender(custom CustomAppender) AppenderFunc { + return func(fmter Formatter, b []byte, v reflect.Value) []byte { + return appendDriverValue(fmter, b, v.Interface().(driver.Valuer), custom) + } +} + +func appendDriverValue(fmter Formatter, b []byte, v driver.Valuer, custom CustomAppender) []byte { + value, err := v.Value() + if err != nil { + return dialect.AppendError(b, err) + } + return Append(fmter, b, value, custom) +} + +func addrAppender(fn AppenderFunc, custom CustomAppender) AppenderFunc { + return func(fmter Formatter, b []byte, v reflect.Value) []byte { + if !v.CanAddr() { + err := fmt.Errorf("bun: Append(nonaddressable %T)", v.Interface()) + return dialect.AppendError(b, err) + } + return fn(fmter, b, v.Addr()) + } +} diff --git a/vendor/github.com/uptrace/bun/schema/dialect.go b/vendor/github.com/uptrace/bun/schema/dialect.go new file mode 100644 index 000000000..c50de715a --- /dev/null +++ b/vendor/github.com/uptrace/bun/schema/dialect.go @@ -0,0 +1,99 @@ +package schema + +import ( + "database/sql" + "reflect" + "sync" + + "github.com/uptrace/bun/dialect" + "github.com/uptrace/bun/dialect/feature" +) + +type Dialect interface { + Init(db *sql.DB) + + Name() dialect.Name + Features() feature.Feature + + Tables() *Tables + OnTable(table *Table) + + IdentQuote() byte + Append(fmter Formatter, b []byte, v interface{}) []byte + Appender(typ reflect.Type) AppenderFunc + FieldAppender(field *Field) AppenderFunc + Scanner(typ reflect.Type) ScannerFunc +} + +//------------------------------------------------------------------------------ + +type nopDialect struct { + tables *Tables + features feature.Feature + + appenderMap sync.Map + scannerMap sync.Map +} + +func newNopDialect() *nopDialect { + d := new(nopDialect) + d.tables = NewTables(d) + d.features = feature.Returning + return d +} + +func (d *nopDialect) Init(*sql.DB) {} + +func (d *nopDialect) Name() dialect.Name { + return dialect.Invalid +} + +func (d *nopDialect) Features() feature.Feature { + return d.features +} + +func (d *nopDialect) Tables() *Tables { + return d.tables +} + +func (d *nopDialect) OnField(field *Field) {} + +func (d *nopDialect) OnTable(table *Table) {} + +func (d *nopDialect) IdentQuote() byte { + return '"' +} + +func (d *nopDialect) Append(fmter Formatter, b []byte, v interface{}) []byte { + return Append(fmter, b, v, nil) +} + +func (d *nopDialect) Appender(typ reflect.Type) AppenderFunc { + if v, ok := d.appenderMap.Load(typ); ok { + return v.(AppenderFunc) + } + + fn := Appender(typ, nil) + + if v, ok := d.appenderMap.LoadOrStore(typ, fn); ok { + return v.(AppenderFunc) + } + return fn +} + +func (d *nopDialect) FieldAppender(field *Field) AppenderFunc { + return FieldAppender(d, field) +} + +func (d *nopDialect) Scanner(typ reflect.Type) ScannerFunc { + if v, ok := d.scannerMap.Load(typ); ok { + return v.(ScannerFunc) + } + + fn := Scanner(typ) + + if v, ok := d.scannerMap.LoadOrStore(typ, fn); ok { + return v.(ScannerFunc) + } + return fn +} diff --git a/vendor/github.com/uptrace/bun/schema/field.go b/vendor/github.com/uptrace/bun/schema/field.go new file mode 100644 index 000000000..1e069b82f --- /dev/null +++ b/vendor/github.com/uptrace/bun/schema/field.go @@ -0,0 +1,117 @@ +package schema + +import ( + "fmt" + "reflect" + + "github.com/uptrace/bun/dialect" + "github.com/uptrace/bun/internal/tagparser" +) + +type Field struct { + StructField reflect.StructField + + Tag tagparser.Tag + IndirectType reflect.Type + Index []int + + Name string // SQL name, .e.g. id + SQLName Safe // escaped SQL name, e.g. "id" + GoName string // struct field name, e.g. Id + + DiscoveredSQLType string + UserSQLType string + CreateTableSQLType string + SQLDefault string + + OnDelete string + OnUpdate string + + IsPK bool + NotNull bool + NullZero bool + AutoIncrement bool + + Append AppenderFunc + Scan ScannerFunc + IsZero IsZeroerFunc +} + +func (f *Field) String() string { + return f.Name +} + +func (f *Field) Clone() *Field { + cp := *f + cp.Index = cp.Index[:len(f.Index):len(f.Index)] + return &cp +} + +func (f *Field) Value(strct reflect.Value) reflect.Value { + return fieldByIndexAlloc(strct, f.Index) +} + +func (f *Field) HasZeroValue(v reflect.Value) bool { + for _, idx := range f.Index { + if v.Kind() == reflect.Ptr { + if v.IsNil() { + return true + } + v = v.Elem() + } + v = v.Field(idx) + } + return f.IsZero(v) +} + +func (f *Field) AppendValue(fmter Formatter, b []byte, strct reflect.Value) []byte { + fv, ok := fieldByIndex(strct, f.Index) + if !ok { + return dialect.AppendNull(b) + } + + if f.NullZero && f.IsZero(fv) { + return dialect.AppendNull(b) + } + if f.Append == nil { + panic(fmt.Errorf("bun: AppendValue(unsupported %s)", fv.Type())) + } + return f.Append(fmter, b, fv) +} + +func (f *Field) ScanWithCheck(fv reflect.Value, src interface{}) error { + if f.Scan == nil { + return fmt.Errorf("bun: Scan(unsupported %s)", f.IndirectType) + } + return f.Scan(fv, src) +} + +func (f *Field) ScanValue(strct reflect.Value, src interface{}) error { + if src == nil { + if fv, ok := fieldByIndex(strct, f.Index); ok { + return f.ScanWithCheck(fv, src) + } + return nil + } + + fv := fieldByIndexAlloc(strct, f.Index) + return f.ScanWithCheck(fv, src) +} + +func (f *Field) markAsPK() { + f.IsPK = true + f.NotNull = true + f.NullZero = true +} + +func indexEqual(ind1, ind2 []int) bool { + if len(ind1) != len(ind2) { + return false + } + for i, ind := range ind1 { + if ind != ind2[i] { + return false + } + } + return true +} diff --git a/vendor/github.com/uptrace/bun/schema/formatter.go b/vendor/github.com/uptrace/bun/schema/formatter.go new file mode 100644 index 000000000..7b26fbaca --- /dev/null +++ b/vendor/github.com/uptrace/bun/schema/formatter.go @@ -0,0 +1,248 @@ +package schema + +import ( + "reflect" + "strconv" + "strings" + + "github.com/uptrace/bun/dialect" + "github.com/uptrace/bun/dialect/feature" + "github.com/uptrace/bun/internal" + "github.com/uptrace/bun/internal/parser" +) + +var nopFormatter = Formatter{ + dialect: newNopDialect(), +} + +type Formatter struct { + dialect Dialect + args *namedArgList +} + +func NewFormatter(dialect Dialect) Formatter { + return Formatter{ + dialect: dialect, + } +} + +func NewNopFormatter() Formatter { + return nopFormatter +} + +func (f Formatter) IsNop() bool { + return f.dialect.Name() == dialect.Invalid +} + +func (f Formatter) Dialect() Dialect { + return f.dialect +} + +func (f Formatter) IdentQuote() byte { + return f.dialect.IdentQuote() +} + +func (f Formatter) AppendIdent(b []byte, ident string) []byte { + return dialect.AppendIdent(b, ident, f.IdentQuote()) +} + +func (f Formatter) AppendValue(b []byte, v reflect.Value) []byte { + if v.Kind() == reflect.Ptr && v.IsNil() { + return dialect.AppendNull(b) + } + appender := f.dialect.Appender(v.Type()) + return appender(f, b, v) +} + +func (f Formatter) HasFeature(feature feature.Feature) bool { + return f.dialect.Features().Has(feature) +} + +func (f Formatter) WithArg(arg NamedArgAppender) Formatter { + return Formatter{ + dialect: f.dialect, + args: f.args.WithArg(arg), + } +} + +func (f Formatter) WithNamedArg(name string, value interface{}) Formatter { + return Formatter{ + dialect: f.dialect, + args: f.args.WithArg(&namedArg{name: name, value: value}), + } +} + +func (f Formatter) FormatQuery(query string, args ...interface{}) string { + if f.IsNop() || (args == nil && f.args == nil) || strings.IndexByte(query, '?') == -1 { + return query + } + return internal.String(f.AppendQuery(nil, query, args...)) +} + +func (f Formatter) AppendQuery(dst []byte, query string, args ...interface{}) []byte { + if f.IsNop() || (args == nil && f.args == nil) || strings.IndexByte(query, '?') == -1 { + return append(dst, query...) + } + return f.append(dst, parser.NewString(query), args) +} + +func (f Formatter) append(dst []byte, p *parser.Parser, args []interface{}) []byte { + var namedArgs NamedArgAppender + if len(args) == 1 { + var ok bool + namedArgs, ok = args[0].(NamedArgAppender) + if !ok { + namedArgs, _ = newStructArgs(f, args[0]) + } + } + + var argIndex int + for p.Valid() { + b, ok := p.ReadSep('?') + if !ok { + dst = append(dst, b...) + continue + } + if len(b) > 0 && b[len(b)-1] == '\\' { + dst = append(dst, b[:len(b)-1]...) + dst = append(dst, '?') + continue + } + dst = append(dst, b...) + + name, numeric := p.ReadIdentifier() + if name != "" { + if numeric { + idx, err := strconv.Atoi(name) + if err != nil { + goto restore_arg + } + + if idx >= len(args) { + goto restore_arg + } + + dst = f.appendArg(dst, args[idx]) + continue + } + + if namedArgs != nil { + dst, ok = namedArgs.AppendNamedArg(f, dst, name) + if ok { + continue + } + } + + dst, ok = f.args.AppendNamedArg(f, dst, name) + if ok { + continue + } + + restore_arg: + dst = append(dst, '?') + dst = append(dst, name...) + continue + } + + if argIndex >= len(args) { + dst = append(dst, '?') + continue + } + + arg := args[argIndex] + argIndex++ + + dst = f.appendArg(dst, arg) + } + + return dst +} + +func (f Formatter) appendArg(b []byte, arg interface{}) []byte { + switch arg := arg.(type) { + case QueryAppender: + bb, err := arg.AppendQuery(f, b) + if err != nil { + return dialect.AppendError(b, err) + } + return bb + default: + return f.dialect.Append(f, b, arg) + } +} + +//------------------------------------------------------------------------------ + +type NamedArgAppender interface { + AppendNamedArg(fmter Formatter, b []byte, name string) ([]byte, bool) +} + +//------------------------------------------------------------------------------ + +type namedArgList struct { + arg NamedArgAppender + next *namedArgList +} + +func (l *namedArgList) WithArg(arg NamedArgAppender) *namedArgList { + return &namedArgList{ + arg: arg, + next: l, + } +} + +func (l *namedArgList) AppendNamedArg(fmter Formatter, b []byte, name string) ([]byte, bool) { + for l != nil && l.arg != nil { + if b, ok := l.arg.AppendNamedArg(fmter, b, name); ok { + return b, true + } + l = l.next + } + return b, false +} + +//------------------------------------------------------------------------------ + +type namedArg struct { + name string + value interface{} +} + +var _ NamedArgAppender = (*namedArg)(nil) + +func (a *namedArg) AppendNamedArg(fmter Formatter, b []byte, name string) ([]byte, bool) { + if a.name == name { + return fmter.appendArg(b, a.value), true + } + return b, false +} + +//------------------------------------------------------------------------------ + +var _ NamedArgAppender = (*structArgs)(nil) + +type structArgs struct { + table *Table + strct reflect.Value +} + +func newStructArgs(fmter Formatter, strct interface{}) (*structArgs, bool) { + v := reflect.ValueOf(strct) + if !v.IsValid() { + return nil, false + } + + v = reflect.Indirect(v) + if v.Kind() != reflect.Struct { + return nil, false + } + + return &structArgs{ + table: fmter.Dialect().Tables().Get(v.Type()), + strct: v, + }, true +} + +func (m *structArgs) AppendNamedArg(fmter Formatter, b []byte, name string) ([]byte, bool) { + return m.table.AppendNamedArg(fmter, b, name, m.strct) +} diff --git a/vendor/github.com/uptrace/bun/schema/hook.go b/vendor/github.com/uptrace/bun/schema/hook.go new file mode 100644 index 000000000..5391981d5 --- /dev/null +++ b/vendor/github.com/uptrace/bun/schema/hook.go @@ -0,0 +1,20 @@ +package schema + +import ( + "context" + "reflect" +) + +type BeforeScanHook interface { + BeforeScan(context.Context) error +} + +var beforeScanHookType = reflect.TypeOf((*BeforeScanHook)(nil)).Elem() + +//------------------------------------------------------------------------------ + +type AfterScanHook interface { + AfterScan(context.Context) error +} + +var afterScanHookType = reflect.TypeOf((*AfterScanHook)(nil)).Elem() diff --git a/vendor/github.com/uptrace/bun/schema/relation.go b/vendor/github.com/uptrace/bun/schema/relation.go new file mode 100644 index 000000000..8d1baeb3f --- /dev/null +++ b/vendor/github.com/uptrace/bun/schema/relation.go @@ -0,0 +1,32 @@ +package schema + +import ( + "fmt" +) + +const ( + InvalidRelation = iota + HasOneRelation + BelongsToRelation + HasManyRelation + ManyToManyRelation +) + +type Relation struct { + Type int + Field *Field + JoinTable *Table + BaseFields []*Field + JoinFields []*Field + + PolymorphicField *Field + PolymorphicValue string + + M2MTable *Table + M2MBaseFields []*Field + M2MJoinFields []*Field +} + +func (r *Relation) String() string { + return fmt.Sprintf("relation=%s", r.Field.GoName) +} diff --git a/vendor/github.com/uptrace/bun/schema/scan.go b/vendor/github.com/uptrace/bun/schema/scan.go new file mode 100644 index 000000000..0e66a860f --- /dev/null +++ b/vendor/github.com/uptrace/bun/schema/scan.go @@ -0,0 +1,392 @@ +package schema + +import ( + "bytes" + "database/sql" + "fmt" + "net" + "reflect" + "strconv" + "time" + + "github.com/vmihailenco/msgpack/v5" + + "github.com/uptrace/bun/extra/bunjson" + "github.com/uptrace/bun/internal" +) + +var scannerType = reflect.TypeOf((*sql.Scanner)(nil)).Elem() + +type ScannerFunc func(dest reflect.Value, src interface{}) error + +var scanners = []ScannerFunc{ + reflect.Bool: scanBool, + reflect.Int: scanInt64, + reflect.Int8: scanInt64, + reflect.Int16: scanInt64, + reflect.Int32: scanInt64, + reflect.Int64: scanInt64, + reflect.Uint: scanUint64, + reflect.Uint8: scanUint64, + reflect.Uint16: scanUint64, + reflect.Uint32: scanUint64, + reflect.Uint64: scanUint64, + reflect.Uintptr: scanUint64, + reflect.Float32: scanFloat64, + reflect.Float64: scanFloat64, + reflect.Complex64: nil, + reflect.Complex128: nil, + reflect.Array: nil, + reflect.Chan: nil, + reflect.Func: nil, + reflect.Map: scanJSON, + reflect.Ptr: nil, + reflect.Slice: scanJSON, + reflect.String: scanString, + reflect.Struct: scanJSON, + reflect.UnsafePointer: nil, +} + +func FieldScanner(dialect Dialect, field *Field) ScannerFunc { + if field.Tag.HasOption("msgpack") { + return scanMsgpack + } + if field.Tag.HasOption("json_use_number") { + return scanJSONUseNumber + } + return dialect.Scanner(field.StructField.Type) +} + +func Scanner(typ reflect.Type) ScannerFunc { + kind := typ.Kind() + + if kind == reflect.Ptr { + if fn := Scanner(typ.Elem()); fn != nil { + return ptrScanner(fn) + } + } + + if typ.Implements(scannerType) { + return scanScanner + } + + if kind != reflect.Ptr { + ptr := reflect.PtrTo(typ) + if ptr.Implements(scannerType) { + return addrScanner(scanScanner) + } + } + + switch typ { + case timeType: + return scanTime + case ipType: + return scanIP + case ipNetType: + return scanIPNet + case jsonRawMessageType: + return scanJSONRawMessage + } + + return scanners[kind] +} + +func scanBool(dest reflect.Value, src interface{}) error { + switch src := src.(type) { + case nil: + dest.SetBool(false) + return nil + case bool: + dest.SetBool(src) + return nil + case int64: + dest.SetBool(src != 0) + return nil + case []byte: + if len(src) == 1 { + dest.SetBool(src[0] != '0') + return nil + } + } + return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type()) +} + +func scanInt64(dest reflect.Value, src interface{}) error { + switch src := src.(type) { + case nil: + dest.SetInt(0) + return nil + case int64: + dest.SetInt(src) + return nil + case uint64: + dest.SetInt(int64(src)) + return nil + case []byte: + n, err := strconv.ParseInt(internal.String(src), 10, 64) + if err != nil { + return err + } + dest.SetInt(n) + return nil + case string: + n, err := strconv.ParseInt(src, 10, 64) + if err != nil { + return err + } + dest.SetInt(n) + return nil + } + return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type()) +} + +func scanUint64(dest reflect.Value, src interface{}) error { + switch src := src.(type) { + case nil: + dest.SetUint(0) + return nil + case uint64: + dest.SetUint(src) + return nil + case int64: + dest.SetUint(uint64(src)) + return nil + case []byte: + n, err := strconv.ParseUint(internal.String(src), 10, 64) + if err != nil { + return err + } + dest.SetUint(n) + return nil + } + return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type()) +} + +func scanFloat64(dest reflect.Value, src interface{}) error { + switch src := src.(type) { + case nil: + dest.SetFloat(0) + return nil + case float64: + dest.SetFloat(src) + return nil + case []byte: + f, err := strconv.ParseFloat(internal.String(src), 64) + if err != nil { + return err + } + dest.SetFloat(f) + return nil + } + return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type()) +} + +func scanString(dest reflect.Value, src interface{}) error { + switch src := src.(type) { + case nil: + dest.SetString("") + return nil + case string: + dest.SetString(src) + return nil + case []byte: + dest.SetString(string(src)) + return nil + } + return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type()) +} + +func scanTime(dest reflect.Value, src interface{}) error { + switch src := src.(type) { + case nil: + destTime := dest.Addr().Interface().(*time.Time) + *destTime = time.Time{} + return nil + case time.Time: + destTime := dest.Addr().Interface().(*time.Time) + *destTime = src + return nil + case string: + srcTime, err := internal.ParseTime(src) + if err != nil { + return err + } + destTime := dest.Addr().Interface().(*time.Time) + *destTime = srcTime + return nil + case []byte: + srcTime, err := internal.ParseTime(internal.String(src)) + if err != nil { + return err + } + destTime := dest.Addr().Interface().(*time.Time) + *destTime = srcTime + return nil + } + return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type()) +} + +func scanScanner(dest reflect.Value, src interface{}) error { + return dest.Interface().(sql.Scanner).Scan(src) +} + +func scanMsgpack(dest reflect.Value, src interface{}) error { + if src == nil { + return scanNull(dest) + } + + b, err := toBytes(src) + if err != nil { + return err + } + + dec := msgpack.GetDecoder() + defer msgpack.PutDecoder(dec) + + dec.Reset(bytes.NewReader(b)) + return dec.DecodeValue(dest) +} + +func scanJSON(dest reflect.Value, src interface{}) error { + if src == nil { + return scanNull(dest) + } + + b, err := toBytes(src) + if err != nil { + return err + } + + return bunjson.Unmarshal(b, dest.Addr().Interface()) +} + +func scanJSONUseNumber(dest reflect.Value, src interface{}) error { + if src == nil { + return scanNull(dest) + } + + b, err := toBytes(src) + if err != nil { + return err + } + + dec := bunjson.NewDecoder(bytes.NewReader(b)) + dec.UseNumber() + return dec.Decode(dest.Addr().Interface()) +} + +func scanIP(dest reflect.Value, src interface{}) error { + if src == nil { + return scanNull(dest) + } + + b, err := toBytes(src) + if err != nil { + return err + } + + ip := net.ParseIP(internal.String(b)) + if ip == nil { + return fmt.Errorf("bun: invalid ip: %q", b) + } + + ptr := dest.Addr().Interface().(*net.IP) + *ptr = ip + + return nil +} + +func scanIPNet(dest reflect.Value, src interface{}) error { + if src == nil { + return scanNull(dest) + } + + b, err := toBytes(src) + if err != nil { + return err + } + + _, ipnet, err := net.ParseCIDR(internal.String(b)) + if err != nil { + return err + } + + ptr := dest.Addr().Interface().(*net.IPNet) + *ptr = *ipnet + + return nil +} + +func scanJSONRawMessage(dest reflect.Value, src interface{}) error { + if src == nil { + dest.SetBytes(nil) + return nil + } + + b, err := toBytes(src) + if err != nil { + return err + } + + dest.SetBytes(b) + return nil +} + +func addrScanner(fn ScannerFunc) ScannerFunc { + return func(dest reflect.Value, src interface{}) error { + if !dest.CanAddr() { + return fmt.Errorf("bun: Scan(nonaddressable %T)", dest.Interface()) + } + return fn(dest.Addr(), src) + } +} + +func toBytes(src interface{}) ([]byte, error) { + switch src := src.(type) { + case string: + return internal.Bytes(src), nil + case []byte: + return src, nil + default: + return nil, fmt.Errorf("bun: got %T, wanted []byte or string", src) + } +} + +func ptrScanner(fn ScannerFunc) ScannerFunc { + return func(dest reflect.Value, src interface{}) error { + if src == nil { + if !dest.CanAddr() { + if dest.IsNil() { + return nil + } + return fn(dest.Elem(), src) + } + + if !dest.IsNil() { + dest.Set(reflect.New(dest.Type().Elem())) + } + return nil + } + + if dest.IsNil() { + dest.Set(reflect.New(dest.Type().Elem())) + } + return fn(dest.Elem(), src) + } +} + +func scanNull(dest reflect.Value) error { + if nilable(dest.Kind()) && dest.IsNil() { + return nil + } + dest.Set(reflect.New(dest.Type()).Elem()) + return nil +} + +func nilable(kind reflect.Kind) bool { + switch kind { + case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice: + return true + } + return false +} diff --git a/vendor/github.com/uptrace/bun/schema/sqlfmt.go b/vendor/github.com/uptrace/bun/schema/sqlfmt.go new file mode 100644 index 000000000..7b538cd0c --- /dev/null +++ b/vendor/github.com/uptrace/bun/schema/sqlfmt.go @@ -0,0 +1,76 @@ +package schema + +type QueryAppender interface { + AppendQuery(fmter Formatter, b []byte) ([]byte, error) +} + +type ColumnsAppender interface { + AppendColumns(fmter Formatter, b []byte) ([]byte, error) +} + +//------------------------------------------------------------------------------ + +// Safe represents a safe SQL query. +type Safe string + +var _ QueryAppender = (*Safe)(nil) + +func (s Safe) AppendQuery(fmter Formatter, b []byte) ([]byte, error) { + return append(b, s...), nil +} + +//------------------------------------------------------------------------------ + +// Ident represents a SQL identifier, for example, table or column name. +type Ident string + +var _ QueryAppender = (*Ident)(nil) + +func (s Ident) AppendQuery(fmter Formatter, b []byte) ([]byte, error) { + return fmter.AppendIdent(b, string(s)), nil +} + +//------------------------------------------------------------------------------ + +type QueryWithArgs struct { + Query string + Args []interface{} +} + +var _ QueryAppender = QueryWithArgs{} + +func SafeQuery(query string, args []interface{}) QueryWithArgs { + if query != "" && args == nil { + args = make([]interface{}, 0) + } + return QueryWithArgs{Query: query, Args: args} +} + +func UnsafeIdent(ident string) QueryWithArgs { + return QueryWithArgs{Query: ident} +} + +func (q QueryWithArgs) IsZero() bool { + return q.Query == "" && q.Args == nil +} + +func (q QueryWithArgs) AppendQuery(fmter Formatter, b []byte) ([]byte, error) { + if q.Args == nil { + return fmter.AppendIdent(b, q.Query), nil + } + return fmter.AppendQuery(b, q.Query, q.Args...), nil +} + +//------------------------------------------------------------------------------ + +type QueryWithSep struct { + QueryWithArgs + Sep string +} + +func SafeQueryWithSep(query string, args []interface{}, sep string) QueryWithSep { + return QueryWithSep{ + QueryWithArgs: SafeQuery(query, args), + Sep: sep, + } +} diff --git a/vendor/github.com/uptrace/bun/schema/sqltype.go b/vendor/github.com/uptrace/bun/schema/sqltype.go new file mode 100644 index 000000000..560f695c2 --- /dev/null +++ b/vendor/github.com/uptrace/bun/schema/sqltype.go @@ -0,0 +1,129 @@ +package schema + +import ( + "bytes" + "database/sql" + "encoding/json" + "fmt" + "reflect" + "time" + + "github.com/uptrace/bun/dialect" + "github.com/uptrace/bun/dialect/sqltype" + "github.com/uptrace/bun/internal" +) + +var ( + bunNullTimeType = reflect.TypeOf((*NullTime)(nil)).Elem() + nullTimeType = reflect.TypeOf((*sql.NullTime)(nil)).Elem() + nullBoolType = reflect.TypeOf((*sql.NullBool)(nil)).Elem() + nullFloatType = reflect.TypeOf((*sql.NullFloat64)(nil)).Elem() + nullIntType = reflect.TypeOf((*sql.NullInt64)(nil)).Elem() + nullStringType = reflect.TypeOf((*sql.NullString)(nil)).Elem() +) + +var sqlTypes = []string{ + reflect.Bool: sqltype.Boolean, + reflect.Int: sqltype.BigInt, + reflect.Int8: sqltype.SmallInt, + reflect.Int16: sqltype.SmallInt, + reflect.Int32: sqltype.Integer, + reflect.Int64: sqltype.BigInt, + reflect.Uint: sqltype.BigInt, + reflect.Uint8: sqltype.SmallInt, + reflect.Uint16: sqltype.SmallInt, + reflect.Uint32: sqltype.Integer, + reflect.Uint64: sqltype.BigInt, + reflect.Uintptr: sqltype.BigInt, + reflect.Float32: sqltype.Real, + reflect.Float64: sqltype.DoublePrecision, + reflect.Complex64: "", + reflect.Complex128: "", + reflect.Array: "", + reflect.Chan: "", + reflect.Func: "", + reflect.Interface: "", + reflect.Map: sqltype.VarChar, + reflect.Ptr: "", + reflect.Slice: sqltype.VarChar, + reflect.String: sqltype.VarChar, + reflect.Struct: sqltype.VarChar, + reflect.UnsafePointer: "", +} + +func DiscoverSQLType(typ reflect.Type) string { + switch typ { + case timeType, nullTimeType, bunNullTimeType: + return sqltype.Timestamp + case nullBoolType: + return sqltype.Boolean + case nullFloatType: + return sqltype.DoublePrecision + case nullIntType: + return sqltype.BigInt + case nullStringType: + return sqltype.VarChar + } + return sqlTypes[typ.Kind()] +} + +//------------------------------------------------------------------------------ + +var jsonNull = []byte("null") + +// NullTime is a time.Time wrapper that marshals zero time as JSON null and SQL NULL. +type NullTime struct { + time.Time +} + +var ( + _ json.Marshaler = (*NullTime)(nil) + _ json.Unmarshaler = (*NullTime)(nil) + _ sql.Scanner = (*NullTime)(nil) + _ QueryAppender = (*NullTime)(nil) +) + +func (tm NullTime) MarshalJSON() ([]byte, error) { + if tm.IsZero() { + return jsonNull, nil + } + return tm.Time.MarshalJSON() +} + +func (tm *NullTime) UnmarshalJSON(b []byte) error { + if bytes.Equal(b, jsonNull) { + tm.Time = time.Time{} + return nil + } + return tm.Time.UnmarshalJSON(b) +} + +func (tm NullTime) AppendQuery(fmter Formatter, b []byte) ([]byte, error) { + if tm.IsZero() { + return dialect.AppendNull(b), nil + } + return dialect.AppendTime(b, tm.Time), nil +} + +func (tm *NullTime) Scan(src interface{}) error { + if src == nil { + tm.Time = time.Time{} + return nil + } + + switch src := src.(type) { + case []byte: + newtm, err := internal.ParseTime(internal.String(src)) + if err != nil { + return err + } + + tm.Time = newtm + return nil + case time.Time: + tm.Time = src + return nil + default: + return fmt.Errorf("bun: can't scan %#v into NullTime", src) + } +} diff --git a/vendor/github.com/uptrace/bun/schema/table.go b/vendor/github.com/uptrace/bun/schema/table.go new file mode 100644 index 000000000..eca18b781 --- /dev/null +++ b/vendor/github.com/uptrace/bun/schema/table.go @@ -0,0 +1,948 @@ +package schema + +import ( + "database/sql" + "fmt" + "reflect" + "strings" + "sync" + "time" + + "github.com/jinzhu/inflection" + + "github.com/uptrace/bun/internal" + "github.com/uptrace/bun/internal/tagparser" +) + +const ( + beforeScanHookFlag internal.Flag = 1 << iota + afterScanHookFlag +) + +var ( + baseModelType = reflect.TypeOf((*BaseModel)(nil)).Elem() + tableNameInflector = inflection.Plural +) + +type BaseModel struct{} + +// SetTableNameInflector overrides the default func that pluralizes +// model name to get table name, e.g. my_article becomes my_articles. +func SetTableNameInflector(fn func(string) string) { + tableNameInflector = fn +} + +// Table represents a SQL table created from Go struct. +type Table struct { + dialect Dialect + + Type reflect.Type + ZeroValue reflect.Value // reflect.Struct + ZeroIface interface{} // struct pointer + + TypeName string + ModelName string + + Name string + SQLName Safe + SQLNameForSelects Safe + Alias string + SQLAlias Safe + + Fields []*Field // PKs + DataFields + PKs []*Field + DataFields []*Field + + fieldsMapMu sync.RWMutex + FieldMap map[string]*Field + + Relations map[string]*Relation + Unique map[string][]*Field + + SoftDeleteField *Field + UpdateSoftDeleteField func(fv reflect.Value) error + + allFields []*Field // read only + skippedFields []*Field + + flags internal.Flag +} + +func newTable(dialect Dialect, typ reflect.Type) *Table { + t := new(Table) + t.dialect = dialect + t.Type = typ + t.ZeroValue = reflect.New(t.Type).Elem() + t.ZeroIface = reflect.New(t.Type).Interface() + t.TypeName = internal.ToExported(t.Type.Name()) + t.ModelName = internal.Underscore(t.Type.Name()) + tableName := tableNameInflector(t.ModelName) + t.setName(tableName) + t.Alias = t.ModelName + t.SQLAlias = t.quoteIdent(t.ModelName) + + hooks := []struct { + typ reflect.Type + flag internal.Flag + }{ + {beforeScanHookType, beforeScanHookFlag}, + {afterScanHookType, afterScanHookFlag}, + } + + typ = reflect.PtrTo(t.Type) + for _, hook := range hooks { + if typ.Implements(hook.typ) { + t.flags = t.flags.Set(hook.flag) + } + } + + return t +} + +func (t *Table) init1() { + t.initFields() +} + +func (t *Table) init2() { + t.initInlines() + t.initRelations() + t.skippedFields = nil +} + +func (t *Table) setName(name string) { + t.Name = name + t.SQLName = t.quoteIdent(name) + t.SQLNameForSelects = t.quoteIdent(name) + if t.SQLAlias == "" { + t.Alias = name + t.SQLAlias = t.quoteIdent(name) + } +} + +func (t *Table) String() string { + return "model=" + t.TypeName +} + +func (t *Table) CheckPKs() error { + if len(t.PKs) == 0 { + return fmt.Errorf("bun: %s does not have primary keys", t) + } + return nil +} + +func (t *Table) addField(field *Field) { + t.Fields = append(t.Fields, field) + if field.IsPK { + t.PKs = append(t.PKs, field) + } else { + t.DataFields = append(t.DataFields, field) + } + t.FieldMap[field.Name] = field +} + +func (t *Table) removeField(field *Field) { + t.Fields = removeField(t.Fields, field) + if field.IsPK { + t.PKs = removeField(t.PKs, field) + } else { + t.DataFields = removeField(t.DataFields, field) + } + delete(t.FieldMap, field.Name) +} + +func (t *Table) fieldWithLock(name string) *Field { + t.fieldsMapMu.RLock() + field := t.FieldMap[name] + t.fieldsMapMu.RUnlock() + return field +} + +func (t *Table) HasField(name string) bool { + _, ok := t.FieldMap[name] + return ok +} + +func (t *Table) Field(name string) (*Field, error) { + field, ok := t.FieldMap[name] + if !ok { + return nil, fmt.Errorf("bun: %s does not have column=%s", t, name) + } + return field, nil +} + +func (t *Table) fieldByGoName(name string) *Field { + for _, f := range t.allFields { + if f.GoName == name { + return f + } + } + return nil +} + +func (t *Table) initFields() { + t.Fields = make([]*Field, 0, t.Type.NumField()) + t.FieldMap = make(map[string]*Field, t.Type.NumField()) + t.addFields(t.Type, nil) + + if len(t.PKs) > 0 { + return + } + for _, name := range []string{"id", "uuid", "pk_" + t.ModelName} { + if field, ok := t.FieldMap[name]; ok { + field.markAsPK() + t.PKs = []*Field{field} + t.DataFields = removeField(t.DataFields, field) + break + } + } + if len(t.PKs) == 1 { + switch t.PKs[0].IndirectType.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + t.PKs[0].AutoIncrement = true + } + } +} + +func (t *Table) addFields(typ reflect.Type, baseIndex []int) { + for i := 0; i < typ.NumField(); i++ { + f := typ.Field(i) + + // Make a copy so slice is not shared between fields. + index := make([]int, len(baseIndex)) + copy(index, baseIndex) + + if f.Anonymous { + if f.Tag.Get("bun") == "-" { + continue + } + if f.Name == "BaseModel" && f.Type == baseModelType { + if len(index) == 0 { + t.processBaseModelField(f) + } + continue + } + + fieldType := indirectType(f.Type) + if fieldType.Kind() != reflect.Struct { + continue + } + t.addFields(fieldType, append(index, f.Index...)) + + tag := tagparser.Parse(f.Tag.Get("bun")) + if _, inherit := tag.Options["inherit"]; inherit { + embeddedTable := t.dialect.Tables().Ref(fieldType) + t.TypeName = embeddedTable.TypeName + t.SQLName = embeddedTable.SQLName + t.SQLNameForSelects = embeddedTable.SQLNameForSelects + t.Alias = embeddedTable.Alias + t.SQLAlias = embeddedTable.SQLAlias + t.ModelName = embeddedTable.ModelName + } + + continue + } + + field := t.newField(f, index) + if field != nil { + t.addField(field) + } + } +} + +func (t *Table) processBaseModelField(f reflect.StructField) { + tag := tagparser.Parse(f.Tag.Get("bun")) + + if isKnownTableOption(tag.Name) { + internal.Warn.Printf( + "%s.%s tag name %q is also an option name; is it a mistake?", + t.TypeName, f.Name, tag.Name, + ) + } + + for name := range tag.Options { + if !isKnownTableOption(name) { + internal.Warn.Printf("%s.%s has unknown tag option: %q", t.TypeName, f.Name, name) + } + } + + if tag.Name != "" { + t.setName(tag.Name) + } + + if s, ok := tag.Options["select"]; ok { + t.SQLNameForSelects = t.quoteTableName(s) + } + + if s, ok := tag.Options["alias"]; ok { + t.Alias = s + t.SQLAlias = t.quoteIdent(s) + } +} + +//nolint +func (t *Table) newField(f reflect.StructField, index []int) *Field { + tag := tagparser.Parse(f.Tag.Get("bun")) + + if f.PkgPath != "" { + return nil + } + + sqlName := internal.Underscore(f.Name) + + if tag.Name != sqlName && isKnownFieldOption(tag.Name) { + internal.Warn.Printf( + "%s.%s tag name %q is also an option name; is it a mistake?", + t.TypeName, f.Name, tag.Name, + ) + } + + for name := range tag.Options { + if !isKnownFieldOption(name) { + internal.Warn.Printf("%s.%s has unknown tag option: %q", t.TypeName, f.Name, name) + } + } + + skip := tag.Name == "-" + if !skip && tag.Name != "" { + sqlName = tag.Name + } + + index = append(index, f.Index...) + if field := t.fieldWithLock(sqlName); field != nil { + if indexEqual(field.Index, index) { + return field + } + t.removeField(field) + } + + field := &Field{ + StructField: f, + + Tag: tag, + IndirectType: indirectType(f.Type), + Index: index, + + Name: sqlName, + GoName: f.Name, + SQLName: t.quoteIdent(sqlName), + } + + field.NotNull = tag.HasOption("notnull") + field.NullZero = tag.HasOption("nullzero") + field.AutoIncrement = tag.HasOption("autoincrement") + if tag.HasOption("pk") { + field.markAsPK() + } + if tag.HasOption("allowzero") { + if tag.HasOption("nullzero") { + internal.Warn.Printf( + "%s.%s: nullzero and allowzero options are mutually exclusive", + t.TypeName, f.Name, + ) + } + field.NullZero = false + } + + if v, ok := tag.Options["unique"]; ok { + // Split the value by comma, this will allow multiple names to be specified. + // We can use this to create multiple named unique constraints where a single column + // might be included in multiple constraints. + for _, uniqueName := range strings.Split(v, ",") { + if t.Unique == nil { + t.Unique = make(map[string][]*Field) + } + t.Unique[uniqueName] = append(t.Unique[uniqueName], field) + } + } + if s, ok := tag.Options["default"]; ok { + field.SQLDefault = s + } + if s, ok := field.Tag.Options["type"]; ok { + field.UserSQLType = s + } + field.DiscoveredSQLType = DiscoverSQLType(field.IndirectType) + field.Append = t.dialect.FieldAppender(field) + field.Scan = FieldScanner(t.dialect, field) + field.IsZero = FieldZeroChecker(field) + + if v, ok := tag.Options["alt"]; ok { + t.FieldMap[v] = field + } + + t.allFields = append(t.allFields, field) + if skip { + t.skippedFields = append(t.skippedFields, field) + t.FieldMap[field.Name] = field + return nil + } + + if _, ok := tag.Options["soft_delete"]; ok { + field.NullZero = true + t.SoftDeleteField = field + t.UpdateSoftDeleteField = softDeleteFieldUpdater(field) + } + + return field +} + +func (t *Table) initInlines() { + for _, f := range t.skippedFields { + if f.IndirectType.Kind() == reflect.Struct { + t.inlineFields(f, nil) + } + } +} + +//--------------------------------------------------------------------------------------- + +func (t *Table) initRelations() { + for i := 0; i < len(t.Fields); { + f := t.Fields[i] + if t.tryRelation(f) { + t.Fields = removeField(t.Fields, f) + t.DataFields = removeField(t.DataFields, f) + } else { + i++ + } + + if f.IndirectType.Kind() == reflect.Struct { + t.inlineFields(f, nil) + } + } +} + +func (t *Table) tryRelation(field *Field) bool { + if rel, ok := field.Tag.Options["rel"]; ok { + t.initRelation(field, rel) + return true + } + if field.Tag.HasOption("m2m") { + t.addRelation(t.m2mRelation(field)) + return true + } + + if field.Tag.HasOption("join") { + internal.Warn.Printf( + `%s.%s option "join" requires a relation type`, + t.TypeName, field.GoName, + ) + } + + return false +} + +func (t *Table) initRelation(field *Field, rel string) { + switch rel { + case "belongs-to": + t.addRelation(t.belongsToRelation(field)) + case "has-one": + t.addRelation(t.hasOneRelation(field)) + case "has-many": + t.addRelation(t.hasManyRelation(field)) + default: + panic(fmt.Errorf("bun: unknown relation=%s on field=%s", rel, field.GoName)) + } +} + +func (t *Table) addRelation(rel *Relation) { + if t.Relations == nil { + t.Relations = make(map[string]*Relation) + } + _, ok := t.Relations[rel.Field.GoName] + if ok { + panic(fmt.Errorf("%s already has %s", t, rel)) + } + t.Relations[rel.Field.GoName] = rel +} + +func (t *Table) belongsToRelation(field *Field) *Relation { + joinTable := t.dialect.Tables().Ref(field.IndirectType) + if err := joinTable.CheckPKs(); err != nil { + panic(err) + } + + rel := &Relation{ + Type: HasOneRelation, + Field: field, + JoinTable: joinTable, + } + + if join, ok := field.Tag.Options["join"]; ok { + baseColumns, joinColumns := parseRelationJoin(join) + for i, baseColumn := range baseColumns { + joinColumn := joinColumns[i] + + if f := t.fieldWithLock(baseColumn); f != nil { + rel.BaseFields = append(rel.BaseFields, f) + } else { + panic(fmt.Errorf( + "bun: %s belongs-to %s: %s must have column %s", + t.TypeName, field.GoName, t.TypeName, baseColumn, + )) + } + + if f := joinTable.fieldWithLock(joinColumn); f != nil { + rel.JoinFields = append(rel.JoinFields, f) + } else { + panic(fmt.Errorf( + "bun: %s belongs-to %s: %s must have column %s", + t.TypeName, field.GoName, t.TypeName, baseColumn, + )) + } + } + return rel + } + + rel.JoinFields = joinTable.PKs + fkPrefix := internal.Underscore(field.GoName) + "_" + for _, joinPK := range joinTable.PKs { + fkName := fkPrefix + joinPK.Name + if fk := t.fieldWithLock(fkName); fk != nil { + rel.BaseFields = append(rel.BaseFields, fk) + continue + } + + if fk := t.fieldWithLock(joinPK.Name); fk != nil { + rel.BaseFields = append(rel.BaseFields, fk) + continue + } + + panic(fmt.Errorf( + "bun: %s belongs-to %s: %s must have column %s "+ + "(to override, use join:base_column=join_column tag on %s field)", + t.TypeName, field.GoName, t.TypeName, fkName, field.GoName, + )) + } + return rel +} + +func (t *Table) hasOneRelation(field *Field) *Relation { + if err := t.CheckPKs(); err != nil { + panic(err) + } + + joinTable := t.dialect.Tables().Ref(field.IndirectType) + rel := &Relation{ + Type: BelongsToRelation, + Field: field, + JoinTable: joinTable, + } + + if join, ok := field.Tag.Options["join"]; ok { + baseColumns, joinColumns := parseRelationJoin(join) + for i, baseColumn := range baseColumns { + if f := t.fieldWithLock(baseColumn); f != nil { + rel.BaseFields = append(rel.BaseFields, f) + } else { + panic(fmt.Errorf( + "bun: %s has-one %s: %s must have column %s", + field.GoName, t.TypeName, joinTable.TypeName, baseColumn, + )) + } + + joinColumn := joinColumns[i] + if f := joinTable.fieldWithLock(joinColumn); f != nil { + rel.JoinFields = append(rel.JoinFields, f) + } else { + panic(fmt.Errorf( + "bun: %s has-one %s: %s must have column %s", + field.GoName, t.TypeName, joinTable.TypeName, baseColumn, + )) + } + } + return rel + } + + rel.BaseFields = t.PKs + fkPrefix := internal.Underscore(t.ModelName) + "_" + for _, pk := range t.PKs { + fkName := fkPrefix + pk.Name + if f := joinTable.fieldWithLock(fkName); f != nil { + rel.JoinFields = append(rel.JoinFields, f) + continue + } + + if f := joinTable.fieldWithLock(pk.Name); f != nil { + rel.JoinFields = append(rel.JoinFields, f) + continue + } + + panic(fmt.Errorf( + "bun: %s has-one %s: %s must have column %s "+ + "(to override, use join:base_column=join_column tag on %s field)", + field.GoName, t.TypeName, joinTable.TypeName, fkName, field.GoName, + )) + } + return rel +} + +func (t *Table) hasManyRelation(field *Field) *Relation { + if err := t.CheckPKs(); err != nil { + panic(err) + } + if field.IndirectType.Kind() != reflect.Slice { + panic(fmt.Errorf( + "bun: %s.%s has-many relation requires slice, got %q", + t.TypeName, field.GoName, field.IndirectType.Kind(), + )) + } + + joinTable := t.dialect.Tables().Ref(indirectType(field.IndirectType.Elem())) + polymorphicValue, isPolymorphic := field.Tag.Options["polymorphic"] + rel := &Relation{ + Type: HasManyRelation, + Field: field, + JoinTable: joinTable, + } + var polymorphicColumn string + + if join, ok := field.Tag.Options["join"]; ok { + baseColumns, joinColumns := parseRelationJoin(join) + for i, baseColumn := range baseColumns { + joinColumn := joinColumns[i] + + if isPolymorphic && baseColumn == "type" { + polymorphicColumn = joinColumn + continue + } + + if f := t.fieldWithLock(baseColumn); f != nil { + rel.BaseFields = append(rel.BaseFields, f) + } else { + panic(fmt.Errorf( + "bun: %s has-one %s: %s must have column %s", + t.TypeName, field.GoName, t.TypeName, baseColumn, + )) + } + + if f := joinTable.fieldWithLock(joinColumn); f != nil { + rel.JoinFields = append(rel.JoinFields, f) + } else { + panic(fmt.Errorf( + "bun: %s has-one %s: %s must have column %s", + t.TypeName, field.GoName, t.TypeName, baseColumn, + )) + } + } + } else { + rel.BaseFields = t.PKs + fkPrefix := internal.Underscore(t.ModelName) + "_" + if isPolymorphic { + polymorphicColumn = fkPrefix + "type" + } + + for _, pk := range t.PKs { + joinColumn := fkPrefix + pk.Name + if fk := joinTable.fieldWithLock(joinColumn); fk != nil { + rel.JoinFields = append(rel.JoinFields, fk) + continue + } + + if fk := joinTable.fieldWithLock(pk.Name); fk != nil { + rel.JoinFields = append(rel.JoinFields, fk) + continue + } + + panic(fmt.Errorf( + "bun: %s has-many %s: %s must have column %s "+ + "(to override, use join:base_column=join_column tag on the field %s)", + t.TypeName, field.GoName, joinTable.TypeName, joinColumn, field.GoName, + )) + } + } + + if isPolymorphic { + rel.PolymorphicField = joinTable.fieldWithLock(polymorphicColumn) + if rel.PolymorphicField == nil { + panic(fmt.Errorf( + "bun: %s has-many %s: %s must have polymorphic column %s", + t.TypeName, field.GoName, joinTable.TypeName, polymorphicColumn, + )) + } + + if polymorphicValue == "" { + polymorphicValue = t.ModelName + } + rel.PolymorphicValue = polymorphicValue + } + + return rel +} + +func (t *Table) m2mRelation(field *Field) *Relation { + if field.IndirectType.Kind() != reflect.Slice { + panic(fmt.Errorf( + "bun: %s.%s m2m relation requires slice, got %q", + t.TypeName, field.GoName, field.IndirectType.Kind(), + )) + } + joinTable := t.dialect.Tables().Ref(indirectType(field.IndirectType.Elem())) + + if err := t.CheckPKs(); err != nil { + panic(err) + } + if err := joinTable.CheckPKs(); err != nil { + panic(err) + } + + m2mTableName, ok := field.Tag.Options["m2m"] + if !ok { + panic(fmt.Errorf("bun: %s must have m2m tag option", field.GoName)) + } + + m2mTable := t.dialect.Tables().ByName(m2mTableName) + if m2mTable == nil { + panic(fmt.Errorf( + "bun: can't find m2m %s table (use db.RegisterModel)", + m2mTableName, + )) + } + + rel := &Relation{ + Type: ManyToManyRelation, + Field: field, + JoinTable: joinTable, + M2MTable: m2mTable, + } + var leftColumn, rightColumn string + + if join, ok := field.Tag.Options["join"]; ok { + left, right := parseRelationJoin(join) + leftColumn = left[0] + rightColumn = right[0] + } else { + leftColumn = t.TypeName + rightColumn = joinTable.TypeName + } + + leftField := m2mTable.fieldByGoName(leftColumn) + if leftField == nil { + panic(fmt.Errorf( + "bun: %s many-to-many %s: %s must have field %s "+ + "(to override, use tag join:LeftField=RightField on field %s.%s", + t.TypeName, field.GoName, m2mTable.TypeName, leftColumn, t.TypeName, field.GoName, + )) + } + + rightField := m2mTable.fieldByGoName(rightColumn) + if rightField == nil { + panic(fmt.Errorf( + "bun: %s many-to-many %s: %s must have field %s "+ + "(to override, use tag join:LeftField=RightField on field %s.%s", + t.TypeName, field.GoName, m2mTable.TypeName, rightColumn, t.TypeName, field.GoName, + )) + } + + leftRel := m2mTable.belongsToRelation(leftField) + rel.BaseFields = leftRel.JoinFields + rel.M2MBaseFields = leftRel.BaseFields + + rightRel := m2mTable.belongsToRelation(rightField) + rel.JoinFields = rightRel.JoinFields + rel.M2MJoinFields = rightRel.BaseFields + + return rel +} + +func (t *Table) inlineFields(field *Field, path map[reflect.Type]struct{}) { + if path == nil { + path = map[reflect.Type]struct{}{ + t.Type: {}, + } + } + + if _, ok := path[field.IndirectType]; ok { + return + } + path[field.IndirectType] = struct{}{} + + joinTable := t.dialect.Tables().Ref(field.IndirectType) + for _, f := range joinTable.allFields { + f = f.Clone() + f.GoName = field.GoName + "_" + f.GoName + f.Name = field.Name + "__" + f.Name + f.SQLName = t.quoteIdent(f.Name) + f.Index = appendNew(field.Index, f.Index...) + + t.fieldsMapMu.Lock() + if _, ok := t.FieldMap[f.Name]; !ok { + t.FieldMap[f.Name] = f + } + t.fieldsMapMu.Unlock() + + if f.IndirectType.Kind() != reflect.Struct { + continue + } + + if _, ok := path[f.IndirectType]; !ok { + t.inlineFields(f, path) + } + } +} + +//------------------------------------------------------------------------------ + +func (t *Table) Dialect() Dialect { return t.dialect } + +//------------------------------------------------------------------------------ + +func (t *Table) HasBeforeScanHook() bool { return t.flags.Has(beforeScanHookFlag) } +func (t *Table) HasAfterScanHook() bool { return t.flags.Has(afterScanHookFlag) } + +//------------------------------------------------------------------------------ + +func (t *Table) AppendNamedArg( + fmter Formatter, b []byte, name string, strct reflect.Value, +) ([]byte, bool) { + if field, ok := t.FieldMap[name]; ok { + return fmter.appendArg(b, field.Value(strct).Interface()), true + } + return b, false +} + +func (t *Table) quoteTableName(s string) Safe { + // Don't quote if table name contains placeholder (?) or parentheses. + if strings.IndexByte(s, '?') >= 0 || + strings.IndexByte(s, '(') >= 0 || + strings.IndexByte(s, ')') >= 0 { + return Safe(s) + } + return t.quoteIdent(s) +} + +func (t *Table) quoteIdent(s string) Safe { + return Safe(NewFormatter(t.dialect).AppendIdent(nil, s)) +} + +func appendNew(dst []int, src ...int) []int { + cp := make([]int, len(dst)+len(src)) + copy(cp, dst) + copy(cp[len(dst):], src) + return cp +} + +func isKnownTableOption(name string) bool { + switch name { + case "alias", "select": + return true + } + return false +} + +func isKnownFieldOption(name string) bool { + switch name { + case "alias", + "type", + "array", + "hstore", + "composite", + "json_use_number", + "msgpack", + "notnull", + "nullzero", + "allowzero", + "default", + "unique", + "soft_delete", + + "pk", + "autoincrement", + "rel", + "join", + "m2m", + "polymorphic": + return true + } + return false +} + +func removeField(fields []*Field, field *Field) []*Field { + for i, f := range fields { + if f == field { + return append(fields[:i], fields[i+1:]...) + } + } + return fields +} + +func parseRelationJoin(join string) ([]string, []string) { + ss := strings.Split(join, ",") + baseColumns := make([]string, len(ss)) + joinColumns := make([]string, len(ss)) + for i, s := range ss { + ss := strings.Split(strings.TrimSpace(s), "=") + if len(ss) != 2 { + panic(fmt.Errorf("can't parse relation join: %q", join)) + } + baseColumns[i] = ss[0] + joinColumns[i] = ss[1] + } + return baseColumns, joinColumns +} + +//------------------------------------------------------------------------------ + +func softDeleteFieldUpdater(field *Field) func(fv reflect.Value) error { + typ := field.StructField.Type + + switch typ { + case timeType: + return func(fv reflect.Value) error { + ptr := fv.Addr().Interface().(*time.Time) + *ptr = time.Now() + return nil + } + case nullTimeType: + return func(fv reflect.Value) error { + ptr := fv.Addr().Interface().(*sql.NullTime) + *ptr = sql.NullTime{Time: time.Now()} + return nil + } + case nullIntType: + return func(fv reflect.Value) error { + ptr := fv.Addr().Interface().(*sql.NullInt64) + *ptr = sql.NullInt64{Int64: time.Now().UnixNano()} + return nil + } + } + + switch field.IndirectType.Kind() { + case reflect.Int64: + return func(fv reflect.Value) error { + ptr := fv.Addr().Interface().(*int64) + *ptr = time.Now().UnixNano() + return nil + } + case reflect.Ptr: + typ = typ.Elem() + default: + return softDeleteFieldUpdaterFallback(field) + } + + switch typ { //nolint:gocritic + case timeType: + return func(fv reflect.Value) error { + now := time.Now() + fv.Set(reflect.ValueOf(&now)) + return nil + } + } + + switch typ.Kind() { //nolint:gocritic + case reflect.Int64: + return func(fv reflect.Value) error { + utime := time.Now().UnixNano() + fv.Set(reflect.ValueOf(&utime)) + return nil + } + } + + return softDeleteFieldUpdaterFallback(field) +} + +func softDeleteFieldUpdaterFallback(field *Field) func(fv reflect.Value) error { + return func(fv reflect.Value) error { + return field.ScanWithCheck(fv, time.Now()) + } +} diff --git a/vendor/github.com/uptrace/bun/schema/tables.go b/vendor/github.com/uptrace/bun/schema/tables.go new file mode 100644 index 000000000..d82d08f59 --- /dev/null +++ b/vendor/github.com/uptrace/bun/schema/tables.go @@ -0,0 +1,148 @@ +package schema + +import ( + "fmt" + "reflect" + "sync" +) + +type tableInProgress struct { + table *Table + + init1Once sync.Once + init2Once sync.Once +} + +func newTableInProgress(table *Table) *tableInProgress { + return &tableInProgress{ + table: table, + } +} + +func (inp *tableInProgress) init1() bool { + var inited bool + inp.init1Once.Do(func() { + inp.table.init1() + inited = true + }) + return inited +} + +func (inp *tableInProgress) init2() bool { + var inited bool + inp.init2Once.Do(func() { + inp.table.init2() + inited = true + }) + return inited +} + +type Tables struct { + dialect Dialect + tables sync.Map + + mu sync.RWMutex + inProgress map[reflect.Type]*tableInProgress +} + +func NewTables(dialect Dialect) *Tables { + return &Tables{ + dialect: dialect, + inProgress: make(map[reflect.Type]*tableInProgress), + } +} + +func (t *Tables) Register(models ...interface{}) { + for _, model := range models { + _ = t.Get(reflect.TypeOf(model).Elem()) + } +} + +func (t *Tables) Get(typ reflect.Type) *Table { + return t.table(typ, false) +} + +func (t *Tables) Ref(typ reflect.Type) *Table { + return t.table(typ, true) +} + +func (t *Tables) table(typ reflect.Type, allowInProgress bool) *Table { + if typ.Kind() != reflect.Struct { + panic(fmt.Errorf("got %s, wanted %s", typ.Kind(), reflect.Struct)) + } + + if v, ok := t.tables.Load(typ); ok { + return v.(*Table) + } + + t.mu.Lock() + + if v, ok := t.tables.Load(typ); ok { + t.mu.Unlock() + return v.(*Table) + } + + var table *Table + + inProgress := t.inProgress[typ] + if inProgress == nil { + table = newTable(t.dialect, typ) + inProgress = newTableInProgress(table) + t.inProgress[typ] = inProgress + } else { + table = inProgress.table + } + + t.mu.Unlock() + + inProgress.init1() + if allowInProgress { + return table + } + + if inProgress.init2() { + t.mu.Lock() + delete(t.inProgress, typ) + t.tables.Store(typ, table) + t.mu.Unlock() + } + + t.dialect.OnTable(table) + + for _, field := range table.FieldMap { + if field.UserSQLType == "" { + field.UserSQLType = field.DiscoveredSQLType + } + if field.CreateTableSQLType == "" { + field.CreateTableSQLType = field.UserSQLType + } + } + + return table +} + +func (t *Tables) ByModel(name string) *Table { + var found *Table + t.tables.Range(func(key, value interface{}) bool { + t := value.(*Table) + if t.TypeName == name { + found = t + return false + } + return true + }) + return found +} + +func (t *Tables) ByName(name string) *Table { + var found *Table + t.tables.Range(func(key, value interface{}) bool { + t := value.(*Table) + if t.Name == name { + found = t + return false + } + return true + }) + return found +} diff --git a/vendor/github.com/uptrace/bun/schema/util.go b/vendor/github.com/uptrace/bun/schema/util.go new file mode 100644 index 000000000..6d474e4cc --- /dev/null +++ b/vendor/github.com/uptrace/bun/schema/util.go @@ -0,0 +1,53 @@ +package schema + +import "reflect" + +func indirectType(t reflect.Type) reflect.Type { + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + return t +} + +func fieldByIndex(v reflect.Value, index []int) (_ reflect.Value, ok bool) { + if len(index) == 1 { + return v.Field(index[0]), true + } + + for i, idx := range index { + if i > 0 { + if v.Kind() == reflect.Ptr { + if v.IsNil() { + return v, false + } + v = v.Elem() + } + } + v = v.Field(idx) + } + return v, true +} + +func fieldByIndexAlloc(v reflect.Value, index []int) reflect.Value { + if len(index) == 1 { + return v.Field(index[0]) + } + + for i, idx := range index { + if i > 0 { + v = indirectNil(v) + } + v = v.Field(idx) + } + return v +} + +func indirectNil(v reflect.Value) reflect.Value { + if v.Kind() == reflect.Ptr { + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + v = v.Elem() + } + return v +} diff --git a/vendor/github.com/uptrace/bun/schema/zerochecker.go b/vendor/github.com/uptrace/bun/schema/zerochecker.go new file mode 100644 index 000000000..95efeee6b --- /dev/null +++ b/vendor/github.com/uptrace/bun/schema/zerochecker.go @@ -0,0 +1,126 @@ +package schema + +import ( + "database/sql/driver" + "reflect" +) + +var isZeroerType = reflect.TypeOf((*isZeroer)(nil)).Elem() + +type isZeroer interface { + IsZero() bool +} + +type IsZeroerFunc func(reflect.Value) bool + +func FieldZeroChecker(field *Field) IsZeroerFunc { + return zeroChecker(field.IndirectType) +} + +func zeroChecker(typ reflect.Type) IsZeroerFunc { + if typ.Implements(isZeroerType) { + return isZeroInterface + } + + kind := typ.Kind() + + if kind != reflect.Ptr { + ptr := reflect.PtrTo(typ) + if ptr.Implements(isZeroerType) { + return addrChecker(isZeroInterface) + } + } + + switch kind { + case reflect.Array: + if typ.Elem().Kind() == reflect.Uint8 { + return isZeroBytes + } + return isZeroLen + case reflect.String: + return isZeroLen + case reflect.Bool: + return isZeroBool + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return isZeroInt + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return isZeroUint + case reflect.Float32, reflect.Float64: + return isZeroFloat + case reflect.Interface, reflect.Ptr, reflect.Slice, reflect.Map: + return isNil + } + + if typ.Implements(driverValuerType) { + return isZeroDriverValue + } + + return notZero +} + +func addrChecker(fn IsZeroerFunc) IsZeroerFunc { + return func(v reflect.Value) bool { + if !v.CanAddr() { + return false + } + return fn(v.Addr()) + } +} + +func isZeroInterface(v reflect.Value) bool { + if v.Kind() == reflect.Ptr && v.IsNil() { + return true + } + return v.Interface().(isZeroer).IsZero() +} + +func isZeroDriverValue(v reflect.Value) bool { + if v.Kind() == reflect.Ptr { + return v.IsNil() + } + + valuer := v.Interface().(driver.Valuer) + value, err := valuer.Value() + if err != nil { + return false + } + return value == nil +} + +func isZeroLen(v reflect.Value) bool { + return v.Len() == 0 +} + +func isNil(v reflect.Value) bool { + return v.IsNil() +} + +func isZeroBool(v reflect.Value) bool { + return !v.Bool() +} + +func isZeroInt(v reflect.Value) bool { + return v.Int() == 0 +} + +func isZeroUint(v reflect.Value) bool { + return v.Uint() == 0 +} + +func isZeroFloat(v reflect.Value) bool { + return v.Float() == 0 +} + +func isZeroBytes(v reflect.Value) bool { + b := v.Slice(0, v.Len()).Bytes() + for _, c := range b { + if c != 0 { + return false + } + } + return true +} + +func notZero(v reflect.Value) bool { + return false +} diff --git a/vendor/github.com/uptrace/bun/util.go b/vendor/github.com/uptrace/bun/util.go new file mode 100644 index 000000000..ce56be805 --- /dev/null +++ b/vendor/github.com/uptrace/bun/util.go @@ -0,0 +1,114 @@ +package bun + +import "reflect" + +func indirect(v reflect.Value) reflect.Value { + switch v.Kind() { + case reflect.Interface: + return indirect(v.Elem()) + case reflect.Ptr: + return v.Elem() + default: + return v + } +} + +func walk(v reflect.Value, index []int, fn func(reflect.Value)) { + v = reflect.Indirect(v) + switch v.Kind() { + case reflect.Slice: + sliceLen := v.Len() + for i := 0; i < sliceLen; i++ { + visitField(v.Index(i), index, fn) + } + default: + visitField(v, index, fn) + } +} + +func visitField(v reflect.Value, index []int, fn func(reflect.Value)) { + v = reflect.Indirect(v) + if len(index) > 0 { + v = v.Field(index[0]) + if v.Kind() == reflect.Ptr && v.IsNil() { + return + } + walk(v, index[1:], fn) + } else { + fn(v) + } +} + +func typeByIndex(t reflect.Type, index []int) reflect.Type { + for _, x := range index { + switch t.Kind() { + case reflect.Ptr: + t = t.Elem() + case reflect.Slice: + t = indirectType(t.Elem()) + } + t = t.Field(x).Type + } + return indirectType(t) +} + +func indirectType(t reflect.Type) reflect.Type { + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + return t +} + +func sliceElemType(v reflect.Value) reflect.Type { + elemType := v.Type().Elem() + if elemType.Kind() == reflect.Interface && v.Len() > 0 { + return indirect(v.Index(0).Elem()).Type() + } + return indirectType(elemType) +} + +func makeSliceNextElemFunc(v reflect.Value) func() reflect.Value { + if v.Kind() == reflect.Array { + var pos int + return func() reflect.Value { + v := v.Index(pos) + pos++ + return v + } + } + + sliceType := v.Type() + elemType := sliceType.Elem() + + if elemType.Kind() == reflect.Ptr { + elemType = elemType.Elem() + return func() reflect.Value { + if v.Len() < v.Cap() { + v.Set(v.Slice(0, v.Len()+1)) + elem := v.Index(v.Len() - 1) + if elem.IsNil() { + elem.Set(reflect.New(elemType)) + } + return elem.Elem() + } + + elem := reflect.New(elemType) + v.Set(reflect.Append(v, elem)) + return elem.Elem() + } + } + + zero := reflect.Zero(elemType) + return func() reflect.Value { + l := v.Len() + c := v.Cap() + + if l < c { + v.Set(v.Slice(0, l+1)) + return v.Index(l) + } + + v.Set(reflect.Append(v, zero)) + return v.Index(l) + } +} diff --git a/vendor/github.com/uptrace/bun/version.go b/vendor/github.com/uptrace/bun/version.go new file mode 100644 index 000000000..1baf9a39c --- /dev/null +++ b/vendor/github.com/uptrace/bun/version.go @@ -0,0 +1,6 @@ +package bun + +// Version is the current release version. +func Version() string { + return "0.4.3" +} diff --git a/vendor/modules.txt b/vendor/modules.txt index da1a09c9f..b8a8a26f7 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -247,9 +247,6 @@ github.com/go-fed/activity/streams/vocab # github.com/go-fed/httpsig v1.1.0 ## explicit github.com/go-fed/httpsig -# github.com/go-pg/pg/extra/pgdebug v0.2.0 -## explicit -github.com/go-pg/pg/extra/pgdebug # github.com/go-pg/pg/v10 v10.10.3 ## explicit github.com/go-pg/pg/v10 @@ -383,13 +380,29 @@ github.com/tdewolff/parse/v2/strconv github.com/tmthrgd/go-hex # github.com/ugorji/go/codec v1.2.6 github.com/ugorji/go/codec +# github.com/uptrace/bun v0.4.3 +## explicit +github.com/uptrace/bun +github.com/uptrace/bun/dialect +github.com/uptrace/bun/dialect/feature +github.com/uptrace/bun/dialect/sqltype +github.com/uptrace/bun/extra/bunjson +github.com/uptrace/bun/internal +github.com/uptrace/bun/internal/parser +github.com/uptrace/bun/internal/tagparser +github.com/uptrace/bun/schema +# github.com/uptrace/bun/dialect/pgdialect v0.4.3 +## explicit +github.com/uptrace/bun/dialect/pgdialect +# github.com/uptrace/bun/driver/pgdriver v0.4.3 +## explicit +github.com/uptrace/bun/driver/pgdriver # github.com/urfave/cli/v2 v2.3.0 ## explicit github.com/urfave/cli/v2 # github.com/vmihailenco/bufpool v0.1.11 github.com/vmihailenco/bufpool # github.com/vmihailenco/msgpack/v5 v5.3.4 -## explicit github.com/vmihailenco/msgpack/v5 github.com/vmihailenco/msgpack/v5/msgpcode # github.com/vmihailenco/tagparser v0.1.2 @@ -450,6 +463,7 @@ golang.org/x/text/transform golang.org/x/text/unicode/bidi golang.org/x/text/unicode/norm # google.golang.org/appengine v1.6.7 +## explicit google.golang.org/appengine/internal google.golang.org/appengine/internal/base google.golang.org/appengine/internal/datastore