diff --git a/internal/trans/decoders.go b/internal/trans/decoders.go
index a263adb78..33557b2b5 100644
--- a/internal/trans/decoders.go
+++ b/internal/trans/decoders.go
@@ -19,53 +19,75 @@
package trans
import (
- "crypto/rsa"
- "encoding"
+ "crypto/x509"
+ "encoding/pem"
+ "errors"
"fmt"
- "reflect"
"time"
"github.com/mitchellh/mapstructure"
- "github.com/sirupsen/logrus"
transmodel "github.com/superseriousbusiness/gotosocial/internal/trans/model"
)
+func newDecoder(target interface{}) (*mapstructure.Decoder, error) {
+ decoderConfig := &mapstructure.DecoderConfig{
+ DecodeHook: mapstructure.StringToTimeHookFunc(time.RFC3339), // this is needed to decode time.Time entries serialized as string
+ Result: target,
+ }
+ return mapstructure.NewDecoder(decoderConfig)
+}
+
func (i *importer) accountDecode(e transmodel.TransEntry) (*transmodel.Account, error) {
a := &transmodel.Account{}
-
- decoderConfig := &mapstructure.DecoderConfig{
- DecodeHook: mapstructure.ComposeDecodeHookFunc(
- mapstructure.StringToTimeHookFunc(time.RFC3339),
- keyHookFunc(i.log),
- ),
- Result: a,
+ if err := i.simpleDecode(e, a); err != nil {
+ return nil, err
}
- decoder, err := mapstructure.NewDecoder(decoderConfig)
+
+ // extract public key
+ publicKeyBlock, _ := pem.Decode([]byte(a.PublicKeyString))
+ if publicKeyBlock == nil {
+ return nil, errors.New("accountDecode: error decoding account public key")
+ }
+ publicKey, err := x509.ParsePKCS1PublicKey(publicKeyBlock.Bytes)
if err != nil {
- return nil, fmt.Errorf("accountDecode: error creating decoder: %s", err)
+ return nil, fmt.Errorf("accountDecode: error parsing account public key: %s", err)
}
+ a.PublicKey = publicKey
- if err := decoder.Decode(&e); err != nil {
- return nil, fmt.Errorf("accountDecode: error decoding account: %s", err)
+ if a.Domain == "" {
+ // extract private key (local account)
+ privateKeyBlock, _ := pem.Decode([]byte(a.PrivateKeyString))
+ if privateKeyBlock == nil {
+ return nil, errors.New("accountDecode: error decoding account private key")
+ }
+ privateKey, err := x509.ParsePKCS1PrivateKey(privateKeyBlock.Bytes)
+ if err != nil {
+ return nil, fmt.Errorf("accountDecode: error parsing account private key: %s", err)
+ }
+ a.PrivateKey = privateKey
}
return a, nil
}
-func keyHookFunc(log *logrus.Logger) mapstructure.DecodeHookFunc {
- return func(f reflect.Type, t reflect.Type, data interface{}) (interface{}, error) {
- if t != reflect.TypeOf(rsa.PrivateKey{}) {
- return data, nil
- }
-
- result := reflect.New(t).Interface()
- unmarshaller, ok := result.(encoding.BinaryUnmarshaler)
- if !ok {
- return data, nil
- }
- if err := unmarshaller.UnmarshalBinary([]byte(data.(string))); err != nil {
- return nil, err
- }
- return result, nil
+func (i *importer) blockDecode(e transmodel.TransEntry) (*transmodel.Block, error) {
+ b := &transmodel.Block{}
+ if err := i.simpleDecode(e, b); err != nil {
+ return nil, err
}
+
+ return b, nil
+}
+
+func (i *importer) simpleDecode(entry transmodel.TransEntry, target interface{}) error {
+ decoder, err := newDecoder(target)
+ if err != nil {
+ return fmt.Errorf("simpleDecode: error creating decoder: %s", err)
+ }
+
+ if err := decoder.Decode(&entry); err != nil {
+ return fmt.Errorf("simpleDecode: error decoding: %s", err)
+ }
+
+ return nil
}
diff --git a/internal/trans/encoders.go b/internal/trans/encoders.go
new file mode 100644
index 000000000..89dba6eaa
--- /dev/null
+++ b/internal/trans/encoders.go
@@ -0,0 +1,83 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see .
+*/
+
+package trans
+
+import (
+ "context"
+ "crypto/x509"
+ "encoding/json"
+ "encoding/pem"
+ "errors"
+ "fmt"
+ "os"
+
+ transmodel "github.com/superseriousbusiness/gotosocial/internal/trans/model"
+)
+
+// accountEncode handles special fields like private + public keys on accounts
+func (e *exporter) accountEncode(ctx context.Context, f *os.File, a *transmodel.Account) error {
+ a.Type = transmodel.TransAccount
+
+ // marshal public key
+ encodedPublicKey := x509.MarshalPKCS1PublicKey(a.PublicKey)
+ if encodedPublicKey == nil {
+ return errors.New("could not MarshalPKCS1PublicKey")
+ }
+ publicKeyBytes := pem.EncodeToMemory(&pem.Block{
+ Type: "RSA PUBLIC KEY",
+ Bytes: encodedPublicKey,
+ })
+ a.PublicKeyString = string(publicKeyBytes)
+
+ if a.Domain == "" {
+ // marshal private key for local account
+ encodedPrivateKey := x509.MarshalPKCS1PrivateKey(a.PrivateKey)
+ if encodedPrivateKey == nil {
+ return errors.New("could not MarshalPKCS1PrivateKey")
+ }
+ privateKeyBytes := pem.EncodeToMemory(&pem.Block{
+ Type: "RSA PRIVATE KEY",
+ Bytes: encodedPrivateKey,
+ })
+ a.PrivateKeyString = string(privateKeyBytes)
+ }
+
+ return e.simpleEncode(ctx, f, a, a.ID)
+}
+
+// simpleEncode can be used for any type that doesn't have special keys which need handling differently,
+// or for types where special keys have already been handled.
+//
+// The 'type' key on the passed interface should already have been set, since simpleEncode won't know
+// what type it is!
+func (e *exporter) simpleEncode(ctx context.Context, f *os.File, i interface{}, id string) error {
+ _, alreadyWritten := e.writtenIDs[id]
+ if alreadyWritten {
+ // this exporter has already exported an entry with this ID, no need to do it twice
+ return nil
+ }
+
+ err := json.NewEncoder(f).Encode(i)
+ if err != nil {
+ return fmt.Errorf("simpleEncode: error encoding entry with id %s: %s", id, err)
+ }
+
+ e.writtenIDs[id] = true
+ return nil
+}
diff --git a/internal/trans/exporter.go b/internal/trans/exporter.go
index 602eaee2a..1cd1b38ff 100644
--- a/internal/trans/exporter.go
+++ b/internal/trans/exporter.go
@@ -30,13 +30,15 @@ type Exporter interface {
}
type exporter struct {
- db db.DB
- log *logrus.Logger
+ db db.DB
+ log *logrus.Logger
+ writtenIDs map[string]bool
}
func NewExporter(db db.DB, log *logrus.Logger) Exporter {
return &exporter{
- db: db,
- log: log,
+ db: db,
+ log: log,
+ writtenIDs: make(map[string]bool),
}
}
diff --git a/internal/trans/exportminimal.go b/internal/trans/exportminimal.go
index 9be561947..dd5a9995c 100644
--- a/internal/trans/exportminimal.go
+++ b/internal/trans/exportminimal.go
@@ -20,7 +20,6 @@ package trans
import (
"context"
- "encoding/json"
"fmt"
"os"
@@ -34,28 +33,96 @@ func (e *exporter) ExportMinimal(ctx context.Context, path string) error {
return err
}
- encoder := json.NewEncoder(f)
-
- accounts := []*transmodel.Account{}
- if err := e.db.GetWhere(ctx, []db.Where{{Key: "domain", Value: nil}}, &accounts); err != nil {
- return fmt.Errorf("ExportMinimal: error selecting accounts: %s", err)
+ // export all local accounts we have in the database
+ localAccounts, err := e.exportAccounts(ctx, []db.Where{{Key: "domain", Value: nil}}, f)
+ if err != nil {
+ return fmt.Errorf("ExportMinimal: error exporting accounts: %s", err)
}
- for _, a := range accounts {
- a.Type = transmodel.TransAccount
- if err := encoder.Encode(a); err != nil {
- return fmt.Errorf("ExportMinimal: error encoding account: %s", err)
+ // export all blocks that relate to those accounts
+ blocks, err := e.exportBlocks(ctx, localAccounts, f)
+ if err != nil {
+ return fmt.Errorf("ExportMinimal: error exporting blocks: %s", err)
+ }
+
+ // for each block, make sure we've written out the account owning it, or targeted by it
+ for _, b := range blocks {
+ _, alreadyWritten := e.writtenIDs[b.AccountID]
+ if !alreadyWritten {
+ _, err := e.exportAccounts(ctx, []db.Where{{Key: "id", Value: b.AccountID}}, f)
+ if err != nil {
+ return fmt.Errorf("ExportMinimal: error exporting block owner account: %s", err)
+ }
+ }
+
+ _, alreadyWritten = e.writtenIDs[b.TargetAccountID]
+ if !alreadyWritten {
+ _, err := e.exportAccounts(ctx, []db.Where{{Key: "id", Value: b.TargetAccountID}}, f)
+ if err != nil {
+ return fmt.Errorf("ExportMinimal: error exporting block target account: %s", err)
+ }
}
- e.log.Infof("ExportMinimal: exported account %s to %s", a.ID, path)
}
return neatClose(f)
}
-func neatClose(f *os.File) error {
- if err := f.Close(); err != nil {
- return fmt.Errorf("error closing file: %s", err)
+func (e *exporter) exportAccounts(ctx context.Context, where []db.Where, f *os.File) ([]*transmodel.Account, error) {
+ // select using the 'where' we've been provided
+ accounts := []*transmodel.Account{}
+ if err := e.db.GetWhere(ctx, where, &accounts); err != nil {
+ return nil, fmt.Errorf("exportAccounts: error selecting accounts: %s", err)
}
- return nil
+ // write any accounts found to file
+ for _, a := range accounts {
+ if err := e.accountEncode(ctx, f, a); err != nil {
+ return nil, fmt.Errorf("exportAccounts: error encoding account: %s", err)
+ }
+ }
+
+ return accounts, nil
+}
+
+func (e *exporter) exportBlocks(ctx context.Context, accounts []*transmodel.Account, f *os.File) ([]*transmodel.Block, error) {
+ blocksUnique := make(map[string]*transmodel.Block)
+
+ // for each account we want to export both where it's blocking and where it's blocked
+ for _, a := range accounts {
+ // 1. export blocks owned by given account
+ whereBlocking := []db.Where{{Key: "account_id", Value: a.ID}}
+ blocking := []*transmodel.Block{}
+ if err := e.db.GetWhere(ctx, whereBlocking, &blocking); err != nil {
+ return nil, fmt.Errorf("exportBlocks: error selecting blocks owned by account %s: %s", a.ID, err)
+ }
+ for _, b := range blocking {
+ b.Type = transmodel.TransBlock
+ if err := e.simpleEncode(ctx, f, b, b.ID); err != nil {
+ return nil, fmt.Errorf("exportBlocks: error encoding block owned by account %s: %s", a.ID, err)
+ }
+ blocksUnique[b.ID] = b
+ }
+
+ // 2. export blocks that target given account
+ whereBlocked := []db.Where{{Key: "target_account_id", Value: a.ID}}
+ blocked := []*transmodel.Block{}
+ if err := e.db.GetWhere(ctx, whereBlocked, &blocked); err != nil {
+ return nil, fmt.Errorf("exportBlocks: error selecting blocks targeting account %s: %s", a.ID, err)
+ }
+ for _, b := range blocked {
+ b.Type = transmodel.TransBlock
+ if err := e.simpleEncode(ctx, f, b, b.ID); err != nil {
+ return nil, fmt.Errorf("exportBlocks: error encoding block targeting account %s: %s", a.ID, err)
+ }
+ blocksUnique[b.ID] = b
+ }
+ }
+
+ // now return all the blocks we found
+ blocks := []*transmodel.Block{}
+ for _, b := range blocksUnique {
+ blocks = append(blocks, b)
+ }
+
+ return blocks, nil
}
diff --git a/internal/trans/exportminimal_test.go b/internal/trans/exportminimal_test.go
index 08171585f..2bffffcfe 100644
--- a/internal/trans/exportminimal_test.go
+++ b/internal/trans/exportminimal_test.go
@@ -46,7 +46,7 @@ func (suite *ExportMinimalTestSuite) TestExportMinimalOK() {
b, err := os.ReadFile(tempFilePath)
suite.NoError(err)
suite.NotEmpty(b)
- suite.T().Log(string(b))
+ fmt.Println(string(b))
}
func TestExportMinimalTestSuite(t *testing.T) {
diff --git a/internal/trans/importminimal_test.go b/internal/trans/importminimal_test.go
index e6af94186..af6e25a34 100644
--- a/internal/trans/importminimal_test.go
+++ b/internal/trans/importminimal_test.go
@@ -40,7 +40,7 @@ func (suite *ImportMinimalTestSuite) TestImportMinimalOK() {
ctx := context.Background()
// use a temporary file path
- tempFilePath := fmt.Sprintf("%s/%s", os.TempDir(), uuid.NewString())
+ tempFilePath := fmt.Sprintf("%s/%s", suite.T().TempDir(), uuid.NewString())
// export to the tempFilePath
exporter := trans.NewExporter(suite.db, suite.log)
@@ -53,18 +53,18 @@ func (suite *ImportMinimalTestSuite) TestImportMinimalOK() {
suite.NotEmpty(b)
suite.T().Log(string(b))
- // now that the file is stored, tear down the database...
+ // create a new database with just the tables created, no entries
testrig.StandardDBTeardown(suite.db)
- // and create just the tables -- no entries!
- testrig.CreateTestTables(suite.db)
+ newDB := testrig.NewTestDB()
+ testrig.CreateTestTables(newDB)
- importer := trans.NewImporter(suite.db, suite.log)
+ importer := trans.NewImporter(newDB, suite.log)
err = importer.ImportMinimal(ctx, tempFilePath)
suite.NoError(err)
// we should now have some accounts in the database
accounts := []*gtsmodel.Account{}
- err = suite.db.GetWhere(ctx, []db.Where{{Key: "domain", Value: nil}}, &accounts)
+ err = newDB.GetWhere(ctx, []db.Where{{Key: "domain", Value: nil}}, &accounts)
suite.NoError(err)
suite.NotEmpty(accounts)
}
diff --git a/internal/trans/model/account.go b/internal/trans/model/account.go
index 03e238cfb..64cd3b1a3 100644
--- a/internal/trans/model/account.go
+++ b/internal/trans/model/account.go
@@ -30,7 +30,7 @@ type Account struct {
CreatedAt *time.Time `json:"createdAt"`
UpdatedAt *time.Time `json:"updatedAt"`
Username string `json:"username"`
- Domain string `json:"domain,omitempty"`
+ Domain string `json:"domain,omitempty" bun:",nullzero"`
Locked bool `json:"locked"`
Language string `json:"language,omitempty"`
URI string `json:"uri"`
@@ -41,9 +41,11 @@ type Account struct {
FollowersURI string `json:"followersUri"`
FeaturedCollectionURI string `json:"featuredCollectionUri"`
ActorType string `json:"actorType"`
- PrivateKey *rsa.PrivateKey `json:"privateKey,omitempty"`
- PublicKey *rsa.PublicKey `json:"publicKey"`
+ PrivateKey *rsa.PrivateKey `json:"-" mapstructure:"-"`
+ PrivateKeyString string `json:"privateKey,omitempty" bun:"-" mapstructure:"privateKey"`
+ PublicKey *rsa.PublicKey `json:"-" mapstructure:"-"`
+ PublicKeyString string `json:"publicKey,omitempty" bun:"-" mapstructure:"publicKey"`
PublicKeyURI string `json:"publicKeyUri"`
SuspendedAt *time.Time `json:"suspendedAt,omitempty"`
- SuspensionOrigin string `json:"suspensionOrigin,omitempty"`
+ SuspensionOrigin string `json:"suspensionOrigin,omitempty" bun:",nullzero"`
}
diff --git a/internal/trans/model/account_test.go b/internal/trans/model/account_test.go
deleted file mode 100644
index aa1f37f85..000000000
--- a/internal/trans/model/account_test.go
+++ /dev/null
@@ -1,57 +0,0 @@
-/*
- GoToSocial
- Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
-
- This program is free software: you can redistribute it and/or modify
- it under the terms of the GNU Affero General Public License as published by
- the Free Software Foundation, either version 3 of the License, or
- (at your option) any later version.
-
- This program is distributed in the hope that it will be useful,
- but WITHOUT ANY WARRANTY; without even the implied warranty of
- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- GNU Affero General Public License for more details.
-
- You should have received a copy of the GNU Affero General Public License
- along with this program. If not, see .
-*/
-
-package trans_test
-
-import (
- "context"
- "encoding/json"
- "testing"
-
- "github.com/stretchr/testify/suite"
- trans "github.com/superseriousbusiness/gotosocial/internal/trans/model"
-)
-
-type AccountTestSuite struct {
- ModelTestSuite
-}
-
-func (suite *AccountTestSuite) TestAccountsIdempotent() {
- // we should be able to get all accounts with the simple trans.Account struct
- accounts := []*trans.Account{}
- err := suite.db.GetAll(context.Background(), &accounts)
- suite.NoError(err)
- suite.NotEmpty(accounts)
-
- // we should be able to marshal the accounts to json with no problems
- b, err := json.Marshal(&accounts)
- suite.NoError(err)
- suite.NotNil(b)
- suite.T().Log(string(b))
-
- // the json should be idempotent
- mAccounts := []*trans.Account{}
- err = json.Unmarshal(b, &mAccounts)
- suite.NoError(err)
- suite.NotEmpty(mAccounts)
- suite.EqualValues(accounts, mAccounts)
-}
-
-func TestAccountTestSuite(t *testing.T) {
- suite.Run(t, &AccountTestSuite{})
-}
diff --git a/internal/trans/model/block_test.go b/internal/trans/model/block_test.go
deleted file mode 100644
index 9a3b78e1a..000000000
--- a/internal/trans/model/block_test.go
+++ /dev/null
@@ -1,57 +0,0 @@
-/*
- GoToSocial
- Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
-
- This program is free software: you can redistribute it and/or modify
- it under the terms of the GNU Affero General Public License as published by
- the Free Software Foundation, either version 3 of the License, or
- (at your option) any later version.
-
- This program is distributed in the hope that it will be useful,
- but WITHOUT ANY WARRANTY; without even the implied warranty of
- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- GNU Affero General Public License for more details.
-
- You should have received a copy of the GNU Affero General Public License
- along with this program. If not, see .
-*/
-
-package trans_test
-
-import (
- "context"
- "encoding/json"
- "testing"
-
- "github.com/stretchr/testify/suite"
- trans "github.com/superseriousbusiness/gotosocial/internal/trans/model"
-)
-
-type BlockTestSuite struct {
- ModelTestSuite
-}
-
-func (suite *AccountTestSuite) TestBlocksIdempotent() {
- // we should be able to get all blocks with the simple trans.Block struct
- blocks := []*trans.Block{}
- err := suite.db.GetAll(context.Background(), &blocks)
- suite.NoError(err)
- suite.NotEmpty(blocks)
-
- // we should be able to marshal the blocks to json with no problems
- b, err := json.Marshal(&blocks)
- suite.NoError(err)
- suite.NotNil(b)
- suite.T().Log(string(b))
-
- // the json should be idempotent
- mBlocks := []*trans.Block{}
- err = json.Unmarshal(b, &mBlocks)
- suite.NoError(err)
- suite.NotEmpty(mBlocks)
- suite.EqualValues(blocks, mBlocks)
-}
-
-func TestBlockTestSuite(t *testing.T) {
- suite.Run(t, &BlockTestSuite{})
-}
diff --git a/internal/trans/model/follow.go b/internal/trans/model/follow.go
new file mode 100644
index 000000000..854cb4372
--- /dev/null
+++ b/internal/trans/model/follow.go
@@ -0,0 +1,31 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see .
+*/
+
+package trans
+
+import "time"
+
+type Follow struct {
+ Type TransType `json:"type" bun:"-"`
+ ID string `json:"id"`
+ CreatedAt time.Time `json:"createdAt"`
+ UpdatedAt time.Time `json:"updatedAt"`
+ URI string `json:"uri"`
+ AccountID string `json:"accountId"`
+ TargetAccountID string `json:"targetAccountId"`
+}
diff --git a/internal/trans/model/type.go b/internal/trans/model/type.go
index 0a9f3aac0..2372203e8 100644
--- a/internal/trans/model/type.go
+++ b/internal/trans/model/type.go
@@ -27,6 +27,7 @@ type TransType string
const (
TransAccount TransType = "account"
TransBlock TransType = "block"
+ TransFollow TransType = "follow"
)
type TransEntry map[string]interface{}
diff --git a/internal/trans/model/model_test.go b/internal/trans/util.go
similarity index 64%
rename from internal/trans/model/model_test.go
rename to internal/trans/util.go
index 278ef4a9d..4ccc1a4b6 100644
--- a/internal/trans/model/model_test.go
+++ b/internal/trans/util.go
@@ -16,24 +16,17 @@
along with this program. If not, see .
*/
-package trans_test
+package trans
import (
- "github.com/stretchr/testify/suite"
- "github.com/superseriousbusiness/gotosocial/internal/db"
- "github.com/superseriousbusiness/gotosocial/testrig"
+ "fmt"
+ "os"
)
-type ModelTestSuite struct {
- suite.Suite
- db db.DB
-}
+func neatClose(f *os.File) error {
+ if err := f.Close(); err != nil {
+ return fmt.Errorf("error closing file: %s", err)
+ }
-func (suite *ModelTestSuite) SetupTest() {
- suite.db = testrig.NewTestDB()
- testrig.StandardDBSetup(suite.db, nil)
-}
-
-func (suite *ModelTestSuite) TearDownTest() {
- testrig.StandardDBTeardown(suite.db)
+ return nil
}