mirror of
				https://github.com/superseriousbusiness/gotosocial.git
				synced 2025-11-04 08:52:26 -06:00 
			
		
		
		
	
		
			
	
	
		
			231 lines
		
	
	
	
		
			7 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
		
		
			
		
	
	
			231 lines
		
	
	
	
		
			7 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| 
								 | 
							
								// GoToSocial
							 | 
						||
| 
								 | 
							
								// Copyright (C) GoToSocial Authors admin@gotosocial.org
							 | 
						||
| 
								 | 
							
								// SPDX-License-Identifier: AGPL-3.0-or-later
							 | 
						||
| 
								 | 
							
								//
							 | 
						||
| 
								 | 
							
								// This program is free software: you can redistribute it and/or modify
							 | 
						||
| 
								 | 
							
								// it under the terms of the GNU Affero General Public License as published by
							 | 
						||
| 
								 | 
							
								// the Free Software Foundation, either version 3 of the License, or
							 | 
						||
| 
								 | 
							
								// (at your option) any later version.
							 | 
						||
| 
								 | 
							
								//
							 | 
						||
| 
								 | 
							
								// This program is distributed in the hope that it will be useful,
							 | 
						||
| 
								 | 
							
								// but WITHOUT ANY WARRANTY; without even the implied warranty of
							 | 
						||
| 
								 | 
							
								// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
							 | 
						||
| 
								 | 
							
								// GNU Affero General Public License for more details.
							 | 
						||
| 
								 | 
							
								//
							 | 
						||
| 
								 | 
							
								// You should have received a copy of the GNU Affero General Public License
							 | 
						||
| 
								 | 
							
								// along with this program.  If not, see <http://www.gnu.org/licenses/>.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								package bundb
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								import (
							 | 
						||
| 
								 | 
							
									"context"
							 | 
						||
| 
								 | 
							
									"database/sql"
							 | 
						||
| 
								 | 
							
									"reflect"
							 | 
						||
| 
								 | 
							
									"strings"
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									"github.com/superseriousbusiness/gotosocial/internal/gtserror"
							 | 
						||
| 
								 | 
							
									"github.com/uptrace/bun"
							 | 
						||
| 
								 | 
							
									"github.com/uptrace/bun/dialect"
							 | 
						||
| 
								 | 
							
								)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								// UpsertQuery is a wrapper around an insert query that can update if an insert fails.
							 | 
						||
| 
								 | 
							
								// Doesn't implement the full set of Bun query methods, but we can add more if we need them.
							 | 
						||
| 
								 | 
							
								// See https://bun.uptrace.dev/guide/query-insert.html#upsert
							 | 
						||
