blob: 35b6ec69228523342f9a851aab7e77b91f0f56a7 [file] [log] [blame]
// Protocol Buffers - Google's data interchange format
// Copyright 2008 Google Inc. All rights reserved.
// https://developers.google.com/protocol-buffers/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// Author: anuraag@google.com (Anuraag Agrawal)
// Author: tibell@google.com (Johan Tibell)
#include <google/protobuf/pyext/extension_dict.h>
#include <memory>
#include <google/protobuf/stubs/logging.h>
#include <google/protobuf/stubs/common.h>
#include <google/protobuf/descriptor.h>
#include <google/protobuf/dynamic_message.h>
#include <google/protobuf/message.h>
#include <google/protobuf/descriptor.pb.h>
#include <google/protobuf/pyext/descriptor.h>
#include <google/protobuf/pyext/message.h>
#include <google/protobuf/pyext/message_factory.h>
#include <google/protobuf/pyext/repeated_composite_container.h>
#include <google/protobuf/pyext/repeated_scalar_container.h>
#include <google/protobuf/pyext/scoped_pyobject_ptr.h>
#if PY_MAJOR_VERSION >= 3
#if PY_VERSION_HEX < 0x03030000
#error "Python 3.0 - 3.2 are not supported."
#endif
#define PyString_AsStringAndSize(ob, charpp, sizep) \
(PyUnicode_Check(ob) ? ((*(charpp) = const_cast<char*>( \
PyUnicode_AsUTF8AndSize(ob, (sizep)))) == NULL \
? -1 \
: 0) \
: PyBytes_AsStringAndSize(ob, (charpp), (sizep)))
#endif
namespace google {
namespace protobuf {
namespace python {
namespace extension_dict {
static Py_ssize_t len(ExtensionDict* self) {
Py_ssize_t size = 0;
std::vector<const FieldDescriptor*> fields;
self->parent->message->GetReflection()->ListFields(*self->parent->message,
&fields);
for (size_t i = 0; i < fields.size(); ++i) {
if (fields[i]->is_extension()) {
// With C++ descriptors, the field can always be retrieved, but for
// unknown extensions which have not been imported in Python code, there
// is no message class and we cannot retrieve the value.
// ListFields() has the same behavior.
if (fields[i]->message_type() != nullptr &&
message_factory::GetMessageClass(
cmessage::GetFactoryForMessage(self->parent),
fields[i]->message_type()) == nullptr) {
PyErr_Clear();
continue;
}
++size;
}
}
return size;
}
struct ExtensionIterator {
PyObject_HEAD;
Py_ssize_t index;
std::vector<const FieldDescriptor*> fields;
// Owned reference, to keep the FieldDescriptors alive.
ExtensionDict* extension_dict;
};
PyObject* GetIter(PyObject* _self) {
ExtensionDict* self = reinterpret_cast<ExtensionDict*>(_self);
ScopedPyObjectPtr obj(PyType_GenericAlloc(&ExtensionIterator_Type, 0));
if (obj == nullptr) {
return PyErr_Format(PyExc_MemoryError,
"Could not allocate extension iterator");
}
ExtensionIterator* iter = reinterpret_cast<ExtensionIterator*>(obj.get());
// Call "placement new" to initialize. So the constructor of
// std::vector<...> fields will be called.
new (iter) ExtensionIterator;
self->parent->message->GetReflection()->ListFields(*self->parent->message,
&iter->fields);
iter->index = 0;
Py_INCREF(self);
iter->extension_dict = self;
return obj.release();
}
static void DeallocExtensionIterator(PyObject* _self) {
ExtensionIterator* self = reinterpret_cast<ExtensionIterator*>(_self);
self->fields.clear();
Py_XDECREF(self->extension_dict);
self->~ExtensionIterator();
Py_TYPE(_self)->tp_free(_self);
}
PyObject* subscript(ExtensionDict* self, PyObject* key) {
const FieldDescriptor* descriptor = cmessage::GetExtensionDescriptor(key);
if (descriptor == NULL) {
return NULL;
}
if (!CheckFieldBelongsToMessage(descriptor, self->parent->message)) {
return NULL;
}
if (descriptor->label() != FieldDescriptor::LABEL_REPEATED &&
descriptor->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) {
return cmessage::InternalGetScalar(self->parent->message, descriptor);
}
CMessage::CompositeFieldsMap::iterator iterator =
self->parent->composite_fields->find(descriptor);
if (iterator != self->parent->composite_fields->end()) {
Py_INCREF(iterator->second);
return iterator->second->AsPyObject();
}
if (descriptor->label() != FieldDescriptor::LABEL_REPEATED &&
descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
// TODO(plabatut): consider building the class on the fly!
ContainerBase* sub_message = cmessage::InternalGetSubMessage(
self->parent, descriptor);
if (sub_message == NULL) {
return NULL;
}
(*self->parent->composite_fields)[descriptor] = sub_message;
return sub_message->AsPyObject();
}
if (descriptor->label() == FieldDescriptor::LABEL_REPEATED) {
if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
// On the fly message class creation is needed to support the following
// situation:
// 1- add FileDescriptor to the pool that contains extensions of a message
// defined by another proto file. Do not create any message classes.
// 2- instantiate an extended message, and access the extension using
// the field descriptor.
// 3- the extension submessage fails to be returned, because no class has
// been created.
// It happens when deserializing text proto format, or when enumerating
// fields of a deserialized message.
CMessageClass* message_class = message_factory::GetOrCreateMessageClass(
cmessage::GetFactoryForMessage(self->parent),
descriptor->message_type());
ScopedPyObjectPtr message_class_handler(
reinterpret_cast<PyObject*>(message_class));
if (message_class == NULL) {
return NULL;
}
ContainerBase* py_container = repeated_composite_container::NewContainer(
self->parent, descriptor, message_class);
if (py_container == NULL) {
return NULL;
}
(*self->parent->composite_fields)[descriptor] = py_container;
return py_container->AsPyObject();
} else {
ContainerBase* py_container = repeated_scalar_container::NewContainer(
self->parent, descriptor);
if (py_container == NULL) {
return NULL;
}
(*self->parent->composite_fields)[descriptor] = py_container;
return py_container->AsPyObject();
}
}
PyErr_SetString(PyExc_ValueError, "control reached unexpected line");
return NULL;
}
int ass_subscript(ExtensionDict* self, PyObject* key, PyObject* value) {
const FieldDescriptor* descriptor = cmessage::GetExtensionDescriptor(key);
if (descriptor == NULL) {
return -1;
}
if (!CheckFieldBelongsToMessage(descriptor, self->parent->message)) {
return -1;
}
if (descriptor->label() != FieldDescriptor::LABEL_OPTIONAL ||
descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
PyErr_SetString(PyExc_TypeError, "Extension is repeated and/or composite "
"type");
return -1;
}
cmessage::AssureWritable(self->parent);
if (cmessage::InternalSetScalar(self->parent, descriptor, value) < 0) {
return -1;
}
return 0;
}
PyObject* _FindExtensionByName(ExtensionDict* self, PyObject* arg) {
char* name;
Py_ssize_t name_size;
if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) {
return NULL;
}
PyDescriptorPool* pool = cmessage::GetFactoryForMessage(self->parent)->pool;
const FieldDescriptor* message_extension =
pool->pool->FindExtensionByName(string(name, name_size));
if (message_extension == NULL) {
// Is is the name of a message set extension?
const Descriptor* message_descriptor = pool->pool->FindMessageTypeByName(
string(name, name_size));
if (message_descriptor && message_descriptor->extension_count() > 0) {
const FieldDescriptor* extension = message_descriptor->extension(0);
if (extension->is_extension() &&
extension->containing_type()->options().message_set_wire_format() &&
extension->type() == FieldDescriptor::TYPE_MESSAGE &&
extension->label() == FieldDescriptor::LABEL_OPTIONAL) {
message_extension = extension;
}
}
}
if (message_extension == NULL) {
Py_RETURN_NONE;
}
return PyFieldDescriptor_FromDescriptor(message_extension);
}
PyObject* _FindExtensionByNumber(ExtensionDict* self, PyObject* arg) {
int64 number = PyLong_AsLong(arg);
if (number == -1 && PyErr_Occurred()) {
return NULL;
}
PyDescriptorPool* pool = cmessage::GetFactoryForMessage(self->parent)->pool;
const FieldDescriptor* message_extension = pool->pool->FindExtensionByNumber(
self->parent->message->GetDescriptor(), number);
if (message_extension == NULL) {
Py_RETURN_NONE;
}
return PyFieldDescriptor_FromDescriptor(message_extension);
}
static int Contains(PyObject* _self, PyObject* key) {
ExtensionDict* self = reinterpret_cast<ExtensionDict*>(_self);
const FieldDescriptor* field_descriptor =
cmessage::GetExtensionDescriptor(key);
if (field_descriptor == nullptr) {
return -1;
}
if (!field_descriptor->is_extension()) {
PyErr_Format(PyExc_KeyError, "%s is not an extension",
field_descriptor->full_name().c_str());
return -1;
}
const Message* message = self->parent->message;
const Reflection* reflection = message->GetReflection();
if (field_descriptor->is_repeated()) {
if (reflection->FieldSize(*message, field_descriptor) > 0) {
return 1;
}
} else {
if (reflection->HasField(*message, field_descriptor)) {
return 1;
}
}
return 0;
}
ExtensionDict* NewExtensionDict(CMessage *parent) {
ExtensionDict* self = reinterpret_cast<ExtensionDict*>(
PyType_GenericAlloc(&ExtensionDict_Type, 0));
if (self == NULL) {
return NULL;
}
Py_INCREF(parent);
self->parent = parent;
return self;
}
void dealloc(PyObject* pself) {
ExtensionDict* self = reinterpret_cast<ExtensionDict*>(pself);
Py_CLEAR(self->parent);
Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self));
}
static PyObject* RichCompare(ExtensionDict* self, PyObject* other, int opid) {
// Only equality comparisons are implemented.
if (opid != Py_EQ && opid != Py_NE) {
Py_INCREF(Py_NotImplemented);
return Py_NotImplemented;
}
bool equals = false;
if (PyObject_TypeCheck(other, &ExtensionDict_Type)) {
equals = self->parent == reinterpret_cast<ExtensionDict*>(other)->parent;;
}
if (equals ^ (opid == Py_EQ)) {
Py_RETURN_FALSE;
} else {
Py_RETURN_TRUE;
}
}
static PySequenceMethods SeqMethods = {
(lenfunc)len, // sq_length
0, // sq_concat
0, // sq_repeat
0, // sq_item
0, // sq_slice
0, // sq_ass_item
0, // sq_ass_slice
(objobjproc)Contains, // sq_contains
};
static PyMappingMethods MpMethods = {
(lenfunc)len, /* mp_length */
(binaryfunc)subscript, /* mp_subscript */
(objobjargproc)ass_subscript,/* mp_ass_subscript */
};
#define EDMETHOD(name, args, doc) { #name, (PyCFunction)name, args, doc }
static PyMethodDef Methods[] = {
EDMETHOD(_FindExtensionByName, METH_O, "Finds an extension by name."),
EDMETHOD(_FindExtensionByNumber, METH_O,
"Finds an extension by field number."),
{NULL, NULL},
};
} // namespace extension_dict
PyTypeObject ExtensionDict_Type = {
PyVarObject_HEAD_INIT(&PyType_Type, 0) //
FULL_MODULE_NAME ".ExtensionDict", // tp_name
sizeof(ExtensionDict), // tp_basicsize
0, // tp_itemsize
(destructor)extension_dict::dealloc, // tp_dealloc
0, // tp_print
0, // tp_getattr
0, // tp_setattr
0, // tp_compare
0, // tp_repr
0, // tp_as_number
&extension_dict::SeqMethods, // tp_as_sequence
&extension_dict::MpMethods, // tp_as_mapping
PyObject_HashNotImplemented, // tp_hash
0, // tp_call
0, // tp_str
0, // tp_getattro
0, // tp_setattro
0, // tp_as_buffer
Py_TPFLAGS_DEFAULT, // tp_flags
"An extension dict", // tp_doc
0, // tp_traverse
0, // tp_clear
(richcmpfunc)extension_dict::RichCompare, // tp_richcompare
0, // tp_weaklistoffset
extension_dict::GetIter, // tp_iter
0, // tp_iternext
extension_dict::Methods, // tp_methods
0, // tp_members
0, // tp_getset
0, // tp_base
0, // tp_dict
0, // tp_descr_get
0, // tp_descr_set
0, // tp_dictoffset
0, // tp_init
};
PyObject* IterNext(PyObject* _self) {
extension_dict::ExtensionIterator* self =
reinterpret_cast<extension_dict::ExtensionIterator*>(_self);
Py_ssize_t total_size = self->fields.size();
Py_ssize_t index = self->index;
while (self->index < total_size) {
index = self->index;
++self->index;
if (self->fields[index]->is_extension()) {
// With C++ descriptors, the field can always be retrieved, but for
// unknown extensions which have not been imported in Python code, there
// is no message class and we cannot retrieve the value.
// ListFields() has the same behavior.
if (self->fields[index]->message_type() != nullptr &&
message_factory::GetMessageClass(
cmessage::GetFactoryForMessage(self->extension_dict->parent),
self->fields[index]->message_type()) == nullptr) {
PyErr_Clear();
continue;
}
return PyFieldDescriptor_FromDescriptor(self->fields[index]);
}
}
return nullptr;
}
PyTypeObject ExtensionIterator_Type = {
PyVarObject_HEAD_INIT(&PyType_Type, 0) //
FULL_MODULE_NAME ".ExtensionIterator", // tp_name
sizeof(extension_dict::ExtensionIterator), // tp_basicsize
0, // tp_itemsize
extension_dict::DeallocExtensionIterator, // tp_dealloc
0, // tp_print
0, // tp_getattr
0, // tp_setattr
0, // tp_compare
0, // tp_repr
0, // tp_as_number
0, // tp_as_sequence
0, // tp_as_mapping
0, // tp_hash
0, // tp_call
0, // tp_str
0, // tp_getattro
0, // tp_setattro
0, // tp_as_buffer
Py_TPFLAGS_DEFAULT, // tp_flags
"A scalar map iterator", // tp_doc
0, // tp_traverse
0, // tp_clear
0, // tp_richcompare
0, // tp_weaklistoffset
PyObject_SelfIter, // tp_iter
IterNext, // tp_iternext
0, // tp_methods
0, // tp_members
0, // tp_getset
0, // tp_base
0, // tp_dict
0, // tp_descr_get
0, // tp_descr_set
0, // tp_dictoffset
0, // tp_init
};
} // namespace python
} // namespace protobuf
} // namespace google