From 11d64db5a37b3dac128411f0d81b438384020cda Mon Sep 17 00:00:00 2001 From: tsmethurst Date: Tue, 7 Sep 2021 14:07:13 +0200 Subject: [PATCH] some more fiddling --- internal/trans/decoders.go | 33 +++++++++++++++++++-------------- internal/trans/input.go | 2 +- 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/internal/trans/decoders.go b/internal/trans/decoders.go index 00c93328b..a263adb78 100644 --- a/internal/trans/decoders.go +++ b/internal/trans/decoders.go @@ -20,22 +20,23 @@ package trans import ( "crypto/rsa" + "encoding" "fmt" - "net" "reflect" "time" "github.com/mitchellh/mapstructure" + "github.com/sirupsen/logrus" transmodel "github.com/superseriousbusiness/gotosocial/internal/trans/model" ) -func accountDecode(e transmodel.TransEntry) (*transmodel.Account, error) { +func (i *importer) accountDecode(e transmodel.TransEntry) (*transmodel.Account, error) { a := &transmodel.Account{} decoderConfig := &mapstructure.DecoderConfig{ DecodeHook: mapstructure.ComposeDecodeHookFunc( mapstructure.StringToTimeHookFunc(time.RFC3339), - PrivateKeyHookFunc(), + keyHookFunc(i.log), ), Result: a, } @@ -51,16 +52,20 @@ func accountDecode(e transmodel.TransEntry) (*transmodel.Account, error) { return a, nil } -var PrivateKeyHookFunc mapstructure.DecodeHookFunc = func(f reflect.Type, t reflect.Type, data interface{}) (interface{}, error) { - if t != reflect.TypeOf(rsa.PrivateKey{}) { - return data, nil +func keyHookFunc(log *logrus.Logger) mapstructure.DecodeHookFunc { + return func(f reflect.Type, t reflect.Type, data interface{}) (interface{}, error) { + if t != reflect.TypeOf(rsa.PrivateKey{}) { + return data, nil + } + + result := reflect.New(t).Interface() + unmarshaller, ok := result.(encoding.BinaryUnmarshaler) + if !ok { + return data, nil + } + if err := unmarshaller.UnmarshalBinary([]byte(data.(string))); err != nil { + return nil, err + } + return result, nil } - - - rsa. - - // Convert it by parsing - _, net, err := net.ParseCIDR(data.(string)) - return net, err - } diff --git a/internal/trans/input.go b/internal/trans/input.go index 472d2c2b3..70ece0ac8 100644 --- a/internal/trans/input.go +++ b/internal/trans/input.go @@ -34,7 +34,7 @@ func (i *importer) inputEntry(ctx context.Context, entry transmodel.TransEntry) switch transmodel.TransType(t) { case transmodel.TransAccount: - account, err := accountDecode(entry) + account, err := i.accountDecode(entry) if err != nil { return fmt.Errorf("inputEntry: error decoding entry into account: %s", err) }