mirror of
				https://github.com/superseriousbusiness/gotosocial.git
				synced 2025-11-04 08:22:25 -06:00 
			
		
		
		
	
		
			
	
	
		
			225 lines
		
	
	
	
		
			5.4 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
		
		
			
		
	
	
			225 lines
		
	
	
	
		
			5.4 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| 
								 | 
							
								// Copyright 2024 The Go Authors. All rights reserved.
							 | 
						||
| 
								 | 
							
								// Use of this source code is governed by a BSD-style
							 | 
						||
| 
								 | 
							
								// license that can be found in the LICENSE file.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								package impl
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								import (
							 | 
						||
| 
								 | 
							
									"bytes"
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									"google.golang.org/protobuf/encoding/protowire"
							 | 
						||
| 
								 | 
							
									"google.golang.org/protobuf/reflect/protoreflect"
							 | 
						||
| 
								 | 
							
									"google.golang.org/protobuf/runtime/protoiface"
							 | 
						||
| 
								 | 
							
								)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								func equal(in protoiface.EqualInput) protoiface.EqualOutput {
							 | 
						||
| 
								 | 
							
									return protoiface.EqualOutput{Equal: equalMessage(in.MessageA, in.MessageB)}
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								// equalMessage is a fast-path variant of protoreflect.equalMessage.
							 | 
						||
| 
								 | 
							
								// It takes advantage of the internal messageState type to avoid
							 | 
						||
| 
								 | 
							
								// unnecessary allocations, type assertions.
							 | 
						||
| 
								 | 
							
								func equalMessage(mx, my protoreflect.Message) bool {
							 | 
						||
| 
								 | 
							
									if mx == nil || my == nil {
							 | 
						||
| 
								 | 
							
										return mx == my
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
									if mx.Descriptor() != my.Descriptor() {
							 | 
						||
| 
								 | 
							
										return false
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									msx, ok := mx.(*messageState)
							 | 
						||
| 
								 | 
							
									if !ok {
							 | 
						||
| 
								 | 
							
										return protoreflect.ValueOfMessage(mx).Equal(protoreflect.ValueOfMessage(my))
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
									msy, ok := my.(*messageState)
							 | 
						||
| 
								 | 
							
									if !ok {
							 | 
						||
| 
								 | 
							
										return protoreflect.ValueOfMessage(mx).Equal(protoreflect.ValueOfMessage(my))
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									mi := msx.messageInfo()
							 | 
						||
| 
								 | 
							
									miy := msy.messageInfo()
							 | 
						||
| 
								 | 
							
									if mi != miy {
							 | 
						||
| 
								 | 
							
										return protoreflect.ValueOfMessage(mx).Equal(protoreflect.ValueOfMessage(my))
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
									mi.init()
							 | 
						||
| 
								 | 
							
									// Compares regular fields
							 | 
						||
| 
								 | 
							
									// Modified Message.Range code that compares two messages of the same type
							 | 
						||
| 
								 | 
							
									// while going over the fields.
							 | 
						||
| 
								 | 
							
									for _, ri := range mi.rangeInfos {
							 | 
						||
| 
								 | 
							
										var fd protoreflect.FieldDescriptor
							 | 
						||
| 
								 | 
							
										var vx, vy protoreflect.Value
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
										switch ri := ri.(type) {
							 | 
						||
| 
								 | 
							
										case *fieldInfo:
							 | 
						||
| 
								 | 
							
											hx := ri.has(msx.pointer())
							 | 
						||
| 
								 | 
							
											hy := ri.has(msy.pointer())
							 | 
						||
| 
								 | 
							
											if hx != hy {
							 | 
						||
| 
								 | 
							
												return false
							 | 
						||
| 
								 | 
							
											}
							 | 
						||
| 
								 | 
							
											if !hx {
							 | 
						||
| 
								 | 
							
												continue
							 | 
						||
| 
								 | 
							
											}
							 | 
						||
| 
								 | 
							
											fd = ri.fieldDesc
							 | 
						||
| 
								 | 
							
											vx = ri.get(msx.pointer())
							 | 
						||
| 
								 | 
							
											vy = ri.get(msy.pointer())
							 | 
						||
| 
								 | 
							
										case *oneofInfo:
							 | 
						||
| 
								 | 
							
											fnx := ri.which(msx.pointer())
							 | 
						||
| 
								 | 
							
											fny := ri.which(msy.pointer())
							 | 
						||
| 
								 | 
							
											if fnx != fny {
							 | 
						||
| 
								 | 
							
												return false
							 | 
						||
| 
								 | 
							
											}
							 | 
						||
| 
								 | 
							
											if fnx <= 0 {
							 | 
						||
| 
								 | 
							
												continue
							 | 
						||
| 
								 | 
							
											}
							 | 
						||
| 
								 | 
							
											fi := mi.fields[fnx]
							 | 
						||
| 
								 | 
							
											fd = fi.fieldDesc
							 | 
						||
| 
								 | 
							
											vx = fi.get(msx.pointer())
							 | 
						||
| 
								 | 
							
											vy = fi.get(msy.pointer())
							 | 
						||
| 
								 | 
							
										}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
										if !equalValue(fd, vx, vy) {
							 | 
						||
| 
								 | 
							
											return false
							 | 
						||
| 
								 | 
							
										}
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									// Compare extensions.
							 | 
						||
| 
								 | 
							
									// This is more complicated because mx or my could have empty/nil extension maps,
							 | 
						||
| 
								 | 
							
									// however some populated extension map values are equal to nil extension maps.
							 | 
						||
| 
								 | 
							
									emx := mi.extensionMap(msx.pointer())
							 | 
						||
| 
								 | 
							
									emy := mi.extensionMap(msy.pointer())
							 | 
						||
| 
								 | 
							
									if emx != nil {
							 | 
						||
| 
								 | 
							
										for k, x := range *emx {
							 | 
						||
| 
								 | 
							
											xd := x.Type().TypeDescriptor()
							 | 
						||
| 
								 | 
							
											xv := x.Value()
							 | 
						||
| 
								 | 
							
											var y ExtensionField
							 | 
						||
| 
								 | 
							
											ok := false
							 | 
						||
| 
								 | 
							
											if emy != nil {
							 | 
						||
| 
								 | 
							
												y, ok = (*emy)[k]
							 | 
						||
| 
								 | 
							
											}
							 | 
						||
| 
								 | 
							
											// We need to treat empty lists as equal to nil values
							 | 
						||
| 
								 | 
							
											if emy == nil || !ok {
							 | 
						||
| 
								 | 
							
												if xd.IsList() && xv.List().Len() == 0 {
							 | 
						||
| 
								 | 
							
													continue
							 | 
						||
| 
								 | 
							
												}
							 | 
						||
| 
								 | 
							
												return false
							 | 
						||
| 
								 | 
							
											}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
											if !equalValue(xd, xv, y.Value()) {
							 | 
						||
| 
								 | 
							
												return false
							 | 
						||
| 
								 | 
							
											}
							 | 
						||
| 
								 | 
							
										}
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
									if emy != nil {
							 | 
						||
| 
								 | 
							
										// emy may have extensions emx does not have, need to check them as well
							 | 
						||
| 
								 | 
							
										for k, y := range *emy {
							 | 
						||
| 
								 | 
							
											if emx != nil {
							 | 
						||
| 
								 | 
							
												// emx has the field, so we already checked it
							 | 
						||
| 
								 | 
							
												if _, ok := (*emx)[k]; ok {
							 | 
						||
| 
								 | 
							
													continue
							 | 
						||
| 
								 | 
							
												}
							 | 
						||
| 
								 | 
							
											}
							 | 
						||
| 
								 | 
							
											// Empty lists are equal to nil
							 | 
						||
| 
								 | 
							
											if y.Type().TypeDescriptor().IsList() && y.Value().List().Len() == 0 {
							 | 
						||
| 
								 | 
							
												continue
							 | 
						||
| 
								 | 
							
											}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
											// Cant be equal if the extension is populated
							 | 
						||
| 
								 | 
							
											return false
							 | 
						||
| 
								 | 
							
										}
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									return equalUnknown(mx.GetUnknown(), my.GetUnknown())
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								func equalValue(fd protoreflect.FieldDescriptor, vx, vy protoreflect.Value) bool {
							 | 
						||
| 
								 | 
							
									// slow path
							 | 
						||
| 
								 | 
							
									if fd.Kind() != protoreflect.MessageKind {
							 | 
						||
| 
								 | 
							
										return vx.Equal(vy)
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									// fast path special cases
							 | 
						||
| 
								 | 
							
									if fd.IsMap() {
							 | 
						||
| 
								 | 
							
										if fd.MapValue().Kind() == protoreflect.MessageKind {
							 | 
						||
| 
								 | 
							
											return equalMessageMap(vx.Map(), vy.Map())
							 | 
						||
| 
								 | 
							
										}
							 | 
						||
| 
								 | 
							
										return vx.Equal(vy)
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									if fd.IsList() {
							 | 
						||
| 
								 | 
							
										return equalMessageList(vx.List(), vy.List())
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									return equalMessage(vx.Message(), vy.Message())
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								// Mostly copied from protoreflect.equalMap.
							 | 
						||
| 
								 | 
							
								// This variant only works for messages as map types.
							 | 
						||
| 
								 | 
							
								// All other map types should be handled via Value.Equal.
							 | 
						||
| 
								 | 
							
								func equalMessageMap(mx, my protoreflect.Map) bool {
							 | 
						||
| 
								 | 
							
									if mx.Len() != my.Len() {
							 | 
						||
| 
								 | 
							
										return false
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
									equal := true
							 | 
						||
| 
								 | 
							
									mx.Range(func(k protoreflect.MapKey, vx protoreflect.Value) bool {
							 | 
						||
| 
								 | 
							
										if !my.Has(k) {
							 | 
						||
| 
								 | 
							
											equal = false
							 | 
						||
| 
								 | 
							
											return false
							 | 
						||
| 
								 | 
							
										}
							 | 
						||
| 
								 | 
							
										vy := my.Get(k)
							 | 
						||
| 
								 | 
							
										equal = equalMessage(vx.Message(), vy.Message())
							 | 
						||
| 
								 | 
							
										return equal
							 | 
						||
| 
								 | 
							
									})
							 | 
						||
| 
								 | 
							
									return equal
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								// Mostly copied from protoreflect.equalList.
							 | 
						||
| 
								 | 
							
								// The only change is the usage of equalImpl instead of protoreflect.equalValue.
							 | 
						||
| 
								 | 
							
								func equalMessageList(lx, ly protoreflect.List) bool {
							 | 
						||
| 
								 | 
							
									if lx.Len() != ly.Len() {
							 | 
						||
| 
								 | 
							
										return false
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
									for i := 0; i < lx.Len(); i++ {
							 | 
						||
| 
								 | 
							
										// We only operate on messages here since equalImpl will not call us in any other case.
							 | 
						||
| 
								 | 
							
										if !equalMessage(lx.Get(i).Message(), ly.Get(i).Message()) {
							 | 
						||
| 
								 | 
							
											return false
							 | 
						||
| 
								 | 
							
										}
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
									return true
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								// equalUnknown compares unknown fields by direct comparison on the raw bytes
							 | 
						||
| 
								 | 
							
								// of each individual field number.
							 | 
						||
| 
								 | 
							
								// Copied from protoreflect.equalUnknown.
							 | 
						||
| 
								 | 
							
								func equalUnknown(x, y protoreflect.RawFields) bool {
							 | 
						||
| 
								 | 
							
									if len(x) != len(y) {
							 | 
						||
| 
								 | 
							
										return false
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
									if bytes.Equal([]byte(x), []byte(y)) {
							 | 
						||
| 
								 | 
							
										return true
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									mx := make(map[protoreflect.FieldNumber]protoreflect.RawFields)
							 | 
						||
| 
								 | 
							
									my := make(map[protoreflect.FieldNumber]protoreflect.RawFields)
							 | 
						||
| 
								 | 
							
									for len(x) > 0 {
							 | 
						||
| 
								 | 
							
										fnum, _, n := protowire.ConsumeField(x)
							 | 
						||
| 
								 | 
							
										mx[fnum] = append(mx[fnum], x[:n]...)
							 | 
						||
| 
								 | 
							
										x = x[n:]
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
									for len(y) > 0 {
							 | 
						||
| 
								 | 
							
										fnum, _, n := protowire.ConsumeField(y)
							 | 
						||
| 
								 | 
							
										my[fnum] = append(my[fnum], y[:n]...)
							 | 
						||
| 
								 | 
							
										y = y[n:]
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
									if len(mx) != len(my) {
							 | 
						||
| 
								 | 
							
										return false
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									for k, v1 := range mx {
							 | 
						||
| 
								 | 
							
										if v2, ok := my[k]; !ok || !bytes.Equal([]byte(v1), []byte(v2)) {
							 | 
						||
| 
								 | 
							
											return false
							 | 
						||
| 
								 | 
							
										}
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									return true
							 | 
						||
| 
								 | 
							
								}
							 |