validate csv headers before full read

This commit is contained in:
tobi 2025-01-07 18:43:27 +01:00
commit 28193df3d0
2 changed files with 27 additions and 24 deletions

View file

@ -523,26 +523,15 @@ func permsFromCSV(
permType gtsmodel.DomainPermissionType, permType gtsmodel.DomainPermissionType,
body io.ReadCloser, body io.ReadCloser,
) ([]gtsmodel.DomainPermission, error) { ) ([]gtsmodel.DomainPermission, error) {
// Read body into memory as slice of CSV records. csvReader := csv.NewReader(body)
records, err := csv.NewReader(body).ReadAll()
// Whatever happened, we're // Read and validate column headers.
// done with the body now. columnHeaders, err := csvReader.Read()
body.Close()
// Check if error reading body.
if err != nil { if err != nil {
return nil, gtserror.NewfAt(3, "error decoding into csv: %w", err) body.Close()
return nil, gtserror.NewfAt(3, "error decoding csv column headers: %w", err)
} }
// Make sure we actually
// have some records.
if len(records) == 0 {
return nil, nil
}
// Validate column headers.
columnHeaders := records[0]
if !slices.Equal( if !slices.Equal(
columnHeaders, columnHeaders,
[]string{ []string{
@ -554,15 +543,29 @@ func permsFromCSV(
"#obfuscate", "#obfuscate",
}, },
) { ) {
return nil, gtserror.Newf( body.Close()
"unexpected column headers in csv: %+v", err := gtserror.NewfAt(3, "unexpected column headers in csv: %+v", columnHeaders)
columnHeaders, return nil, err
)
} }
// Trim off column headers // Read remaining CSV records.
// now they're validated. records, err := csvReader.ReadAll()
records = records[1:]
// Totally done
// with body now.
body.Close()
// Check for decode error.
if err != nil {
err := gtserror.NewfAt(3, "error decoding body into csv: %w", err)
return nil, err
}
// Make sure we actually
// have some records.
if len(records) == 0 {
return nil, nil
}
// Convert records to permissions slice. // Convert records to permissions slice.
perms := make([]gtsmodel.DomainPermission, 0, len(records)) perms := make([]gtsmodel.DomainPermission, 0, len(records))

View file

@ -472,7 +472,7 @@ func (suite *SubscriptionsTestSuite) TestDomainBlocksWrongContentTypeCSV() {
suite.Zero(count) suite.Zero(count)
suite.WithinDuration(time.Now(), permSub.FetchedAt, 1*time.Minute) suite.WithinDuration(time.Now(), permSub.FetchedAt, 1*time.Minute)
suite.Zero(permSub.SuccessfullyFetchedAt) suite.Zero(permSub.SuccessfullyFetchedAt)
suite.Equal(`permsFromCSV: unexpected column headers in csv: [bumfaces.net]`, permSub.Error) suite.Equal(`ProcessDomainPermissionSubscription: unexpected column headers in csv: [bumfaces.net]`, permSub.Error)
} }
func (suite *SubscriptionsTestSuite) TestDomainBlocksWrongContentTypePlain() { func (suite *SubscriptionsTestSuite) TestDomainBlocksWrongContentTypePlain() {