| 
								 | 
							
								type UpsertQuery struct {
							 | 
						||
| 
								 | 
							
									db          bun.IDB
							 | 
						||
| 
								 | 
							
									model       interface{}
							 | 
						||
| 
								 | 
							
									constraints []string
							 | 
						||
| 
								 | 
							
									columns     []string
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								func NewUpsert(idb bun.IDB) *UpsertQuery {
							 | 
						||
| 
								 | 
							
									// note: passing in rawtx as conn iface so no double query-hook
							 | 
						||
| 
								 | 
							
									// firing when passed through the bun.Tx.Query___() functions.
							 | 
						||
| 
								 | 
							
									return &UpsertQuery{db: idb}
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								// Model sets the model or models to upsert.
							 | 
						||
| 
								 | 
							
								func (u *UpsertQuery) Model(model interface{}) *UpsertQuery {
							 | 
						||
| 
								 | 
							
									u.model = model
							 | 
						||
| 
								 | 
							
									return u
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								// Constraint sets the columns or indexes that are used to check for conflicts.
							 | 
						||
| 
								 | 
							
								// This is required.
							 | 
						||
| 
								 | 
							
								func (u *UpsertQuery) Constraint(constraints ...string) *UpsertQuery {
							 | 
						||
| 
								 | 
							
									u.constraints = constraints
							 | 
						||
| 
								 | 
							
									return u
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								// Column sets the columns to update if an insert does't happen.
							 | 
						||
| 
								 | 
							
								// If empty, all columns not being used for constraints will be updated.
							 | 
						||
| 
								 | 
							
								// Cannot overlap with Constraint.
							 | 
						||
| 
								 | 
							
								func (u *UpsertQuery) Column(columns ...string) *UpsertQuery {
							 | 
						||
| 
								 | 
							
									u.columns = columns
							 | 
						||
| 
								 | 
							
									return u
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								// insertDialect errors if we're using a dialect in which we don't know how to upsert.
							 | 
						||
| 
								 | 
							
								func (u *UpsertQuery) insertDialect() error {
							 | 
						||
| 
								 | 
							
									dialectName := u.db.Dialect().Name()
							 | 
						||
| 
								 | 
							
									switch dialectName {
							 | 
						||
| 
								 | 
							
									case dialect.PG, dialect.SQLite:
							 | 
						||
| 
								 | 
							
										return nil
							 | 
						||
| 
								 | 
							
									default:
							 | 
						||
| 
								 | 
							
										// FUTURE: MySQL has its own variation on upserts, but the syntax is different.
							 | 
						||
| 
								 | 
							
										return gtserror.Newf("UpsertQuery: upsert not supported by SQL dialect: %s", dialectName)
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								// insertConstraints checks that we have constraints and returns them.
							 | 
						||
| 
								 | 
							
								func (u *UpsertQuery) insertConstraints() ([]string, error) {
							 | 
						||
| 
								 | 
							
									if len(u.constraints) == 0 {
							 | 
						||
| 
								 | 
							
										return nil, gtserror.New("UpsertQuery: upserts require at least one constraint column or index, none provided")
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
									return u.constraints, nil
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								// insertColumns returns the non-constraint columns we'll be updating.
							 | 
						||
| 
								 | 
							
								func (u *UpsertQuery) insertColumns(constraints []string) ([]string, error) {
							 | 
						||
| 
								 | 
							
									// Constraints as a set.
							 | 
						||
| 
								 | 
							
									constraintSet := make(map[string]struct{}, len(constraints))
							 | 
						||
| 
								 | 
							
									for _, constraint := range constraints {
							 | 
						||
| 
								 | 
							
										constraintSet[constraint] = struct{}{}
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									var columns []string
							 | 
						||
| 
								 | 
							
									var err error
							 | 
						||
| 
								 | 
							
									if len(u.columns) == 0 {
							 | 
						||
| 
								 | 
							
										columns, err = u.insertColumnsDefault(constraintSet)
							 | 
						||
| 
								 | 
							
									} else {
							 | 
						||
| 
								 | 
							
										columns, err = u.insertColumnsSpecified(constraintSet)
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
									if err != nil {
							 | 
						||
| 
								 | 
							
										return nil, err
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
									if len(columns) == 0 {
							 | 
						||
| 
								 | 
							
										return nil, gtserror.New("UpsertQuery: there are no columns to update when upserting")
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									return columns, nil
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								// hasElem returns whether the type has an element and can call [reflect.Type.Elem] without panicking.
							 | 
						||
| 
								 | 
							
								func hasElem(modelType reflect.Type) bool {
							 | 
						||
| 
								 | 
							
									switch modelType.Kind() {
							 | 
						||
| 
								 | 
							
									case reflect.Array, reflect.Chan, reflect.Map, reflect.Pointer, reflect.Slice:
							 | 
						||
| 
								 | 
							
										return true
							 | 
						||
| 
								 | 
							
									default:
							 | 
						||
| 
								 | 
							
										return false
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								// insertColumnsDefault returns all non-constraint columns from the model schema.
							 | 
						||
| 
								 | 
							
								func (u *UpsertQuery) insertColumnsDefault(constraintSet map[string]struct{}) ([]string, error) {
							 | 
						||
| 
								 | 
							
									// Get underlying struct type.
							 | 
						||
| 
								 | 
							
									modelType := reflect.TypeOf(u.model)
							 | 
						||
| 
								 | 
							
									for hasElem(modelType) {
							 | 
						||
| 
								 | 
							
										modelType = modelType.Elem()
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									table := u.db.Dialect().Tables().Get(modelType)
							 | 
						||
| 
								 | 
							
									if table == nil {
							 | 
						||
| 
								 | 
							
										return nil, gtserror.Newf("UpsertQuery: couldn't find the table schema for model: %v", u.model)
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									columns := make([]string, 0, len(u.columns))
							 | 
						||
| 
								 | 
							
									for _, field := range table.Fields {
							 | 
						||
| 
								 | 
							
										column := field.Name
							 | 
						||
| 
								 | 
							
										if _, overlaps := constraintSet[column]; !overlaps {
							 | 
						||
| 
								 | 
							
											columns = append(columns, column)
							 | 
						||
| 
								 | 
							
										}
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									return columns, nil
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								// insertColumnsSpecified ensures constraints and specified columns to update don't overlap.
							 | 
						||
| 
								 | 
							
								func (u *UpsertQuery) insertColumnsSpecified(constraintSet map[string]struct{}) ([]string, error) {
							 | 
						||
| 
								 | 
							
									overlapping := make([]string, 0, min(len(u.constraints), len(u.columns)))
							 | 
						||
| 
								 | 
							
									for _, column := range u.columns {
							 | 
						||
| 
								 | 
							
										if _, overlaps := constraintSet[column]; overlaps {
							 | 
						||
| 
								 | 
							
											overlapping = append(overlapping, column)
							 | 
						||
| 
								 | 
							
										}
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									if len(overlapping) > 0 {
							 | 
						||
| 
								 | 
							
										return nil, gtserror.Newf(
							 | 
						||
| 
								 | 
							
											"UpsertQuery: the following columns can't be used for both constraints and columns to update: %s",
							 | 
						||
| 
								 | 
							
											strings.Join(overlapping, ", "),
							 | 
						||
| 
								 | 
							
										)
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									return u.columns, nil
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								// insert tries to create a Bun insert query from an upsert query.
							 | 
						||
| 
								 | 
							
								func (u *UpsertQuery) insertQuery() (*bun.InsertQuery, error) {
							 | 
						||
| 
								 | 
							
									var err error
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									err = u.insertDialect()
							 | 
						||
| 
								 | 
							
									if err != nil {
							 | 
						||
| 
								 | 
							
										return nil, err
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									constraints, err := u.insertConstraints()
							 | 
						||
| 
								 | 
							
									if err != nil {
							 | 
						||
| 
								 | 
							
										return nil, err
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									columns, err := u.insertColumns(constraints)
							 | 
						||
| 
								 | 
							
									if err != nil {
							 | 
						||
| 
								 | 
							
										return nil, err
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									// Build the parts of the query that need us to generate SQL.
							 | 
						||
| 
								 | 
							
									constraintIDPlaceholders := make([]string, 0, len(constraints))
							 | 
						||
| 
								 | 
							
									constraintIDs := make([]interface{}, 0, len(constraints))
							 | 
						||
| 
								 | 
							
									for _, constraint := range constraints {
							 | 
						||
| 
								 | 
							
										constraintIDPlaceholders = append(constraintIDPlaceholders, "?")
							 | 
						||
| 
								 | 
							
										constraintIDs = append(constraintIDs, bun.Ident(constraint))
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
									onSQL := "conflict (" + strings.Join(constraintIDPlaceholders, ", ") + ") do update"
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									setClauses := make([]string, 0, len(columns))
							 | 
						||
| 
								 | 
							
									setIDs := make([]interface{}, 0, 2*len(columns))
							 | 
						||
| 
								 | 
							
									for _, column := range columns {
							 | 
						||
| 
								 | 
							
										// "excluded" is a special table that contains only the row involved in a conflict.
							 | 
						||
| 
								 | 
							
										setClauses = append(setClauses, "? = excluded.?")
							 | 
						||
| 
								 | 
							
										setIDs = append(setIDs, bun.Ident(column), bun.Ident(column))
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
									setSQL := strings.Join(setClauses, ", ")
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									insertQuery := u.db.
							 | 
						||
| 
								 | 
							
										NewInsert().
							 | 
						||
| 
								 | 
							
										Model(u.model).
							 | 
						||
| 
								 | 
							
										On(onSQL, constraintIDs...).
							 | 
						||
| 
								 | 
							
										Set(setSQL, setIDs...)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									return insertQuery, nil
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								// Exec builds a Bun insert query from the upsert query, and executes it.
							 | 
						||
| 
								 | 
							
								func (u *UpsertQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) {
							 | 
						||
| 
								 | 
							
									insertQuery, err := u.insertQuery()
							 | 
						||
| 
								 | 
							
									if err != nil {
							 | 
						||
| 
								 | 
							
										return nil, err
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									return insertQuery.Exec(ctx, dest...)
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								// Scan builds a Bun insert query from the upsert query, and scans it.
							 | 
						||
| 
								 | 
							
								func (u *UpsertQuery) Scan(ctx context.Context, dest ...interface{}) error {
							 | 
						||
| 
								 | 
							
									insertQuery, err := u.insertQuery()
							 | 
						||
| 
								 | 
							
									if err != nil {
							 | 
						||
| 
								 | 
							
										return err
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									return insertQuery.Scan(ctx, dest...)
							 | 
						||
| 
								 | 
							
								}
							 |