From 0b48d64d7fcfd34514b6fa9b046d40457ed3e4b9 Mon Sep 17 00:00:00 2001 From: shaod2 Date: Wed, 21 May 2025 14:30:53 -0400 Subject: [PATCH] Manually backport recursion limit enforcement to 25.x CVE: CVE-2025-4565 Upstream-Status: Backport [d31100c9195819edb0a12f44705dfc2da111ea9b] Adjusted for the 3.19.6 version, resolving conflicts and removing unused testing codes. Signed-off-by: Chen Qi --- python/google/protobuf/internal/decoder.py | 110 ++++++++++++++---- .../protobuf/internal/python_message.py | 6 +- .../protobuf/internal/self_recursive.proto | 17 +++ 3 files changed, 106 insertions(+), 27 deletions(-) create mode 100644 python/google/protobuf/internal/self_recursive.proto diff --git a/python/google/protobuf/internal/decoder.py b/python/google/protobuf/internal/decoder.py index bc1b7b785..445c1d0d2 100644 --- a/python/google/protobuf/internal/decoder.py +++ b/python/google/protobuf/internal/decoder.py @@ -195,7 +195,10 @@ def _SimpleDecoder(wire_type, decode_value): clear_if_default=False): if is_packed: local_DecodeVarint = _DecodeVarint - def DecodePackedField(buffer, pos, end, message, field_dict): + def DecodePackedField( + buffer, pos, end, message, field_dict, current_depth=0 + ): + del current_depth # unused value = field_dict.get(key) if value is None: value = field_dict.setdefault(key, new_default(message)) @@ -214,7 +217,10 @@ def _SimpleDecoder(wire_type, decode_value): elif is_repeated: tag_bytes = encoder.TagBytes(field_number, wire_type) tag_len = len(tag_bytes) - def DecodeRepeatedField(buffer, pos, end, message, field_dict): + def DecodeRepeatedField( + buffer, pos, end, message, field_dict, current_depth=0 + ): + del current_depth # unused value = field_dict.get(key) if value is None: value = field_dict.setdefault(key, new_default(message)) @@ -231,7 +237,8 @@ def _SimpleDecoder(wire_type, decode_value): return new_pos return DecodeRepeatedField else: - def DecodeField(buffer, pos, end, message, field_dict): + def DecodeField(buffer, pos, end, message, field_dict, current_depth=0): + del current_depth # unused (new_value, pos) = decode_value(buffer, pos) if pos > end: raise _DecodeError('Truncated message.') @@ -375,7 +382,9 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default, enum_type = key.enum_type if is_packed: local_DecodeVarint = _DecodeVarint - def DecodePackedField(buffer, pos, end, message, field_dict): + def DecodePackedField( + buffer, pos, end, message, field_dict, current_depth=0 + ): """Decode serialized packed enum to its value and a new position. Args: @@ -388,6 +397,7 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default, Returns: int, new position in serialized data. """ + del current_depth # unused value = field_dict.get(key) if value is None: value = field_dict.setdefault(key, new_default(message)) @@ -428,7 +438,9 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default, elif is_repeated: tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT) tag_len = len(tag_bytes) - def DecodeRepeatedField(buffer, pos, end, message, field_dict): + def DecodeRepeatedField( + buffer, pos, end, message, field_dict, current_depth=0 + ): """Decode serialized repeated enum to its value and a new position. Args: @@ -441,6 +453,7 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default, Returns: int, new position in serialized data. """ + del current_depth # unused value = field_dict.get(key) if value is None: value = field_dict.setdefault(key, new_default(message)) @@ -469,7 +482,7 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default, return new_pos return DecodeRepeatedField else: - def DecodeField(buffer, pos, end, message, field_dict): + def DecodeField(buffer, pos, end, message, field_dict, current_depth=0): """Decode serialized repeated enum to its value and a new position. Args: @@ -482,6 +495,7 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default, Returns: int, new position in serialized data. """ + del current_depth # unused value_start_pos = pos (enum_value, pos) = _DecodeSignedVarint32(buffer, pos) if pos > end: @@ -563,7 +577,10 @@ def StringDecoder(field_number, is_repeated, is_packed, key, new_default, tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) tag_len = len(tag_bytes) - def DecodeRepeatedField(buffer, pos, end, message, field_dict): + def DecodeRepeatedField( + buffer, pos, end, message, field_dict, current_depth=0 + ): + del current_depth # unused value = field_dict.get(key) if value is None: value = field_dict.setdefault(key, new_default(message)) @@ -580,7 +597,8 @@ def StringDecoder(field_number, is_repeated, is_packed, key, new_default, return new_pos return DecodeRepeatedField else: - def DecodeField(buffer, pos, end, message, field_dict): + def DecodeField(buffer, pos, end, message, field_dict, current_depth=0): + del current_depth # unused (size, pos) = local_DecodeVarint(buffer, pos) new_pos = pos + size if new_pos > end: @@ -604,7 +622,10 @@ def BytesDecoder(field_number, is_repeated, is_packed, key, new_default, tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) tag_len = len(tag_bytes) - def DecodeRepeatedField(buffer, pos, end, message, field_dict): + def DecodeRepeatedField( + buffer, pos, end, message, field_dict, current_depth=0 + ): + del current_depth # unused value = field_dict.get(key) if value is None: value = field_dict.setdefault(key, new_default(message)) @@ -621,7 +642,8 @@ def BytesDecoder(field_number, is_repeated, is_packed, key, new_default, return new_pos return DecodeRepeatedField else: - def DecodeField(buffer, pos, end, message, field_dict): + def DecodeField(buffer, pos, end, message, field_dict, current_depth=0): + del current_depth # unused (size, pos) = local_DecodeVarint(buffer, pos) new_pos = pos + size if new_pos > end: @@ -646,7 +668,9 @@ def GroupDecoder(field_number, is_repeated, is_packed, key, new_default): tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_START_GROUP) tag_len = len(tag_bytes) - def DecodeRepeatedField(buffer, pos, end, message, field_dict): + def DecodeRepeatedField( + buffer, pos, end, message, field_dict, current_depth=0 + ): value = field_dict.get(key) if value is None: value = field_dict.setdefault(key, new_default(message)) @@ -655,7 +679,13 @@ def GroupDecoder(field_number, is_repeated, is_packed, key, new_default): if value is None: value = field_dict.setdefault(key, new_default(message)) # Read sub-message. - pos = value.add()._InternalParse(buffer, pos, end) + current_depth += 1 + if current_depth > _recursion_limit: + raise _DecodeError( + 'Error parsing message: too many levels of nesting.' + ) + pos = value.add()._InternalParse(buffer, pos, end, current_depth) + current_depth -= 1 # Read end tag. new_pos = pos+end_tag_len if buffer[pos:new_pos] != end_tag_bytes or new_pos > end: @@ -667,12 +697,16 @@ def GroupDecoder(field_number, is_repeated, is_packed, key, new_default): return new_pos return DecodeRepeatedField else: - def DecodeField(buffer, pos, end, message, field_dict): + def DecodeField(buffer, pos, end, message, field_dict, current_depth=0): value = field_dict.get(key) if value is None: value = field_dict.setdefault(key, new_default(message)) # Read sub-message. - pos = value._InternalParse(buffer, pos, end) + current_depth += 1 + if current_depth > _recursion_limit: + raise _DecodeError('Error parsing message: too many levels of nesting.') + pos = value._InternalParse(buffer, pos, end, current_depth) + current_depth -= 1 # Read end tag. new_pos = pos+end_tag_len if buffer[pos:new_pos] != end_tag_bytes or new_pos > end: @@ -691,7 +725,9 @@ def MessageDecoder(field_number, is_repeated, is_packed, key, new_default): tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) tag_len = len(tag_bytes) - def DecodeRepeatedField(buffer, pos, end, message, field_dict): + def DecodeRepeatedField( + buffer, pos, end, message, field_dict, current_depth=0 + ): value = field_dict.get(key) if value is None: value = field_dict.setdefault(key, new_default(message)) @@ -702,18 +738,27 @@ def MessageDecoder(field_number, is_repeated, is_packed, key, new_default): if new_pos > end: raise _DecodeError('Truncated message.') # Read sub-message. - if value.add()._InternalParse(buffer, pos, new_pos) != new_pos: + current_depth += 1 + if current_depth > _recursion_limit: + raise _DecodeError( + 'Error parsing message: too many levels of nesting.' + ) + if ( + value.add()._InternalParse(buffer, pos, new_pos, current_depth) + != new_pos + ): # The only reason _InternalParse would return early is if it # encountered an end-group tag. raise _DecodeError('Unexpected end-group tag.') # Predict that the next tag is another copy of the same repeated field. + current_depth -= 1 pos = new_pos + tag_len if buffer[new_pos:pos] != tag_bytes or new_pos == end: # Prediction failed. Return. return new_pos return DecodeRepeatedField else: - def DecodeField(buffer, pos, end, message, field_dict): + def DecodeField(buffer, pos, end, message, field_dict, current_depth=0): value = field_dict.get(key) if value is None: value = field_dict.setdefault(key, new_default(message)) @@ -722,11 +767,14 @@ def MessageDecoder(field_number, is_repeated, is_packed, key, new_default): new_pos = pos + size if new_pos > end: raise _DecodeError('Truncated message.') - # Read sub-message. - if value._InternalParse(buffer, pos, new_pos) != new_pos: + current_depth += 1 + if current_depth > _recursion_limit: + raise _DecodeError('Error parsing message: too many levels of nesting.') + if value._InternalParse(buffer, pos, new_pos, current_depth) != new_pos: # The only reason _InternalParse would return early is if it encountered # an end-group tag. raise _DecodeError('Unexpected end-group tag.') + current_depth -= 1 return new_pos return DecodeField @@ -844,7 +892,8 @@ def MapDecoder(field_descriptor, new_default, is_message_map): # Can't read _concrete_class yet; might not be initialized. message_type = field_descriptor.message_type - def DecodeMap(buffer, pos, end, message, field_dict): + def DecodeMap(buffer, pos, end, message, field_dict, current_depth=0): + del current_depth # unused submsg = message_type._concrete_class() value = field_dict.get(key) if value is None: @@ -926,8 +975,16 @@ def _SkipGroup(buffer, pos, end): return pos pos = new_pos +DEFAULT_RECURSION_LIMIT = 100 +_recursion_limit = DEFAULT_RECURSION_LIMIT + + +def SetRecursionLimit(new_limit): + global _recursion_limit + _recursion_limit = new_limit + -def _DecodeUnknownFieldSet(buffer, pos, end_pos=None): +def _DecodeUnknownFieldSet(buffer, pos, end_pos=None, current_depth=0): """Decode UnknownFieldSet. Returns the UnknownFieldSet and new position.""" unknown_field_set = containers.UnknownFieldSet() @@ -937,14 +994,14 @@ def _DecodeUnknownFieldSet(buffer, pos, end_pos=None): field_number, wire_type = wire_format.UnpackTag(tag) if wire_type == wire_format.WIRETYPE_END_GROUP: break - (data, pos) = _DecodeUnknownField(buffer, pos, wire_type) + (data, pos) = _DecodeUnknownField(buffer, pos, wire_type, current_depth) # pylint: disable=protected-access unknown_field_set._add(field_number, wire_type, data) return (unknown_field_set, pos) -def _DecodeUnknownField(buffer, pos, wire_type): +def _DecodeUnknownField(buffer, pos, wire_type, current_depth=0): """Decode a unknown field. Returns the UnknownField and new position.""" if wire_type == wire_format.WIRETYPE_VARINT: @@ -958,7 +1015,12 @@ def _DecodeUnknownField(buffer, pos, wire_type): data = buffer[pos:pos+size].tobytes() pos += size elif wire_type == wire_format.WIRETYPE_START_GROUP: - (data, pos) = _DecodeUnknownFieldSet(buffer, pos) + print("MMP " + str(current_depth)) + current_depth += 1 + if current_depth >= _recursion_limit: + raise _DecodeError('Error parsing message: too many levels of nesting.') + (data, pos) = _DecodeUnknownFieldSet(buffer, pos, None, current_depth) + current_depth -= 1 elif wire_type == wire_format.WIRETYPE_END_GROUP: return (0, -1) else: diff --git a/python/google/protobuf/internal/python_message.py b/python/google/protobuf/internal/python_message.py index 2921d5cb6..c7fec8c1c 100644 --- a/python/google/protobuf/internal/python_message.py +++ b/python/google/protobuf/internal/python_message.py @@ -1141,7 +1141,7 @@ def _AddMergeFromStringMethod(message_descriptor, cls): local_SkipField = decoder.SkipField decoders_by_tag = cls._decoders_by_tag - def InternalParse(self, buffer, pos, end): + def InternalParse(self, buffer, pos, end, current_depth=0): """Create a message from serialized bytes. Args: @@ -1179,7 +1179,7 @@ def _AddMergeFromStringMethod(message_descriptor, cls): # TODO(jieluo): remove old_pos. old_pos = new_pos (data, new_pos) = decoder._DecodeUnknownField( - buffer, new_pos, wire_type) # pylint: disable=protected-access + buffer, new_pos, wire_type, current_depth) # pylint: disable=protected-access if new_pos == -1: return pos # pylint: disable=protected-access @@ -1192,7 +1192,7 @@ def _AddMergeFromStringMethod(message_descriptor, cls): (tag_bytes, buffer[old_pos:new_pos].tobytes())) pos = new_pos else: - pos = field_decoder(buffer, new_pos, end, self, field_dict) + pos = field_decoder(buffer, new_pos, end, self, field_dict, current_depth) if field_desc: self._UpdateOneofState(field_desc) return pos diff --git a/python/google/protobuf/internal/self_recursive.proto b/python/google/protobuf/internal/self_recursive.proto new file mode 100644 index 000000000..2a7aacb0b --- /dev/null +++ b/python/google/protobuf/internal/self_recursive.proto @@ -0,0 +1,17 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2024 Google Inc. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file or at +// https://developers.google.com/open-source/licenses/bsd + +syntax = "proto2"; + +package google.protobuf.python.internal; + +message SelfRecursive { + optional group RecursiveGroup = 1 { + optional RecursiveGroup sub_group = 2; + optional int32 i = 3; + }; +} \ No newline at end of file -- 2.34.1