mirror of
https://github.com/strongdm/comply
synced 2024-11-14 11:54:53 +00:00
277 lines
7.2 KiB
Go
277 lines
7.2 KiB
Go
// Copyright 2019 The Go Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package impl
|
|
|
|
import (
|
|
"math/bits"
|
|
|
|
"google.golang.org/protobuf/encoding/protowire"
|
|
"google.golang.org/protobuf/internal/errors"
|
|
"google.golang.org/protobuf/internal/flags"
|
|
"google.golang.org/protobuf/proto"
|
|
"google.golang.org/protobuf/reflect/protoreflect"
|
|
preg "google.golang.org/protobuf/reflect/protoregistry"
|
|
"google.golang.org/protobuf/runtime/protoiface"
|
|
piface "google.golang.org/protobuf/runtime/protoiface"
|
|
)
|
|
|
|
var errDecode = errors.New("cannot parse invalid wire-format data")
|
|
|
|
type unmarshalOptions struct {
|
|
flags protoiface.UnmarshalInputFlags
|
|
resolver interface {
|
|
FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error)
|
|
FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error)
|
|
}
|
|
}
|
|
|
|
func (o unmarshalOptions) Options() proto.UnmarshalOptions {
|
|
return proto.UnmarshalOptions{
|
|
Merge: true,
|
|
AllowPartial: true,
|
|
DiscardUnknown: o.DiscardUnknown(),
|
|
Resolver: o.resolver,
|
|
}
|
|
}
|
|
|
|
func (o unmarshalOptions) DiscardUnknown() bool { return o.flags&piface.UnmarshalDiscardUnknown != 0 }
|
|
|
|
func (o unmarshalOptions) IsDefault() bool {
|
|
return o.flags == 0 && o.resolver == preg.GlobalTypes
|
|
}
|
|
|
|
var lazyUnmarshalOptions = unmarshalOptions{
|
|
resolver: preg.GlobalTypes,
|
|
}
|
|
|
|
type unmarshalOutput struct {
|
|
n int // number of bytes consumed
|
|
initialized bool
|
|
}
|
|
|
|
// unmarshal is protoreflect.Methods.Unmarshal.
|
|
func (mi *MessageInfo) unmarshal(in piface.UnmarshalInput) (piface.UnmarshalOutput, error) {
|
|
var p pointer
|
|
if ms, ok := in.Message.(*messageState); ok {
|
|
p = ms.pointer()
|
|
} else {
|
|
p = in.Message.(*messageReflectWrapper).pointer()
|
|
}
|
|
out, err := mi.unmarshalPointer(in.Buf, p, 0, unmarshalOptions{
|
|
flags: in.Flags,
|
|
resolver: in.Resolver,
|
|
})
|
|
var flags piface.UnmarshalOutputFlags
|
|
if out.initialized {
|
|
flags |= piface.UnmarshalInitialized
|
|
}
|
|
return piface.UnmarshalOutput{
|
|
Flags: flags,
|
|
}, err
|
|
}
|
|
|
|
// errUnknown is returned during unmarshaling to indicate a parse error that
|
|
// should result in a field being placed in the unknown fields section (for example,
|
|
// when the wire type doesn't match) as opposed to the entire unmarshal operation
|
|
// failing (for example, when a field extends past the available input).
|
|
//
|
|
// This is a sentinel error which should never be visible to the user.
|
|
var errUnknown = errors.New("unknown")
|
|
|
|
func (mi *MessageInfo) unmarshalPointer(b []byte, p pointer, groupTag protowire.Number, opts unmarshalOptions) (out unmarshalOutput, err error) {
|
|
mi.init()
|
|
if flags.ProtoLegacy && mi.isMessageSet {
|
|
return unmarshalMessageSet(mi, b, p, opts)
|
|
}
|
|
initialized := true
|
|
var requiredMask uint64
|
|
var exts *map[int32]ExtensionField
|
|
start := len(b)
|
|
for len(b) > 0 {
|
|
// Parse the tag (field number and wire type).
|
|
var tag uint64
|
|
if b[0] < 0x80 {
|
|
tag = uint64(b[0])
|
|
b = b[1:]
|
|
} else if len(b) >= 2 && b[1] < 128 {
|
|
tag = uint64(b[0]&0x7f) + uint64(b[1])<<7
|
|
b = b[2:]
|
|
} else {
|
|
var n int
|
|
tag, n = protowire.ConsumeVarint(b)
|
|
if n < 0 {
|
|
return out, errDecode
|
|
}
|
|
b = b[n:]
|
|
}
|
|
var num protowire.Number
|
|
if n := tag >> 3; n < uint64(protowire.MinValidNumber) || n > uint64(protowire.MaxValidNumber) {
|
|
return out, errDecode
|
|
} else {
|
|
num = protowire.Number(n)
|
|
}
|
|
wtyp := protowire.Type(tag & 7)
|
|
|
|
if wtyp == protowire.EndGroupType {
|
|
if num != groupTag {
|
|
return out, errDecode
|
|
}
|
|
groupTag = 0
|
|
break
|
|
}
|
|
|
|
var f *coderFieldInfo
|
|
if int(num) < len(mi.denseCoderFields) {
|
|
f = mi.denseCoderFields[num]
|
|
} else {
|
|
f = mi.coderFields[num]
|
|
}
|
|
var n int
|
|
err := errUnknown
|
|
switch {
|
|
case f != nil:
|
|
if f.funcs.unmarshal == nil {
|
|
break
|
|
}
|
|
var o unmarshalOutput
|
|
o, err = f.funcs.unmarshal(b, p.Apply(f.offset), wtyp, f, opts)
|
|
n = o.n
|
|
if err != nil {
|
|
break
|
|
}
|
|
requiredMask |= f.validation.requiredBit
|
|
if f.funcs.isInit != nil && !o.initialized {
|
|
initialized = false
|
|
}
|
|
default:
|
|
// Possible extension.
|
|
if exts == nil && mi.extensionOffset.IsValid() {
|
|
exts = p.Apply(mi.extensionOffset).Extensions()
|
|
if *exts == nil {
|
|
*exts = make(map[int32]ExtensionField)
|
|
}
|
|
}
|
|
if exts == nil {
|
|
break
|
|
}
|
|
var o unmarshalOutput
|
|
o, err = mi.unmarshalExtension(b, num, wtyp, *exts, opts)
|
|
if err != nil {
|
|
break
|
|
}
|
|
n = o.n
|
|
if !o.initialized {
|
|
initialized = false
|
|
}
|
|
}
|
|
if err != nil {
|
|
if err != errUnknown {
|
|
return out, err
|
|
}
|
|
n = protowire.ConsumeFieldValue(num, wtyp, b)
|
|
if n < 0 {
|
|
return out, errDecode
|
|
}
|
|
if !opts.DiscardUnknown() && mi.unknownOffset.IsValid() {
|
|
u := mi.mutableUnknownBytes(p)
|
|
*u = protowire.AppendTag(*u, num, wtyp)
|
|
*u = append(*u, b[:n]...)
|
|
}
|
|
}
|
|
b = b[n:]
|
|
}
|
|
if groupTag != 0 {
|
|
return out, errDecode
|
|
}
|
|
if mi.numRequiredFields > 0 && bits.OnesCount64(requiredMask) != int(mi.numRequiredFields) {
|
|
initialized = false
|
|
}
|
|
if initialized {
|
|
out.initialized = true
|
|
}
|
|
out.n = start - len(b)
|
|
return out, nil
|
|
}
|
|
|
|
func (mi *MessageInfo) unmarshalExtension(b []byte, num protowire.Number, wtyp protowire.Type, exts map[int32]ExtensionField, opts unmarshalOptions) (out unmarshalOutput, err error) {
|
|
x := exts[int32(num)]
|
|
xt := x.Type()
|
|
if xt == nil {
|
|
var err error
|
|
xt, err = opts.resolver.FindExtensionByNumber(mi.Desc.FullName(), num)
|
|
if err != nil {
|
|
if err == preg.NotFound {
|
|
return out, errUnknown
|
|
}
|
|
return out, errors.New("%v: unable to resolve extension %v: %v", mi.Desc.FullName(), num, err)
|
|
}
|
|
}
|
|
xi := getExtensionFieldInfo(xt)
|
|
if xi.funcs.unmarshal == nil {
|
|
return out, errUnknown
|
|
}
|
|
if flags.LazyUnmarshalExtensions {
|
|
if opts.IsDefault() && x.canLazy(xt) {
|
|
out, valid := skipExtension(b, xi, num, wtyp, opts)
|
|
switch valid {
|
|
case ValidationValid:
|
|
if out.initialized {
|
|
x.appendLazyBytes(xt, xi, num, wtyp, b[:out.n])
|
|
exts[int32(num)] = x
|
|
return out, nil
|
|
}
|
|
case ValidationInvalid:
|
|
return out, errDecode
|
|
case ValidationUnknown:
|
|
}
|
|
}
|
|
}
|
|
ival := x.Value()
|
|
if !ival.IsValid() && xi.unmarshalNeedsValue {
|
|
// Create a new message, list, or map value to fill in.
|
|
// For enums, create a prototype value to let the unmarshal func know the
|
|
// concrete type.
|
|
ival = xt.New()
|
|
}
|
|
v, out, err := xi.funcs.unmarshal(b, ival, num, wtyp, opts)
|
|
if err != nil {
|
|
return out, err
|
|
}
|
|
if xi.funcs.isInit == nil {
|
|
out.initialized = true
|
|
}
|
|
x.Set(xt, v)
|
|
exts[int32(num)] = x
|
|
return out, nil
|
|
}
|
|
|
|
func skipExtension(b []byte, xi *extensionFieldInfo, num protowire.Number, wtyp protowire.Type, opts unmarshalOptions) (out unmarshalOutput, _ ValidationStatus) {
|
|
if xi.validation.mi == nil {
|
|
return out, ValidationUnknown
|
|
}
|
|
xi.validation.mi.init()
|
|
switch xi.validation.typ {
|
|
case validationTypeMessage:
|
|
if wtyp != protowire.BytesType {
|
|
return out, ValidationUnknown
|
|
}
|
|
v, n := protowire.ConsumeBytes(b)
|
|
if n < 0 {
|
|
return out, ValidationUnknown
|
|
}
|
|
out, st := xi.validation.mi.validate(v, 0, opts)
|
|
out.n = n
|
|
return out, st
|
|
case validationTypeGroup:
|
|
if wtyp != protowire.StartGroupType {
|
|
return out, ValidationUnknown
|
|
}
|
|
out, st := xi.validation.mi.validate(b, num, opts)
|
|
return out, st
|
|
default:
|
|
return out, ValidationUnknown
|
|
}
|
|
}
|