Compare commits

...

15 commits

Author SHA1 Message Date
kim
5fd52369c9 [performance] handle emoji refreshes asynchronously when fetched as part of account|status dereferences (#4486)
# Description

Updates our dereferencer emoji handling to work asynchronously when going through the route of account or status dereferencing.

closes https://codeberg.org/superseriousbusiness/gotosocial/issues/4485

## Checklist

- [x] I/we have read the [GoToSocial contribution guidelines](https://codeberg.org/superseriousbusiness/gotosocial/src/branch/main/CONTRIBUTING.md).
- [x] I/we have discussed the proposed changes already, either in an issue on the repository, or in the Matrix chat.
- [x] I/we have not leveraged AI to create the proposed changes.
- [x] I/we have performed a self-review of added code.
- [x] I/we have written code that is legible and maintainable by others.
- [ ] I/we have commented the added code, particularly in hard-to-understand areas.
- [ ] I/we have made any necessary changes to documentation.
- [ ] I/we have added tests that cover new code.
- [ ] I/we have run tests and they pass locally with the changes.
- [x] I/we have run `go fmt ./...` and `golangci-lint run`.

Reviewed-on: https://codeberg.org/superseriousbusiness/gotosocial/pulls/4486
Co-authored-by: kim <grufwub@gmail.com>
Co-committed-by: kim <grufwub@gmail.com>
2025-10-08 14:13:40 +02:00
kim
baf2c54730 [performance] add benchmarks for native Go imaging code, small tweaks to reduce nil and boundary checks, some loop unrolling (#4482)
Reviewed-on: https://codeberg.org/superseriousbusiness/gotosocial/pulls/4482
Co-authored-by: kim <grufwub@gmail.com>
Co-committed-by: kim <grufwub@gmail.com>
2025-10-08 11:12:12 +02:00
tobi
b012a81f66 [bugfix] Log a warning when clientIP could not be parsed during rate limiting (#4481)
# Description

> If this is a code change, please include a summary of what you've coded, and link to the issue(s) it closes/implements.
>
> If this is a documentation change, please briefly describe what you've changed and why.

Fixes a panic when clientIP cannot be parsed in the rate limiting middleware, and warn logs the derived clientIP and a hint that reverse proxy may be misconfigured.

Closes https://codeberg.org/superseriousbusiness/gotosocial/issues/4479

## Checklist

Please put an x inside each checkbox to indicate that you've read and followed it: `[ ]` -> `[x]`

If this is a documentation change, only the first checkbox must be filled (you can delete the others if you want).

- [x] I/we have read the [GoToSocial contribution guidelines](https://codeberg.org/superseriousbusiness/gotosocial/src/branch/main/CONTRIBUTING.md).
- [x] I/we have discussed the proposed changes already, either in an issue on the repository, or in the Matrix chat.
- [x] I/we have not leveraged AI to create the proposed changes.
- [x] I/we have performed a self-review of added code.
- [x] I/we have written code that is legible and maintainable by others.
- [x] I/we have commented the added code, particularly in hard-to-understand areas.
- [ ] I/we have made any necessary changes to documentation.
- [ ] I/we have added tests that cover new code.
- [x] I/we have run tests and they pass locally with the changes.
- [x] I/we have run `go fmt ./...` and `golangci-lint run`.

Reviewed-on: https://codeberg.org/superseriousbusiness/gotosocial/pulls/4481
Co-authored-by: tobi <tobi.smethurst@protonmail.com>
Co-committed-by: tobi <tobi.smethurst@protonmail.com>
2025-10-07 16:02:57 +02:00
tobi
c6044d0142 [bugfix] Fix db error checking for int req: sql: no rows in result set (#4478)
Fixes `sql: no rows in result set` when trying to append approvedByURI to a reply that was sent impolitely and approved impolitely.

Reviewed-on: https://codeberg.org/superseriousbusiness/gotosocial/pulls/4478
Co-authored-by: tobi <tobi.smethurst@protonmail.com>
Co-committed-by: tobi <tobi.smethurst@protonmail.com>
2025-10-06 13:11:23 +02:00
tobi
03fc6eaf39 [bugfix] Fix nil ptr in DifferentFrom func (#4477)
Closes https://codeberg.org/superseriousbusiness/gotosocial/issues/4476

Reviewed-on: https://codeberg.org/superseriousbusiness/gotosocial/pulls/4477
Co-authored-by: tobi <tobi.smethurst@protonmail.com>
Co-committed-by: tobi <tobi.smethurst@protonmail.com>
2025-10-06 11:45:40 +02:00
Daniël Franke
5b95636993 [docs] Add db migration tip for slow hardware instances. (#4457)
This PR adds a new section to the documentation to contain workarounds for running
GtS on slow hardware. Right now it only contains a procedure on how to run migrations
on a different database instance in case the original database is too slow to finish
a database migration in a timely manner.

Reviewed-on: https://codeberg.org/superseriousbusiness/gotosocial/pulls/4457
Co-authored-by: Daniël Franke <df@ponc.tech>
Co-committed-by: Daniël Franke <df@ponc.tech>
2025-10-05 14:43:09 +02:00
tobi
259fa1ffac [bugfix] Update interaction policies of freshly dereffed statuses if different from last deref (#4474)
# Description

> If this is a code change, please include a summary of what you've coded, and link to the issue(s) it closes/implements.
>
> If this is a documentation change, please briefly describe what you've changed and why.

This pull request adds a check to see whether interaction policy on a refreshed status is different from the interaction policy set on that status before, and updates the status with the new policy if it's changed.

Should fix a pesky issue where folks on v0.19.2 and above still can't interact with statuses they dereferenced before updating.

## Checklist

Please put an x inside each checkbox to indicate that you've read and followed it: `[ ]` -> `[x]`

If this is a documentation change, only the first checkbox must be filled (you can delete the others if you want).

- [x] I/we have read the [GoToSocial contribution guidelines](https://codeberg.org/superseriousbusiness/gotosocial/src/branch/main/CONTRIBUTING.md).
- [x] I/we have discussed the proposed changes already, either in an issue on the repository, or in the Matrix chat.
- [x] I/we have not leveraged AI to create the proposed changes.
- [x] I/we have performed a self-review of added code.
- [x] I/we have written code that is legible and maintainable by others.
- [x] I/we have commented the added code, particularly in hard-to-understand areas.
- [ ] I/we have made any necessary changes to documentation.
- [ ] I/we have added tests that cover new code.
- [x] I/we have run tests and they pass locally with the changes.
- [x] I/we have run `go fmt ./...` and `golangci-lint run`.

Reviewed-on: https://codeberg.org/superseriousbusiness/gotosocial/pulls/4474
Reviewed-by: kim <gruf@noreply.codeberg.org>
Co-authored-by: tobi <tobi.smethurst@protonmail.com>
Co-committed-by: tobi <tobi.smethurst@protonmail.com>
2025-10-05 13:33:16 +02:00
kim
57cb4fe748 [bugfix] status refresh race condition causing double edit notifications (#4470)
# Description

fixes possible race condition of existing status being out-of-date in enrichStatus()

## Checklist

- [x] I/we have read the [GoToSocial contribution guidelines](https://codeberg.org/superseriousbusiness/gotosocial/src/branch/main/CONTRIBUTING.md).
- [x] I/we have discussed the proposed changes already, either in an issue on the repository, or in the Matrix chat.
- [x] I/we have not leveraged AI to create the proposed changes.
- [x] I/we have performed a self-review of added code.
- [x] I/we have written code that is legible and maintainable by others.
- [x] I/we have commented the added code, particularly in hard-to-understand areas.
- [ ] I/we have made any necessary changes to documentation.
- [x] I/we have added tests that cover new code.
- [x] I/we have run tests and they pass locally with the changes.
- [x] I/we have run `go fmt ./...` and `golangci-lint run`.

Reviewed-on: https://codeberg.org/superseriousbusiness/gotosocial/pulls/4470
Co-authored-by: kim <grufwub@gmail.com>
Co-committed-by: kim <grufwub@gmail.com>
2025-10-03 15:50:57 +02:00
kim
ff950e94bb [chore] update dependencies (#4468)
- github.com/ncruces/go-sqlite3
- codeberg.org/gruf/go-mempool
- codeberg.org/gruf/go-structr (changes related on the above) *
- codeberg.org/gruf/go-mutexes (changes related on the above) *

* this is largely just fiddling around with package internals in structr and mutexes to rely on changes in mempool, which added a new concurrency-safe pool

Reviewed-on: https://codeberg.org/superseriousbusiness/gotosocial/pulls/4468
Co-authored-by: kim <grufwub@gmail.com>
Co-committed-by: kim <grufwub@gmail.com>
2025-10-03 15:29:41 +02:00
tobi
e7cd8bb43e [chore] Use bulk updates + fewer loops in status rethreading migration (#4459)
This pull request tries to optimize our status rethreading migration by using bulk updates + avoiding unnecessary writes, and doing the migration in one top-level loop and one stragglers loop, without the extra loop to copy thread_id over.

On my machine it runs at about 2400 rows per second on Postgres, now, and about 9000 rows per second on SQLite.

Tried *many* different ways of doing this, with and without temporary indexes, with different batch and transaction sizes, etc., and this seems to be just about the most performant way of getting stuff done.

With the changes, a few minutes have been shaved off migration time testing on my development machine. *Hopefully* this will translate to more time shaved off when running on a vps with slower read/write speed and less processor power.

SQLite before:

```
real	20m58,446s
user	16m26,635s
sys	5m53,648s
```

SQLite after:

```
real	14m25,435s
user	12m47,449s
sys	2m27,898s
```

Postgres before:

```
real	28m25,307s
user	3m40,005s
sys	4m45,018s
```

Postgres after:

```
real	22m31,999s
user	3m46,674s
sys	4m39,592s
```

Reviewed-on: https://codeberg.org/superseriousbusiness/gotosocial/pulls/4459
Co-authored-by: tobi <tobi.smethurst@protonmail.com>
Co-committed-by: tobi <tobi.smethurst@protonmail.com>
2025-10-03 12:28:55 +02:00
Zoë Bijl
bd1c43d55e [bugfix/frontend] restore blockquote “block” margin (#4465)
stripping `<blockquote>` of all the margin looks a bit funky. this only removes the inline margin. in English this generally means that it won’t have horizontal margin but will still have vertical margin.

Closes #4466

![before the change any content after the blockquote is flush against it without space](/attachments/7cc808ee-a999-435d-9235-60651a3d9bca)

![after the changes there is vertical rhythm](/attachments/3240480a-14ee-4739-a497-14237879993c)

## Checklist

Please put an x inside each checkbox to indicate that you've read and followed it: `[ ]` -> `[x]`

If this is a documentation change, only the first checkbox must be filled (you can delete the others if you want).

- [x] I/we have read the [GoToSocial contribution guidelines](https://codeberg.org/superseriousbusiness/gotosocial/src/branch/main/CONTRIBUTING.md).
- [ ] I/we have discussed the proposed changes already, either in an issue on the repository, or in the Matrix chat.
- [x] I/we have not leveraged AI to create the proposed changes.
- [x] I/we have performed a self-review of added code.
- [x] I/we have written code that is legible and maintainable by others.
- [ ] I/we have commented the added code, particularly in hard-to-understand areas.
- [ ] I/we have made any necessary changes to documentation.
- [ ] I/we have added tests that cover new code.
- [x] I/we have run tests and they pass locally with the changes. (I ran `go test ./...` from the main dir, they passed with one exception related to thumbnail file size, most likely caused by testing on macOS)
- [x] I/we have run `go fmt ./...` and `golangci-lint run`.

Reviewed-on: https://codeberg.org/superseriousbusiness/gotosocial/pulls/4465
Co-authored-by: Zoë Bijl <code@moiety.me>
Co-committed-by: Zoë Bijl <code@moiety.me>
2025-10-01 19:04:44 +02:00
kim
dfdf06e4ad [chore] update dependencies (#4458)
- codeberg.org/gruf/go-ffmpreg: v0.6.11 -> v0.6.12

Reviewed-on: https://codeberg.org/superseriousbusiness/gotosocial/pulls/4458
Co-authored-by: kim <grufwub@gmail.com>
Co-committed-by: kim <grufwub@gmail.com>
2025-09-25 16:38:19 +02:00
kim
3db2d42247 [chore] ffmpeg webassembly fiddling (#4454)
This disables ffmpeg / ffprobe support on platforms where the wazero compiler is not available. The slowness introduced is hard to pindown for admins (and us!), so it's easier to just return an error message linking to docs on attempted media processing. It still allows the instance to run, just erroring if anything other than a jpeg is attempted to be processed. This should hopefully make it easier for users to notice these issues.

Also further locks down our wazero 'allowFiles' fs and other media code to address: https://codeberg.org/superseriousbusiness/gotosocial/issues/4408

relates to: https://codeberg.org/superseriousbusiness/gotosocial/issues/4427
also relates to issues raised in #gotosocial-help on matrix

closes https://codeberg.org/superseriousbusiness/gotosocial/issues/4408

Reviewed-on: https://codeberg.org/superseriousbusiness/gotosocial/pulls/4454
Co-authored-by: kim <grufwub@gmail.com>
Co-committed-by: kim <grufwub@gmail.com>
2025-09-24 15:12:25 +02:00
tobi
121677754c [docs] Update tracing.md with up-to-date way of doing things (#4452)
Updates tracing docs with the latest stufffff

Closes https://codeberg.org/superseriousbusiness/gotosocial/issues/4446

Reviewed-on: https://codeberg.org/superseriousbusiness/gotosocial/pulls/4452
Co-authored-by: tobi <tobi.smethurst@protonmail.com>
Co-committed-by: tobi <tobi.smethurst@protonmail.com>
2025-09-22 15:32:04 +02:00
tobi
602022701b [chore] Update config to remove unnecessary square brackets around ipv6 addresses (#4451)
Tweak example config for `bind-address`, as square brackets around ipv6 addresses was causing issues launching.

Closes https://codeberg.org/superseriousbusiness/gotosocial/issues/4450

Reviewed-on: https://codeberg.org/superseriousbusiness/gotosocial/pulls/4451
Co-authored-by: tobi <tobi.smethurst@protonmail.com>
Co-committed-by: tobi <tobi.smethurst@protonmail.com>
2025-09-22 13:10:30 +02:00
68 changed files with 2444 additions and 1113 deletions

View file

@ -33,6 +33,8 @@ These contribution guidelines were adapted from / inspired by those of Gitea (ht
- [Federation](#federation) - [Federation](#federation)
- [Updating Swagger docs](#updating-swagger-docs) - [Updating Swagger docs](#updating-swagger-docs)
- [CI/CD configuration](#ci-cd-configuration) - [CI/CD configuration](#ci-cd-configuration)
- [Other Useful Stuff](#other-useful-stuff)
- [Running migrations on a Postgres DB backup locally](#running-migrations-on-a-postgres-db-backup-locally)
## Introduction ## Introduction
@ -525,3 +527,38 @@ The `woodpecker` pipeline files are in the `.woodpecker` directory of this repos
The Woodpecker instance for GoToSocial is [here](https://woodpecker.superseriousbusiness.org/repos/2). The Woodpecker instance for GoToSocial is [here](https://woodpecker.superseriousbusiness.org/repos/2).
Documentation for Woodpecker is [here](https://woodpecker-ci.org/docs/intro). Documentation for Woodpecker is [here](https://woodpecker-ci.org/docs/intro).
## Other Useful Stuff
Various bits and bobs.
### Running migrations on a Postgres DB backup locally
It may be useful when testing or debugging migrations to be able to run them against a copy of a real instance's Postgres database locally.
Basic steps for this:
First dump the Postgres database on the remote machine, and copy the dump over to your development machine.
Now create a local Postgres container and mount the dump into it with, for example:
```bash
docker run -it --name postgres --network host -e POSTGRES_PASSWORD=postgres -v /path/to/db_dump:/db_dump postgres
```
In a separate terminal window, execute a command inside the running container to load the dump into the "postgres" database:
```bash
docker exec -it --user postgres postgres psql -X -f /db_dump postgres
```
With the Postgres container still running, run GoToSocial and point it towards the container. Use the appropriate `GTS_HOST` (and `GTS_ACCOUNT_DOMAIN`) values for the instance you dumped:
```bash
GTS_HOST=example.org \
GTS_DB_TYPE=postgres \
GTS_DB_POSTGRES_CONNECTION_STRING=postgres://postgres:postgres@localhost:5432/postgres \
./gotosocial migrations run
```
When you're done messing around, don't forget to remove any containers that you started up, and remove any lingering volumes with `docker volume prune`, else you might end up filling your disk with unused temporary volumes.

View file

@ -328,7 +328,7 @@ This is the current status of support offered by GoToSocial for different platfo
Notes on 64-bit CPU feature requirements: Notes on 64-bit CPU feature requirements:
- x86_64 requires the SSE4.1 instruction set. (CPUs manufactured after ~2010) - x86_64 requires the [x86-64-v2](https://en.wikipedia.org/wiki/X86-64-v2) level instruction sets. (CPUs manufactured after ~2010)
- ARM64 requires no specific features, ARMv8 CPUs (and later) have all required features. - ARM64 requires no specific features, ARMv8 CPUs (and later) have all required features.

View file

@ -0,0 +1,43 @@
# Managing GtS on slow hardware
While GoToSocial runs great on lower-end hardware, some operations are not practical on it, especially
instances with the database on slow storage (think anything that is not an SSD). This document
offers some suggestions on how to work around common issues when running GtS on slow hardware.
## Running database migrations on a different machine
Sometimes a database migration will need to do operations that are taxing on the database's storage.
These operations can take days if the database resides on a hard disk or SD card. If your
database is on slow storage, it can save a lot of time to follow the following procedure:
!!! danger
It might seem tempting to keep GtS running while you run the migrations on another machine, but
doing this will lead to all the posts that are received during the migration post disappearing
once the migrated database is re-imported.
1. Shut down GtS
2. Take a [backup](backup_and_restore.md#what-to-backup-database) of the database
3. Import the database on faster hardware
4. Run the GtS migration on the faster hardware
5. Take a backup of the resultant database
6. Import the resultant backup and overwrite the old database
7. Start GtS with the new version
### Running GtS migrations separately
After you import the database on the faster hardware, you can run the migration without starting
GtS by downloading the *target* GtS version from the [releases](https://codeberg.org/superseriousbusiness/gotosocial/releases) page.
For instance, if you are running `v0.19.2` and you want to upgrade to `v0.20.0-rc1`, download the
latter version. Once you have the binary, set it to executable by running `chmod u+x /path/to/gotosocial`. Afterwards, copy the configuration of the original server, and alter
it with the location of the new database. We copy the configuration in case variables like
the hostname is used in the migration, we want to keep that consistent.
Once everything is in place, you can run the migration like this:
```sh
$ /path/to/gotosocial --config-path /path/to/config migrations run
```
This will run all the migrations, just like GtS would if it was started normally. Once this is done
you can copy the result to the original instance and start the new GtS version there as well, which
will see that everything is migrated and that there's nothing to do except run as expected.

View file

@ -1,25 +1,20 @@
# Tracing # Tracing
GoToSocial comes with [OpenTelemetry][otel] based tracing built-in. It's not wired through every function, but our HTTP handlers and database library will create spans. How to configure tracing is explained in the [Observability configuration reference][obs]. GoToSocial comes with [OpenTelemetry][otel] based tracing built-in. It's not wired through every function, but our HTTP handlers and database library will create spans that may help you debug issues.
## Enabling tracing
To enable tracing on your instance, you must set `tracing-enabled` to `true` in your config.yaml file. Then, you must set the environment variable `OTEL_TRACES_EXPORTER` to your desired tracing format. A list of available options is available [here](https://opentelemetry.io/docs/languages/sdk-configuration/general/#otel_traces_exporter). Once you have changed your config and set the environment variable, restart your instance.
If necessary, you can do further configuration of tracing using the other environment variables listed [here](https://opentelemetry.io/docs/languages/sdk-configuration/general/).
## Ingesting traces
In order to receive the traces, you need something to ingest them and then visualise them. There are many options available including self-hosted and commercial options. In order to receive the traces, you need something to ingest them and then visualise them. There are many options available including self-hosted and commercial options.
We provide an example of how to do this using [Grafana Tempo][tempo] to ingest the spans and [Grafana][grafana] to explore them. Please beware that the configuration we provide is not suitable for a production setup. It can be used safely for local development and can provide a good starting point for setting up your own tracing infrastructure. In [`example/tracing`][ext] we provide an example of how to do this using [Grafana Tempo][tempo] to ingest the spans and [Grafana][grafana] to explore them. You can use the files with `docker-compose up -d` to get Tempo and Grafana running.
You'll need the files in [`example/tracing`][ext]. Once you have those you can run `docker-compose up -d` to get Tempo and Grafana running. With both services running, you can add the following to your GoToSocial configuration and restart your instance: Please be aware that while the example configuration we provide can be used safely for local development and can provide a good starting point for setting up your own tracing infrastructure, it is not suitable for a so-called "production" setup.
```yaml
tracing-enabled: true
tracing-transport: "grpc"
tracing-endpoint: "localhost:4317"
tracing-insecure-transport: true
```
[otel]: https://opentelemetry.io/
[obs]: ../configuration/observability_and_metrics.md
[tempo]: https://grafana.com/oss/tempo/
[grafana]: https://grafana.com/oss/grafana/
[ext]: https://codeberg.org/superseriousbusiness/gotosocial/tree/main/example/tracing
## Querying and visualising traces ## Querying and visualising traces
@ -27,18 +22,23 @@ Once you execute a few queries against your instance, you'll be able to find the
Using TraceQL, a simple query to find all traces related to requests to `/api/v1/instance` would look like this: Using TraceQL, a simple query to find all traces related to requests to `/api/v1/instance` would look like this:
``` ```traceql
{.http.route = "/api/v1/instance"} {.http.route = "/api/v1/instance"}
``` ```
If you wanted to see all GoToSocial traces, you could instead run: If you wanted to see all GoToSocial traces, you could instead run:
``` ```traceql
{.service.name = "GoToSocial"} {.service.name = "GoToSocial"}
``` ```
Once you select a trace, a second panel will open up visualising the span. You can drill down from there, by clicking into every sub-span to see what it was doing. Once you select a trace, a second panel will open up visualising the span. You can drill down from there, by clicking into every sub-span to see what it was doing.
![Grafana showing a trace for the /api/v1/instance endpoint](../public/tracing.png) ![Grafana showing a trace for the /api/v1/instance endpoint](../overrides/public/tracing.png)
[traceql]: https://grafana.com/docs/tempo/latest/traceql/ [traceql]: https://grafana.com/docs/tempo/latest/traceql/
[otel]: https://opentelemetry.io/
[obs]: ../configuration/observability_and_metrics.md
[tempo]: https://grafana.com/oss/tempo/
[grafana]: https://grafana.com/oss/grafana/
[ext]: https://codeberg.org/superseriousbusiness/gotosocial/src/branch/main/example/tracing

View file

@ -107,14 +107,16 @@ account-domain: ""
# Default: "https" # Default: "https"
protocol: "https" protocol: "https"
# String. Address to bind the GoToSocial server to. # String. Address to bind the GoToSocial HTTP server to.
# This can be an IPv4 address or an IPv6 address (surrounded in square brackets), or a hostname. # This can be an IPv4 address, an IPv6 address, or a hostname.
#
# The default value will bind to all interfaces, which makes the server # The default value will bind to all interfaces, which makes the server
# accessible by other machines. For most setups there is no need to change this. # accessible by other machines. For most setups you won't need to change this.
# If you are using GoToSocial in a reverse proxy setup with the proxy running on # However, if you are using GoToSocial in a reverse proxy setup with the proxy
# the same machine, you will want to set this to "localhost" or an equivalent, # running on the same machine, you may want to set this to "localhost" or equivalent,
# so that the proxy can't be bypassed. # so that the proxy definitely can't be bypassed.
# Examples: ["0.0.0.0", "172.128.0.16", "localhost", "[::]", "[2001:db8::fed1]"] #
# Examples: ["0.0.0.0", "172.128.0.16", "localhost", "::1", "2001:db8::fed1"]
# Default: "0.0.0.0" # Default: "0.0.0.0"
bind-address: "0.0.0.0" bind-address: "0.0.0.0"

View file

@ -117,14 +117,16 @@ account-domain: ""
# Default: "https" # Default: "https"
protocol: "https" protocol: "https"
# String. Address to bind the GoToSocial server to. # String. Address to bind the GoToSocial HTTP server to.
# This can be an IPv4 address or an IPv6 address (surrounded in square brackets), or a hostname. # This can be an IPv4 address, an IPv6 address, or a hostname.
#
# The default value will bind to all interfaces, which makes the server # The default value will bind to all interfaces, which makes the server
# accessible by other machines. For most setups there is no need to change this. # accessible by other machines. For most setups you won't need to change this.
# If you are using GoToSocial in a reverse proxy setup with the proxy running on # However, if you are using GoToSocial in a reverse proxy setup with the proxy
# the same machine, you will want to set this to "localhost" or an equivalent, # running on the same machine, you may want to set this to "localhost" or equivalent,
# so that the proxy can't be bypassed. # so that the proxy definitely can't be bypassed.
# Examples: ["0.0.0.0", "172.128.0.16", "localhost", "[::]", "[2001:db8::fed1]"] #
# Examples: ["0.0.0.0", "172.128.0.16", "localhost", "::1", "2001:db8::fed1"]
# Default: "0.0.0.0" # Default: "0.0.0.0"
bind-address: "0.0.0.0" bind-address: "0.0.0.0"

10
go.mod
View file

@ -21,17 +21,17 @@ require (
codeberg.org/gruf/go-errors/v2 v2.3.2 codeberg.org/gruf/go-errors/v2 v2.3.2
codeberg.org/gruf/go-fastcopy v1.1.3 codeberg.org/gruf/go-fastcopy v1.1.3
codeberg.org/gruf/go-fastpath/v2 v2.0.0 codeberg.org/gruf/go-fastpath/v2 v2.0.0
codeberg.org/gruf/go-ffmpreg v0.6.11 codeberg.org/gruf/go-ffmpreg v0.6.12
codeberg.org/gruf/go-iotools v0.0.0-20240710125620-934ae9c654cf codeberg.org/gruf/go-iotools v0.0.0-20240710125620-934ae9c654cf
codeberg.org/gruf/go-kv/v2 v2.0.7 codeberg.org/gruf/go-kv/v2 v2.0.7
codeberg.org/gruf/go-list v0.0.0-20240425093752-494db03d641f codeberg.org/gruf/go-list v0.0.0-20240425093752-494db03d641f
codeberg.org/gruf/go-mempool v0.0.0-20240507125005-cef10d64a760 codeberg.org/gruf/go-mempool v0.0.0-20251003110531-b54adae66253
codeberg.org/gruf/go-mutexes v1.5.3 codeberg.org/gruf/go-mutexes v1.5.8
codeberg.org/gruf/go-runners v1.6.3 codeberg.org/gruf/go-runners v1.6.3
codeberg.org/gruf/go-sched v1.2.4 codeberg.org/gruf/go-sched v1.2.4
codeberg.org/gruf/go-split v1.2.0 codeberg.org/gruf/go-split v1.2.0
codeberg.org/gruf/go-storage v0.3.1 codeberg.org/gruf/go-storage v0.3.1
codeberg.org/gruf/go-structr v0.9.9 codeberg.org/gruf/go-structr v0.9.12
github.com/DmitriyVTitov/size v1.5.0 github.com/DmitriyVTitov/size v1.5.0
github.com/KimMachineGun/automemlimit v0.7.4 github.com/KimMachineGun/automemlimit v0.7.4
github.com/SherClockHolmes/webpush-go v1.4.0 github.com/SherClockHolmes/webpush-go v1.4.0
@ -53,7 +53,7 @@ require (
github.com/miekg/dns v1.1.68 github.com/miekg/dns v1.1.68
github.com/minio/minio-go/v7 v7.0.95 github.com/minio/minio-go/v7 v7.0.95
github.com/mitchellh/mapstructure v1.5.0 github.com/mitchellh/mapstructure v1.5.0
github.com/ncruces/go-sqlite3 v0.29.0 github.com/ncruces/go-sqlite3 v0.29.1
github.com/oklog/ulid v1.3.1 github.com/oklog/ulid v1.3.1
github.com/pquerna/otp v1.5.0 github.com/pquerna/otp v1.5.0
github.com/rivo/uniseg v0.4.7 github.com/rivo/uniseg v0.4.7

20
go.sum generated
View file

@ -26,8 +26,8 @@ codeberg.org/gruf/go-fastcopy v1.1.3 h1:Jo9VTQjI6KYimlw25PPc7YLA3Xm+XMQhaHwKnM7x
codeberg.org/gruf/go-fastcopy v1.1.3/go.mod h1:GDDYR0Cnb3U/AIfGM3983V/L+GN+vuwVMvrmVABo21s= codeberg.org/gruf/go-fastcopy v1.1.3/go.mod h1:GDDYR0Cnb3U/AIfGM3983V/L+GN+vuwVMvrmVABo21s=
codeberg.org/gruf/go-fastpath/v2 v2.0.0 h1:iAS9GZahFhyWEH0KLhFEJR+txx1ZhMXxYzu2q5Qo9c0= codeberg.org/gruf/go-fastpath/v2 v2.0.0 h1:iAS9GZahFhyWEH0KLhFEJR+txx1ZhMXxYzu2q5Qo9c0=
codeberg.org/gruf/go-fastpath/v2 v2.0.0/go.mod h1:3pPqu5nZjpbRrOqvLyAK7puS1OfEtQvjd6342Cwz56Q= codeberg.org/gruf/go-fastpath/v2 v2.0.0/go.mod h1:3pPqu5nZjpbRrOqvLyAK7puS1OfEtQvjd6342Cwz56Q=
codeberg.org/gruf/go-ffmpreg v0.6.11 h1:+lvB5Loy0KUAKfv6nOZRWHFVgN08cpHhUlYcZxL8M20= codeberg.org/gruf/go-ffmpreg v0.6.12 h1:mPdRx1TAQJQPhRkTOOHnRSY6omNCLJ7M6ajjuEMNNvE=
codeberg.org/gruf/go-ffmpreg v0.6.11/go.mod h1:tGqIMh/I2cizqauxxNAN+WGkICI0j5G3xwF1uBkyw1E= codeberg.org/gruf/go-ffmpreg v0.6.12/go.mod h1:tGqIMh/I2cizqauxxNAN+WGkICI0j5G3xwF1uBkyw1E=
codeberg.org/gruf/go-iotools v0.0.0-20240710125620-934ae9c654cf h1:84s/ii8N6lYlskZjHH+DG6jyia8w2mXMZlRwFn8Gs3A= codeberg.org/gruf/go-iotools v0.0.0-20240710125620-934ae9c654cf h1:84s/ii8N6lYlskZjHH+DG6jyia8w2mXMZlRwFn8Gs3A=
codeberg.org/gruf/go-iotools v0.0.0-20240710125620-934ae9c654cf/go.mod h1:zZAICsp5rY7+hxnws2V0ePrWxE0Z2Z/KXcN3p/RQCfk= codeberg.org/gruf/go-iotools v0.0.0-20240710125620-934ae9c654cf/go.mod h1:zZAICsp5rY7+hxnws2V0ePrWxE0Z2Z/KXcN3p/RQCfk=
codeberg.org/gruf/go-kv v1.6.5 h1:ttPf0NA8F79pDqBttSudPTVCZmGncumeNIxmeM9ztz0= codeberg.org/gruf/go-kv v1.6.5 h1:ttPf0NA8F79pDqBttSudPTVCZmGncumeNIxmeM9ztz0=
@ -42,10 +42,10 @@ codeberg.org/gruf/go-mangler/v2 v2.0.6 h1:c3cwnI6Mi17EAwGSYGNMN6+9PMzaIj2GLAKx9D
codeberg.org/gruf/go-mangler/v2 v2.0.6/go.mod h1:CXIm7zAWPdNmZVAGM1NRiF/ekJTPE7YTb8kiRxiEFaQ= codeberg.org/gruf/go-mangler/v2 v2.0.6/go.mod h1:CXIm7zAWPdNmZVAGM1NRiF/ekJTPE7YTb8kiRxiEFaQ=
codeberg.org/gruf/go-maps v1.0.4 h1:K+Ww4vvR3TZqm5jqrKVirmguZwa3v1VUvmig2SE8uxY= codeberg.org/gruf/go-maps v1.0.4 h1:K+Ww4vvR3TZqm5jqrKVirmguZwa3v1VUvmig2SE8uxY=
codeberg.org/gruf/go-maps v1.0.4/go.mod h1:ASX7osM7kFwt5O8GfGflcFjrwYGD8eIuRLl/oMjhEi8= codeberg.org/gruf/go-maps v1.0.4/go.mod h1:ASX7osM7kFwt5O8GfGflcFjrwYGD8eIuRLl/oMjhEi8=
codeberg.org/gruf/go-mempool v0.0.0-20240507125005-cef10d64a760 h1:m2/UCRXhjDwAg4vyji6iKCpomKw6P4PmBOUi5DvAMH4= codeberg.org/gruf/go-mempool v0.0.0-20251003110531-b54adae66253 h1:qPAY72xCWlySVROSNZecfLGAyeV/SiXmPmfhUU+o3Xw=
codeberg.org/gruf/go-mempool v0.0.0-20240507125005-cef10d64a760/go.mod h1:E3RcaCFNq4zXpvaJb8lfpPqdUAmSkP5F1VmMiEUYTEk= codeberg.org/gruf/go-mempool v0.0.0-20251003110531-b54adae66253/go.mod h1:761koiXmqfgzvu5mez2Rk7YlwWilpqJ/zv5hIA6NoNI=
codeberg.org/gruf/go-mutexes v1.5.3 h1:RIEy1UuDxKgAiINRMrPxTUWSGW6pFx9DzeJN4WPqra8= codeberg.org/gruf/go-mutexes v1.5.8 h1:HRGnvT4COb3jX9xdeoSUUbjPgmk5kXPuDfld9ksUJKA=
codeberg.org/gruf/go-mutexes v1.5.3/go.mod h1:AnhagsMzUISL/nBVwhnHwDwTZOAxMILwCOG8/wKOblg= codeberg.org/gruf/go-mutexes v1.5.8/go.mod h1:21sy/hWH8dDQBk7ocsxqo2GNpWiIir+e82RG3hjnN20=
codeberg.org/gruf/go-runners v1.6.3 h1:To/AX7eTrWuXrTkA3RA01YTP5zha1VZ68LQ+0D4RY7E= codeberg.org/gruf/go-runners v1.6.3 h1:To/AX7eTrWuXrTkA3RA01YTP5zha1VZ68LQ+0D4RY7E=
codeberg.org/gruf/go-runners v1.6.3/go.mod h1:oXAaUmG2VxoKttpCqZGv5nQBeSvZSR2BzIk7h1yTRlU= codeberg.org/gruf/go-runners v1.6.3/go.mod h1:oXAaUmG2VxoKttpCqZGv5nQBeSvZSR2BzIk7h1yTRlU=
codeberg.org/gruf/go-sched v1.2.4 h1:ddBB9o0D/2oU8NbQ0ldN5aWxogpXPRBATWi58+p++Hw= codeberg.org/gruf/go-sched v1.2.4 h1:ddBB9o0D/2oU8NbQ0ldN5aWxogpXPRBATWi58+p++Hw=
@ -54,8 +54,8 @@ codeberg.org/gruf/go-split v1.2.0 h1:PmzL23nVEVHm8VxjsJmv4m4wGQz2bGgQw52dgSSj65c
codeberg.org/gruf/go-split v1.2.0/go.mod h1:0rejWJpqvOoFAd7nwm5tIXYKaAqjtFGOXmTqQV+VO38= codeberg.org/gruf/go-split v1.2.0/go.mod h1:0rejWJpqvOoFAd7nwm5tIXYKaAqjtFGOXmTqQV+VO38=
codeberg.org/gruf/go-storage v0.3.1 h1:g66UIM/xXnEk9ejT+W0T9s/PODBZhXa/8ajzeY/MELI= codeberg.org/gruf/go-storage v0.3.1 h1:g66UIM/xXnEk9ejT+W0T9s/PODBZhXa/8ajzeY/MELI=
codeberg.org/gruf/go-storage v0.3.1/go.mod h1:r43n/zi7YGOCl2iSl7AMI27D1zcWS65Bi2+5xDzypeo= codeberg.org/gruf/go-storage v0.3.1/go.mod h1:r43n/zi7YGOCl2iSl7AMI27D1zcWS65Bi2+5xDzypeo=
codeberg.org/gruf/go-structr v0.9.9 h1:fwIzi/94yBNSWleXZIfVW/QyNK5+/xxI2reVYzu5V/c= codeberg.org/gruf/go-structr v0.9.12 h1:yMopvexnuKgZme9WgvIhrJaAuAjfper/x38xsVuJOOo=
codeberg.org/gruf/go-structr v0.9.9/go.mod h1:5dsazOsIeJyV8Dl2DdSXqCDEZUx3e3dc41N6f2mPtgw= codeberg.org/gruf/go-structr v0.9.12/go.mod h1:sP2ZSjM5X5XKlxuhAbTKuVQm9DWbHsrQRuTl3MUwbHw=
codeberg.org/gruf/go-xunsafe v0.0.0-20250809104800-512a9df57d73 h1:pRaOwIOS1WSZoPCAvE0H1zpv+D4gF37OVppybffqdI8= codeberg.org/gruf/go-xunsafe v0.0.0-20250809104800-512a9df57d73 h1:pRaOwIOS1WSZoPCAvE0H1zpv+D4gF37OVppybffqdI8=
codeberg.org/gruf/go-xunsafe v0.0.0-20250809104800-512a9df57d73/go.mod h1:9wkq+dmHjUhB/0ZxDUWAwsWuXwwGyx5N1dDCB9hpWs8= codeberg.org/gruf/go-xunsafe v0.0.0-20250809104800-512a9df57d73/go.mod h1:9wkq+dmHjUhB/0ZxDUWAwsWuXwwGyx5N1dDCB9hpWs8=
codeberg.org/superseriousbusiness/go-swagger v0.32.3-gts-go1.23-fix h1:k76/Th+bruqU/d+dB0Ru466ctTF2aVjKpisy/471ILE= codeberg.org/superseriousbusiness/go-swagger v0.32.3-gts-go1.23-fix h1:k76/Th+bruqU/d+dB0Ru466ctTF2aVjKpisy/471ILE=
@ -338,8 +338,8 @@ github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9G
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
github.com/ncruces/go-sqlite3 v0.29.0 h1:1tsLiagCoqZEfcHDeKsNSv5jvrY/Iu393pAnw2wLNJU= github.com/ncruces/go-sqlite3 v0.29.1 h1:NIi8AISWBToRHyoz01FXiTNvU147Tqdibgj2tFzJCqM=
github.com/ncruces/go-sqlite3 v0.29.0/go.mod h1:r1hSvYKPNJ+OlUA1O3r8o9LAawzPAlqeZiIdxTBBBJ0= github.com/ncruces/go-sqlite3 v0.29.1/go.mod h1:PpccBNNhvjwUOwDQEn2gXQPFPTWdlromj0+fSkd5KSg=
github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4= github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4=
github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/ncruces/julianday v1.0.0 h1:fH0OKwa7NWvniGQtxdJRxAgkBMolni2BjDHaWTxqt7M= github.com/ncruces/julianday v1.0.0 h1:fH0OKwa7NWvniGQtxdJRxAgkBMolni2BjDHaWTxqt7M=

View file

@ -24,13 +24,16 @@ import (
"reflect" "reflect"
"slices" "slices"
"strings" "strings"
"time"
"code.superseriousbusiness.org/gotosocial/internal/db" "code.superseriousbusiness.org/gotosocial/internal/db"
newmodel "code.superseriousbusiness.org/gotosocial/internal/db/bundb/migrations/20250415111056_thread_all_statuses/new" newmodel "code.superseriousbusiness.org/gotosocial/internal/db/bundb/migrations/20250415111056_thread_all_statuses/new"
oldmodel "code.superseriousbusiness.org/gotosocial/internal/db/bundb/migrations/20250415111056_thread_all_statuses/old" oldmodel "code.superseriousbusiness.org/gotosocial/internal/db/bundb/migrations/20250415111056_thread_all_statuses/old"
"code.superseriousbusiness.org/gotosocial/internal/db/bundb/migrations/20250415111056_thread_all_statuses/util"
"code.superseriousbusiness.org/gotosocial/internal/gtserror" "code.superseriousbusiness.org/gotosocial/internal/gtserror"
"code.superseriousbusiness.org/gotosocial/internal/id" "code.superseriousbusiness.org/gotosocial/internal/id"
"code.superseriousbusiness.org/gotosocial/internal/log" "code.superseriousbusiness.org/gotosocial/internal/log"
"code.superseriousbusiness.org/gotosocial/internal/util/xslices"
"github.com/uptrace/bun" "github.com/uptrace/bun"
) )
@ -49,10 +52,26 @@ func init() {
"thread_id", "thread_id_new", 1) "thread_id", "thread_id_new", 1)
var sr statusRethreader var sr statusRethreader
var count int var updatedTotal int64
var maxID string var maxID string
var statuses []*oldmodel.Status var statuses []*oldmodel.Status
// Create thread_id_new already
// so we can populate it as we go.
log.Info(ctx, "creating statuses column thread_id_new")
if _, err := db.NewAddColumn().
Table("statuses").
ColumnExpr(newColDef).
Exec(ctx); err != nil {
return gtserror.Newf("error adding statuses column thread_id_new: %w", err)
}
// Try to merge the wal so we're
// not working on the wal file.
if err := doWALCheckpoint(ctx, db); err != nil {
return err
}
// Get a total count of all statuses before migration. // Get a total count of all statuses before migration.
total, err := db.NewSelect().Table("statuses").Count(ctx) total, err := db.NewSelect().Table("statuses").Count(ctx)
if err != nil { if err != nil {
@ -63,74 +82,129 @@ func init() {
// possible ULID value. // possible ULID value.
maxID = id.Highest maxID = id.Highest
log.Warn(ctx, "rethreading top-level statuses, this will take a *long* time") log.Warnf(ctx, "rethreading %d statuses, this will take a *long* time", total)
for /* TOP LEVEL STATUS LOOP */ {
// Open initial transaction.
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return err
}
for i := 1; ; i++ {
// Reset slice. // Reset slice.
clear(statuses) clear(statuses)
statuses = statuses[:0] statuses = statuses[:0]
// Select top-level statuses. batchStart := time.Now()
if err := db.NewSelect().
Model(&statuses).
Column("id", "thread_id").
// Select top-level statuses.
if err := tx.NewSelect().
Model(&statuses).
Column("id").
// We specifically use in_reply_to_account_id instead of in_reply_to_id as // We specifically use in_reply_to_account_id instead of in_reply_to_id as
// they should both be set / unset in unison, but we specifically have an // they should both be set / unset in unison, but we specifically have an
// index on in_reply_to_account_id with ID ordering, unlike in_reply_to_id. // index on in_reply_to_account_id with ID ordering, unlike in_reply_to_id.
Where("? IS NULL", bun.Ident("in_reply_to_account_id")). Where("? IS NULL", bun.Ident("in_reply_to_account_id")).
Where("? < ?", bun.Ident("id"), maxID). Where("? < ?", bun.Ident("id"), maxID).
OrderExpr("? DESC", bun.Ident("id")). OrderExpr("? DESC", bun.Ident("id")).
Limit(5000). Limit(500).
Scan(ctx); err != nil && !errors.Is(err, sql.ErrNoRows) { Scan(ctx); err != nil && !errors.Is(err, sql.ErrNoRows) {
return gtserror.Newf("error selecting top level statuses: %w", err) return gtserror.Newf("error selecting top level statuses: %w", err)
} }
// Reached end of block. l := len(statuses)
if len(statuses) == 0 { if l == 0 {
// No more statuses!
//
// Transaction will be closed
// after leaving the loop.
break break
} else if i%200 == 0 {
// Begin a new transaction every
// 200 batches (~100,000 statuses),
// to avoid massive commits.
// Close existing transaction.
if err := tx.Commit(); err != nil {
return err
}
// Try to flush the wal
// to avoid silly wal sizes.
if err := doWALCheckpoint(ctx, db); err != nil {
return err
}
// Open new transaction.
tx, err = db.BeginTx(ctx, nil)
if err != nil {
return err
}
} }
// Set next maxID value from statuses. // Set next maxID value from statuses.
maxID = statuses[len(statuses)-1].ID maxID = statuses[len(statuses)-1].ID
// Rethread each selected batch of top-level statuses in a transaction. // Rethread using the
if err := db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { // open transaction.
var updatedInBatch int64
// Rethread each top-level status. for _, status := range statuses {
for _, status := range statuses { n, err := sr.rethreadStatus(ctx, tx, status, false)
n, err := sr.rethreadStatus(ctx, tx, status) if err != nil {
if err != nil { return gtserror.Newf("error rethreading status %s: %w", status.URI, err)
return gtserror.Newf("error rethreading status %s: %w", status.URI, err)
}
count += n
} }
updatedInBatch += n
return nil updatedTotal += n
}); err != nil {
return err
} }
log.Infof(ctx, "[approx %d of %d] rethreading statuses (top-level)", count, total) // Show speed for this batch.
timeTaken := time.Since(batchStart).Milliseconds()
msPerRow := float64(timeTaken) / float64(updatedInBatch)
rowsPerMs := float64(1) / float64(msPerRow)
rowsPerSecond := 1000 * rowsPerMs
// Show percent migrated overall.
totalDone := (float64(updatedTotal) / float64(total)) * 100
log.Infof(
ctx,
"[~%.2f%% done; ~%.0f rows/s] migrating threads",
totalDone, rowsPerSecond,
)
} }
// Attempt to merge any sqlite write-ahead-log. // Close transaction.
if err := doWALCheckpoint(ctx, db); err != nil { if err := tx.Commit(); err != nil {
return err return err
} }
log.Warn(ctx, "rethreading straggler statuses, this will take a *long* time") // Create a partial index on thread_id_new to find stragglers.
for /* STRAGGLER STATUS LOOP */ { // This index will be removed at the end of the migration.
log.Info(ctx, "creating temporary statuses thread_id_new index")
if _, err := db.NewCreateIndex().
Table("statuses").
Index("statuses_thread_id_new_idx").
Column("thread_id_new").
Where("? = ?", bun.Ident("thread_id_new"), id.Lowest).
Exec(ctx); err != nil {
return gtserror.Newf("error creating new thread_id index: %w", err)
}
for i := 1; ; i++ {
// Reset slice. // Reset slice.
clear(statuses) clear(statuses)
statuses = statuses[:0] statuses = statuses[:0]
batchStart := time.Now()
// Select straggler statuses. // Select straggler statuses.
if err := db.NewSelect(). if err := db.NewSelect().
Model(&statuses). Model(&statuses).
Column("id", "in_reply_to_id", "thread_id"). Column("id").
Where("? IS NULL", bun.Ident("thread_id")). Where("? = ?", bun.Ident("thread_id_new"), id.Lowest).
// We select in smaller batches for this part // We select in smaller batches for this part
// of the migration as there is a chance that // of the migration as there is a chance that
@ -138,7 +212,7 @@ func init() {
// part of the same thread, i.e. one call to // part of the same thread, i.e. one call to
// rethreadStatus() may effect other statuses // rethreadStatus() may effect other statuses
// later in the slice. // later in the slice.
Limit(1000). Limit(250).
Scan(ctx); err != nil && !errors.Is(err, sql.ErrNoRows) { Scan(ctx); err != nil && !errors.Is(err, sql.ErrNoRows) {
return gtserror.Newf("error selecting straggler statuses: %w", err) return gtserror.Newf("error selecting straggler statuses: %w", err)
} }
@ -149,23 +223,35 @@ func init() {
} }
// Rethread each selected batch of straggler statuses in a transaction. // Rethread each selected batch of straggler statuses in a transaction.
var updatedInBatch int64
if err := db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { if err := db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
// Rethread each top-level status.
for _, status := range statuses { for _, status := range statuses {
n, err := sr.rethreadStatus(ctx, tx, status) n, err := sr.rethreadStatus(ctx, tx, status, true)
if err != nil { if err != nil {
return gtserror.Newf("error rethreading status %s: %w", status.URI, err) return gtserror.Newf("error rethreading status %s: %w", status.URI, err)
} }
count += n updatedInBatch += n
updatedTotal += n
} }
return nil return nil
}); err != nil { }); err != nil {
return err return err
} }
log.Infof(ctx, "[approx %d of %d] rethreading statuses (stragglers)", count, total) // Show speed for this batch.
timeTaken := time.Since(batchStart).Milliseconds()
msPerRow := float64(timeTaken) / float64(updatedInBatch)
rowsPerMs := float64(1) / float64(msPerRow)
rowsPerSecond := 1000 * rowsPerMs
// Show percent migrated overall.
totalDone := (float64(updatedTotal) / float64(total)) * 100
log.Infof(
ctx,
"[~%.2f%% done; ~%.0f rows/s] migrating stragglers",
totalDone, rowsPerSecond,
)
} }
// Attempt to merge any sqlite write-ahead-log. // Attempt to merge any sqlite write-ahead-log.
@ -173,6 +259,13 @@ func init() {
return err return err
} }
log.Info(ctx, "dropping temporary thread_id_new index")
if _, err := db.NewDropIndex().
Index("statuses_thread_id_new_idx").
Exec(ctx); err != nil {
return gtserror.Newf("error dropping temporary thread_id_new index: %w", err)
}
log.Info(ctx, "dropping old thread_to_statuses table") log.Info(ctx, "dropping old thread_to_statuses table")
if _, err := db.NewDropTable(). if _, err := db.NewDropTable().
Table("thread_to_statuses"). Table("thread_to_statuses").
@ -180,33 +273,6 @@ func init() {
return gtserror.Newf("error dropping old thread_to_statuses table: %w", err) return gtserror.Newf("error dropping old thread_to_statuses table: %w", err)
} }
log.Info(ctx, "creating new statuses thread_id column")
if _, err := db.NewAddColumn().
Table("statuses").
ColumnExpr(newColDef).
Exec(ctx); err != nil {
return gtserror.Newf("error adding new thread_id column: %w", err)
}
log.Info(ctx, "setting thread_id_new = thread_id (this may take a while...)")
if err := db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
return batchUpdateByID(ctx, tx,
"statuses", // table
"id", // batchByCol
"UPDATE ? SET ? = ?", // updateQuery
[]any{bun.Ident("statuses"),
bun.Ident("thread_id_new"),
bun.Ident("thread_id")},
)
}); err != nil {
return err
}
// Attempt to merge any sqlite write-ahead-log.
if err := doWALCheckpoint(ctx, db); err != nil {
return err
}
log.Info(ctx, "dropping old statuses thread_id index") log.Info(ctx, "dropping old statuses thread_id index")
if _, err := db.NewDropIndex(). if _, err := db.NewDropIndex().
Index("statuses_thread_id_idx"). Index("statuses_thread_id_idx").
@ -274,6 +340,11 @@ type statusRethreader struct {
// its contents are ephemeral. // its contents are ephemeral.
statuses []*oldmodel.Status statuses []*oldmodel.Status
// newThreadIDSet is used to track whether
// statuses in statusIDs have already have
// thread_id_new set on them.
newThreadIDSet map[string]struct{}
// seenIDs tracks the unique status and // seenIDs tracks the unique status and
// thread IDs we have seen, ensuring we // thread IDs we have seen, ensuring we
// don't append duplicates to statusIDs // don't append duplicates to statusIDs
@ -289,14 +360,15 @@ type statusRethreader struct {
} }
// rethreadStatus is the main logic handler for statusRethreader{}. this is what gets called from the migration // rethreadStatus is the main logic handler for statusRethreader{}. this is what gets called from the migration
// in order to trigger a status rethreading operation for the given status, returning total number rethreaded. // in order to trigger a status rethreading operation for the given status, returning total number of rows changed.
func (sr *statusRethreader) rethreadStatus(ctx context.Context, tx bun.Tx, status *oldmodel.Status) (int, error) { func (sr *statusRethreader) rethreadStatus(ctx context.Context, tx bun.Tx, status *oldmodel.Status, straggler bool) (int64, error) {
// Zero slice and // Zero slice and
// map ptr values. // map ptr values.
clear(sr.statusIDs) clear(sr.statusIDs)
clear(sr.threadIDs) clear(sr.threadIDs)
clear(sr.statuses) clear(sr.statuses)
clear(sr.newThreadIDSet)
clear(sr.seenIDs) clear(sr.seenIDs)
// Reset slices and values for use. // Reset slices and values for use.
@ -305,6 +377,11 @@ func (sr *statusRethreader) rethreadStatus(ctx context.Context, tx bun.Tx, statu
sr.statuses = sr.statuses[:0] sr.statuses = sr.statuses[:0]
sr.allThreaded = true sr.allThreaded = true
if sr.newThreadIDSet == nil {
// Allocate new hash set for newThreadIDSet.
sr.newThreadIDSet = make(map[string]struct{})
}
if sr.seenIDs == nil { if sr.seenIDs == nil {
// Allocate new hash set for status IDs. // Allocate new hash set for status IDs.
sr.seenIDs = make(map[string]struct{}) sr.seenIDs = make(map[string]struct{})
@ -317,12 +394,22 @@ func (sr *statusRethreader) rethreadStatus(ctx context.Context, tx bun.Tx, statu
// to the rethreadStatus() call. // to the rethreadStatus() call.
if err := tx.NewSelect(). if err := tx.NewSelect().
Model(status). Model(status).
Column("in_reply_to_id", "thread_id"). Column("in_reply_to_id", "thread_id", "thread_id_new").
Where("? = ?", bun.Ident("id"), status.ID). Where("? = ?", bun.Ident("id"), status.ID).
Scan(ctx); err != nil { Scan(ctx); err != nil {
return 0, gtserror.Newf("error selecting status: %w", err) return 0, gtserror.Newf("error selecting status: %w", err)
} }
// If we've just threaded this status by setting
// thread_id_new, then by definition anything we
// could find from the entire thread must now be
// threaded, so we can save some database calls
// by skipping iterating up + down from here.
if status.ThreadIDNew != id.Lowest {
log.Debugf(ctx, "skipping just rethreaded status: %s", status.ID)
return 0, nil
}
// status and thread ID cursor // status and thread ID cursor
// index values. these are used // index values. these are used
// to keep track of newly loaded // to keep track of newly loaded
@ -371,14 +458,14 @@ func (sr *statusRethreader) rethreadStatus(ctx context.Context, tx bun.Tx, statu
threadIdx = len(sr.threadIDs) threadIdx = len(sr.threadIDs)
} }
// Total number of
// statuses threaded.
total := len(sr.statusIDs)
// Check for the case where the entire // Check for the case where the entire
// batch of statuses is already correctly // batch of statuses is already correctly
// threaded. Then we have nothing to do! // threaded. Then we have nothing to do!
if sr.allThreaded && len(sr.threadIDs) == 1 { //
// Skip this check for straggler statuses
// that are part of broken threads.
if !straggler && sr.allThreaded && len(sr.threadIDs) == 1 {
log.Debug(ctx, "skipping just rethreaded thread")
return 0, nil return 0, nil
} }
@ -417,36 +504,120 @@ func (sr *statusRethreader) rethreadStatus(ctx context.Context, tx bun.Tx, statu
} }
} }
// Update all the statuses to var (
// use determined thread_id. res sql.Result
if _, err := tx.NewUpdate(). err error
Table("statuses"). )
Where("? IN (?)", bun.Ident("id"), bun.In(sr.statusIDs)).
Set("? = ?", bun.Ident("thread_id"), threadID). if len(sr.statusIDs) == 1 {
Exec(ctx); err != nil {
// If we're only updating one status
// we can use a simple update query.
res, err = tx.NewUpdate().
// Update the status model.
TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
// Set the new thread ID, which we can use as
// an indication that we've migrated this batch.
Set("? = ?", bun.Ident("thread_id_new"), threadID).
// While we're here, also set old thread_id, as
// we'll use it for further rethreading purposes.
Set("? = ?", bun.Ident("thread_id"), threadID).
Where("? = ?", bun.Ident("status.id"), sr.statusIDs[0]).
Exec(ctx)
} else {
// If we're updating multiple statuses at once,
// build up a common table expression to update
// all statuses in this thread to use threadID.
//
// This ought to be a little more readable than
// using an "IN(*)" query, and PG or SQLite *may*
// be able to optimize it better.
//
// See:
//
// - https://sqlite.org/lang_with.html
// - https://www.postgresql.org/docs/current/queries-with.html
// - https://bun.uptrace.dev/guide/query-update.html#bulk-update
values := make([]*util.Status, 0, len(sr.statusIDs))
for _, statusID := range sr.statusIDs {
// Filter out statusIDs that have already had
// thread_id_new set, to avoid spurious writes.
if _, set := sr.newThreadIDSet[statusID]; !set {
values = append(values, &util.Status{
ID: statusID,
})
}
}
// Resulting query will look something like this:
//
// WITH "_data" ("id") AS (
// VALUES
// ('01JR6PZED0DDR2VZHQ8H87ZW98'),
// ('01JR6PZED0J91MJCAFDTCCCG8Q')
// )
// UPDATE "statuses" AS "status"
// SET
// "thread_id_new" = '01K6MGKX54BBJ3Y1FBPQY45E5P',
// "thread_id" = '01K6MGKX54BBJ3Y1FBPQY45E5P'
// FROM _data
// WHERE "status"."id" = "_data"."id"
res, err = tx.NewUpdate().
// Update the status model.
Model((*oldmodel.Status)(nil)).
// Provide the CTE values as "_data".
With("_data", tx.NewValues(&values)).
// Include `FROM _data` statement so we can use
// `_data` table in SET and WHERE components.
TableExpr("_data").
// Set the new thread ID, which we can use as
// an indication that we've migrated this batch.
Set("? = ?", bun.Ident("thread_id_new"), threadID).
// While we're here, also set old thread_id, as
// we'll use it for further rethreading purposes.
Set("? = ?", bun.Ident("thread_id"), threadID).
// "Join" to the CTE on status ID.
Where("? = ?", bun.Ident("status.id"), bun.Ident("_data.id")).
Exec(ctx)
}
if err != nil {
return 0, gtserror.Newf("error updating status thread ids: %w", err) return 0, gtserror.Newf("error updating status thread ids: %w", err)
} }
rowsAffected, err := res.RowsAffected()
if err != nil {
return 0, gtserror.Newf("error counting rows affected: %w", err)
}
if len(sr.threadIDs) > 0 { if len(sr.threadIDs) > 0 {
// Update any existing thread // Update any existing thread
// mutes to use latest thread_id. // mutes to use latest thread_id.
// Dedupe thread IDs before query
// to avoid ludicrous "IN" clause.
threadIDs := sr.threadIDs
threadIDs = xslices.Deduplicate(threadIDs)
if _, err := tx.NewUpdate(). if _, err := tx.NewUpdate().
Table("thread_mutes"). Table("thread_mutes").
Where("? IN (?)", bun.Ident("thread_id"), bun.In(sr.threadIDs)). Where("? IN (?)", bun.Ident("thread_id"), bun.In(threadIDs)).
Set("? = ?", bun.Ident("thread_id"), threadID). Set("? = ?", bun.Ident("thread_id"), threadID).
Exec(ctx); err != nil { Exec(ctx); err != nil {
return 0, gtserror.Newf("error updating mute thread ids: %w", err) return 0, gtserror.Newf("error updating mute thread ids: %w", err)
} }
} }
return total, nil return rowsAffected, nil
} }
// append will append the given status to the internal tracking of statusRethreader{} for // append will append the given status to the internal tracking of statusRethreader{} for
// potential future operations, checking for uniqueness. it tracks the inReplyToID value // potential future operations, checking for uniqueness. it tracks the inReplyToID value
// for the next call to getParents(), it tracks the status ID for list of statuses that // for the next call to getParents(), it tracks the status ID for list of statuses that
// need updating, the thread ID for the list of thread links and mutes that need updating, // may need updating, whether a new thread ID has been set for each status, the thread ID
// and whether all the statuses all have a provided thread ID (i.e. allThreaded). // for the list of thread links and mutes that need updating, and whether all the statuses
// all have a provided thread ID (i.e. allThreaded).
func (sr *statusRethreader) append(status *oldmodel.Status) { func (sr *statusRethreader) append(status *oldmodel.Status) {
// Check if status already seen before. // Check if status already seen before.
@ -479,7 +650,14 @@ func (sr *statusRethreader) append(status *oldmodel.Status) {
} }
// Add status ID to map of seen IDs. // Add status ID to map of seen IDs.
sr.seenIDs[status.ID] = struct{}{} mark := struct{}{}
sr.seenIDs[status.ID] = mark
// If new thread ID has already been
// set, add status ID to map of set IDs.
if status.ThreadIDNew != id.Lowest {
sr.newThreadIDSet[status.ID] = mark
}
} }
func (sr *statusRethreader) getParents(ctx context.Context, tx bun.Tx) error { func (sr *statusRethreader) getParents(ctx context.Context, tx bun.Tx) error {
@ -496,7 +674,7 @@ func (sr *statusRethreader) getParents(ctx context.Context, tx bun.Tx) error {
// Select next parent status. // Select next parent status.
if err := tx.NewSelect(). if err := tx.NewSelect().
Model(&parent). Model(&parent).
Column("id", "in_reply_to_id", "thread_id"). Column("id", "in_reply_to_id", "thread_id", "thread_id_new").
Where("? = ?", bun.Ident("id"), id). Where("? = ?", bun.Ident("id"), id).
Scan(ctx); err != nil && err != db.ErrNoEntries { Scan(ctx); err != nil && err != db.ErrNoEntries {
return err return err
@ -535,7 +713,7 @@ func (sr *statusRethreader) getChildren(ctx context.Context, tx bun.Tx, idx int)
// Select children of ID. // Select children of ID.
if err := tx.NewSelect(). if err := tx.NewSelect().
Model(&sr.statuses). Model(&sr.statuses).
Column("id", "thread_id"). Column("id", "thread_id", "thread_id_new").
Where("? = ?", bun.Ident("in_reply_to_id"), id). Where("? = ?", bun.Ident("in_reply_to_id"), id).
Scan(ctx); err != nil && err != db.ErrNoEntries { Scan(ctx); err != nil && err != db.ErrNoEntries {
return err return err
@ -560,14 +738,19 @@ func (sr *statusRethreader) getStragglers(ctx context.Context, tx bun.Tx, idx in
clear(sr.statuses) clear(sr.statuses)
sr.statuses = sr.statuses[:0] sr.statuses = sr.statuses[:0]
// Dedupe thread IDs before query
// to avoid ludicrous "IN" clause.
threadIDs := sr.threadIDs[idx:]
threadIDs = xslices.Deduplicate(threadIDs)
// Select stragglers that // Select stragglers that
// also have thread IDs. // also have thread IDs.
if err := tx.NewSelect(). if err := tx.NewSelect().
Model(&sr.statuses). Model(&sr.statuses).
Column("id", "thread_id", "in_reply_to_id"). Column("id", "thread_id", "in_reply_to_id", "thread_id_new").
Where("? IN (?) AND ? NOT IN (?)", Where("? IN (?) AND ? NOT IN (?)",
bun.Ident("thread_id"), bun.Ident("thread_id"),
bun.In(sr.threadIDs[idx:]), bun.In(threadIDs),
bun.Ident("id"), bun.Ident("id"),
bun.In(sr.statusIDs), bun.In(sr.statusIDs),
). ).

View file

@ -23,45 +23,45 @@ import (
// Status represents a user-created 'post' or 'status' in the database, either remote or local // Status represents a user-created 'post' or 'status' in the database, either remote or local
type Status struct { type Status struct {
ID string `bun:"type:CHAR(26),pk,nullzero,notnull,unique"` // id of this item in the database ID string `bun:"type:CHAR(26),pk,nullzero,notnull,unique"` // id of this item in the database
CreatedAt time.Time `bun:"type:timestamptz,nullzero,notnull,default:current_timestamp"` // when was item created CreatedAt time.Time `bun:"type:timestamptz,nullzero,notnull,default:current_timestamp"` // when was item created
EditedAt time.Time `bun:"type:timestamptz,nullzero"` // when this status was last edited (if set) EditedAt time.Time `bun:"type:timestamptz,nullzero"` // when this status was last edited (if set)
FetchedAt time.Time `bun:"type:timestamptz,nullzero"` // when was item (remote) last fetched. FetchedAt time.Time `bun:"type:timestamptz,nullzero"` // when was item (remote) last fetched.
PinnedAt time.Time `bun:"type:timestamptz,nullzero"` // Status was pinned by owning account at this time. PinnedAt time.Time `bun:"type:timestamptz,nullzero"` // Status was pinned by owning account at this time.
URI string `bun:",unique,nullzero,notnull"` // activitypub URI of this status URI string `bun:",unique,nullzero,notnull"` // activitypub URI of this status
URL string `bun:",nullzero"` // web url for viewing this status URL string `bun:",nullzero"` // web url for viewing this status
Content string `bun:""` // Content HTML for this status. Content string `bun:""` // Content HTML for this status.
AttachmentIDs []string `bun:"attachments,array"` // Database IDs of any media attachments associated with this status AttachmentIDs []string `bun:"attachments,array"` // Database IDs of any media attachments associated with this status
TagIDs []string `bun:"tags,array"` // Database IDs of any tags used in this status TagIDs []string `bun:"tags,array"` // Database IDs of any tags used in this status
MentionIDs []string `bun:"mentions,array"` // Database IDs of any mentions in this status MentionIDs []string `bun:"mentions,array"` // Database IDs of any mentions in this status
EmojiIDs []string `bun:"emojis,array"` // Database IDs of any emojis used in this status EmojiIDs []string `bun:"emojis,array"` // Database IDs of any emojis used in this status
Local *bool `bun:",nullzero,notnull,default:false"` // is this status from a local account? Local *bool `bun:",nullzero,notnull,default:false"` // is this status from a local account?
AccountID string `bun:"type:CHAR(26),nullzero,notnull"` // which account posted this status? AccountID string `bun:"type:CHAR(26),nullzero,notnull"` // which account posted this status?
AccountURI string `bun:",nullzero,notnull"` // activitypub uri of the owner of this status AccountURI string `bun:",nullzero,notnull"` // activitypub uri of the owner of this status
InReplyToID string `bun:"type:CHAR(26),nullzero"` // id of the status this status replies to InReplyToID string `bun:"type:CHAR(26),nullzero"` // id of the status this status replies to
InReplyToURI string `bun:",nullzero"` // activitypub uri of the status this status is a reply to InReplyToURI string `bun:",nullzero"` // activitypub uri of the status this status is a reply to
InReplyToAccountID string `bun:"type:CHAR(26),nullzero"` // id of the account that this status replies to InReplyToAccountID string `bun:"type:CHAR(26),nullzero"` // id of the account that this status replies to
InReplyTo *Status `bun:"-"` // status corresponding to inReplyToID InReplyTo *Status `bun:"-"` // status corresponding to inReplyToID
BoostOfID string `bun:"type:CHAR(26),nullzero"` // id of the status this status is a boost of BoostOfID string `bun:"type:CHAR(26),nullzero"` // id of the status this status is a boost of
BoostOfURI string `bun:"-"` // URI of the status this status is a boost of; field not inserted in the db, just for dereferencing purposes. BoostOfURI string `bun:"-"` // URI of the status this status is a boost of; field not inserted in the db, just for dereferencing purposes.
BoostOfAccountID string `bun:"type:CHAR(26),nullzero"` // id of the account that owns the boosted status BoostOfAccountID string `bun:"type:CHAR(26),nullzero"` // id of the account that owns the boosted status
BoostOf *Status `bun:"-"` // status that corresponds to boostOfID BoostOf *Status `bun:"-"` // status that corresponds to boostOfID
ThreadID string `bun:"type:CHAR(26),nullzero,notnull,default:00000000000000000000000000"` // id of the thread to which this status belongs ThreadID string `bun:"type:CHAR(26),nullzero,notnull,default:'00000000000000000000000000'"` // id of the thread to which this status belongs
EditIDs []string `bun:"edits,array"` // EditIDs []string `bun:"edits,array"` //
PollID string `bun:"type:CHAR(26),nullzero"` // PollID string `bun:"type:CHAR(26),nullzero"` //
ContentWarning string `bun:",nullzero"` // Content warning HTML for this status. ContentWarning string `bun:",nullzero"` // Content warning HTML for this status.
ContentWarningText string `bun:""` // Original text of the content warning without formatting ContentWarningText string `bun:""` // Original text of the content warning without formatting
Visibility Visibility `bun:",nullzero,notnull"` // visibility entry for this status Visibility Visibility `bun:",nullzero,notnull"` // visibility entry for this status
Sensitive *bool `bun:",nullzero,notnull,default:false"` // mark the status as sensitive? Sensitive *bool `bun:",nullzero,notnull,default:false"` // mark the status as sensitive?
Language string `bun:",nullzero"` // what language is this status written in? Language string `bun:",nullzero"` // what language is this status written in?
CreatedWithApplicationID string `bun:"type:CHAR(26),nullzero"` // Which application was used to create this status? CreatedWithApplicationID string `bun:"type:CHAR(26),nullzero"` // Which application was used to create this status?
ActivityStreamsType string `bun:",nullzero,notnull"` // What is the activitystreams type of this status? See: https://www.w3.org/TR/activitystreams-vocabulary/#object-types. Will probably almost always be Note but who knows!. ActivityStreamsType string `bun:",nullzero,notnull"` // What is the activitystreams type of this status? See: https://www.w3.org/TR/activitystreams-vocabulary/#object-types. Will probably almost always be Note but who knows!.
Text string `bun:""` // Original text of the status without formatting Text string `bun:""` // Original text of the status without formatting
ContentType StatusContentType `bun:",nullzero"` // Content type used to process the original text of the status ContentType StatusContentType `bun:",nullzero"` // Content type used to process the original text of the status
Federated *bool `bun:",notnull"` // This status will be federated beyond the local timeline(s) Federated *bool `bun:",notnull"` // This status will be federated beyond the local timeline(s)
PendingApproval *bool `bun:",nullzero,notnull,default:false"` // If true then status is a reply or boost wrapper that must be Approved by the reply-ee or boost-ee before being fully distributed. PendingApproval *bool `bun:",nullzero,notnull,default:false"` // If true then status is a reply or boost wrapper that must be Approved by the reply-ee or boost-ee before being fully distributed.
PreApproved bool `bun:"-"` // If true, then status is a reply to or boost wrapper of a status on our instance, has permission to do the interaction, and an Accept should be sent out for it immediately. Field not stored in the DB. PreApproved bool `bun:"-"` // If true, then status is a reply to or boost wrapper of a status on our instance, has permission to do the interaction, and an Accept should be sent out for it immediately. Field not stored in the DB.
ApprovedByURI string `bun:",nullzero"` // URI of an Accept Activity that approves the Announce or Create Activity that this status was/will be attached to. ApprovedByURI string `bun:",nullzero"` // URI of an Accept Activity that approves the Announce or Create Activity that this status was/will be attached to.
} }
// enumType is the type we (at least, should) use // enumType is the type we (at least, should) use

View file

@ -21,7 +21,10 @@ import (
"time" "time"
) )
// Status represents a user-created 'post' or 'status' in the database, either remote or local // Status represents a user-created 'post' or 'status' in the database, either remote or local.
//
// Note: this model differs from an exact representation of the old model at the time of migration,
// as it includes the intermediate field "ThreadIDNew", which is only used during the migration.
type Status struct { type Status struct {
ID string `bun:"type:CHAR(26),pk,nullzero,notnull,unique"` // id of this item in the database ID string `bun:"type:CHAR(26),pk,nullzero,notnull,unique"` // id of this item in the database
CreatedAt time.Time `bun:"type:timestamptz,nullzero,notnull,default:current_timestamp"` // when was item created CreatedAt time.Time `bun:"type:timestamptz,nullzero,notnull,default:current_timestamp"` // when was item created
@ -60,6 +63,9 @@ type Status struct {
PendingApproval *bool `bun:",nullzero,notnull,default:false"` // If true then status is a reply or boost wrapper that must be Approved by the reply-ee or boost-ee before being fully distributed. PendingApproval *bool `bun:",nullzero,notnull,default:false"` // If true then status is a reply or boost wrapper that must be Approved by the reply-ee or boost-ee before being fully distributed.
PreApproved bool `bun:"-"` // If true, then status is a reply to or boost wrapper of a status on our instance, has permission to do the interaction, and an Accept should be sent out for it immediately. Field not stored in the DB. PreApproved bool `bun:"-"` // If true, then status is a reply to or boost wrapper of a status on our instance, has permission to do the interaction, and an Accept should be sent out for it immediately. Field not stored in the DB.
ApprovedByURI string `bun:",nullzero"` // URI of an Accept Activity that approves the Announce or Create Activity that this status was/will be attached to. ApprovedByURI string `bun:",nullzero"` // URI of an Accept Activity that approves the Announce or Create Activity that this status was/will be attached to.
// This field is *only* used during the migration, it was not on the original status model.
ThreadIDNew string `bun:"type:CHAR(26),nullzero,notnull,default:'00000000000000000000000000'"`
} }
// enumType is the type we (at least, should) use // enumType is the type we (at least, should) use

View file

@ -0,0 +1,24 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package util
// Status is a helper type specifically
// for updating the thread ID of a status.
type Status struct {
ID string `bun:"type:CHAR(26)"`
}

View file

@ -66,98 +66,6 @@ func doWALCheckpoint(ctx context.Context, db *bun.DB) error {
return nil return nil
} }
// batchUpdateByID performs the given updateQuery with updateArgs
// over the entire given table, batching by the ID of batchByCol.
func batchUpdateByID(
ctx context.Context,
tx bun.Tx,
table string,
batchByCol string,
updateQuery string,
updateArgs []any,
) error {
// Get a count of all in table.
total, err := tx.NewSelect().
Table(table).
Count(ctx)
if err != nil {
return gtserror.Newf("error selecting total count: %w", err)
}
// Query batch size
// in number of rows.
const batchsz = 5000
// Stores highest batch value
// used in iterate queries,
// starting at highest possible.
highest := id.Highest
// Total updated rows.
var updated int
for {
// Limit to batchsz
// items at once.
batchQ := tx.
NewSelect().
Table(table).
Column(batchByCol).
Where("? < ?", bun.Ident(batchByCol), highest).
OrderExpr("? DESC", bun.Ident(batchByCol)).
Limit(batchsz)
// Finalize UPDATE to act only on batch.
qStr := updateQuery + " WHERE ? IN (?)"
args := append(slices.Clone(updateArgs),
bun.Ident(batchByCol),
batchQ,
)
// Execute the prepared raw query with arguments.
res, err := tx.NewRaw(qStr, args...).Exec(ctx)
if err != nil {
return gtserror.Newf("error updating old column values: %w", err)
}
// Check how many items we updated.
thisUpdated, err := res.RowsAffected()
if err != nil {
return gtserror.Newf("error counting affected rows: %w", err)
}
if thisUpdated == 0 {
// Nothing updated
// means we're done.
break
}
// Update the overall count.
updated += int(thisUpdated)
// Log helpful message to admin.
log.Infof(ctx, "migrated %d of %d %s (up to %s)",
updated, total, table, highest)
// Get next highest
// id for next batch.
if err := tx.
NewSelect().
With("batch_query", batchQ).
ColumnExpr("min(?) FROM ?", bun.Ident(batchByCol), bun.Ident("batch_query")).
Scan(ctx, &highest); err != nil {
return gtserror.Newf("error selecting next highest: %w", err)
}
}
if total != int(updated) {
// Return error here in order to rollback the whole transaction.
return fmt.Errorf("total=%d does not match updated=%d", total, updated)
}
return nil
}
// convertEnums performs a transaction that converts // convertEnums performs a transaction that converts
// a table's column of our old-style enums (strings) to // a table's column of our old-style enums (strings) to
// more performant and space-saving integer types. // more performant and space-saving integer types.

View file

@ -51,6 +51,7 @@ func (d *Dereferencer) GetEmoji(
remoteURL string, remoteURL string,
info media.AdditionalEmojiInfo, info media.AdditionalEmojiInfo,
refresh bool, refresh bool,
async bool,
) ( ) (
*gtsmodel.Emoji, *gtsmodel.Emoji,
error, error,
@ -66,7 +67,7 @@ func (d *Dereferencer) GetEmoji(
if emoji != nil { if emoji != nil {
// This was an existing emoji, pass to refresh func. // This was an existing emoji, pass to refresh func.
return d.RefreshEmoji(ctx, emoji, info, refresh) return d.RefreshEmoji(ctx, emoji, info, refresh, async)
} }
if domain == "" { if domain == "" {
@ -112,6 +113,7 @@ func (d *Dereferencer) GetEmoji(
info, info,
) )
}, },
async,
) )
} }
@ -130,6 +132,7 @@ func (d *Dereferencer) RefreshEmoji(
emoji *gtsmodel.Emoji, emoji *gtsmodel.Emoji,
info media.AdditionalEmojiInfo, info media.AdditionalEmojiInfo,
force bool, force bool,
async bool,
) ( ) (
*gtsmodel.Emoji, *gtsmodel.Emoji,
error, error,
@ -162,7 +165,7 @@ func (d *Dereferencer) RefreshEmoji(
// We still want to make sure // We still want to make sure
// the emoji is cached. Simply // the emoji is cached. Simply
// check whether emoji is cached. // check whether emoji is cached.
return d.RecacheEmoji(ctx, emoji) return d.RecacheEmoji(ctx, emoji, async)
} }
// Can't refresh local. // Can't refresh local.
@ -207,6 +210,7 @@ func (d *Dereferencer) RefreshEmoji(
info, info,
) )
}, },
async,
) )
} }
@ -222,6 +226,7 @@ func (d *Dereferencer) RefreshEmoji(
func (d *Dereferencer) RecacheEmoji( func (d *Dereferencer) RecacheEmoji(
ctx context.Context, ctx context.Context,
emoji *gtsmodel.Emoji, emoji *gtsmodel.Emoji,
async bool,
) ( ) (
*gtsmodel.Emoji, *gtsmodel.Emoji,
error, error,
@ -272,23 +277,24 @@ func (d *Dereferencer) RecacheEmoji(
data, data,
) )
}, },
async,
) )
} }
// processingEmojiSafely provides concurrency-safe processing of // processingEmojiSafely provides concurrency-safe processing of
// an emoji with given shortcode+domain. if a copy of the emoji is // an emoji with given shortcode+domain. if a copy of the emoji is
// not already being processed, the given 'process' callback will // not already being processed, the given 'process' callback will
// be used to generate new *media.ProcessingEmoji{} instance. // be used to generate new *media.ProcessingEmoji{} instance. async
// determines whether to load it immediately, or in the background.
func (d *Dereferencer) processEmojiSafely( func (d *Dereferencer) processEmojiSafely(
ctx context.Context, ctx context.Context,
shortcodeDomain string, shortcodeDomain string,
process func() (*media.ProcessingEmoji, error), process func() (*media.ProcessingEmoji, error),
async bool,
) ( ) (
emoji *gtsmodel.Emoji, emoji *gtsmodel.Emoji,
err error, err error,
) { ) {
// Acquire map lock. // Acquire map lock.
d.derefEmojisMu.Lock() d.derefEmojisMu.Lock()
@ -309,25 +315,34 @@ func (d *Dereferencer) processEmojiSafely(
// Add processing emoji media to hash map. // Add processing emoji media to hash map.
d.derefEmojis[shortcodeDomain] = processing d.derefEmojis[shortcodeDomain] = processing
}
// Unlock map.
unlock()
if async {
emoji = processing.LoadAsync(func() {
// Remove on finish.
d.derefEmojisMu.Lock()
delete(d.derefEmojis, shortcodeDomain)
d.derefEmojisMu.Unlock()
})
} else {
defer func() { defer func() {
// Remove on finish. // Remove on finish.
d.derefEmojisMu.Lock() d.derefEmojisMu.Lock()
delete(d.derefEmojis, shortcodeDomain) delete(d.derefEmojis, shortcodeDomain)
d.derefEmojisMu.Unlock() d.derefEmojisMu.Unlock()
}() }()
}
// Unlock map. // Perform emoji load operation.
unlock() emoji, err = processing.Load(ctx)
if err != nil {
err = gtserror.Newf("error loading emoji %s: %w", shortcodeDomain, err)
// Perform emoji load operation. // TODO: in time we should return checkable flags by gtserror.Is___()
emoji, err = processing.Load(ctx) // which can determine if loading error should allow remaining placeholder.
if err != nil { }
err = gtserror.Newf("error loading emoji %s: %w", shortcodeDomain, err)
// TODO: in time we should return checkable flags by gtserror.Is___()
// which can determine if loading error should allow remaining placeholder.
} }
return return
@ -364,7 +379,7 @@ func (d *Dereferencer) fetchEmojis(
URI: &placeholder.URI, URI: &placeholder.URI,
ImageRemoteURL: &placeholder.ImageRemoteURL, ImageRemoteURL: &placeholder.ImageRemoteURL,
ImageStaticRemoteURL: &placeholder.ImageStaticRemoteURL, ImageStaticRemoteURL: &placeholder.ImageStaticRemoteURL,
}, force) }, force, true)
if err != nil { if err != nil {
log.Errorf(ctx, "error refreshing emoji: %v", err) log.Errorf(ctx, "error refreshing emoji: %v", err)
@ -396,6 +411,7 @@ func (d *Dereferencer) fetchEmojis(
ImageStaticRemoteURL: &placeholder.ImageStaticRemoteURL, ImageStaticRemoteURL: &placeholder.ImageStaticRemoteURL,
}, },
false, false,
true,
) )
if err != nil { if err != nil {
if emoji == nil { if emoji == nil {

View file

@ -54,6 +54,7 @@ func (suite *EmojiTestSuite) TestDereferenceEmojiBlocking() {
VisibleInPicker: &emojiVisibleInPicker, VisibleInPicker: &emojiVisibleInPicker,
}, },
false, false,
false,
) )
suite.NoError(err) suite.NoError(err)
suite.NotNil(emoji) suite.NotNil(emoji)

View file

@ -277,18 +277,6 @@ func (d *Dereferencer) enrichStatusSafely(
) (*gtsmodel.Status, ap.Statusable, bool, error) { ) (*gtsmodel.Status, ap.Statusable, bool, error) {
uriStr := status.URI uriStr := status.URI
var isNew bool
// Check if this is a new status (to us).
if isNew = (status.ID == ""); !isNew {
// This is an existing status, first try to populate it. This
// is required by the checks below for existing tags, media etc.
if err := d.state.DB.PopulateStatus(ctx, status); err != nil {
log.Errorf(ctx, "error populating existing status %s: %v", uriStr, err)
}
}
// Acquire per-URI deref lock, wraping unlock // Acquire per-URI deref lock, wraping unlock
// to safely defer in case of panic, while still // to safely defer in case of panic, while still
// performing more granular unlocks when needed. // performing more granular unlocks when needed.
@ -296,6 +284,23 @@ func (d *Dereferencer) enrichStatusSafely(
unlock = util.DoOnce(unlock) unlock = util.DoOnce(unlock)
defer unlock() defer unlock()
var err error
var isNew bool
// Check if this is a new status (to us).
if isNew = (status.ID == ""); !isNew {
// We reload the existing status, just to ensure we have the
// latest version of it. e.g. another racing thread might have
// just input a change but we still have an old status copy.
//
// Note: returned status will be fully populated, required below.
status, err = d.state.DB.GetStatusByID(ctx, status.ID)
if err != nil {
return nil, nil, false, gtserror.Newf("error getting up-to-date existing status: %w", err)
}
}
// Perform status enrichment with passed vars. // Perform status enrichment with passed vars.
latest, statusable, err := d.enrichStatus(ctx, latest, statusable, err := d.enrichStatus(ctx,
requestUser, requestUser,
@ -479,12 +484,10 @@ func (d *Dereferencer) enrichStatus(
// Ensure the final parsed status URI or URL matches // Ensure the final parsed status URI or URL matches
// the input URI we fetched (or received) it as. // the input URI we fetched (or received) it as.
matches, err := util.URIMatches(uri, matches, err := util.URIMatches(uri, append(
append( ap.GetURL(statusable), // status URL(s)
ap.GetURL(statusable), // status URL(s) ap.GetJSONLDId(statusable), // status URI
ap.GetJSONLDId(statusable), // status URI )...)
)...,
)
if err != nil { if err != nil {
return nil, nil, gtserror.Newf( return nil, nil, gtserror.Newf(
"error checking dereferenced status uri %s: %w", "error checking dereferenced status uri %s: %w",
@ -605,6 +608,9 @@ func (d *Dereferencer) enrichStatus(
return nil, nil, gtserror.Newf("error populating emojis for status %s: %w", uri, err) return nil, nil, gtserror.Newf("error populating emojis for status %s: %w", uri, err)
} }
// Check if interaction policy has changed between status and latestStatus.
interactionPolicyChanged := status.InteractionPolicy.DifferentFrom(latestStatus.InteractionPolicy)
if isNew { if isNew {
// Simplest case, insert this new remote status into the database. // Simplest case, insert this new remote status into the database.
if err := d.state.DB.PutStatus(ctx, latestStatus); err != nil { if err := d.state.DB.PutStatus(ctx, latestStatus); err != nil {
@ -622,6 +628,7 @@ func (d *Dereferencer) enrichStatus(
tagsChanged, tagsChanged,
mediaChanged, mediaChanged,
emojiChanged, emojiChanged,
interactionPolicyChanged,
) )
if err != nil { if err != nil {
return nil, nil, gtserror.Newf("error handling edit for status %s: %w", uri, err) return nil, nil, gtserror.Newf("error handling edit for status %s: %w", uri, err)
@ -1054,6 +1061,7 @@ func (d *Dereferencer) handleStatusEdit(
tagsChanged bool, tagsChanged bool,
mediaChanged bool, mediaChanged bool,
emojiChanged bool, emojiChanged bool,
interactionPolicyChanged bool,
) ( ) (
cols []string, cols []string,
err error, err error,
@ -1138,6 +1146,15 @@ func (d *Dereferencer) handleStatusEdit(
// been previously populated properly. // been previously populated properly.
} }
if interactionPolicyChanged {
// Interaction policy changed.
cols = append(cols, "interaction_policy")
// Int pol changed doesn't necessarily
// indicate an edit, it may just not have
// been previously populated properly.
}
if edited { if edited {
// Get previous-most-recent modified time, // Get previous-most-recent modified time,
// which will be this edit's creation time. // which will be this edit's creation time.

View file

@ -18,7 +18,6 @@
package dereferencing_test package dereferencing_test
import ( import (
"context"
"fmt" "fmt"
"testing" "testing"
"time" "time"
@ -237,9 +236,7 @@ func (suite *StatusTestSuite) TestDereferenceStatusWithNonMatchingURI() {
} }
func (suite *StatusTestSuite) TestDereferencerRefreshStatusUpdated() { func (suite *StatusTestSuite) TestDereferencerRefreshStatusUpdated() {
// Create a new context for this test. ctx := suite.T().Context()
ctx, cncl := context.WithCancel(suite.T().Context())
defer cncl()
// The local account we will be fetching statuses as. // The local account we will be fetching statuses as.
fetchingAccount := suite.testAccounts["local_account_1"] fetchingAccount := suite.testAccounts["local_account_1"]
@ -343,6 +340,104 @@ func (suite *StatusTestSuite) TestDereferencerRefreshStatusUpdated() {
} }
} }
func (suite *StatusTestSuite) TestDereferencerRefreshStatusRace() {
ctx := suite.T().Context()
// The local account we will be fetching statuses as.
fetchingAccount := suite.testAccounts["local_account_1"]
// The test status in question that we will be dereferencing from "remote".
testURIStr := "https://unknown-instance.com/users/brand_new_person/statuses/01FE4NTHKWW7THT67EF10EB839"
testURI := testrig.URLMustParse(testURIStr)
testStatusable := suite.client.TestRemoteStatuses[testURIStr]
// Fetch the remote status first to load it into instance.
testStatus, statusable, err := suite.dereferencer.GetStatusByURI(ctx,
fetchingAccount.Username,
testURI,
)
suite.NotNil(statusable)
suite.NoError(err)
// Take a snapshot of current
// state of the test status.
beforeEdit := copyStatus(testStatus)
// Edit the "remote" statusable obj.
suite.editStatusable(testStatusable,
"updated status content!",
"CW: edited status content",
beforeEdit.Language, // no change
*beforeEdit.Sensitive, // no change
beforeEdit.AttachmentIDs, // no change
getPollOptions(beforeEdit), // no change
getPollVotes(beforeEdit), // no change
time.Now(),
)
// Refresh with a given statusable to updated to edited copy.
afterEdit, statusable, err := suite.dereferencer.RefreshStatus(ctx,
fetchingAccount.Username,
testStatus,
testStatusable,
instantFreshness,
)
suite.NotNil(statusable)
suite.NoError(err)
// verify updated status details.
suite.verifyEditedStatusUpdate(
// the original status
// before any changes.
beforeEdit,
// latest status
// being tested.
afterEdit,
// expected current state.
&gtsmodel.StatusEdit{
Content: "updated status content!",
ContentWarning: "CW: edited status content",
Language: beforeEdit.Language,
Sensitive: beforeEdit.Sensitive,
AttachmentIDs: beforeEdit.AttachmentIDs,
PollOptions: getPollOptions(beforeEdit),
PollVotes: getPollVotes(beforeEdit),
// createdAt never changes
},
// expected historic edit.
&gtsmodel.StatusEdit{
Content: beforeEdit.Content,
ContentWarning: beforeEdit.ContentWarning,
Language: beforeEdit.Language,
Sensitive: beforeEdit.Sensitive,
AttachmentIDs: beforeEdit.AttachmentIDs,
PollOptions: getPollOptions(beforeEdit),
PollVotes: getPollVotes(beforeEdit),
CreatedAt: beforeEdit.UpdatedAt(),
},
)
// Now make another attempt to refresh, using the old copy of the
// status. This should still successfully update based on our passed
// freshness window, but it *should* refetch the provided status to
// check for race shenanigans and realize that no edit has occurred.
afterBodge, statusable, err := suite.dereferencer.RefreshStatus(ctx,
fetchingAccount.Username,
beforeEdit,
testStatusable,
instantFreshness,
)
suite.NotNil(statusable)
suite.NoError(err)
// Check that no further edit occurred on status.
suite.Equal(afterEdit.EditIDs, afterBodge.EditIDs)
}
// editStatusable updates the given statusable attributes. // editStatusable updates the given statusable attributes.
// note that this acts on the original object, no copying. // note that this acts on the original object, no copying.
func (suite *StatusTestSuite) editStatusable( func (suite *StatusTestSuite) editStatusable(

View file

@ -17,6 +17,8 @@
package gtsmodel package gtsmodel
import "slices"
// A policy URI is GoToSocial's internal representation of // A policy URI is GoToSocial's internal representation of
// one ActivityPub URI for an Actor or a Collection of Actors, // one ActivityPub URI for an Actor or a Collection of Actors,
// specific to the domain of enforcing interaction policies. // specific to the domain of enforcing interaction policies.
@ -232,6 +234,45 @@ type PolicyRules struct {
ManualApproval PolicyValues `json:"WithApproval,omitempty"` ManualApproval PolicyValues `json:"WithApproval,omitempty"`
} }
// DifferentFrom returns true if pr1 and pr2
// are not equal in terms of nilness or content.
func (pr1 *PolicyRules) DifferentFrom(pr2 *PolicyRules) bool {
// If one PolicyRules is nil and
// the other isn't, they're different.
if pr1 == nil && pr2 != nil ||
pr1 != nil && pr2 == nil {
return true
}
// If they're both nil we don't
// need to check anything else.
if pr1 == nil && pr2 == nil {
return false
}
// Check if AutomaticApproval
// differs between the two.
if slices.Compare(
pr1.AutomaticApproval,
pr2.AutomaticApproval,
) != 0 {
return true
}
// Check if ManualApproval
// differs between the two.
if slices.Compare(
pr1.ManualApproval,
pr2.ManualApproval,
) != 0 {
return true
}
// They're the
// same picture.
return false
}
// Returns the default interaction policy // Returns the default interaction policy
// for the given visibility level. // for the given visibility level.
func DefaultInteractionPolicyFor(v Visibility) *InteractionPolicy { func DefaultInteractionPolicyFor(v Visibility) *InteractionPolicy {
@ -422,3 +463,41 @@ func DefaultInteractionPolicyDirect() *InteractionPolicy {
*c = *defaultPolicyDirect *c = *defaultPolicyDirect
return c return c
} }
// DifferentFrom returns true if p1 and p2 are different.
func (ip1 *InteractionPolicy) DifferentFrom(ip2 *InteractionPolicy) bool {
// If one policy is null and the
// other isn't, they're different.
if ip1 == nil && ip2 != nil ||
ip1 != nil && ip2 == nil {
return true
}
// If they're both nil we don't
// need to check anything else.
if ip1 == nil && ip2 == nil {
return false
}
// If CanLike differs from one policy
// to the next, they're different.
if ip1.CanLike.DifferentFrom(ip2.CanLike) {
return true
}
// If CanReply differs from one policy
// to the next, they're different.
if ip1.CanReply.DifferentFrom(ip2.CanReply) {
return true
}
// If CanAnnounce differs from one policy
// to the next, they're different.
if ip1.CanAnnounce.DifferentFrom(ip2.CanAnnounce) {
return true
}
// Looks the
// same chief.
return false
}

View file

@ -27,56 +27,56 @@ import (
// Status represents a user-created 'post' or 'status' in the database, either remote or local // Status represents a user-created 'post' or 'status' in the database, either remote or local
type Status struct { type Status struct {
ID string `bun:"type:CHAR(26),pk,nullzero,notnull,unique"` // id of this item in the database ID string `bun:"type:CHAR(26),pk,nullzero,notnull,unique"` // id of this item in the database
CreatedAt time.Time `bun:"type:timestamptz,nullzero,notnull,default:current_timestamp"` // when was item created CreatedAt time.Time `bun:"type:timestamptz,nullzero,notnull,default:current_timestamp"` // when was item created
EditedAt time.Time `bun:"type:timestamptz,nullzero"` // when this status was last edited (if set) EditedAt time.Time `bun:"type:timestamptz,nullzero"` // when this status was last edited (if set)
FetchedAt time.Time `bun:"type:timestamptz,nullzero"` // when was item (remote) last fetched. FetchedAt time.Time `bun:"type:timestamptz,nullzero"` // when was item (remote) last fetched.
PinnedAt time.Time `bun:"type:timestamptz,nullzero"` // Status was pinned by owning account at this time. PinnedAt time.Time `bun:"type:timestamptz,nullzero"` // Status was pinned by owning account at this time.
URI string `bun:",unique,nullzero,notnull"` // activitypub URI of this status URI string `bun:",unique,nullzero,notnull"` // activitypub URI of this status
URL string `bun:",nullzero"` // web url for viewing this status URL string `bun:",nullzero"` // web url for viewing this status
Content string `bun:""` // Content HTML for this status. Content string `bun:""` // Content HTML for this status.
AttachmentIDs []string `bun:"attachments,array"` // Database IDs of any media attachments associated with this status AttachmentIDs []string `bun:"attachments,array"` // Database IDs of any media attachments associated with this status
Attachments []*MediaAttachment `bun:"attached_media,rel:has-many"` // Attachments corresponding to attachmentIDs Attachments []*MediaAttachment `bun:"attached_media,rel:has-many"` // Attachments corresponding to attachmentIDs
TagIDs []string `bun:"tags,array"` // Database IDs of any tags used in this status TagIDs []string `bun:"tags,array"` // Database IDs of any tags used in this status
Tags []*Tag `bun:"attached_tags,m2m:status_to_tags"` // Tags corresponding to tagIDs. https://bun.uptrace.dev/guide/relations.html#many-to-many-relation Tags []*Tag `bun:"attached_tags,m2m:status_to_tags"` // Tags corresponding to tagIDs. https://bun.uptrace.dev/guide/relations.html#many-to-many-relation
MentionIDs []string `bun:"mentions,array"` // Database IDs of any mentions in this status MentionIDs []string `bun:"mentions,array"` // Database IDs of any mentions in this status
Mentions []*Mention `bun:"attached_mentions,rel:has-many"` // Mentions corresponding to mentionIDs Mentions []*Mention `bun:"attached_mentions,rel:has-many"` // Mentions corresponding to mentionIDs
EmojiIDs []string `bun:"emojis,array"` // Database IDs of any emojis used in this status EmojiIDs []string `bun:"emojis,array"` // Database IDs of any emojis used in this status
Emojis []*Emoji `bun:"attached_emojis,m2m:status_to_emojis"` // Emojis corresponding to emojiIDs. https://bun.uptrace.dev/guide/relations.html#many-to-many-relation Emojis []*Emoji `bun:"attached_emojis,m2m:status_to_emojis"` // Emojis corresponding to emojiIDs. https://bun.uptrace.dev/guide/relations.html#many-to-many-relation
Local *bool `bun:",nullzero,notnull,default:false"` // is this status from a local account? Local *bool `bun:",nullzero,notnull,default:false"` // is this status from a local account?
AccountID string `bun:"type:CHAR(26),nullzero,notnull"` // which account posted this status? AccountID string `bun:"type:CHAR(26),nullzero,notnull"` // which account posted this status?
Account *Account `bun:"rel:belongs-to"` // account corresponding to accountID Account *Account `bun:"rel:belongs-to"` // account corresponding to accountID
AccountURI string `bun:",nullzero,notnull"` // activitypub uri of the owner of this status AccountURI string `bun:",nullzero,notnull"` // activitypub uri of the owner of this status
InReplyToID string `bun:"type:CHAR(26),nullzero"` // id of the status this status replies to InReplyToID string `bun:"type:CHAR(26),nullzero"` // id of the status this status replies to
InReplyToURI string `bun:",nullzero"` // activitypub uri of the status this status is a reply to InReplyToURI string `bun:",nullzero"` // activitypub uri of the status this status is a reply to
InReplyToAccountID string `bun:"type:CHAR(26),nullzero"` // id of the account that this status replies to InReplyToAccountID string `bun:"type:CHAR(26),nullzero"` // id of the account that this status replies to
InReplyTo *Status `bun:"-"` // status corresponding to inReplyToID InReplyTo *Status `bun:"-"` // status corresponding to inReplyToID
InReplyToAccount *Account `bun:"rel:belongs-to"` // account corresponding to inReplyToAccountID InReplyToAccount *Account `bun:"rel:belongs-to"` // account corresponding to inReplyToAccountID
BoostOfID string `bun:"type:CHAR(26),nullzero"` // id of the status this status is a boost of BoostOfID string `bun:"type:CHAR(26),nullzero"` // id of the status this status is a boost of
BoostOfURI string `bun:"-"` // URI of the status this status is a boost of; field not inserted in the db, just for dereferencing purposes. BoostOfURI string `bun:"-"` // URI of the status this status is a boost of; field not inserted in the db, just for dereferencing purposes.
BoostOfAccountID string `bun:"type:CHAR(26),nullzero"` // id of the account that owns the boosted status BoostOfAccountID string `bun:"type:CHAR(26),nullzero"` // id of the account that owns the boosted status
BoostOf *Status `bun:"-"` // status that corresponds to boostOfID BoostOf *Status `bun:"-"` // status that corresponds to boostOfID
BoostOfAccount *Account `bun:"rel:belongs-to"` // account that corresponds to boostOfAccountID BoostOfAccount *Account `bun:"rel:belongs-to"` // account that corresponds to boostOfAccountID
ThreadID string `bun:"type:CHAR(26),nullzero,notnull,default:00000000000000000000000000"` // id of the thread to which this status belongs ThreadID string `bun:"type:CHAR(26),nullzero,notnull,default:'00000000000000000000000000'"` // id of the thread to which this status belongs
EditIDs []string `bun:"edits,array"` // IDs of status edits for this status, ordered from smallest (oldest) -> largest (newest) ID. EditIDs []string `bun:"edits,array"` // IDs of status edits for this status, ordered from smallest (oldest) -> largest (newest) ID.
Edits []*StatusEdit `bun:"-"` // Edits of this status, ordered from oldest -> newest edit. Edits []*StatusEdit `bun:"-"` // Edits of this status, ordered from oldest -> newest edit.
PollID string `bun:"type:CHAR(26),nullzero"` // PollID string `bun:"type:CHAR(26),nullzero"` //
Poll *Poll `bun:"-"` // Poll *Poll `bun:"-"` //
ContentWarning string `bun:",nullzero"` // Content warning HTML for this status. ContentWarning string `bun:",nullzero"` // Content warning HTML for this status.
ContentWarningText string `bun:""` // Original text of the content warning without formatting ContentWarningText string `bun:""` // Original text of the content warning without formatting
Visibility Visibility `bun:",nullzero,notnull"` // visibility entry for this status Visibility Visibility `bun:",nullzero,notnull"` // visibility entry for this status
Sensitive *bool `bun:",nullzero,notnull,default:false"` // mark the status as sensitive? Sensitive *bool `bun:",nullzero,notnull,default:false"` // mark the status as sensitive?
Language string `bun:",nullzero"` // what language is this status written in? Language string `bun:",nullzero"` // what language is this status written in?
CreatedWithApplicationID string `bun:"type:CHAR(26),nullzero"` // Which application was used to create this status? CreatedWithApplicationID string `bun:"type:CHAR(26),nullzero"` // Which application was used to create this status?
CreatedWithApplication *Application `bun:"rel:belongs-to"` // application corresponding to createdWithApplicationID CreatedWithApplication *Application `bun:"rel:belongs-to"` // application corresponding to createdWithApplicationID
ActivityStreamsType string `bun:",nullzero,notnull"` // What is the activitystreams type of this status? See: https://www.w3.org/TR/activitystreams-vocabulary/#object-types. Will probably almost always be Note but who knows!. ActivityStreamsType string `bun:",nullzero,notnull"` // What is the activitystreams type of this status? See: https://www.w3.org/TR/activitystreams-vocabulary/#object-types. Will probably almost always be Note but who knows!.
Text string `bun:""` // Original text of the status without formatting Text string `bun:""` // Original text of the status without formatting
ContentType StatusContentType `bun:",nullzero"` // Content type used to process the original text of the status ContentType StatusContentType `bun:",nullzero"` // Content type used to process the original text of the status
Federated *bool `bun:",notnull"` // This status will be federated beyond the local timeline(s) Federated *bool `bun:",notnull"` // This status will be federated beyond the local timeline(s)
InteractionPolicy *InteractionPolicy `bun:""` // InteractionPolicy for this status. If null then the default InteractionPolicy should be assumed for this status's Visibility. Always null for boost wrappers. InteractionPolicy *InteractionPolicy `bun:""` // InteractionPolicy for this status. If null then the default InteractionPolicy should be assumed for this status's Visibility. Always null for boost wrappers.
PendingApproval *bool `bun:",nullzero,notnull,default:false"` // If true then status is a reply or boost wrapper that must be Approved by the reply-ee or boost-ee before being fully distributed. PendingApproval *bool `bun:",nullzero,notnull,default:false"` // If true then status is a reply or boost wrapper that must be Approved by the reply-ee or boost-ee before being fully distributed.
PreApproved bool `bun:"-"` // If true, then status is a reply to or boost wrapper of a status on our instance, has permission to do the interaction, and an Accept should be sent out for it immediately. Field not stored in the DB. PreApproved bool `bun:"-"` // If true, then status is a reply to or boost wrapper of a status on our instance, has permission to do the interaction, and an Accept should be sent out for it immediately. Field not stored in the DB.
ApprovedByURI string `bun:",nullzero"` // URI of *either* an Accept Activity, or a ReplyAuthorization or AnnounceAuthorization, which approves the Announce, Create or interaction request Activity that this status was/will be attached to. ApprovedByURI string `bun:",nullzero"` // URI of *either* an Accept Activity, or a ReplyAuthorization or AnnounceAuthorization, which approves the Announce, Create or interaction request Activity that this status was/will be attached to.
} }
// GetID implements timeline.Timelineable{}. // GetID implements timeline.Timelineable{}.

View file

@ -21,8 +21,6 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"errors" "errors"
"os"
"path"
"strconv" "strconv"
"strings" "strings"
@ -158,34 +156,20 @@ func ffmpeg(ctx context.Context, inpath string, outpath string, args ...string)
Config: func(modcfg wazero.ModuleConfig) wazero.ModuleConfig { Config: func(modcfg wazero.ModuleConfig) wazero.ModuleConfig {
fscfg := wazero.NewFSConfig() fscfg := wazero.NewFSConfig()
// Needs read-only access to // Needs read-only access /dev/urandom,
// /dev/urandom for some types. // required by some ffmpeg operations.
urandom := &allowFiles{ fscfg = fscfg.WithFSMount(&allowFiles{
{ allowRead("/dev/urandom"),
abs: "/dev/urandom", }, "/dev")
flag: os.O_RDONLY,
perm: 0,
},
}
fscfg = fscfg.WithFSMount(urandom, "/dev")
// In+out dirs are always the same (tmp), // In+out dirs are always the same (tmp),
// so we can share one file system for // so we can share one file system for
// both + grant different perms to inpath // both + grant different perms to inpath
// (read only) and outpath (read+write). // (read only) and outpath (read+write).
shared := &allowFiles{ fscfg = fscfg.WithFSMount(&allowFiles{
{ allowCreate(outpath),
abs: inpath, allowRead(inpath),
flag: os.O_RDONLY, }, tmpdir)
perm: 0,
},
{
abs: outpath,
flag: os.O_RDWR | os.O_CREATE | os.O_TRUNC,
perm: 0666,
},
}
fscfg = fscfg.WithFSMount(shared, path.Dir(inpath))
// Set anonymous module name. // Set anonymous module name.
modcfg = modcfg.WithName("") modcfg = modcfg.WithName("")
@ -246,16 +230,10 @@ func ffprobe(ctx context.Context, filepath string) (*result, error) {
Config: func(modcfg wazero.ModuleConfig) wazero.ModuleConfig { Config: func(modcfg wazero.ModuleConfig) wazero.ModuleConfig {
fscfg := wazero.NewFSConfig() fscfg := wazero.NewFSConfig()
// Needs read-only access // Needs read-only access to probed file.
// to file being probed. fscfg = fscfg.WithFSMount(&allowFiles{
in := &allowFiles{ allowRead(filepath),
{ }, tmpdir)
abs: filepath,
flag: os.O_RDONLY,
perm: 0,
},
}
fscfg = fscfg.WithFSMount(in, path.Dir(filepath))
// Set anonymous module name. // Set anonymous module name.
modcfg = modcfg.WithName("") modcfg = modcfg.WithName("")

View file

@ -21,12 +21,12 @@ package ffmpeg
import ( import (
"context" "context"
"errors"
"os" "os"
"runtime" "runtime"
"sync/atomic" "sync/atomic"
"unsafe" "unsafe"
"code.superseriousbusiness.org/gotosocial/internal/log"
"codeberg.org/gruf/go-ffmpreg/embed" "codeberg.org/gruf/go-ffmpreg/embed"
"codeberg.org/gruf/go-ffmpreg/wasm" "codeberg.org/gruf/go-ffmpreg/wasm"
"github.com/tetratelabs/wazero" "github.com/tetratelabs/wazero"
@ -49,24 +49,19 @@ func initWASM(ctx context.Context) error {
return nil return nil
} }
var cfg wazero.RuntimeConfig // Check at runtime whether Wazero compiler support is available,
// interpreter mode is too slow for a usable gotosocial experience.
// Allocate new runtime config, letting if reason, supported := isCompilerSupported(); !supported {
// wazero determine compiler / interpreter. return errors.New("!!! WAZERO COMPILER SUPPORT NOT AVAILABLE !!!" +
cfg = wazero.NewRuntimeConfig() " Reason: " + reason + "." +
" Wazero in interpreter mode is too slow to use ffmpeg" +
// Though still perform a check of CPU features at " (this will also affect SQLite if in use)." +
// runtime to warn about slow interpreter performance. " For more info and possible workarounds, please check: https://docs.gotosocial.org/en/latest/getting_started/releases/#supported-platforms")
if reason, supported := compilerSupported(); !supported {
log.Warn(ctx, "!!! WAZERO COMPILER MAY NOT BE AVAILABLE !!!"+
" Reason: "+reason+"."+
" Wazero will likely fall back to interpreter mode,"+
" resulting in poor performance for media processing (and SQLite, if in use)."+
" For more info and possible workarounds, please check:"+
" https://docs.gotosocial.org/en/latest/getting_started/releases/#supported-platforms",
)
} }
// Allocate new runtime compiler config.
cfg := wazero.NewRuntimeConfigCompiler()
if dir := os.Getenv("GTS_WAZERO_COMPILATION_CACHE"); dir != "" { if dir := os.Getenv("GTS_WAZERO_COMPILATION_CACHE"); dir != "" {
// Use on-filesystem compilation cache given by env. // Use on-filesystem compilation cache given by env.
cache, err := wazero.NewCompilationCacheWithDir(dir) cache, err := wazero.NewCompilationCacheWithDir(dir)
@ -88,7 +83,7 @@ func initWASM(ctx context.Context) error {
defer func() { defer func() {
if err == nil && set { if err == nil && set {
// Drop binary. // Drop binary.
embed.B = nil embed.Free()
return return
} }
@ -110,7 +105,7 @@ func initWASM(ctx context.Context) error {
} }
// Compile ffmpreg WebAssembly into memory. // Compile ffmpreg WebAssembly into memory.
mod, err = run.CompileModule(ctx, embed.B) mod, err = run.CompileModule(ctx, embed.B())
if err != nil { if err != nil {
return err return err
} }
@ -128,7 +123,7 @@ func initWASM(ctx context.Context) error {
return nil return nil
} }
func compilerSupported() (string, bool) { func isCompilerSupported() (string, bool) {
switch runtime.GOOS { switch runtime.GOOS {
case "linux", "android", case "linux", "android",
"windows", "darwin", "windows", "darwin",
@ -141,10 +136,11 @@ func compilerSupported() (string, bool) {
switch runtime.GOARCH { switch runtime.GOARCH {
case "amd64": case "amd64":
// NOTE: wazero in the future may decouple the // NOTE: wazero in the future may decouple the
// requirement of simd (sse4_1) from requirements // requirement of simd (sse4_1+2) from requirements
// for compiler support in the future, but even // for compiler support in the future, but even
// still our module go-ffmpreg makes use of them. // still our module go-ffmpreg makes use of them.
return "amd64 SSE4.1 required", cpu.X86.HasSSE41 return "amd64 x86-64-v2 required (see: https://en.wikipedia.org/wiki/X86-64-v2)",
cpu.Initialized && cpu.X86.HasSSE3 && cpu.X86.HasSSE41 && cpu.X86.HasSSE42
case "arm64": case "arm64":
// NOTE: this particular check may change if we // NOTE: this particular check may change if we
// later update go-ffmpreg to a version that makes // later update go-ffmpreg to a version that makes

View file

@ -21,6 +21,8 @@ import (
"image" "image"
"image/color" "image/color"
"math" "math"
"code.superseriousbusiness.org/gotosocial/internal/gtserror"
) )
// NOTE: // NOTE:
@ -73,15 +75,15 @@ func resizeDownLinear(img image.Image, width, height int) image.Image {
// flipH flips the image horizontally (left to right). // flipH flips the image horizontally (left to right).
func flipH(img image.Image) image.Image { func flipH(img image.Image) image.Image {
src := newScanner(img) srcW, srcH := img.Bounds().Dx(), img.Bounds().Dy()
dstW := src.w dstW := srcW
dstH := src.h dstH := srcH
rowSize := dstW * 4 rowSize := dstW * 4
dst := image.NewNRGBA(image.Rect(0, 0, dstW, dstH)) dst := image.NewNRGBA(image.Rect(0, 0, dstW, dstH))
for y := 0; y < dstH; y++ { for y := 0; y < dstH; y++ {
i := y * dst.Stride i := y * dst.Stride
srcY := y srcY := y
src.scan(0, srcY, src.w, srcY+1, dst.Pix[i:i+rowSize]) scanImage(img, 0, srcY, srcW, srcY+1, dst.Pix[i:i+rowSize])
reverse(dst.Pix[i : i+rowSize]) reverse(dst.Pix[i : i+rowSize])
} }
return dst return dst
@ -89,45 +91,45 @@ func flipH(img image.Image) image.Image {
// flipV flips the image vertically (from top to bottom). // flipV flips the image vertically (from top to bottom).
func flipV(img image.Image) image.Image { func flipV(img image.Image) image.Image {
src := newScanner(img) srcW, srcH := img.Bounds().Dx(), img.Bounds().Dy()
dstW := src.w dstW := srcW
dstH := src.h dstH := srcH
rowSize := dstW * 4 rowSize := dstW * 4
dst := image.NewNRGBA(image.Rect(0, 0, dstW, dstH)) dst := image.NewNRGBA(image.Rect(0, 0, dstW, dstH))
for y := 0; y < dstH; y++ { for y := 0; y < dstH; y++ {
i := y * dst.Stride i := y * dst.Stride
srcY := dstH - y - 1 srcY := dstH - y - 1
src.scan(0, srcY, src.w, srcY+1, dst.Pix[i:i+rowSize]) scanImage(img, 0, srcY, srcW, srcY+1, dst.Pix[i:i+rowSize])
} }
return dst return dst
} }
// rotate90 rotates the image 90 counter-clockwise. // rotate90 rotates the image 90 counter-clockwise.
func rotate90(img image.Image) image.Image { func rotate90(img image.Image) image.Image {
src := newScanner(img) srcW, srcH := img.Bounds().Dx(), img.Bounds().Dy()
dstW := src.h dstW := srcH
dstH := src.w dstH := srcW
rowSize := dstW * 4 rowSize := dstW * 4
dst := image.NewNRGBA(image.Rect(0, 0, dstW, dstH)) dst := image.NewNRGBA(image.Rect(0, 0, dstW, dstH))
for y := 0; y < dstH; y++ { for y := 0; y < dstH; y++ {
i := y * dst.Stride i := y * dst.Stride
srcX := dstH - y - 1 srcX := dstH - y - 1
src.scan(srcX, 0, srcX+1, src.h, dst.Pix[i:i+rowSize]) scanImage(img, srcX, 0, srcX+1, srcH, dst.Pix[i:i+rowSize])
} }
return dst return dst
} }
// rotate180 rotates the image 180 counter-clockwise. // rotate180 rotates the image 180 counter-clockwise.
func rotate180(img image.Image) image.Image { func rotate180(img image.Image) image.Image {
src := newScanner(img) srcW, srcH := img.Bounds().Dx(), img.Bounds().Dy()
dstW := src.w dstW := srcW
dstH := src.h dstH := srcH
rowSize := dstW * 4 rowSize := dstW * 4
dst := image.NewNRGBA(image.Rect(0, 0, dstW, dstH)) dst := image.NewNRGBA(image.Rect(0, 0, dstW, dstH))
for y := 0; y < dstH; y++ { for y := 0; y < dstH; y++ {
i := y * dst.Stride i := y * dst.Stride
srcY := dstH - y - 1 srcY := dstH - y - 1
src.scan(0, srcY, src.w, srcY+1, dst.Pix[i:i+rowSize]) scanImage(img, 0, srcY, srcW, srcY+1, dst.Pix[i:i+rowSize])
reverse(dst.Pix[i : i+rowSize]) reverse(dst.Pix[i : i+rowSize])
} }
return dst return dst
@ -135,15 +137,15 @@ func rotate180(img image.Image) image.Image {
// rotate270 rotates the image 270 counter-clockwise. // rotate270 rotates the image 270 counter-clockwise.
func rotate270(img image.Image) image.Image { func rotate270(img image.Image) image.Image {
src := newScanner(img) srcW, srcH := img.Bounds().Dx(), img.Bounds().Dy()
dstW := src.h dstW := srcH
dstH := src.w dstH := srcW
rowSize := dstW * 4 rowSize := dstW * 4
dst := image.NewNRGBA(image.Rect(0, 0, dstW, dstH)) dst := image.NewNRGBA(image.Rect(0, 0, dstW, dstH))
for y := 0; y < dstH; y++ { for y := 0; y < dstH; y++ {
i := y * dst.Stride i := y * dst.Stride
srcX := y srcX := y
src.scan(srcX, 0, srcX+1, src.h, dst.Pix[i:i+rowSize]) scanImage(img, srcX, 0, srcX+1, srcH, dst.Pix[i:i+rowSize])
reverse(dst.Pix[i : i+rowSize]) reverse(dst.Pix[i : i+rowSize])
} }
return dst return dst
@ -151,30 +153,30 @@ func rotate270(img image.Image) image.Image {
// transpose flips the image horizontally and rotates 90 counter-clockwise. // transpose flips the image horizontally and rotates 90 counter-clockwise.
func transpose(img image.Image) image.Image { func transpose(img image.Image) image.Image {
src := newScanner(img) srcW, srcH := img.Bounds().Dx(), img.Bounds().Dy()
dstW := src.h dstW := srcH
dstH := src.w dstH := srcW
rowSize := dstW * 4 rowSize := dstW * 4
dst := image.NewNRGBA(image.Rect(0, 0, dstW, dstH)) dst := image.NewNRGBA(image.Rect(0, 0, dstW, dstH))
for y := 0; y < dstH; y++ { for y := 0; y < dstH; y++ {
i := y * dst.Stride i := y * dst.Stride
srcX := y srcX := y
src.scan(srcX, 0, srcX+1, src.h, dst.Pix[i:i+rowSize]) scanImage(img, srcX, 0, srcX+1, srcH, dst.Pix[i:i+rowSize])
} }
return dst return dst
} }
// transverse flips the image vertically and rotates 90 counter-clockwise. // transverse flips the image vertically and rotates 90 counter-clockwise.
func transverse(img image.Image) image.Image { func transverse(img image.Image) image.Image {
src := newScanner(img) srcW, srcH := img.Bounds().Dx(), img.Bounds().Dy()
dstW := src.h dstW := srcH
dstH := src.w dstH := srcW
rowSize := dstW * 4 rowSize := dstW * 4
dst := image.NewNRGBA(image.Rect(0, 0, dstW, dstH)) dst := image.NewNRGBA(image.Rect(0, 0, dstW, dstH))
for y := 0; y < dstH; y++ { for y := 0; y < dstH; y++ {
i := y * dst.Stride i := y * dst.Stride
srcX := dstH - y - 1 srcX := dstH - y - 1
src.scan(srcX, 0, srcX+1, src.h, dst.Pix[i:i+rowSize]) scanImage(img, srcX, 0, srcX+1, srcH, dst.Pix[i:i+rowSize])
reverse(dst.Pix[i : i+rowSize]) reverse(dst.Pix[i : i+rowSize])
} }
return dst return dst
@ -182,12 +184,12 @@ func transverse(img image.Image) image.Image {
// resizeHorizontalLinear resizes image to given width using linear resampling. // resizeHorizontalLinear resizes image to given width using linear resampling.
func resizeHorizontalLinear(img image.Image, dstWidth int) image.Image { func resizeHorizontalLinear(img image.Image, dstWidth int) image.Image {
src := newScanner(img) srcW, srcH := img.Bounds().Dx(), img.Bounds().Dy()
dst := image.NewRGBA(image.Rect(0, 0, dstWidth, src.h)) dst := image.NewRGBA(image.Rect(0, 0, dstWidth, srcH))
weights := precomputeWeightsLinear(dstWidth, src.w) weights := precomputeWeightsLinear(dstWidth, srcW)
scanLine := make([]uint8, src.w*4) scanLine := make([]uint8, srcW*4)
for y := 0; y < src.h; y++ { for y := 0; y < srcH; y++ {
src.scan(0, y, src.w, y+1, scanLine) scanImage(img, 0, y, srcW, y+1, scanLine)
j0 := y * dst.Stride j0 := y * dst.Stride
for x := range weights { for x := range weights {
var r, g, b, a float64 var r, g, b, a float64
@ -201,13 +203,12 @@ func resizeHorizontalLinear(img image.Image, dstWidth int) image.Image {
a += aw a += aw
} }
if a != 0 { if a != 0 {
aInv := 1 / a
j := j0 + x*4 j := j0 + x*4
d := dst.Pix[j : j+4 : j+4] d := dst.Pix[j : j+4 : j+4]
d[0] = clampFloat(r * aInv) d[0] = clampFloatTo8(r / a)
d[1] = clampFloat(g * aInv) d[1] = clampFloatTo8(g / a)
d[2] = clampFloat(b * aInv) d[2] = clampFloatTo8(b / a)
d[3] = clampFloat(a) d[3] = clampFloatTo8(a)
} }
} }
} }
@ -216,12 +217,12 @@ func resizeHorizontalLinear(img image.Image, dstWidth int) image.Image {
// resizeVerticalLinear resizes image to given height using linear resampling. // resizeVerticalLinear resizes image to given height using linear resampling.
func resizeVerticalLinear(img image.Image, height int) image.Image { func resizeVerticalLinear(img image.Image, height int) image.Image {
src := newScanner(img) srcW, srcH := img.Bounds().Dx(), img.Bounds().Dy()
dst := image.NewNRGBA(image.Rect(0, 0, src.w, height)) dst := image.NewNRGBA(image.Rect(0, 0, srcW, height))
weights := precomputeWeightsLinear(height, src.h) weights := precomputeWeightsLinear(height, srcH)
scanLine := make([]uint8, src.h*4) scanLine := make([]uint8, srcH*4)
for x := 0; x < src.w; x++ { for x := 0; x < srcW; x++ {
src.scan(x, 0, x+1, src.h, scanLine) scanImage(img, x, 0, x+1, srcH, scanLine)
for y := range weights { for y := range weights {
var r, g, b, a float64 var r, g, b, a float64
for _, w := range weights[y] { for _, w := range weights[y] {
@ -234,13 +235,12 @@ func resizeVerticalLinear(img image.Image, height int) image.Image {
a += aw a += aw
} }
if a != 0 { if a != 0 {
aInv := 1 / a
j := y*dst.Stride + x*4 j := y*dst.Stride + x*4
d := dst.Pix[j : j+4 : j+4] d := dst.Pix[j : j+4 : j+4]
d[0] = clampFloat(r * aInv) d[0] = clampFloatTo8(r / a)
d[1] = clampFloat(g * aInv) d[1] = clampFloatTo8(g / a)
d[2] = clampFloat(b * aInv) d[2] = clampFloatTo8(b / a)
d[3] = clampFloat(a) d[3] = clampFloatTo8(a)
} }
} }
} }
@ -263,13 +263,14 @@ func precomputeWeightsLinear(dstSize, srcSize int) [][]indexWeight {
out := make([][]indexWeight, dstSize) out := make([][]indexWeight, dstSize)
tmp := make([]indexWeight, 0, dstSize*int(ru+2)*2) tmp := make([]indexWeight, 0, dstSize*int(ru+2)*2)
for v := 0; v < dstSize; v++ { for v := 0; v < len(out); v++ {
fu := (float64(v)+0.5)*du - 0.5 fu := (float64(v)+0.5)*du - 0.5
begin := int(math.Ceil(fu - ru)) begin := int(math.Ceil(fu - ru))
if begin < 0 { if begin < 0 {
begin = 0 begin = 0
} }
end := int(math.Floor(fu + ru)) end := int(math.Floor(fu + ru))
if end > srcSize-1 { if end > srcSize-1 {
end = srcSize - 1 end = srcSize - 1
@ -280,9 +281,13 @@ func precomputeWeightsLinear(dstSize, srcSize int) [][]indexWeight {
w := resampleLinear((float64(u) - fu) / scale) w := resampleLinear((float64(u) - fu) / scale)
if w != 0 { if w != 0 {
sum += w sum += w
tmp = append(tmp, indexWeight{index: u, weight: w}) tmp = append(tmp, indexWeight{
index: u,
weight: w,
})
} }
} }
if sum != 0 { if sum != 0 {
for i := range tmp { for i := range tmp {
tmp[i].weight /= sum tmp[i].weight /= sum
@ -305,204 +310,209 @@ func resampleLinear(x float64) float64 {
return 0 return 0
} }
// scanner wraps an image.Image for // scan scans the given rectangular region of the image into dst.
// easier size access and image type func scanImage(img image.Image, x1, y1, x2, y2 int, dst []uint8) {
// agnostic access to data at coords. switch img := img.(type) {
type scanner struct { case *image.NRGBA:
image image.Image scanNRGBA(img, x1, y1, x2, y2, dst)
w, h int case *image.NRGBA64:
palette []color.NRGBA scanNRGBA64(img, x1, y1, x2, y2, dst)
case *image.RGBA:
scanRGBA(img, x1, y1, x2, y2, dst)
case *image.RGBA64:
scanRGBA64(img, x1, y1, x2, y2, dst)
case *image.Gray:
scanGray(img, x1, y1, x2, y2, dst)
case *image.Gray16:
scanGray16(img, x1, y1, x2, y2, dst)
case *image.YCbCr:
scanYCbCr(img, x1, y1, x2, y2, dst)
case *image.Paletted:
scanPaletted(img, x1, y1, x2, y2, dst)
default:
scanAny(img, x1, y1, x2, y2, dst)
}
} }
// newScanner wraps an image.Image in scanner{} type. func scanNRGBA(img *image.NRGBA, x1, y1, x2, y2 int, dst []uint8) {
func newScanner(img image.Image) *scanner { size := (x2 - x1) * 4
b := img.Bounds() j := 0
s := &scanner{ i := y1*img.Stride + x1*4
image: img, if size == 4 {
for y := y1; y < y2; y++ {
w: b.Dx(), d := dst[j : j+4 : j+4]
h: b.Dy(), s := img.Pix[i : i+4 : i+4]
} d[0] = s[0]
if img, ok := img.(*image.Paletted); ok { d[1] = s[1]
s.palette = make([]color.NRGBA, len(img.Palette)) d[2] = s[2]
for i := 0; i < len(img.Palette); i++ { d[3] = s[3]
s.palette[i] = color.NRGBAModel.Convert(img.Palette[i]).(color.NRGBA) j += size
i += img.Stride
}
} else {
for y := y1; y < y2; y++ {
copy(dst[j:j+size], img.Pix[i:i+size])
j += size
i += img.Stride
} }
} }
return s
} }
// scan scans the given rectangular region of the image into dst. func scanNRGBA64(img *image.NRGBA64, x1, y1, x2, y2 int, dst []uint8) {
func (s *scanner) scan(x1, y1, x2, y2 int, dst []uint8) { if img == nil {
switch img := s.image.(type) { panic(gtserror.New("nil check elimination"))
case *image.NRGBA: }
size := (x2 - x1) * 4 j := 0
j := 0 for y := y1; y < y2; y++ {
i := y1*img.Stride + x1*4 i := y*img.Stride + x1*8
if size == 4 { for x := x1; x < x2; x++ {
for y := y1; y < y2; y++ { s := img.Pix[i : i+8 : i+8]
d := dst[j : j+4 : j+4] d := dst[j : j+4 : j+4]
d[0] = s[0]
d[1] = s[2]
d[2] = s[4]
d[3] = s[6]
j += 4
i += 8
}
}
}
func scanRGBA(img *image.RGBA, x1, y1, x2, y2 int, dst []uint8) {
if img == nil {
panic(gtserror.New("nil check elimination"))
}
j := 0
for y := y1; y < y2; y++ {
i := y*img.Stride + x1*4
for x := x1; x < x2; x++ {
d := dst[j : j+4 : j+4]
a := img.Pix[i+3]
switch a {
case 0:
d[0] = 0
d[1] = 0
d[2] = 0
d[3] = a
case 0xff:
s := img.Pix[i : i+4 : i+4] s := img.Pix[i : i+4 : i+4]
d[0] = s[0] d[0] = s[0]
d[1] = s[1] d[1] = s[1]
d[2] = s[2] d[2] = s[2]
d[3] = s[3] d[3] = a
j += size default:
i += img.Stride s := img.Pix[i : i+4 : i+4]
} r16 := uint16(s[0])
} else { g16 := uint16(s[1])
for y := y1; y < y2; y++ { b16 := uint16(s[2])
copy(dst[j:j+size], img.Pix[i:i+size]) a16 := uint16(a)
j += size d[0] = uint8(r16 * 0xff / a16) // #nosec G115 -- Overflow desired.
i += img.Stride d[1] = uint8(g16 * 0xff / a16) // #nosec G115 -- Overflow desired.
d[2] = uint8(b16 * 0xff / a16) // #nosec G115 -- Overflow desired.
d[3] = a
} }
j += 4
i += 4
} }
}
}
case *image.NRGBA64: func scanRGBA64(img *image.RGBA64, x1, y1, x2, y2 int, dst []uint8) {
j := 0 if img == nil {
for y := y1; y < y2; y++ { panic(gtserror.New("nil check elimination"))
i := y*img.Stride + x1*8 }
for x := x1; x < x2; x++ { j := 0
s := img.Pix[i : i+8 : i+8] for y := y1; y < y2; y++ {
d := dst[j : j+4 : j+4] i := y*img.Stride + x1*8
for x := x1; x < x2; x++ {
s := img.Pix[i : i+8 : i+8]
d := dst[j : j+4 : j+4]
a := s[6]
switch a {
case 0:
d[0] = 0
d[1] = 0
d[2] = 0
case 0xff:
d[0] = s[0] d[0] = s[0]
d[1] = s[2] d[1] = s[2]
d[2] = s[4] d[2] = s[4]
d[3] = s[6] default:
j += 4 r32 := uint32(s[0])<<8 | uint32(s[1])
i += 8 g32 := uint32(s[2])<<8 | uint32(s[3])
b32 := uint32(s[4])<<8 | uint32(s[5])
a32 := uint32(s[6])<<8 | uint32(s[7])
d[0] = uint8((r32 * 0xffff / a32) >> 8) // #nosec G115 -- Overflow desired.
d[1] = uint8((g32 * 0xffff / a32) >> 8) // #nosec G115 -- Overflow desired.
d[2] = uint8((b32 * 0xffff / a32) >> 8) // #nosec G115 -- Overflow desired.
} }
d[3] = a
j += 4
i += 8
} }
}
}
case *image.RGBA: func scanGray(img *image.Gray, x1, y1, x2, y2 int, dst []uint8) {
j := 0 if img == nil {
for y := y1; y < y2; y++ { panic(gtserror.New("nil check elimination"))
i := y*img.Stride + x1*4 }
for x := x1; x < x2; x++ { j := 0
d := dst[j : j+4 : j+4] for y := y1; y < y2; y++ {
a := img.Pix[i+3] i := y*img.Stride + x1
switch a { for x := x1; x < x2; x++ {
case 0: c := img.Pix[i]
d[0] = 0 d := dst[j : j+4 : j+4]
d[1] = 0 d[0] = c
d[2] = 0 d[1] = c
d[3] = a d[2] = c
case 0xff: d[3] = 0xff
s := img.Pix[i : i+4 : i+4] j += 4
d[0] = s[0] i++
d[1] = s[1]
d[2] = s[2]
d[3] = a
default:
s := img.Pix[i : i+4 : i+4]
r16 := uint16(s[0])
g16 := uint16(s[1])
b16 := uint16(s[2])
a16 := uint16(a)
d[0] = uint8(r16 * 0xff / a16) // #nosec G115 -- Overflow desired.
d[1] = uint8(g16 * 0xff / a16) // #nosec G115 -- Overflow desired.
d[2] = uint8(b16 * 0xff / a16) // #nosec G115 -- Overflow desired.
d[3] = a
}
j += 4
i += 4
}
} }
}
}
case *image.RGBA64: func scanGray16(img *image.Gray16, x1, y1, x2, y2 int, dst []uint8) {
j := 0 if img == nil {
for y := y1; y < y2; y++ { panic(gtserror.New("nil check elimination"))
i := y*img.Stride + x1*8 }
for x := x1; x < x2; x++ { j := 0
s := img.Pix[i : i+8 : i+8] for y := y1; y < y2; y++ {
d := dst[j : j+4 : j+4] i := y*img.Stride + x1*2
a := s[6] for x := x1; x < x2; x++ {
switch a { c := img.Pix[i]
case 0: d := dst[j : j+4 : j+4]
d[0] = 0 d[0] = c
d[1] = 0 d[1] = c
d[2] = 0 d[2] = c
case 0xff: d[3] = 0xff
d[0] = s[0] j += 4
d[1] = s[2] i += 2
d[2] = s[4]
default:
r32 := uint32(s[0])<<8 | uint32(s[1])
g32 := uint32(s[2])<<8 | uint32(s[3])
b32 := uint32(s[4])<<8 | uint32(s[5])
a32 := uint32(s[6])<<8 | uint32(s[7])
d[0] = uint8((r32 * 0xffff / a32) >> 8) // #nosec G115 -- Overflow desired.
d[1] = uint8((g32 * 0xffff / a32) >> 8) // #nosec G115 -- Overflow desired.
d[2] = uint8((b32 * 0xffff / a32) >> 8) // #nosec G115 -- Overflow desired.
}
d[3] = a
j += 4
i += 8
}
} }
}
}
case *image.Gray: func scanYCbCr(img *image.YCbCr, x1, y1, x2, y2 int, dst []uint8) {
j := 0 j := 0
for y := y1; y < y2; y++ {
i := y*img.Stride + x1
for x := x1; x < x2; x++ {
c := img.Pix[i]
d := dst[j : j+4 : j+4]
d[0] = c
d[1] = c
d[2] = c
d[3] = 0xff
j += 4
i++
}
}
case *image.Gray16: x1 += img.Rect.Min.X
j := 0 x2 += img.Rect.Min.X
for y := y1; y < y2; y++ { y1 += img.Rect.Min.Y
i := y*img.Stride + x1*2 y2 += img.Rect.Min.Y
for x := x1; x < x2; x++ {
c := img.Pix[i]
d := dst[j : j+4 : j+4]
d[0] = c
d[1] = c
d[2] = c
d[3] = 0xff
j += 4
i += 2
}
}
case *image.YCbCr: hy := img.Rect.Min.Y / 2
j := 0 hx := img.Rect.Min.X / 2
x1 += img.Rect.Min.X
x2 += img.Rect.Min.X
y1 += img.Rect.Min.Y
y2 += img.Rect.Min.Y
hy := img.Rect.Min.Y / 2 switch img.SubsampleRatio {
hx := img.Rect.Min.X / 2 case image.YCbCrSubsampleRatio420:
for y := y1; y < y2; y++ { for y := y1; y < y2; y++ {
iy := (y-img.Rect.Min.Y)*img.YStride + (x1 - img.Rect.Min.X) iy := (y-img.Rect.Min.Y)*img.YStride + (x1 - img.Rect.Min.X)
var yBase int yBase := (y/2 - hy) * img.CStride
switch img.SubsampleRatio {
case image.YCbCrSubsampleRatio444, image.YCbCrSubsampleRatio422:
yBase = (y - img.Rect.Min.Y) * img.CStride
case image.YCbCrSubsampleRatio420, image.YCbCrSubsampleRatio440:
yBase = (y/2 - hy) * img.CStride
}
for x := x1; x < x2; x++ { for x := x1; x < x2; x++ {
var ic int ic := yBase + (x/2 - hx)
switch img.SubsampleRatio {
case image.YCbCrSubsampleRatio444, image.YCbCrSubsampleRatio440:
ic = yBase + (x - img.Rect.Min.X)
case image.YCbCrSubsampleRatio422, image.YCbCrSubsampleRatio420:
ic = yBase + (x/2 - hx)
default:
ic = img.COffset(x, y)
}
yy1 := int32(img.Y[iy]) * 0x10101 yy1 := int32(img.Y[iy]) * 0x10101
cb1 := int32(img.Cb[ic]) - 128 cb1 := int32(img.Cb[ic]) - 128
@ -540,78 +550,296 @@ func (s *scanner) scan(x1, y1, x2, y2 int, dst []uint8) {
} }
} }
case *image.Paletted: case image.YCbCrSubsampleRatio422:
j := 0
for y := y1; y < y2; y++ { for y := y1; y < y2; y++ {
i := y*img.Stride + x1 iy := (y-img.Rect.Min.Y)*img.YStride + (x1 - img.Rect.Min.X)
yBase := (y - img.Rect.Min.Y) * img.CStride
for x := x1; x < x2; x++ { for x := x1; x < x2; x++ {
c := s.palette[img.Pix[i]] ic := yBase + (x/2 - hx)
yy1 := int32(img.Y[iy]) * 0x10101
cb1 := int32(img.Cb[ic]) - 128
cr1 := int32(img.Cr[ic]) - 128
r := yy1 + 91881*cr1
if uint32(r)&0xff000000 == 0 { //nolint:gosec
r >>= 16
} else {
r = ^(r >> 31)
}
g := yy1 - 22554*cb1 - 46802*cr1
if uint32(g)&0xff000000 == 0 { //nolint:gosec
g >>= 16
} else {
g = ^(g >> 31)
}
b := yy1 + 116130*cb1
if uint32(b)&0xff000000 == 0 { //nolint:gosec
b >>= 16
} else {
b = ^(b >> 31)
}
d := dst[j : j+4 : j+4] d := dst[j : j+4 : j+4]
d[0] = c.R d[0] = uint8(r) // #nosec G115 -- Overflow desired.
d[1] = c.G d[1] = uint8(g) // #nosec G115 -- Overflow desired.
d[2] = c.B d[2] = uint8(b) // #nosec G115 -- Overflow desired.
d[3] = c.A d[3] = 0xff
iy++
j += 4
}
}
case image.YCbCrSubsampleRatio440:
for y := y1; y < y2; y++ {
iy := (y-img.Rect.Min.Y)*img.YStride + (x1 - img.Rect.Min.X)
yBase := (y/2 - hy) * img.CStride
for x := x1; x < x2; x++ {
ic := yBase + (x - img.Rect.Min.X)
yy1 := int32(img.Y[iy]) * 0x10101
cb1 := int32(img.Cb[ic]) - 128
cr1 := int32(img.Cr[ic]) - 128
r := yy1 + 91881*cr1
if uint32(r)&0xff000000 == 0 { //nolint:gosec
r >>= 16
} else {
r = ^(r >> 31)
}
g := yy1 - 22554*cb1 - 46802*cr1
if uint32(g)&0xff000000 == 0 { //nolint:gosec
g >>= 16
} else {
g = ^(g >> 31)
}
b := yy1 + 116130*cb1
if uint32(b)&0xff000000 == 0 { //nolint:gosec
b >>= 16
} else {
b = ^(b >> 31)
}
d := dst[j : j+4 : j+4]
d[0] = uint8(r) // #nosec G115 -- Overflow desired.
d[1] = uint8(g) // #nosec G115 -- Overflow desired.
d[2] = uint8(b) // #nosec G115 -- Overflow desired.
d[3] = 0xff
iy++
j += 4
}
}
case image.YCbCrSubsampleRatio444:
for y := y1; y < y2; y++ {
iy := (y-img.Rect.Min.Y)*img.YStride + (x1 - img.Rect.Min.X)
yBase := (y - img.Rect.Min.Y) * img.CStride
for x := x1; x < x2; x++ {
ic := yBase + (x - img.Rect.Min.X)
yy1 := int32(img.Y[iy]) * 0x10101
cb1 := int32(img.Cb[ic]) - 128
cr1 := int32(img.Cr[ic]) - 128
r := yy1 + 91881*cr1
if uint32(r)&0xff000000 == 0 { //nolint:gosec
r >>= 16
} else {
r = ^(r >> 31)
}
g := yy1 - 22554*cb1 - 46802*cr1
if uint32(g)&0xff000000 == 0 { //nolint:gosec
g >>= 16
} else {
g = ^(g >> 31)
}
b := yy1 + 116130*cb1
if uint32(b)&0xff000000 == 0 { //nolint:gosec
b >>= 16
} else {
b = ^(b >> 31)
}
d := dst[j : j+4 : j+4]
d[0] = uint8(r) // #nosec G115 -- Overflow desired.
d[1] = uint8(g) // #nosec G115 -- Overflow desired.
d[2] = uint8(b) // #nosec G115 -- Overflow desired.
d[3] = 0xff
iy++
j += 4 j += 4
i++
} }
} }
default: default:
j := 0
b := s.image.Bounds()
x1 += b.Min.X
x2 += b.Min.X
y1 += b.Min.Y
y2 += b.Min.Y
for y := y1; y < y2; y++ { for y := y1; y < y2; y++ {
iy := (y-img.Rect.Min.Y)*img.YStride + (x1 - img.Rect.Min.X)
for x := x1; x < x2; x++ { for x := x1; x < x2; x++ {
r16, g16, b16, a16 := s.image.At(x, y).RGBA() ic := img.COffset(x, y)
d := dst[j : j+4 : j+4]
switch a16 { yy1 := int32(img.Y[iy]) * 0x10101
case 0xffff: cb1 := int32(img.Cb[ic]) - 128
d[0] = uint8(r16 >> 8) // #nosec G115 -- Overflow desired. cr1 := int32(img.Cr[ic]) - 128
d[1] = uint8(g16 >> 8) // #nosec G115 -- Overflow desired.
d[2] = uint8(b16 >> 8) // #nosec G115 -- Overflow desired. r := yy1 + 91881*cr1
d[3] = 0xff if uint32(r)&0xff000000 == 0 { //nolint:gosec
case 0: r >>= 16
d[0] = 0 } else {
d[1] = 0 r = ^(r >> 31)
d[2] = 0
d[3] = 0
default:
d[0] = uint8(((r16 * 0xffff) / a16) >> 8) // #nosec G115 -- Overflow desired.
d[1] = uint8(((g16 * 0xffff) / a16) >> 8) // #nosec G115 -- Overflow desired.
d[2] = uint8(((b16 * 0xffff) / a16) >> 8) // #nosec G115 -- Overflow desired.
d[3] = uint8(a16 >> 8) // #nosec G115 -- Overflow desired.
} }
g := yy1 - 22554*cb1 - 46802*cr1
if uint32(g)&0xff000000 == 0 { //nolint:gosec
g >>= 16
} else {
g = ^(g >> 31)
}
b := yy1 + 116130*cb1
if uint32(b)&0xff000000 == 0 { //nolint:gosec
b >>= 16
} else {
b = ^(b >> 31)
}
d := dst[j : j+4 : j+4]
d[0] = uint8(r) // #nosec G115 -- Overflow desired.
d[1] = uint8(g) // #nosec G115 -- Overflow desired.
d[2] = uint8(b) // #nosec G115 -- Overflow desired.
d[3] = 0xff
iy++
j += 4 j += 4
} }
} }
} }
} }
func scanPaletted(img *image.Paletted, x1, y1, x2, y2 int, dst []uint8) {
var palette [256]color.NRGBA
if len(palette) < len(img.Palette) {
panic(gtserror.New("bound check elimination"))
}
for i := 0; i < len(img.Palette); i++ {
palette[i] = colorToNRGBA(img.Palette[i])
}
j := 0
for y := y1; y < y2; y++ {
i := y*img.Stride + x1
for x := x1; x < x2; x++ {
c := palette[img.Pix[i]]
d := dst[j : j+4 : j+4]
d[0] = c.R
d[1] = c.G
d[2] = c.B
d[3] = c.A
j += 4
i++
}
}
}
// inlined from: image/color.NRGBAModel.Convert()
func colorToNRGBA(c color.Color) color.NRGBA {
if c, ok := c.(color.NRGBA); ok {
return c
}
r, g, b, a := c.RGBA()
if a == 0xffff {
return color.NRGBA{
uint8(r >> 8), // #nosec G115 -- from stdlib
uint8(g >> 8), // #nosec G115 -- from stdlib
uint8(b >> 8), // #nosec G115 -- from stdlib
0xff,
}
}
if a == 0 {
return color.NRGBA{
0,
0,
0,
0,
}
}
// Since Color.RGBA returns an alpha-premultiplied color,
// we should have r <= a && g <= a && b <= a.
r = (r * 0xffff) / a
g = (g * 0xffff) / a
b = (b * 0xffff) / a
return color.NRGBA{
uint8(r >> 8), // #nosec G115 -- from stdlib
uint8(g >> 8), // #nosec G115 -- from stdlib
uint8(b >> 8), // #nosec G115 -- from stdlib
uint8(a >> 8), // #nosec G115 -- from stdlib
}
}
func scanAny(img image.Image, x1, y1, x2, y2 int, dst []uint8) {
j := 0
b := img.Bounds()
x1 += b.Min.X
x2 += b.Min.X
y1 += b.Min.Y
y2 += b.Min.Y
for y := y1; y < y2; y++ {
for x := x1; x < x2; x++ {
r16, g16, b16, a16 := img.At(x, y).RGBA()
d := dst[j : j+4 : j+4]
switch a16 {
case 0xffff:
d[0] = uint8(r16 >> 8) // #nosec G115 -- Overflow desired.
d[1] = uint8(g16 >> 8) // #nosec G115 -- Overflow desired.
d[2] = uint8(b16 >> 8) // #nosec G115 -- Overflow desired.
d[3] = 0xff
case 0:
d[0] = 0
d[1] = 0
d[2] = 0
d[3] = 0
default:
d[0] = uint8(((r16 * 0xffff) / a16) >> 8) // #nosec G115 -- Overflow desired.
d[1] = uint8(((g16 * 0xffff) / a16) >> 8) // #nosec G115 -- Overflow desired.
d[2] = uint8(((b16 * 0xffff) / a16) >> 8) // #nosec G115 -- Overflow desired.
d[3] = uint8(a16 >> 8) // #nosec G115 -- Overflow desired.
}
j += 4
}
}
}
// reverse reverses the data // reverse reverses the data
// in contained pixel slice. // in contained pixel slice.
func reverse(pix []uint8) { func reverse(pix8 []uint8) {
if len(pix) <= 4 { if len(pix8) <= 4 || len(pix8)%4 != 0 {
return return
} }
i := 0 for i, j := 0, len(pix8)-4; i < j; i, j = i+4, j-4 {
j := len(pix) - 4 di := pix8[i : i+4 : i+4]
for i < j { dj := pix8[j : j+4 : j+4]
pi := pix[i : i+4 : i+4] di[0], dj[0] = dj[0], di[0]
pj := pix[j : j+4 : j+4] di[1], dj[1] = dj[1], di[1]
pi[0], pj[0] = pj[0], pi[0] di[2], dj[2] = dj[2], di[2]
pi[1], pj[1] = pj[1], pi[1] di[3], dj[3] = dj[3], di[3]
pi[2], pj[2] = pj[2], pi[2]
pi[3], pj[3] = pj[3], pi[3]
i += 4
j -= 4
} }
} }
// clampFloat rounds and clamps float64 value to fit into uint8. // clampFloatTo8 rounds and clamps
func clampFloat(x float64) uint8 { // float64 value to fit into uint8.
func clampFloatTo8(x float64) uint8 {
v := int64(x + 0.5) v := int64(x + 0.5)
if v > 255 { if v > 255 {
return 255 return 255

View file

@ -0,0 +1,157 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package media
import (
"fmt"
"image"
"image/gif"
"image/jpeg"
"image/png"
"io"
"path"
"reflect"
"strings"
"testing"
"golang.org/x/image/webp"
)
func BenchmarkFlipH(b *testing.B) {
benchmarkFunc(b, func(img image.Image) {
_ = flipH(img)
})
}
func BenchmarkFlipV(b *testing.B) {
benchmarkFunc(b, func(img image.Image) {
_ = flipV(img)
})
}
func BenchmarkRotate90(b *testing.B) {
benchmarkFunc(b, func(img image.Image) {
_ = rotate90(img)
})
}
func BenchmarkRotate180(b *testing.B) {
benchmarkFunc(b, func(img image.Image) {
_ = rotate180(img)
})
}
func BenchmarkRotate270(b *testing.B) {
benchmarkFunc(b, func(img image.Image) {
_ = rotate270(img)
})
}
func BenchmarkTranspose(b *testing.B) {
benchmarkFunc(b, func(img image.Image) {
_ = transpose(img)
})
}
func BenchmarkTransverse(b *testing.B) {
benchmarkFunc(b, func(img image.Image) {
_ = transverse(img)
})
}
func BenchmarkResizeHorizontalLinear(b *testing.B) {
benchmarkFunc(b, func(img image.Image) {
_ = resizeHorizontalLinear(img, 64)
})
}
func BenchmarkResizeVerticalLinear(b *testing.B) {
benchmarkFunc(b, func(img image.Image) {
_ = resizeVerticalLinear(img, 64)
})
}
func benchmarkFunc(b *testing.B, fn func(image.Image)) {
b.Helper()
for _, testcase := range []struct {
Path string
Decode func(io.Reader) (image.Image, error)
}{
{
Path: "./test/big-panda.gif",
Decode: gif.Decode,
},
{
Path: "./test/clock-original.gif",
Decode: gif.Decode,
},
{
Path: "./test/test-jpeg.jpg",
Decode: jpeg.Decode,
},
{
Path: "./test/test-png-noalphachannel.png",
Decode: png.Decode,
},
{
Path: "./test/test-png-alphachannel.png",
Decode: png.Decode,
},
{
Path: "./test/rainbow-original.png",
Decode: png.Decode,
},
{
Path: "./test/nb-flag-original.webp",
Decode: webp.Decode,
},
} {
file, err := openRead(testcase.Path)
if err != nil {
panic(err)
}
img, err := testcase.Decode(file)
if err != nil {
panic(err)
}
info, err := file.Stat()
if err != nil {
panic(err)
}
file.Close()
testname := fmt.Sprintf("ext=%s type=%s size=%d",
strings.TrimPrefix(path.Ext(testcase.Path), "."),
strings.TrimPrefix(reflect.TypeOf(img).String(), "*image."),
info.Size(),
)
b.Run(testname, func(b *testing.B) {
b.Helper()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
fn(img)
}
})
})
}
}

View file

@ -74,20 +74,28 @@ func clearMetadata(ctx context.Context, filepath string) error {
// terminateExif cleans exif data from file at input path, into file // terminateExif cleans exif data from file at input path, into file
// at output path, using given file extension to determine cleaning type. // at output path, using given file extension to determine cleaning type.
func terminateExif(outpath, inpath string, ext string) error { func terminateExif(outpath, inpath string, ext string) (err error) {
var inFile *os.File
var outFile *os.File
// Ensure handles
// closed on return.
defer func() {
outFile.Close()
inFile.Close()
}()
// Open input file at given path. // Open input file at given path.
inFile, err := os.Open(inpath) inFile, err = openRead(inpath)
if err != nil { if err != nil {
return gtserror.Newf("error opening input file %s: %w", inpath, err) return gtserror.Newf("error opening input file %s: %w", inpath, err)
} }
defer inFile.Close()
// Open output file at given path. // Create output file at given path.
outFile, err := os.Create(outpath) outFile, err = openWrite(outpath)
if err != nil { if err != nil {
return gtserror.Newf("error opening output file %s: %w", outpath, err) return gtserror.Newf("error opening output file %s: %w", outpath, err)
} }
defer outFile.Close()
// Terminate EXIF data from 'inFile' -> 'outFile'. // Terminate EXIF data from 'inFile' -> 'outFile'.
err = terminator.TerminateInto(outFile, inFile, ext) err = terminator.TerminateInto(outFile, inFile, ext)

View file

@ -38,8 +38,9 @@ const (
// probe will first attempt to probe the file at path using native Go code // probe will first attempt to probe the file at path using native Go code
// (for performance), but falls back to using ffprobe to retrieve media details. // (for performance), but falls back to using ffprobe to retrieve media details.
func probe(ctx context.Context, filepath string) (*result, error) { func probe(ctx context.Context, filepath string) (*result, error) {
// Open input file at given path. // Open input file at given path.
file, err := os.Open(filepath) file, err := openRead(filepath)
if err != nil { if err != nil {
return nil, gtserror.Newf("error opening file %s: %w", filepath, err) return nil, gtserror.Newf("error opening file %s: %w", filepath, err)
} }
@ -80,6 +81,7 @@ func probe(ctx context.Context, filepath string) (*result, error) {
// probeJPEG decodes the given file as JPEG and determines // probeJPEG decodes the given file as JPEG and determines
// image details from the decoded JPEG using native Go code. // image details from the decoded JPEG using native Go code.
func probeJPEG(file *os.File) (*result, error) { func probeJPEG(file *os.File) (*result, error) {
// Attempt to decode JPEG, adding back hdr magic. // Attempt to decode JPEG, adding back hdr magic.
cfg, err := jpeg.DecodeConfig(io.MultiReader( cfg, err := jpeg.DecodeConfig(io.MultiReader(
strings.NewReader(magicJPEG), strings.NewReader(magicJPEG),
@ -136,17 +138,29 @@ func readOrientation(r *os.File) int {
orientationTag = 0x0112 orientationTag = 0x0112
) )
// Setup a discard read buffer. // Setup a read buffer.
buf := new(byteutil.Buffer) var buf byteutil.Buffer
buf.Guarantee(32) buf.B = make([]byte, 0, 64)
// discard simply reads into buf. // discard simply reads into buf.
discard := func(n int) error { discard := func(n int) error {
buf.Guarantee(n) // ensure big enough buf.Guarantee(n)
_, err := io.ReadFull(r, buf.B[:n]) _, err := io.ReadFull(r, buf.B[:n])
return err return err
} }
// readUint16 reads uint16 bytes into buffer then parses.
readUint16 := func(b binary.ByteOrder) (uint16, error) {
_, err := io.ReadFull(r, buf.B[:2])
return b.Uint16(buf.B[:2]), err
}
// readUint32 reads uint32 bytes into buffer then parses.
readUint32 := func(b binary.ByteOrder) (uint32, error) {
_, err := io.ReadFull(r, buf.B[:4])
return b.Uint32(buf.B[:4]), err
}
// Skip past JPEG SOI marker. // Skip past JPEG SOI marker.
if err := discard(2); err != nil { if err := discard(2); err != nil {
return orientationUnspecified return orientationUnspecified
@ -155,13 +169,13 @@ func readOrientation(r *os.File) int {
// Find JPEG // Find JPEG
// APP1 marker. // APP1 marker.
for { for {
var marker, size uint16 marker, err := readUint16(binary.BigEndian)
if err != nil {
if err := binary.Read(r, binary.BigEndian, &marker); err != nil {
return orientationUnspecified return orientationUnspecified
} }
if err := binary.Read(r, binary.BigEndian, &size); err != nil { size, err := readUint16(binary.BigEndian)
if err != nil {
return orientationUnspecified return orientationUnspecified
} }
@ -182,11 +196,9 @@ func readOrientation(r *os.File) int {
} }
} }
// Check if EXIF // Check if EXIF header is present.
// header is present. header, err := readUint32(binary.BigEndian)
var header uint32 if err != nil {
if err := binary.Read(r, binary.BigEndian, &header); err != nil {
return orientationUnspecified return orientationUnspecified
} }
@ -198,17 +210,13 @@ func readOrientation(r *os.File) int {
return orientationUnspecified return orientationUnspecified
} }
// Read byte // Read byte order info.
// order info. byteOrderTag, err := readUint16(binary.BigEndian)
var ( if err != nil {
byteOrderTag uint16
byteOrder binary.ByteOrder
)
if err := binary.Read(r, binary.BigEndian, &byteOrderTag); err != nil {
return orientationUnspecified return orientationUnspecified
} }
var byteOrder binary.ByteOrder
switch byteOrderTag { switch byteOrderTag {
case byteOrderBE: case byteOrderBE:
byteOrder = binary.BigEndian byteOrder = binary.BigEndian
@ -222,11 +230,9 @@ func readOrientation(r *os.File) int {
return orientationUnspecified return orientationUnspecified
} }
// Skip the // Skip the EXIF offset.
// EXIF offset. offset, err := readUint32(byteOrder)
var offset uint32 if err != nil {
if err := binary.Read(r, byteOrder, &offset); err != nil {
return orientationUnspecified return orientationUnspecified
} }
@ -238,19 +244,16 @@ func readOrientation(r *os.File) int {
return orientationUnspecified return orientationUnspecified
} }
// Read the // Read the number of tags.
// number of tags. numTags, err := readUint16(byteOrder)
var numTags uint16 if err != nil {
if err := binary.Read(r, byteOrder, &numTags); err != nil {
return orientationUnspecified return orientationUnspecified
} }
// Find the orientation tag. // Find the orientation tag.
for i := 0; i < int(numTags); i++ { for i := 0; i < int(numTags); i++ {
var tag uint16 tag, err := readUint16(byteOrder)
if err != nil {
if err := binary.Read(r, byteOrder, &tag); err != nil {
return orientationUnspecified return orientationUnspecified
} }
@ -265,9 +268,8 @@ func readOrientation(r *os.File) int {
return orientationUnspecified return orientationUnspecified
} }
var val uint16 val, err := readUint16(byteOrder)
if err != nil {
if err := binary.Read(r, byteOrder, &val); err != nil {
return orientationUnspecified return orientationUnspecified
} }

View file

@ -44,11 +44,6 @@ type ProcessingEmoji struct {
mgr *Manager // mgr instance (access to db / storage) mgr *Manager // mgr instance (access to db / storage)
} }
// ID returns the ID of the underlying emoji.
func (p *ProcessingEmoji) ID() string {
return p.emoji.ID // immutable, safe outside mutex.
}
// LoadEmoji blocks until the static and fullsize image has been processed, and then returns the completed emoji. // LoadEmoji blocks until the static and fullsize image has been processed, and then returns the completed emoji.
func (p *ProcessingEmoji) Load(ctx context.Context) (*gtsmodel.Emoji, error) { func (p *ProcessingEmoji) Load(ctx context.Context) (*gtsmodel.Emoji, error) {
emoji, done, err := p.load(ctx) emoji, done, err := p.load(ctx)
@ -63,6 +58,33 @@ func (p *ProcessingEmoji) Load(ctx context.Context) (*gtsmodel.Emoji, error) {
return emoji, err return emoji, err
} }
func (p *ProcessingEmoji) LoadAsync(deferred func()) *gtsmodel.Emoji {
p.mgr.state.Workers.Dereference.Queue.Push(func(ctx context.Context) {
if deferred != nil {
defer deferred()
}
if _, _, err := p.load(ctx); err != nil {
log.Errorf(ctx, "error loading emoji: %v", err)
}
})
// Placeholder returns a copy of internally stored processing placeholder,
// returning only the fields that may be known *before* completion,
// and as such all fields which are safe to concurrently read.
placeholder := new(gtsmodel.Emoji)
placeholder.ID = p.emoji.ID
placeholder.Shortcode = p.emoji.Shortcode
placeholder.Domain = p.emoji.Domain
placeholder.Cached = new(bool)
placeholder.ImageRemoteURL = p.emoji.ImageRemoteURL
placeholder.ImageStaticRemoteURL = p.emoji.ImageStaticRemoteURL
placeholder.Disabled = p.emoji.Disabled
placeholder.VisibleInPicker = p.emoji.VisibleInPicker
placeholder.CategoryID = p.emoji.CategoryID
return placeholder
}
// load is the package private form of load() that is wrapped to catch context canceled. // load is the package private form of load() that is wrapped to catch context canceled.
func (p *ProcessingEmoji) load(ctx context.Context) ( func (p *ProcessingEmoji) load(ctx context.Context) (
emoji *gtsmodel.Emoji, emoji *gtsmodel.Emoji,

View file

@ -24,7 +24,6 @@ import (
"image/jpeg" "image/jpeg"
"image/png" "image/png"
"io" "io"
"os"
"strings" "strings"
"code.superseriousbusiness.org/gotosocial/internal/gtserror" "code.superseriousbusiness.org/gotosocial/internal/gtserror"
@ -89,8 +88,8 @@ func generateThumb(
// Default type is webp. // Default type is webp.
mimeType = "image/webp" mimeType = "image/webp"
// Generate thumb output path REPLACING extension. // Generate thumb output path REPLACING file extension.
if i := strings.IndexByte(filepath, '.'); i != -1 { if i := strings.LastIndexByte(filepath, '.'); i != -1 {
outpath = filepath[:i] + "_thumb.webp" outpath = filepath[:i] + "_thumb.webp"
ext = filepath[i+1:] // old extension ext = filepath[i+1:] // old extension
} else { } else {
@ -231,7 +230,7 @@ func generateNativeThumb(
error, error,
) { ) {
// Open input file at given path. // Open input file at given path.
infile, err := os.Open(inpath) infile, err := openRead(inpath)
if err != nil { if err != nil {
return "", gtserror.Newf("error opening input file %s: %w", inpath, err) return "", gtserror.Newf("error opening input file %s: %w", inpath, err)
} }
@ -272,7 +271,7 @@ func generateNativeThumb(
) )
// Open output file at given path. // Open output file at given path.
outfile, err := os.Create(outpath) outfile, err := openWrite(outpath)
if err != nil { if err != nil {
return "", gtserror.Newf("error opening output file %s: %w", outpath, err) return "", gtserror.Newf("error opening output file %s: %w", outpath, err)
} }
@ -313,8 +312,9 @@ func generateNativeThumb(
// generateWebpBlurhash generates a blurhash for Webp at filepath. // generateWebpBlurhash generates a blurhash for Webp at filepath.
func generateWebpBlurhash(filepath string) (string, error) { func generateWebpBlurhash(filepath string) (string, error) {
// Open the file at given path. // Open the file at given path.
file, err := os.Open(filepath) file, err := openRead(filepath)
if err != nil { if err != nil {
return "", gtserror.Newf("error opening input file %s: %w", filepath, err) return "", gtserror.Newf("error opening input file %s: %w", filepath, err)
} }

View file

@ -30,14 +30,41 @@ import (
"codeberg.org/gruf/go-iotools" "codeberg.org/gruf/go-iotools"
) )
// media processing tmpdir.
var tmpdir = os.TempDir()
// file represents one file // file represents one file
// with the given flag and perms. // with the given flag and perms.
type file struct { type file struct {
abs string abs string // absolute file path, including root
dir string // containing directory of abs
rel string // relative to root, i.e. trim_prefix(abs, dir)
flag int flag int
perm os.FileMode perm os.FileMode
} }
// allowRead returns a new file{} for filepath permitted only to read.
func allowRead(filepath string) file {
return newFile(filepath, os.O_RDONLY, 0)
}
// allowCreate returns a new file{} for filepath permitted to read / write / create.
func allowCreate(filepath string) file {
return newFile(filepath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600)
}
// newFile returns a new instance of file{} for given path and open args.
func newFile(filepath string, flag int, perms os.FileMode) file {
dir, rel := path.Split(filepath)
return file{
abs: filepath,
rel: rel,
dir: dir,
flag: flag,
perm: perms,
}
}
// allowFiles implements fs.FS to allow // allowFiles implements fs.FS to allow
// access to a specified slice of files. // access to a specified slice of files.
type allowFiles []file type allowFiles []file
@ -45,36 +72,32 @@ type allowFiles []file
// Open implements fs.FS. // Open implements fs.FS.
func (af allowFiles) Open(name string) (fs.File, error) { func (af allowFiles) Open(name string) (fs.File, error) {
for _, file := range af { for _, file := range af {
var ( switch name {
abs = file.abs
flag = file.flag
perm = file.perm
)
// Allowed to open file // Allowed to open file
// at absolute path. // at absolute path, or
if name == file.abs { // relative as ffmpeg likes.
return os.OpenFile(abs, flag, perm) case file.abs, file.rel:
} return os.OpenFile(file.abs, file.flag, file.perm)
// Check for other valid reads. // Ffmpeg likes to read containing
thisDir, thisFile := path.Split(file.abs) // dir as '.'. Allow RO access here.
case ".":
// Allowed to read directory itself. return openRead(file.dir)
if name == thisDir || name == "." {
return os.OpenFile(thisDir, flag, perm)
}
// Allowed to read file
// itself (at relative path).
if name == thisFile {
return os.OpenFile(abs, flag, perm)
} }
} }
return nil, os.ErrPermission return nil, os.ErrPermission
} }
// openRead opens the existing file at path for reads only.
func openRead(path string) (*os.File, error) {
return os.OpenFile(path, os.O_RDONLY, 0)
}
// openWrite opens the (new!) file at path for read / writes.
func openWrite(path string) (*os.File, error) {
return os.OpenFile(path, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600)
}
// getExtension splits file extension from path. // getExtension splits file extension from path.
func getExtension(path string) string { func getExtension(path string) string {
for i := len(path) - 1; i >= 0 && path[i] != '/'; i-- { for i := len(path) - 1; i >= 0 && path[i] != '/'; i-- {
@ -93,17 +116,24 @@ func getExtension(path string) string {
// chance that Linux's sendfile syscall can be utilised for optimal // chance that Linux's sendfile syscall can be utilised for optimal
// draining of data source to temporary file storage. // draining of data source to temporary file storage.
func drainToTmp(rc io.ReadCloser) (string, error) { func drainToTmp(rc io.ReadCloser) (string, error) {
defer rc.Close() var tmp *os.File
var err error
// Close handles
// on func return.
defer func() {
tmp.Close()
rc.Close()
}()
// Open new temporary file. // Open new temporary file.
tmp, err := os.CreateTemp( tmp, err = os.CreateTemp(
os.TempDir(), tmpdir,
"gotosocial-*", "gotosocial-*",
) )
if err != nil { if err != nil {
return "", err return "", err
} }
defer tmp.Close()
// Extract file path. // Extract file path.
path := tmp.Name() path := tmp.Name()

View file

@ -25,6 +25,7 @@ import (
"time" "time"
"code.superseriousbusiness.org/gotosocial/internal/gtserror" "code.superseriousbusiness.org/gotosocial/internal/gtserror"
"code.superseriousbusiness.org/gotosocial/internal/log"
"code.superseriousbusiness.org/gotosocial/internal/util" "code.superseriousbusiness.org/gotosocial/internal/util"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/ulule/limiter/v3" "github.com/ulule/limiter/v3"
@ -78,30 +79,43 @@ func RateLimit(limit int, except []netip.Prefix) gin.HandlerFunc {
// Use Gin's heuristic for determining // Use Gin's heuristic for determining
// clientIP, which accounts for reverse // clientIP, which accounts for reverse
// proxies and trusted proxies setting. // proxies and trusted proxies setting.
clientIP := netip.MustParseAddr(c.ClientIP()) clientIP := c.ClientIP()
// ClientIP must be parseable.
ip, err := netip.ParseAddr(clientIP)
if err != nil {
log.Warnf(
c.Request.Context(),
"cannot do rate limiting for this request as client IP %s could not be parsed;"+
" your upstream reverse proxy may be misconfigured: %v",
err,
)
c.Next()
return
}
// Check if this IP is exempt from rate // Check if this IP is exempt from rate
// limits and skip further checks if so. // limits and skip further checks if so.
for _, prefix := range except { for _, prefix := range except {
if prefix.Contains(clientIP) { if prefix.Contains(ip) {
c.Next() c.Next()
return return
} }
} }
if clientIP.Is6() { if ip.Is6() {
// Convert to "net" package IP for mask. // Convert to "net" package IP for mask.
asIP := net.IP(clientIP.AsSlice()) asIP := net.IP(ip.AsSlice())
// Apply coarse IPv6 mask. // Apply coarse IPv6 mask.
asIP = asIP.Mask(ipv6Mask) asIP = asIP.Mask(ipv6Mask)
// Convert back to netip.Addr from net.IP. // Convert back to netip.Addr from net.IP.
clientIP, _ = netip.AddrFromSlice(asIP) ip, _ = netip.AddrFromSlice(asIP)
} }
// Fetch rate limit info for this (masked) clientIP. // Fetch rate limit info for this (masked) clientIP.
context, err := limiter.Get(c, clientIP.String()) context, err := limiter.Get(c, ip.String())
if err != nil { if err != nil {
// Since we use an in-memory cache now, // Since we use an in-memory cache now,
// it's actually impossible for this to // it's actually impossible for this to

View file

@ -293,6 +293,7 @@ func (p *Processor) emojiUpdateCopy(
// Ensure target emoji is locally cached. // Ensure target emoji is locally cached.
target, err := p.federator.RecacheEmoji(ctx, target, err := p.federator.RecacheEmoji(ctx,
target, target,
false,
) )
if err != nil { if err != nil {
err := gtserror.Newf("error recaching emoji %s: %w", target.ImageRemoteURL, err) err := gtserror.Newf("error recaching emoji %s: %w", target.ImageRemoteURL, err)

View file

@ -247,6 +247,7 @@ func (p *Processor) getEmojiContent(
emoji, err = p.federator.RecacheEmoji( emoji, err = p.federator.RecacheEmoji(
ctx, ctx,
emoji, emoji,
false,
) )
if err != nil { if err != nil {
err := gtserror.Newf("error recaching emoji: %w", err) err := gtserror.Newf("error recaching emoji: %w", err)

View file

@ -31,7 +31,7 @@ import (
// elements to reduce overall memory usage. // elements to reduce overall memory usage.
type SimpleQueue[T any] struct { type SimpleQueue[T any] struct {
l list.List[T] l list.List[T]
p mempool.UnsafePool p mempool.UnsafeSimplePool
w chan struct{} w chan struct{}
m sync.Mutex m sync.Mutex
} }

View file

@ -2438,8 +2438,8 @@ func (c *Converter) InteractionReqToASAuthorization(
} }
// appendASInteractionAuthorization is a utility function // appendASInteractionAuthorization is a utility function
// that sets `approvedBy`, and `likeAuthorization`, // that sets `approvedBy`, and (if possible) `likeAuthorization`,
// `replyAuthorization`, or `announceAuthorization`. // `replyAuthorization`, and/or `announceAuthorization`.
func (c *Converter) appendASInteractionAuthorization( func (c *Converter) appendASInteractionAuthorization(
ctx context.Context, ctx context.Context,
approvedByURIStr string, approvedByURIStr string,
@ -2458,11 +2458,28 @@ func (c *Converter) appendASInteractionAuthorization(
gtscontext.SetBarebones(ctx), gtscontext.SetBarebones(ctx),
approvedByURIStr, approvedByURIStr,
) )
if err != nil { if err != nil && !errors.Is(err, db.ErrNoEntries) {
return gtserror.Newf("db error checking for int req: %w", err) return gtserror.Newf("db error checking for int req: %w", err)
} }
// Make sure it's actually accepted. // If the interaction request is nil,
// that means we originally sent out
// the interaction request impolitely,
// and it was accepted impolitely.
// Ie., behavior from <= v0.20.0.
//
// If this is so, just set `approvedBy`
// to given approvedByURIStr and bail,
// as there's nothing else we can do.
if intReq == nil {
if wap, ok := t.(ap.WithApprovedBy); ok {
ap.SetApprovedBy(wap, approvedByURI)
}
return nil
}
// Make sure interaction request
// has actually been accepted.
if !intReq.IsAccepted() { if !intReq.IsAccepted() {
return gtserror.Newf( return gtserror.Newf(
"approvedByURIStr %s corresponded to not-accepted interaction request %s", "approvedByURIStr %s corresponded to not-accepted interaction request %s",

View file

@ -149,6 +149,7 @@ nav:
- "admin/spam.md" - "admin/spam.md"
- "admin/database_maintenance.md" - "admin/database_maintenance.md"
- "admin/themes.md" - "admin/themes.md"
- "admin/slow_hardware.md"
- "Federation": - "Federation":
- "federation/index.md" - "federation/index.md"
- "federation/http_signatures.md" - "federation/http_signatures.md"

View file

@ -1,39 +1,46 @@
package embed package embed
import ( import (
"bytes"
"compress/gzip" "compress/gzip"
_ "embed" _ "embed"
"io" "io"
"os" "strings"
) )
func init() { func init() {
var err error var err error
if path := os.Getenv("FFMPREG_WASM"); path != "" {
// Read file into memory.
B, err = os.ReadFile(path)
if err != nil {
panic(err)
}
}
// Wrap bytes in reader. // Wrap bytes in reader.
b := bytes.NewReader(B) r := strings.NewReader(s)
// Create unzipper from reader. // Create unzipper from reader.
gz, err := gzip.NewReader(b) gz, err := gzip.NewReader(r)
if err != nil { if err != nil {
panic(err) panic(err)
} }
// Extract gzipped binary. // Extract gzipped binary.
B, err = io.ReadAll(gz) b, err := io.ReadAll(gz)
if err != nil { if err != nil {
panic(err) panic(err)
} }
// Set binary.
s = string(b)
} }
// B returns a copy of
// embedded binary data.
func B() []byte {
if s == "" {
panic("binary already dropped from memory")
}
return []byte(s)
}
// Free will drop embedded
// binary from runtime mem.
func Free() { s = "" }
//go:embed ffmpreg.wasm.gz //go:embed ffmpreg.wasm.gz
var B []byte var s string

View file

@ -9,22 +9,93 @@ import (
type snapshotskey struct{} type snapshotskey struct{}
type snapshotctx struct {
context.Context
snaps *snapshots
}
func (ctx snapshotctx) Value(key any) any {
if _, ok := key.(snapshotskey); ok {
return ctx.snaps
}
return ctx.Context.Value(key)
}
const ringsz uint = 8
type snapshots struct {
r [ringsz]struct {
eptr uint32
snap experimental.Snapshot
}
n uint
}
func (s *snapshots) get(envptr uint32) experimental.Snapshot {
start := (s.n % ringsz)
for i := start; i != ^uint(0); i-- {
if s.r[i].eptr == envptr {
snap := s.r[i].snap
s.r[i].eptr = 0
s.r[i].snap = nil
s.n = i - 1
return snap
}
}
for i := ringsz - 1; i > start; i-- {
if s.r[i].eptr == envptr {
snap := s.r[i].snap
s.r[i].eptr = 0
s.r[i].snap = nil
s.n = i - 1
return snap
}
}
panic("snapshot not found")
}
func (s *snapshots) set(envptr uint32, snapshot experimental.Snapshot) {
start := (s.n % ringsz)
for i := start; i < ringsz; i++ {
switch s.r[i].eptr {
case 0, envptr:
s.r[i].eptr = envptr
s.r[i].snap = snapshot
s.n = i
return
}
}
for i := uint(0); i < start; i++ {
switch s.r[i].eptr {
case 0, envptr:
s.r[i].eptr = envptr
s.r[i].snap = snapshot
s.n = i
return
}
}
panic("snapshots full")
}
// withSetjmpLongjmp updates the context to contain wazero/experimental.Snapshotter{} support, // withSetjmpLongjmp updates the context to contain wazero/experimental.Snapshotter{} support,
// and embeds the necessary snapshots map required for later calls to Setjmp() / Longjmp(). // and embeds the necessary snapshots map required for later calls to Setjmp() / Longjmp().
func withSetjmpLongjmp(ctx context.Context) context.Context { func withSetjmpLongjmp(ctx context.Context) context.Context {
snapshots := make(map[uint32]experimental.Snapshot, 10) return snapshotctx{Context: experimental.WithSnapshotter(ctx), snaps: new(snapshots)}
ctx = experimental.WithSnapshotter(ctx)
ctx = context.WithValue(ctx, snapshotskey{}, snapshots)
return ctx
} }
func getSnapshots(ctx context.Context) map[uint32]experimental.Snapshot { func getSnapshots(ctx context.Context) *snapshots {
v, _ := ctx.Value(snapshotskey{}).(map[uint32]experimental.Snapshot) v, _ := ctx.Value(snapshotskey{}).(*snapshots)
return v return v
} }
// setjmp implements the C function: setjmp(env jmp_buf) // setjmp implements the C function: setjmp(env jmp_buf)
func setjmp(ctx context.Context, mod api.Module, stack []uint64) { func setjmp(ctx context.Context, _ api.Module, stack []uint64) {
// Input arguments. // Input arguments.
envptr := api.DecodeU32(stack[0]) envptr := api.DecodeU32(stack[0])
@ -35,19 +106,16 @@ func setjmp(ctx context.Context, mod api.Module, stack []uint64) {
// Get stored snapshots map. // Get stored snapshots map.
snapshots := getSnapshots(ctx) snapshots := getSnapshots(ctx)
if snapshots == nil {
panic("setjmp / longjmp not supported")
}
// Set latest snapshot in map. // Set latest snapshot in map.
snapshots[envptr] = snapshot snapshots.set(envptr, snapshot)
// Set return. // Set return.
stack[0] = 0 stack[0] = 0
} }
// longjmp implements the C function: int longjmp(env jmp_buf, value int) // longjmp implements the C function: int longjmp(env jmp_buf, value int)
func longjmp(ctx context.Context, mod api.Module, stack []uint64) { func longjmp(ctx context.Context, _ api.Module, stack []uint64) {
// Input arguments. // Input arguments.
envptr := api.DecodeU32(stack[0]) envptr := api.DecodeU32(stack[0])
@ -60,10 +128,7 @@ func longjmp(ctx context.Context, mod api.Module, stack []uint64) {
} }
// Get snapshot stored in map. // Get snapshot stored in map.
snapshot := snapshots[envptr] snapshot := snapshots.get(envptr)
if snapshot == nil {
panic("must first call setjmp")
}
// Set return. // Set return.
stack[0] = 0 stack[0] = 0

View file

@ -53,6 +53,7 @@ func Run(
modcfg = modcfg.WithStdin(args.Stdin) modcfg = modcfg.WithStdin(args.Stdin)
modcfg = modcfg.WithStdout(args.Stdout) modcfg = modcfg.WithStdout(args.Stdout)
modcfg = modcfg.WithStderr(args.Stderr) modcfg = modcfg.WithStderr(args.Stderr)
modcfg = modcfg.WithName("")
if args.Config != nil { if args.Config != nil {
// Pass through config fn. // Pass through config fn.

View file

@ -28,6 +28,7 @@ func NewRuntime(ctx context.Context, cfg wazero.RuntimeConfig) (wazero.Runtime,
// Set core features ffmpeg compiled with. // Set core features ffmpeg compiled with.
cfg = cfg.WithCoreFeatures(CoreFeatures) cfg = cfg.WithCoreFeatures(CoreFeatures)
cfg = cfg.WithDebugInfoEnabled(false)
// Instantiate runtime with prepared config. // Instantiate runtime with prepared config.
rt := wazero.NewRuntimeWithConfig(ctx, cfg) rt := wazero.NewRuntimeWithConfig(ctx, cfg)

View file

@ -1,17 +1,17 @@
package mempool package mempool
import ( import (
"sync"
"sync/atomic"
"unsafe" "unsafe"
"golang.org/x/sys/cpu"
) )
const DefaultDirtyFactor = 128 // Pool provides a form of SimplePool
// with the addition of concurrency safety.
// Pool provides a type-safe form
// of UnsafePool using generics.
//
// Note it is NOT safe for concurrent
// use, you must protect it yourself!
type Pool[T any] struct { type Pool[T any] struct {
UnsafePool
// New is an optionally provided // New is an optionally provided
// allocator used when no value // allocator used when no value
@ -21,79 +21,119 @@ type Pool[T any] struct {
// Reset is an optionally provided // Reset is an optionally provided
// value resetting function called // value resetting function called
// on passed value to Put(). // on passed value to Put().
Reset func(T) Reset func(T) bool
}
UnsafePool func NewPool[T any](new func() T, reset func(T) bool, check func(current, victim int) bool) Pool[T] {
return Pool[T]{
New: new,
Reset: reset,
UnsafePool: NewUnsafePool(check),
}
} }
func (p *Pool[T]) Get() T { func (p *Pool[T]) Get() T {
if ptr := p.UnsafePool.Get(); ptr != nil { if ptr := p.UnsafePool.Get(); ptr != nil {
return *(*T)(ptr) return *(*T)(ptr)
} else if p.New != nil {
return p.New()
} }
var z T var t T
return z if p.New != nil {
t = p.New()
}
return t
} }
func (p *Pool[T]) Put(t T) { func (p *Pool[T]) Put(t T) {
if p.Reset != nil { if p.Reset != nil && !p.Reset(t) {
p.Reset(t) return
} }
ptr := unsafe.Pointer(&t) ptr := unsafe.Pointer(&t)
p.UnsafePool.Put(ptr) p.UnsafePool.Put(ptr)
} }
// UnsafePool provides an incredibly // UnsafePool provides a form of UnsafeSimplePool
// simple memory pool implementation // with the addition of concurrency safety.
// that stores ptrs to memory values,
// and regularly flushes internal pool
// structures according to DirtyFactor.
//
// Note it is NOT safe for concurrent
// use, you must protect it yourself!
type UnsafePool struct { type UnsafePool struct {
internal
// DirtyFactor determines the max _ [cache_line_size - unsafe.Sizeof(internal{})%cache_line_size]byte
// number of $dirty count before
// pool is garbage collected. Where:
// $dirty = len(current) - len(victim)
DirtyFactor int
current []unsafe.Pointer
victim []unsafe.Pointer
} }
func (p *UnsafePool) Get() unsafe.Pointer { func NewUnsafePool(check func(current, victim int) bool) UnsafePool {
// First try current list. return UnsafePool{internal: internal{
if len(p.current) > 0 { pool: UnsafeSimplePool{Check: check},
ptr := p.current[len(p.current)-1] }}
p.current = p.current[:len(p.current)-1] }
const (
// current platform integer size.
int_size = 32 << (^uint(0) >> 63)
// platform CPU cache line size to avoid false sharing.
cache_line_size = unsafe.Sizeof(cpu.CacheLinePad{})
)
type internal struct {
// fast-access ring-buffer of
// pointers accessible by index.
//
// if Go ever exposes goroutine IDs
// to us we can make this a lot faster.
ring [int_size / 4]unsafe.Pointer
index atomic.Uint64
// underlying pool and
// slow mutex protection.
pool UnsafeSimplePool
mutex sync.Mutex
}
func (p *internal) Check(fn func(current, victim int) bool) func(current, victim int) bool {
p.mutex.Lock()
if fn == nil {
if p.pool.Check == nil {
fn = defaultCheck
} else {
fn = p.pool.Check
}
} else {
p.pool.Check = fn
}
p.mutex.Unlock()
return fn
}
func (p *internal) Get() unsafe.Pointer {
if ptr := atomic.SwapPointer(&p.ring[p.index.Load()%uint64(cap(p.ring))], nil); ptr != nil {
p.index.Add(^uint64(0)) // i.e. -1
return ptr return ptr
} }
p.mutex.Lock()
// Fallback to victim. ptr := p.pool.Get()
if len(p.victim) > 0 { p.mutex.Unlock()
ptr := p.victim[len(p.victim)-1] return ptr
p.victim = p.victim[:len(p.victim)-1]
return ptr
}
return nil
} }
func (p *UnsafePool) Put(ptr unsafe.Pointer) { func (p *internal) Put(ptr unsafe.Pointer) {
p.current = append(p.current, ptr) if atomic.CompareAndSwapPointer(&p.ring[p.index.Add(1)%uint64(cap(p.ring))], nil, ptr) {
return
// Get dirty factor.
df := p.DirtyFactor
if df == 0 {
df = DefaultDirtyFactor
}
if len(p.current)-len(p.victim) > df {
// Garbage collection!
p.victim = p.current
p.current = nil
} }
p.mutex.Lock()
p.pool.Put(ptr)
p.mutex.Unlock()
}
func (p *internal) GC() {
for i := range p.ring {
atomic.StorePointer(&p.ring[i], nil)
}
p.mutex.Lock()
p.pool.GC()
p.mutex.Unlock()
}
func (p *internal) Size() int {
p.mutex.Lock()
sz := p.pool.Size()
p.mutex.Unlock()
return sz
} }

111
vendor/codeberg.org/gruf/go-mempool/simple.go generated vendored Normal file
View file

@ -0,0 +1,111 @@
package mempool
import (
"unsafe"
)
// SimplePool provides a type-safe form
// of UnsafePool using generics.
//
// Note it is NOT safe for concurrent
// use, you must protect it yourself!
type SimplePool[T any] struct {
UnsafeSimplePool
// New is an optionally provided
// allocator used when no value
// is available for use in pool.
New func() T
// Reset is an optionally provided
// value resetting function called
// on passed value to Put().
Reset func(T) bool
}
func (p *SimplePool[T]) Get() T {
if ptr := p.UnsafeSimplePool.Get(); ptr != nil {
return *(*T)(ptr)
}
var t T
if p.New != nil {
t = p.New()
}
return t
}
func (p *SimplePool[T]) Put(t T) {
if p.Reset != nil && !p.Reset(t) {
return
}
ptr := unsafe.Pointer(&t)
p.UnsafeSimplePool.Put(ptr)
}
// UnsafeSimplePool provides an incredibly
// simple memory pool implementation
// that stores ptrs to memory values,
// and regularly flushes internal pool
// structures according to CheckGC().
//
// Note it is NOT safe for concurrent
// use, you must protect it yourself!
type UnsafeSimplePool struct {
// Check determines how often to flush
// internal pools based on underlying
// current and victim pool sizes. It gets
// called on every pool Put() operation.
//
// A flush will start a new current
// pool, make victim the old current,
// and drop the existing victim pool.
Check func(current, victim int) bool
current []unsafe.Pointer
victim []unsafe.Pointer
}
func (p *UnsafeSimplePool) Get() unsafe.Pointer {
// First try current list.
if len(p.current) > 0 {
ptr := p.current[len(p.current)-1]
p.current = p.current[:len(p.current)-1]
return ptr
}
// Fallback to victim.
if len(p.victim) > 0 {
ptr := p.victim[len(p.victim)-1]
p.victim = p.victim[:len(p.victim)-1]
return ptr
}
return nil
}
func (p *UnsafeSimplePool) Put(ptr unsafe.Pointer) {
p.current = append(p.current, ptr)
// Get GC check func.
if p.Check == nil {
p.Check = defaultCheck
}
if p.Check(len(p.current), len(p.victim)) {
p.GC() // garbage collection time!
}
}
func (p *UnsafeSimplePool) GC() {
p.victim = p.current
p.current = nil
}
func (p *UnsafeSimplePool) Size() int {
return len(p.current) + len(p.victim)
}
func defaultCheck(current, victim int) bool {
return current-victim > 128 || victim > 256
}

View file

@ -26,14 +26,13 @@ const (
type MutexMap struct { type MutexMap struct {
mapmu sync.Mutex mapmu sync.Mutex
mumap hashmap mumap hashmap
mupool mempool.UnsafePool mupool mempool.UnsafeSimplePool
} }
// checkInit ensures MutexMap is initialized (UNSAFE). // checkInit ensures MutexMap is initialized (UNSAFE).
func (mm *MutexMap) checkInit() { func (mm *MutexMap) checkInit() {
if mm.mumap.m == nil { if mm.mumap.m == nil {
mm.mumap.init(0) mm.mumap.init(0)
mm.mupool.DirtyFactor = 256
} }
} }
@ -175,13 +174,9 @@ func (mu *rwmutex) Lock(lt uint8) bool {
// sleeping goroutines waiting on this mutex. // sleeping goroutines waiting on this mutex.
func (mu *rwmutex) Unlock() bool { func (mu *rwmutex) Unlock() bool {
switch mu.l--; { switch mu.l--; {
case mu.l > 0 && mu.t == lockTypeWrite:
panic("BUG: multiple writer locks")
case mu.l < 0:
panic("BUG: negative lock count")
case mu.l == 0: case mu.l == 0:
// Fully unlocked. // Fully
// unlock.
mu.t = 0 mu.t = 0
// Awake all blocked goroutines and check // Awake all blocked goroutines and check
@ -197,11 +192,15 @@ func (mu *rwmutex) Unlock() bool {
// (before == after) => (waiters = 0) // (before == after) => (waiters = 0)
return (before == after) return (before == after)
default: case mu.l < 0:
// i.e. mutex still panic("BUG: negative lock count")
// locked by others. case mu.t == lockTypeWrite:
return false panic("BUG: multiple write locks")
} }
// i.e. mutex still
// locked by others.
return false
} }
// WaitRelock expects a mutex to be passed in, already in the // WaitRelock expects a mutex to be passed in, already in the

View file

@ -4,10 +4,10 @@ import (
"os" "os"
"reflect" "reflect"
"strings" "strings"
"sync"
"unsafe" "unsafe"
"codeberg.org/gruf/go-byteutil" "codeberg.org/gruf/go-byteutil"
"codeberg.org/gruf/go-mempool"
"codeberg.org/gruf/go-xunsafe" "codeberg.org/gruf/go-xunsafe"
) )
@ -371,17 +371,15 @@ type index_entry struct {
key string key string
} }
var index_entry_pool sync.Pool var index_entry_pool mempool.UnsafePool
// new_index_entry returns a new prepared index_entry. // new_index_entry returns a new prepared index_entry.
func new_index_entry() *index_entry { func new_index_entry() *index_entry {
v := index_entry_pool.Get() if ptr := index_entry_pool.Get(); ptr != nil {
if v == nil { return (*index_entry)(ptr)
e := new(index_entry)
e.elem.data = unsafe.Pointer(e)
v = e
} }
entry := v.(*index_entry) entry := new(index_entry)
entry.elem.data = unsafe.Pointer(entry)
return entry return entry
} }
@ -396,7 +394,8 @@ func free_index_entry(entry *index_entry) {
entry.key = "" entry.key = ""
entry.index = nil entry.index = nil
entry.item = nil entry.item = nil
index_entry_pool.Put(entry) ptr := unsafe.Pointer(entry)
index_entry_pool.Put(ptr)
} }
func is_unique(f uint8) bool { func is_unique(f uint8) bool {

View file

@ -2,8 +2,9 @@ package structr
import ( import (
"os" "os"
"sync"
"unsafe" "unsafe"
"codeberg.org/gruf/go-mempool"
) )
type indexed_item struct { type indexed_item struct {
@ -19,17 +20,15 @@ type indexed_item struct {
indexed []*index_entry indexed []*index_entry
} }
var indexed_item_pool sync.Pool var indexed_item_pool mempool.UnsafePool
// new_indexed_item returns a new prepared indexed_item. // new_indexed_item returns a new prepared indexed_item.
func new_indexed_item() *indexed_item { func new_indexed_item() *indexed_item {
v := indexed_item_pool.Get() if ptr := indexed_item_pool.Get(); ptr != nil {
if v == nil { return (*indexed_item)(ptr)
i := new(indexed_item)
i.elem.data = unsafe.Pointer(i)
v = i
} }
item := v.(*indexed_item) item := new(indexed_item)
item.elem.data = unsafe.Pointer(item)
return item return item
} }
@ -43,7 +42,8 @@ func free_indexed_item(item *indexed_item) {
return return
} }
item.data = nil item.data = nil
indexed_item_pool.Put(item) ptr := unsafe.Pointer(item)
indexed_item_pool.Put(ptr)
} }
// drop_index will drop the given index entry from item's indexed. // drop_index will drop the given index entry from item's indexed.

View file

@ -2,8 +2,9 @@ package structr
import ( import (
"os" "os"
"sync"
"unsafe" "unsafe"
"codeberg.org/gruf/go-mempool"
) )
// elem represents an elem // elem represents an elem
@ -27,16 +28,14 @@ type list struct {
len int len int
} }
var list_pool sync.Pool var list_pool mempool.UnsafePool
// new_list returns a new prepared list. // new_list returns a new prepared list.
func new_list() *list { func new_list() *list {
v := list_pool.Get() if ptr := list_pool.Get(); ptr != nil {
if v == nil { return (*list)(ptr)
v = new(list)
} }
list := v.(*list) return new(list)
return list
} }
// free_list releases the list. // free_list releases the list.
@ -48,11 +47,13 @@ func free_list(list *list) {
os.Stderr.WriteString(msg + "\n") os.Stderr.WriteString(msg + "\n")
return return
} }
list_pool.Put(list) ptr := unsafe.Pointer(list)
list_pool.Put(ptr)
} }
// push_front will push the given elem to front (head) of list. // push_front will push the given elem to front (head) of list.
func (l *list) push_front(elem *list_elem) { func (l *list) push_front(elem *list_elem) {
// Set new head. // Set new head.
oldHead := l.head oldHead := l.head
l.head = elem l.head = elem
@ -66,12 +67,14 @@ func (l *list) push_front(elem *list_elem) {
l.tail = elem l.tail = elem
} }
// Incr count // Incr
// count
l.len++ l.len++
} }
// push_back will push the given elem to back (tail) of list. // push_back will push the given elem to back (tail) of list.
func (l *list) push_back(elem *list_elem) { func (l *list) push_back(elem *list_elem) {
// Set new tail. // Set new tail.
oldTail := l.tail oldTail := l.tail
l.tail = elem l.tail = elem
@ -85,7 +88,8 @@ func (l *list) push_back(elem *list_elem) {
l.head = elem l.head = elem
} }
// Incr count // Incr
// count
l.len++ l.len++
} }
@ -131,7 +135,8 @@ func (l *list) insert(elem *list_elem, at *list_elem) {
elem.next = oldNext elem.next = oldNext
} }
// Incr count // Incr
// count
l.len++ l.len++
} }
@ -174,6 +179,7 @@ func (l *list) remove(elem *list_elem) {
prev.next = next prev.next = next
} }
// Decr count // Decr
// count
l.len-- l.len--
} }

View file

@ -146,7 +146,7 @@ func find_field(t xunsafe.TypeIter, names []string) (sfield struct_field, ftype
sfield.mangle = mangler.Get(t) sfield.mangle = mangler.Get(t)
// Calculate zero value string. // Calculate zero value string.
zptr := zero_value_field(o, sfield.offsets) zptr := zero_value_ptr(o, sfield.offsets)
zstr := string(sfield.mangle(nil, zptr)) zstr := string(sfield.mangle(nil, zptr))
sfield.zerostr = zstr sfield.zerostr = zstr
sfield.zero = zptr sfield.zero = zptr
@ -154,7 +154,9 @@ func find_field(t xunsafe.TypeIter, names []string) (sfield struct_field, ftype
return return
} }
// zero_value ... // zero_value iterates the type contained in TypeIter{} along the given
// next_offset{} values, creating new ptrs where necessary, returning the
// zero reflect.Value{} after fully iterating the next_offset{} slice.
func zero_value(t xunsafe.TypeIter, offsets []next_offset) reflect.Value { func zero_value(t xunsafe.TypeIter, offsets []next_offset) reflect.Value {
v := reflect.New(t.Type).Elem() v := reflect.New(t.Type).Elem()
for _, offset := range offsets { for _, offset := range offsets {
@ -175,8 +177,8 @@ func zero_value(t xunsafe.TypeIter, offsets []next_offset) reflect.Value {
return v return v
} }
// zero_value_field ... // zero_value_ptr returns the unsafe pointer address of the result of zero_value().
func zero_value_field(t xunsafe.TypeIter, offsets []next_offset) unsafe.Pointer { func zero_value_ptr(t xunsafe.TypeIter, offsets []next_offset) unsafe.Pointer {
return zero_value(t, offsets).Addr().UnsafePointer() return zero_value(t, offsets).Addr().UnsafePointer()
} }

View file

@ -8,6 +8,8 @@ import (
"strings" "strings"
"sync" "sync"
"unsafe" "unsafe"
"codeberg.org/gruf/go-mempool"
) )
// Direction defines a direction // Direction defines a direction
@ -1133,18 +1135,16 @@ func to_timeline_item(item *indexed_item) *timeline_item {
return to return to
} }
var timeline_item_pool sync.Pool var timeline_item_pool mempool.UnsafePool
// new_timeline_item returns a new prepared timeline_item. // new_timeline_item returns a new prepared timeline_item.
func new_timeline_item() *timeline_item { func new_timeline_item() *timeline_item {
v := timeline_item_pool.Get() if ptr := timeline_item_pool.Get(); ptr != nil {
if v == nil { return (*timeline_item)(ptr)
i := new(timeline_item)
i.elem.data = unsafe.Pointer(i)
i.ck = ^uint(0)
v = i
} }
item := v.(*timeline_item) item := new(timeline_item)
item.elem.data = unsafe.Pointer(item)
item.ck = ^uint(0)
return item return item
} }
@ -1159,5 +1159,6 @@ func free_timeline_item(item *timeline_item) {
} }
item.data = nil item.data = nil
item.pk = nil item.pk = nil
timeline_item_pool.Put(item) ptr := unsafe.Pointer(item)
timeline_item_pool.Put(ptr)
} }

View file

@ -444,20 +444,27 @@ func (c *Conn) Status(op DBStatus, reset bool) (current, highwater int, err erro
// https://sqlite.org/c3ref/table_column_metadata.html // https://sqlite.org/c3ref/table_column_metadata.html
func (c *Conn) TableColumnMetadata(schema, table, column string) (declType, collSeq string, notNull, primaryKey, autoInc bool, err error) { func (c *Conn) TableColumnMetadata(schema, table, column string) (declType, collSeq string, notNull, primaryKey, autoInc bool, err error) {
defer c.arena.mark()() defer c.arena.mark()()
var (
var schemaPtr, columnPtr ptr_t declTypePtr ptr_t
declTypePtr := c.arena.new(ptrlen) collSeqPtr ptr_t
collSeqPtr := c.arena.new(ptrlen) notNullPtr ptr_t
notNullPtr := c.arena.new(ptrlen) primaryKeyPtr ptr_t
autoIncPtr := c.arena.new(ptrlen) autoIncPtr ptr_t
primaryKeyPtr := c.arena.new(ptrlen) columnPtr ptr_t
schemaPtr ptr_t
)
if column != "" {
declTypePtr = c.arena.new(ptrlen)
collSeqPtr = c.arena.new(ptrlen)
notNullPtr = c.arena.new(ptrlen)
primaryKeyPtr = c.arena.new(ptrlen)
autoIncPtr = c.arena.new(ptrlen)
columnPtr = c.arena.string(column)
}
if schema != "" { if schema != "" {
schemaPtr = c.arena.string(schema) schemaPtr = c.arena.string(schema)
} }
tablePtr := c.arena.string(table) tablePtr := c.arena.string(table)
if column != "" {
columnPtr = c.arena.string(column)
}
rc := res_t(c.call("sqlite3_table_column_metadata", stk_t(c.handle), rc := res_t(c.call("sqlite3_table_column_metadata", stk_t(c.handle),
stk_t(schemaPtr), stk_t(tablePtr), stk_t(columnPtr), stk_t(schemaPtr), stk_t(tablePtr), stk_t(columnPtr),

View file

@ -1,7 +1,6 @@
package sqlite3 package sqlite3
import ( import (
"encoding/json"
"errors" "errors"
"math" "math"
"time" "time"
@ -173,21 +172,6 @@ func (ctx Context) ResultPointer(ptr any) {
stk_t(ctx.handle), stk_t(valPtr)) stk_t(ctx.handle), stk_t(valPtr))
} }
// ResultJSON sets the result of the function to the JSON encoding of value.
//
// https://sqlite.org/c3ref/result_blob.html
func (ctx Context) ResultJSON(value any) {
err := json.NewEncoder(callbackWriter(func(p []byte) (int, error) {
ctx.ResultRawText(p[:len(p)-1]) // remove the newline
return 0, nil
})).Encode(value)
if err != nil {
ctx.ResultError(err)
return // notest
}
}
// ResultValue sets the result of the function to a copy of [Value]. // ResultValue sets the result of the function to a copy of [Value].
// //
// https://sqlite.org/c3ref/result_blob.html // https://sqlite.org/c3ref/result_blob.html

View file

@ -607,14 +607,24 @@ func (r resultRowsAffected) RowsAffected() (int64, error) {
type scantype byte type scantype byte
const ( const (
_ANY scantype = iota _ANY scantype = iota
_INT scantype = scantype(sqlite3.INTEGER) _INT
_REAL scantype = scantype(sqlite3.FLOAT) _REAL
_TEXT scantype = scantype(sqlite3.TEXT) _TEXT
_BLOB scantype = scantype(sqlite3.BLOB) _BLOB
_NULL scantype = scantype(sqlite3.NULL) _NULL
_BOOL scantype = iota _BOOL
_TIME _TIME
_NOT_NULL
)
var (
_ [0]struct{} = [scantype(sqlite3.INTEGER) - _INT]struct{}{}
_ [0]struct{} = [scantype(sqlite3.FLOAT) - _REAL]struct{}{}
_ [0]struct{} = [scantype(sqlite3.TEXT) - _TEXT]struct{}{}
_ [0]struct{} = [scantype(sqlite3.BLOB) - _BLOB]struct{}{}
_ [0]struct{} = [scantype(sqlite3.NULL) - _NULL]struct{}{}
_ [0]struct{} = [_NOT_NULL & (_NOT_NULL - 1)]struct{}{}
) )
func scanFromDecl(decl string) scantype { func scanFromDecl(decl string) scantype {
@ -644,8 +654,8 @@ type rows struct {
*stmt *stmt
names []string names []string
types []string types []string
nulls []bool
scans []scantype scans []scantype
dest []driver.Value
} }
var ( var (
@ -675,34 +685,36 @@ func (r *rows) Columns() []string {
func (r *rows) scanType(index int) scantype { func (r *rows) scanType(index int) scantype {
if r.scans == nil { if r.scans == nil {
count := r.Stmt.ColumnCount() count := len(r.names)
scans := make([]scantype, count) scans := make([]scantype, count)
for i := range scans { for i := range scans {
scans[i] = scanFromDecl(strings.ToUpper(r.Stmt.ColumnDeclType(i))) scans[i] = scanFromDecl(strings.ToUpper(r.Stmt.ColumnDeclType(i)))
} }
r.scans = scans r.scans = scans
} }
return r.scans[index] return r.scans[index] &^ _NOT_NULL
} }
func (r *rows) loadColumnMetadata() { func (r *rows) loadColumnMetadata() {
if r.nulls == nil { if r.types == nil {
c := r.Stmt.Conn() c := r.Stmt.Conn()
count := r.Stmt.ColumnCount() count := len(r.names)
nulls := make([]bool, count)
types := make([]string, count) types := make([]string, count)
scans := make([]scantype, count) scans := make([]scantype, count)
for i := range nulls { for i := range types {
var notnull bool
if col := r.Stmt.ColumnOriginName(i); col != "" { if col := r.Stmt.ColumnOriginName(i); col != "" {
types[i], _, nulls[i], _, _, _ = c.TableColumnMetadata( types[i], _, notnull, _, _, _ = c.TableColumnMetadata(
r.Stmt.ColumnDatabaseName(i), r.Stmt.ColumnDatabaseName(i),
r.Stmt.ColumnTableName(i), r.Stmt.ColumnTableName(i),
col) col)
types[i] = strings.ToUpper(types[i]) types[i] = strings.ToUpper(types[i])
scans[i] = scanFromDecl(types[i]) scans[i] = scanFromDecl(types[i])
if notnull {
scans[i] |= _NOT_NULL
}
} }
} }
r.nulls = nulls
r.types = types r.types = types
r.scans = scans r.scans = scans
} }
@ -721,15 +733,13 @@ func (r *rows) ColumnTypeDatabaseTypeName(index int) string {
func (r *rows) ColumnTypeNullable(index int) (nullable, ok bool) { func (r *rows) ColumnTypeNullable(index int) (nullable, ok bool) {
r.loadColumnMetadata() r.loadColumnMetadata()
if r.nulls[index] { nullable = r.scans[index]&^_NOT_NULL == 0
return false, true return nullable, !nullable
}
return true, false
} }
func (r *rows) ColumnTypeScanType(index int) (typ reflect.Type) { func (r *rows) ColumnTypeScanType(index int) (typ reflect.Type) {
r.loadColumnMetadata() r.loadColumnMetadata()
scan := r.scans[index] scan := r.scans[index] &^ _NOT_NULL
if r.Stmt.Busy() { if r.Stmt.Busy() {
// SQLite is dynamically typed and we now have a row. // SQLite is dynamically typed and we now have a row.
@ -772,6 +782,7 @@ func (r *rows) ColumnTypeScanType(index int) (typ reflect.Type) {
} }
func (r *rows) Next(dest []driver.Value) error { func (r *rows) Next(dest []driver.Value) error {
r.dest = nil
c := r.Stmt.Conn() c := r.Stmt.Conn()
if old := c.SetInterrupt(r.ctx); old != r.ctx { if old := c.SetInterrupt(r.ctx); old != r.ctx {
defer c.SetInterrupt(old) defer c.SetInterrupt(old)
@ -790,18 +801,7 @@ func (r *rows) Next(dest []driver.Value) error {
} }
for i := range dest { for i := range dest {
scan := r.scanType(i) scan := r.scanType(i)
switch v := dest[i].(type) { if v, ok := dest[i].([]byte); ok {
case int64:
if scan == _BOOL {
switch v {
case 1:
dest[i] = true
case 0:
dest[i] = false
}
continue
}
case []byte:
if len(v) == cap(v) { // a BLOB if len(v) == cap(v) { // a BLOB
continue continue
} }
@ -816,38 +816,49 @@ func (r *rows) Next(dest []driver.Value) error {
} }
} }
dest[i] = string(v) dest[i] = string(v)
case float64:
break
default:
continue
} }
if scan == _TIME { switch scan {
case _TIME:
t, err := r.tmRead.Decode(dest[i]) t, err := r.tmRead.Decode(dest[i])
if err == nil { if err == nil {
dest[i] = t dest[i] = t
continue }
case _BOOL:
switch dest[i] {
case int64(0):
dest[i] = false
case int64(1):
dest[i] = true
} }
} }
} }
r.dest = dest
return nil return nil
} }
func (r *rows) ScanColumn(dest any, index int) error { func (r *rows) ScanColumn(dest any, index int) (err error) {
// notest // Go 1.26 // notest // Go 1.26
var ptr *time.Time var tm *time.Time
var ok *bool
switch d := dest.(type) { switch d := dest.(type) {
case *time.Time: case *time.Time:
ptr = d tm = d
case *sql.NullTime: case *sql.NullTime:
ptr = &d.Time tm = &d.Time
ok = &d.Valid
case *sql.Null[time.Time]: case *sql.Null[time.Time]:
ptr = &d.V tm = &d.V
ok = &d.Valid
default: default:
return driver.ErrSkip return driver.ErrSkip
} }
if t := r.Stmt.ColumnTime(index, r.tmRead); !t.IsZero() { value := r.dest[index]
*ptr = t *tm, err = r.tmRead.Decode(value)
return nil if ok != nil {
*ok = err == nil
if value == nil {
return nil
}
} }
return driver.ErrSkip return err
} }

View file

@ -1,3 +1,5 @@
//go:build !goexperiment.jsonv2
package util package util
import ( import (

View file

@ -0,0 +1,52 @@
//go:build goexperiment.jsonv2
package util
import (
"encoding/json/v2"
"math"
"strconv"
"time"
"unsafe"
)
type JSON struct{ Value any }
func (j JSON) Scan(value any) error {
var buf []byte
switch v := value.(type) {
case []byte:
buf = v
case string:
buf = unsafe.Slice(unsafe.StringData(v), len(v))
case int64:
buf = strconv.AppendInt(nil, v, 10)
case float64:
buf = AppendNumber(nil, v)
case time.Time:
buf = append(buf, '"')
buf = v.AppendFormat(buf, time.RFC3339Nano)
buf = append(buf, '"')
case nil:
buf = []byte("null")
default:
panic(AssertErr())
}
return json.Unmarshal(buf, j.Value)
}
func AppendNumber(dst []byte, f float64) []byte {
switch {
case math.IsNaN(f):
dst = append(dst, "null"...)
case math.IsInf(f, 1):
dst = append(dst, "9.0e999"...)
case math.IsInf(f, -1):
dst = append(dst, "-9.0e999"...)
default:
return strconv.AppendFloat(dst, f, 'g', -1, 64)
}
return dst
}

View file

@ -1,6 +1,13 @@
//go:build !goexperiment.jsonv2
package sqlite3 package sqlite3
import "github.com/ncruces/go-sqlite3/internal/util" import (
"encoding/json"
"strconv"
"github.com/ncruces/go-sqlite3/internal/util"
)
// JSON returns a value that can be used as an argument to // JSON returns a value that can be used as an argument to
// [database/sql.DB.Exec], [database/sql.Row.Scan] and similar methods to // [database/sql.DB.Exec], [database/sql.Row.Scan] and similar methods to
@ -10,3 +17,77 @@ import "github.com/ncruces/go-sqlite3/internal/util"
func JSON(value any) any { func JSON(value any) any {
return util.JSON{Value: value} return util.JSON{Value: value}
} }
// ResultJSON sets the result of the function to the JSON encoding of value.
//
// https://sqlite.org/c3ref/result_blob.html
func (ctx Context) ResultJSON(value any) {
err := json.NewEncoder(callbackWriter(func(p []byte) (int, error) {
ctx.ResultRawText(p[:len(p)-1]) // remove the newline
return 0, nil
})).Encode(value)
if err != nil {
ctx.ResultError(err)
return // notest
}
}
// BindJSON binds the JSON encoding of value to the prepared statement.
// The leftmost SQL parameter has an index of 1.
//
// https://sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindJSON(param int, value any) error {
return json.NewEncoder(callbackWriter(func(p []byte) (int, error) {
return 0, s.BindRawText(param, p[:len(p)-1]) // remove the newline
})).Encode(value)
}
// ColumnJSON parses the JSON-encoded value of the result column
// and stores it in the value pointed to by ptr.
// The leftmost column of the result set has the index 0.
//
// https://sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnJSON(col int, ptr any) error {
var data []byte
switch s.ColumnType(col) {
case NULL:
data = []byte("null")
case TEXT:
data = s.ColumnRawText(col)
case BLOB:
data = s.ColumnRawBlob(col)
case INTEGER:
data = strconv.AppendInt(nil, s.ColumnInt64(col), 10)
case FLOAT:
data = util.AppendNumber(nil, s.ColumnFloat(col))
default:
panic(util.AssertErr())
}
return json.Unmarshal(data, ptr)
}
// JSON parses a JSON-encoded value
// and stores the result in the value pointed to by ptr.
func (v Value) JSON(ptr any) error {
var data []byte
switch v.Type() {
case NULL:
data = []byte("null")
case TEXT:
data = v.RawText()
case BLOB:
data = v.RawBlob()
case INTEGER:
data = strconv.AppendInt(nil, v.Int64(), 10)
case FLOAT:
data = util.AppendNumber(nil, v.Float())
default:
panic(util.AssertErr())
}
return json.Unmarshal(data, ptr)
}
type callbackWriter func(p []byte) (int, error)
func (fn callbackWriter) Write(p []byte) (int, error) { return fn(p) }

113
vendor/github.com/ncruces/go-sqlite3/json_v2.go generated vendored Normal file
View file

@ -0,0 +1,113 @@
//go:build goexperiment.jsonv2
package sqlite3
import (
"encoding/json/v2"
"strconv"
"github.com/ncruces/go-sqlite3/internal/util"
)
// JSON returns a value that can be used as an argument to
// [database/sql.DB.Exec], [database/sql.Row.Scan] and similar methods to
// store value as JSON, or decode JSON into value.
// JSON should NOT be used with [Stmt.BindJSON], [Stmt.ColumnJSON],
// [Value.JSON], or [Context.ResultJSON].
func JSON(value any) any {
return util.JSON{Value: value}
}
// ResultJSON sets the result of the function to the JSON encoding of value.
//
// https://sqlite.org/c3ref/result_blob.html
func (ctx Context) ResultJSON(value any) {
w := bytesWriter{sqlite: ctx.c.sqlite}
if err := json.MarshalWrite(&w, value); err != nil {
ctx.c.free(w.ptr)
ctx.ResultError(err)
return // notest
}
ctx.c.call("sqlite3_result_text_go",
stk_t(ctx.handle), stk_t(w.ptr), stk_t(len(w.buf)))
}
// BindJSON binds the JSON encoding of value to the prepared statement.
// The leftmost SQL parameter has an index of 1.
//
// https://sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindJSON(param int, value any) error {
w := bytesWriter{sqlite: s.c.sqlite}
if err := json.MarshalWrite(&w, value); err != nil {
s.c.free(w.ptr)
return err
}
rc := res_t(s.c.call("sqlite3_bind_text_go",
stk_t(s.handle), stk_t(param),
stk_t(w.ptr), stk_t(len(w.buf))))
return s.c.error(rc)
}
// ColumnJSON parses the JSON-encoded value of the result column
// and stores it in the value pointed to by ptr.
// The leftmost column of the result set has the index 0.
//
// https://sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnJSON(col int, ptr any) error {
var data []byte
switch s.ColumnType(col) {
case NULL:
data = []byte("null")
case TEXT:
data = s.ColumnRawText(col)
case BLOB:
data = s.ColumnRawBlob(col)
case INTEGER:
data = strconv.AppendInt(nil, s.ColumnInt64(col), 10)
case FLOAT:
data = util.AppendNumber(nil, s.ColumnFloat(col))
default:
panic(util.AssertErr())
}
return json.Unmarshal(data, ptr)
}
// JSON parses a JSON-encoded value
// and stores the result in the value pointed to by ptr.
func (v Value) JSON(ptr any) error {
var data []byte
switch v.Type() {
case NULL:
data = []byte("null")
case TEXT:
data = v.RawText()
case BLOB:
data = v.RawBlob()
case INTEGER:
data = strconv.AppendInt(nil, v.Int64(), 10)
case FLOAT:
data = util.AppendNumber(nil, v.Float())
default:
panic(util.AssertErr())
}
return json.Unmarshal(data, ptr)
}
type bytesWriter struct {
*sqlite
buf []byte
ptr ptr_t
}
func (b *bytesWriter) Write(p []byte) (n int, err error) {
if len(p) > cap(b.buf)-len(b.buf) {
want := int64(len(b.buf)) + int64(len(p))
grow := int64(cap(b.buf))
grow += grow >> 1
want = max(want, grow)
b.ptr = b.realloc(b.ptr, want)
b.buf = util.View(b.mod, b.ptr, want)[:len(b.buf)]
}
b.buf = append(b.buf, p...)
return len(p), nil
}

View file

@ -5,6 +5,7 @@ import (
"context" "context"
"math/bits" "math/bits"
"os" "os"
"strings"
"sync" "sync"
"unsafe" "unsafe"
@ -128,11 +129,10 @@ func (sqlt *sqlite) error(rc res_t, handle ptr_t, sql ...string) error {
var msg, query string var msg, query string
if ptr := ptr_t(sqlt.call("sqlite3_errmsg", stk_t(handle))); ptr != 0 { if ptr := ptr_t(sqlt.call("sqlite3_errmsg", stk_t(handle))); ptr != 0 {
msg = util.ReadString(sqlt.mod, ptr, _MAX_LENGTH) msg = util.ReadString(sqlt.mod, ptr, _MAX_LENGTH)
switch { if msg == "not an error" {
case msg == "not an error":
msg = ""
case msg == util.ErrorCodeString(uint32(rc))[len("sqlite3: "):]:
msg = "" msg = ""
} else {
msg = strings.TrimPrefix(msg, util.ErrorCodeString(uint32(rc))[len("sqlite3: "):])
} }
} }

View file

@ -1,9 +1,7 @@
package sqlite3 package sqlite3
import ( import (
"encoding/json"
"math" "math"
"strconv"
"time" "time"
"github.com/ncruces/go-sqlite3/internal/util" "github.com/ncruces/go-sqlite3/internal/util"
@ -362,16 +360,6 @@ func (s *Stmt) BindPointer(param int, ptr any) error {
return s.c.error(rc) return s.c.error(rc)
} }
// BindJSON binds the JSON encoding of value to the prepared statement.
// The leftmost SQL parameter has an index of 1.
//
// https://sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindJSON(param int, value any) error {
return json.NewEncoder(callbackWriter(func(p []byte) (int, error) {
return 0, s.BindRawText(param, p[:len(p)-1]) // remove the newline
})).Encode(value)
}
// BindValue binds a copy of value to the prepared statement. // BindValue binds a copy of value to the prepared statement.
// The leftmost SQL parameter has an index of 1. // The leftmost SQL parameter has an index of 1.
// //
@ -598,30 +586,6 @@ func (s *Stmt) columnRawBytes(col int, ptr ptr_t, nul int32) []byte {
return util.View(s.c.mod, ptr, int64(n+nul))[:n] return util.View(s.c.mod, ptr, int64(n+nul))[:n]
} }
// ColumnJSON parses the JSON-encoded value of the result column
// and stores it in the value pointed to by ptr.
// The leftmost column of the result set has the index 0.
//
// https://sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnJSON(col int, ptr any) error {
var data []byte
switch s.ColumnType(col) {
case NULL:
data = []byte("null")
case TEXT:
data = s.ColumnRawText(col)
case BLOB:
data = s.ColumnRawBlob(col)
case INTEGER:
data = strconv.AppendInt(nil, s.ColumnInt64(col), 10)
case FLOAT:
data = util.AppendNumber(nil, s.ColumnFloat(col))
default:
panic(util.AssertErr())
}
return json.Unmarshal(data, ptr)
}
// ColumnValue returns the unprotected value of the result column. // ColumnValue returns the unprotected value of the result column.
// The leftmost column of the result set has the index 0. // The leftmost column of the result set has the index 0.
// //
@ -748,7 +712,3 @@ func (s *Stmt) columns(count int64) ([]byte, ptr_t, error) {
return util.View(s.c.mod, typePtr, count), dataPtr, nil return util.View(s.c.mod, typePtr, count), dataPtr, nil
} }
type callbackWriter func(p []byte) (int, error)
func (fn callbackWriter) Write(p []byte) (int, error) { return fn(p) }

View file

@ -94,7 +94,7 @@ func (f TimeFormat) Encode(t time.Time) any {
case TimeFormatUnix: case TimeFormatUnix:
return t.Unix() return t.Unix()
case TimeFormatUnixFrac: case TimeFormatUnixFrac:
return float64(t.Unix()) + float64(t.Nanosecond())*1e-9 return math.FMA(1e-9, float64(t.Nanosecond()), float64(t.Unix()))
case TimeFormatUnixMilli: case TimeFormatUnixMilli:
return t.UnixMilli() return t.UnixMilli()
case TimeFormatUnixMicro: case TimeFormatUnixMicro:

View file

@ -1,9 +1,7 @@
package sqlite3 package sqlite3
import ( import (
"encoding/json"
"math" "math"
"strconv"
"time" "time"
"github.com/ncruces/go-sqlite3/internal/util" "github.com/ncruces/go-sqlite3/internal/util"
@ -162,27 +160,6 @@ func (v Value) Pointer() any {
return util.GetHandle(v.c.ctx, ptr) return util.GetHandle(v.c.ctx, ptr)
} }
// JSON parses a JSON-encoded value
// and stores the result in the value pointed to by ptr.
func (v Value) JSON(ptr any) error {
var data []byte
switch v.Type() {
case NULL:
data = []byte("null")
case TEXT:
data = v.RawText()
case BLOB:
data = v.RawBlob()
case INTEGER:
data = strconv.AppendInt(nil, v.Int64(), 10)
case FLOAT:
data = util.AppendNumber(nil, v.Float())
default:
panic(util.AssertErr())
}
return json.Unmarshal(data, ptr)
}
// NoChange returns true if and only if the value is unchanged // NoChange returns true if and only if the value is unchanged
// in a virtual table update operatiom. // in a virtual table update operatiom.
// //

View file

@ -94,6 +94,10 @@ const (
OPEN_PRIVATECACHE OpenFlag = 0x00040000 /* Ok for sqlite3_open_v2() */ OPEN_PRIVATECACHE OpenFlag = 0x00040000 /* Ok for sqlite3_open_v2() */
OPEN_WAL OpenFlag = 0x00080000 /* VFS only */ OPEN_WAL OpenFlag = 0x00080000 /* VFS only */
OPEN_NOFOLLOW OpenFlag = 0x01000000 /* Ok for sqlite3_open_v2() */ OPEN_NOFOLLOW OpenFlag = 0x01000000 /* Ok for sqlite3_open_v2() */
_FLAG_ATOMIC OpenFlag = 0x10000000
_FLAG_KEEP_WAL OpenFlag = 0x20000000
_FLAG_PSOW OpenFlag = 0x40000000
_FLAG_SYNC_DIR OpenFlag = 0x80000000
) )
// AccessFlag is a flag for the [VFS] Access method. // AccessFlag is a flag for the [VFS] Access method.

View file

@ -51,7 +51,7 @@ func (vfsOS) Delete(path string, syncDir bool) error {
return _OK return _OK
} }
defer f.Close() defer f.Close()
err = osSync(f, false, false) err = osSync(f, 0, SYNC_FULL)
if err != nil { if err != nil {
return _IOERR_DIR_FSYNC return _IOERR_DIR_FSYNC
} }
@ -131,27 +131,24 @@ func (vfsOS) OpenFilename(name *Filename, flags OpenFlag) (File, OpenFlag, error
} }
file := vfsFile{ file := vfsFile{
File: f, File: f,
psow: true, flags: flags | _FLAG_PSOW,
atomic: osBatchAtomic(f), shm: NewSharedMemory(name.String()+"-shm", flags),
readOnly: flags&OPEN_READONLY != 0, }
syncDir: isUnix && isCreate && isJournl, if osBatchAtomic(f) {
delete: !isUnix && flags&OPEN_DELETEONCLOSE != 0, file.flags |= _FLAG_ATOMIC
shm: NewSharedMemory(name.String()+"-shm", flags), }
if isUnix && isCreate && isJournl {
file.flags |= _FLAG_SYNC_DIR
} }
return &file, flags, nil return &file, flags, nil
} }
type vfsFile struct { type vfsFile struct {
*os.File *os.File
shm SharedMemory shm SharedMemory
lock LockLevel lock LockLevel
readOnly bool flags OpenFlag
keepWAL bool
syncDir bool
atomic bool
delete bool
psow bool
} }
var ( var (
@ -164,7 +161,7 @@ var (
) )
func (f *vfsFile) Close() error { func (f *vfsFile) Close() error {
if f.delete { if !isUnix && f.flags&OPEN_DELETEONCLOSE != 0 {
defer os.Remove(f.Name()) defer os.Remove(f.Name())
} }
if f.shm != nil { if f.shm != nil {
@ -183,21 +180,18 @@ func (f *vfsFile) WriteAt(p []byte, off int64) (n int, err error) {
} }
func (f *vfsFile) Sync(flags SyncFlag) error { func (f *vfsFile) Sync(flags SyncFlag) error {
dataonly := (flags & SYNC_DATAONLY) != 0 err := osSync(f.File, f.flags, flags)
fullsync := (flags & 0x0f) == SYNC_FULL
err := osSync(f.File, fullsync, dataonly)
if err != nil { if err != nil {
return err return err
} }
if isUnix && f.syncDir { if isUnix && f.flags&_FLAG_SYNC_DIR != 0 {
f.syncDir = false f.flags ^= _FLAG_SYNC_DIR
d, err := os.Open(filepath.Dir(f.File.Name())) d, err := os.Open(filepath.Dir(f.File.Name()))
if err != nil { if err != nil {
return nil return nil
} }
defer d.Close() defer d.Close()
err = osSync(d, false, false) err = osSync(f.File, f.flags, flags)
if err != nil { if err != nil {
return _IOERR_DIR_FSYNC return _IOERR_DIR_FSYNC
} }
@ -215,10 +209,10 @@ func (f *vfsFile) SectorSize() int {
func (f *vfsFile) DeviceCharacteristics() DeviceCharacteristic { func (f *vfsFile) DeviceCharacteristics() DeviceCharacteristic {
ret := IOCAP_SUBPAGE_READ ret := IOCAP_SUBPAGE_READ
if f.atomic { if f.flags&_FLAG_ATOMIC != 0 {
ret |= IOCAP_BATCH_ATOMIC ret |= IOCAP_BATCH_ATOMIC
} }
if f.psow { if f.flags&_FLAG_PSOW != 0 {
ret |= IOCAP_POWERSAFE_OVERWRITE ret |= IOCAP_POWERSAFE_OVERWRITE
} }
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
@ -249,8 +243,20 @@ func (f *vfsFile) HasMoved() (bool, error) {
return !os.SameFile(fi, pi), nil return !os.SameFile(fi, pi), nil
} }
func (f *vfsFile) LockState() LockLevel { return f.lock } func (f *vfsFile) LockState() LockLevel { return f.lock }
func (f *vfsFile) PowersafeOverwrite() bool { return f.psow } func (f *vfsFile) PowersafeOverwrite() bool { return f.flags&_FLAG_PSOW != 0 }
func (f *vfsFile) PersistWAL() bool { return f.keepWAL } func (f *vfsFile) PersistWAL() bool { return f.flags&_FLAG_KEEP_WAL != 0 }
func (f *vfsFile) SetPowersafeOverwrite(psow bool) { f.psow = psow }
func (f *vfsFile) SetPersistWAL(keepWAL bool) { f.keepWAL = keepWAL } func (f *vfsFile) SetPowersafeOverwrite(psow bool) {
f.flags &^= _FLAG_PSOW
if psow {
f.flags |= _FLAG_PSOW
}
}
func (f *vfsFile) SetPersistWAL(keepWAL bool) {
f.flags &^= _FLAG_KEEP_WAL
if keepWAL {
f.flags |= _FLAG_KEEP_WAL
}
}

View file

@ -41,7 +41,7 @@ func (f *vfsFile) Lock(lock LockLevel) error {
} }
// Do not allow any kind of write-lock on a read-only database. // Do not allow any kind of write-lock on a read-only database.
if f.readOnly && lock >= LOCK_RESERVED { if lock >= LOCK_RESERVED && f.flags&OPEN_READONLY != 0 {
return _IOERR_LOCK return _IOERR_LOCK
} }

View file

@ -7,3 +7,6 @@ It has some benefits over the C version:
- the memory backing the database needs not be contiguous, - the memory backing the database needs not be contiguous,
- the database can grow/shrink incrementally without copying, - the database can grow/shrink incrementally without copying,
- reader-writer concurrency is slightly improved. - reader-writer concurrency is slightly improved.
[`memdb.TestDB`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/vfs/memdb#TestDB)
is the preferred way to setup an in-memory database for testing.

View file

@ -10,6 +10,7 @@
package memdb package memdb
import ( import (
"crypto/rand"
"fmt" "fmt"
"net/url" "net/url"
"sync" "sync"
@ -74,11 +75,27 @@ func Delete(name string) {
// TestDB creates an empty shared memory database for the test to use. // TestDB creates an empty shared memory database for the test to use.
// The database is automatically deleted when the test and all its subtests complete. // The database is automatically deleted when the test and all its subtests complete.
// Returns a URI filename appropriate to call Open with.
// Each subsequent call to TestDB returns a unique database. // Each subsequent call to TestDB returns a unique database.
//
// func Test_something(t *testing.T) {
// t.Parallel()
// dsn := memdb.TestDB(t, url.Values{
// "_pragma": {"busy_timeout(1000)"},
// })
//
// db, err := sql.Open("sqlite3", dsn)
// if err != nil {
// t.Fatal(err)
// }
// defer db.Close()
//
// // ...
// }
func TestDB(tb testing.TB, params ...url.Values) string { func TestDB(tb testing.TB, params ...url.Values) string {
tb.Helper() tb.Helper()
name := fmt.Sprintf("%s_%p", tb.Name(), tb) name := fmt.Sprintf("%s_%s", tb.Name(), rand.Text())
tb.Cleanup(func() { Delete(name) }) tb.Cleanup(func() { Delete(name) })
Create(name, nil) Create(name, nil)

View file

@ -23,12 +23,26 @@ type flocktimeout_t struct {
timeout unix.Timespec timeout unix.Timespec
} }
func osSync(file *os.File, fullsync, _ /*dataonly*/ bool) error { func osSync(file *os.File, open OpenFlag, sync SyncFlag) error {
if fullsync { var cmd int
return file.Sync() if sync&SYNC_FULL == SYNC_FULL {
// For rollback journals all we really need is a barrier.
if open&OPEN_MAIN_JOURNAL != 0 {
cmd = unix.F_BARRIERFSYNC
} else {
cmd = unix.F_FULLFSYNC
}
} }
fd := file.Fd()
for { for {
err := unix.Fsync(int(file.Fd())) err := error(unix.ENOTSUP)
if cmd != 0 {
_, err = unix.FcntlInt(fd, cmd, 0)
}
if err == unix.ENOTSUP {
err = unix.Fsync(int(fd))
}
if err != unix.EINTR { if err != unix.EINTR {
return err return err
} }

View file

@ -10,7 +10,7 @@ import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
func osSync(file *os.File, _ /*fullsync*/, _ /*dataonly*/ bool) error { func osSync(file *os.File, _ OpenFlag, _ SyncFlag) error {
// SQLite trusts Linux's fdatasync for all fsync's. // SQLite trusts Linux's fdatasync for all fsync's.
for { for {
err := unix.Fdatasync(int(file.Fd())) err := unix.Fdatasync(int(file.Fd()))

View file

@ -4,6 +4,6 @@ package vfs
import "os" import "os"
func osSync(file *os.File, _ /*fullsync*/, _ /*dataonly*/ bool) error { func osSync(file *os.File, _ OpenFlag, _ SyncFlag) error {
return file.Sync() return file.Sync()
} }

14
vendor/modules.txt vendored
View file

@ -247,7 +247,7 @@ codeberg.org/gruf/go-fastcopy
# codeberg.org/gruf/go-fastpath/v2 v2.0.0 # codeberg.org/gruf/go-fastpath/v2 v2.0.0
## explicit; go 1.14 ## explicit; go 1.14
codeberg.org/gruf/go-fastpath/v2 codeberg.org/gruf/go-fastpath/v2
# codeberg.org/gruf/go-ffmpreg v0.6.11 # codeberg.org/gruf/go-ffmpreg v0.6.12
## explicit; go 1.22.0 ## explicit; go 1.22.0
codeberg.org/gruf/go-ffmpreg/embed codeberg.org/gruf/go-ffmpreg/embed
codeberg.org/gruf/go-ffmpreg/wasm codeberg.org/gruf/go-ffmpreg/wasm
@ -271,11 +271,11 @@ codeberg.org/gruf/go-mangler/v2
# codeberg.org/gruf/go-maps v1.0.4 # codeberg.org/gruf/go-maps v1.0.4
## explicit; go 1.20 ## explicit; go 1.20
codeberg.org/gruf/go-maps codeberg.org/gruf/go-maps
# codeberg.org/gruf/go-mempool v0.0.0-20240507125005-cef10d64a760 # codeberg.org/gruf/go-mempool v0.0.0-20251003110531-b54adae66253
## explicit; go 1.22.2 ## explicit; go 1.24.0
codeberg.org/gruf/go-mempool codeberg.org/gruf/go-mempool
# codeberg.org/gruf/go-mutexes v1.5.3 # codeberg.org/gruf/go-mutexes v1.5.8
## explicit; go 1.22.2 ## explicit; go 1.24.0
codeberg.org/gruf/go-mutexes codeberg.org/gruf/go-mutexes
# codeberg.org/gruf/go-runners v1.6.3 # codeberg.org/gruf/go-runners v1.6.3
## explicit; go 1.19 ## explicit; go 1.19
@ -293,7 +293,7 @@ codeberg.org/gruf/go-storage/disk
codeberg.org/gruf/go-storage/internal codeberg.org/gruf/go-storage/internal
codeberg.org/gruf/go-storage/memory codeberg.org/gruf/go-storage/memory
codeberg.org/gruf/go-storage/s3 codeberg.org/gruf/go-storage/s3
# codeberg.org/gruf/go-structr v0.9.9 # codeberg.org/gruf/go-structr v0.9.12
## explicit; go 1.24.5 ## explicit; go 1.24.5
codeberg.org/gruf/go-structr codeberg.org/gruf/go-structr
# codeberg.org/gruf/go-xunsafe v0.0.0-20250809104800-512a9df57d73 # codeberg.org/gruf/go-xunsafe v0.0.0-20250809104800-512a9df57d73
@ -727,7 +727,7 @@ github.com/modern-go/reflect2
# github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 # github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822
## explicit ## explicit
github.com/munnerz/goautoneg github.com/munnerz/goautoneg
# github.com/ncruces/go-sqlite3 v0.29.0 # github.com/ncruces/go-sqlite3 v0.29.1
## explicit; go 1.24.0 ## explicit; go 1.24.0
github.com/ncruces/go-sqlite3 github.com/ncruces/go-sqlite3
github.com/ncruces/go-sqlite3/driver github.com/ncruces/go-sqlite3/driver

View file

@ -269,7 +269,7 @@ ol {
blockquote { blockquote {
padding: 0.5rem; padding: 0.5rem;
border-left: 0.2rem solid $border-accent; border-left: 0.2rem solid $border-accent;
margin: 0; margin-inline: 0;
font-style: normal; font-style: normal;
/* /*