Nodes: cache socket identifier to index mapping

While this preprocessing does take some time upfront,
it avoids longer lookup later on, especially as nodes get
more sockets.

It's probably possible to make this more efficient in some cases
but this is good enough for now.
This commit is contained in:
Jacques Lucke 2021-06-11 16:21:08 +02:00
parent 7b30a3e98d
commit 605ce623be
4 changed files with 130 additions and 31 deletions

View File

@ -1381,27 +1381,6 @@ LockedNode::~LockedNode()
}
}
/* TODO: Use a map data structure or so to make this faster. */
static DInputSocket get_input_by_identifier(const DNode node, const StringRef identifier)
{
for (const InputSocketRef *socket : node->inputs()) {
if (socket->identifier() == identifier) {
return {node.context(), socket};
}
}
return {};
}
static DOutputSocket get_output_by_identifier(const DNode node, const StringRef identifier)
{
for (const OutputSocketRef *socket : node->outputs()) {
if (socket->identifier() == identifier) {
return {node.context(), socket};
}
}
return {};
}
NodeParamsProvider::NodeParamsProvider(GeometryNodesEvaluator &evaluator,
DNode dnode,
NodeState &node_state)
@ -1415,7 +1394,7 @@ NodeParamsProvider::NodeParamsProvider(GeometryNodesEvaluator &evaluator,
bool NodeParamsProvider::can_get_input(StringRef identifier) const
{
const DInputSocket socket = get_input_by_identifier(this->dnode, identifier);
const DInputSocket socket = this->dnode.input_by_identifier(identifier);
BLI_assert(socket);
InputState &input_state = node_state_.inputs[socket->index()];
@ -1433,7 +1412,7 @@ bool NodeParamsProvider::can_get_input(StringRef identifier) const
bool NodeParamsProvider::can_set_output(StringRef identifier) const
{
const DOutputSocket socket = get_output_by_identifier(this->dnode, identifier);
const DOutputSocket socket = this->dnode.output_by_identifier(identifier);
BLI_assert(socket);
OutputState &output_state = node_state_.outputs[socket->index()];
@ -1442,7 +1421,7 @@ bool NodeParamsProvider::can_set_output(StringRef identifier) const
GMutablePointer NodeParamsProvider::extract_input(StringRef identifier)
{
const DInputSocket socket = get_input_by_identifier(this->dnode, identifier);
const DInputSocket socket = this->dnode.input_by_identifier(identifier);
BLI_assert(socket);
BLI_assert(!socket->is_multi_input_socket());
BLI_assert(this->can_get_input(identifier));
@ -1456,7 +1435,7 @@ GMutablePointer NodeParamsProvider::extract_input(StringRef identifier)
Vector<GMutablePointer> NodeParamsProvider::extract_multi_input(StringRef identifier)
{
const DInputSocket socket = get_input_by_identifier(this->dnode, identifier);
const DInputSocket socket = this->dnode.input_by_identifier(identifier);
BLI_assert(socket);
BLI_assert(socket->is_multi_input_socket());
BLI_assert(this->can_get_input(identifier));
@ -1487,7 +1466,7 @@ Vector<GMutablePointer> NodeParamsProvider::extract_multi_input(StringRef identi
GPointer NodeParamsProvider::get_input(StringRef identifier) const
{
const DInputSocket socket = get_input_by_identifier(this->dnode, identifier);
const DInputSocket socket = this->dnode.input_by_identifier(identifier);
BLI_assert(socket);
BLI_assert(!socket->is_multi_input_socket());
BLI_assert(this->can_get_input(identifier));
@ -1505,7 +1484,7 @@ GMutablePointer NodeParamsProvider::alloc_output_value(const CPPType &type)
void NodeParamsProvider::set_output(StringRef identifier, GMutablePointer value)
{
const DOutputSocket socket = get_output_by_identifier(this->dnode, identifier);
const DOutputSocket socket = this->dnode.output_by_identifier(identifier);
BLI_assert(socket);
evaluator_.log_socket_value(socket, value);
@ -1519,7 +1498,7 @@ void NodeParamsProvider::set_output(StringRef identifier, GMutablePointer value)
bool NodeParamsProvider::lazy_require_input(StringRef identifier)
{
BLI_assert(node_supports_laziness(this->dnode));
const DInputSocket socket = get_input_by_identifier(this->dnode, identifier);
const DInputSocket socket = this->dnode.input_by_identifier(identifier);
BLI_assert(socket);
InputState &input_state = node_state_.inputs[socket->index()];
@ -1533,7 +1512,7 @@ bool NodeParamsProvider::lazy_require_input(StringRef identifier)
void NodeParamsProvider::set_input_unused(StringRef identifier)
{
const DInputSocket socket = get_input_by_identifier(this->dnode, identifier);
const DInputSocket socket = this->dnode.input_by_identifier(identifier);
BLI_assert(socket);
LockedNode locked_node{evaluator_, this->dnode, node_state_};
@ -1542,7 +1521,7 @@ void NodeParamsProvider::set_input_unused(StringRef identifier)
bool NodeParamsProvider::output_is_required(StringRef identifier) const
{
const DOutputSocket socket = get_output_by_identifier(this->dnode, identifier);
const DOutputSocket socket = this->dnode.output_by_identifier(identifier);
BLI_assert(socket);
OutputState &output_state = node_state_.outputs[socket->index()];
@ -1555,7 +1534,7 @@ bool NodeParamsProvider::output_is_required(StringRef identifier) const
bool NodeParamsProvider::lazy_output_is_required(StringRef identifier) const
{
BLI_assert(node_supports_laziness(this->dnode));
const DOutputSocket socket = get_output_by_identifier(this->dnode, identifier);
const DOutputSocket socket = this->dnode.output_by_identifier(identifier);
BLI_assert(socket);
OutputState &output_state = node_state_.outputs[socket->index()];

View File

@ -95,6 +95,9 @@ class DNode {
DInputSocket input(int index) const;
DOutputSocket output(int index) const;
DInputSocket input_by_identifier(StringRef identifier) const;
DOutputSocket output_by_identifier(StringRef identifier) const;
};
/* A (nullable) reference to a socket and the context it is in. It is unique within an entire
@ -288,6 +291,16 @@ inline DOutputSocket DNode::output(int index) const
return {context_, &node_ref_->output(index)};
}
inline DInputSocket DNode::input_by_identifier(StringRef identifier) const
{
return {context_, &node_ref_->input_by_identifier(identifier)};
}
inline DOutputSocket DNode::output_by_identifier(StringRef identifier) const
{
return {context_, &node_ref_->output_by_identifier(identifier)};
}
/* --------------------------------------------------------------------
* DSocket inline methods.
*/

View File

@ -70,6 +70,8 @@ class NodeTreeRef;
class LinkRef;
class InternalLinkRef;
using SocketIndexByIdentifierMap = Map<std::string, int>;
class SocketRef : NonCopyable, NonMovable {
protected:
NodeRef *node_;
@ -169,6 +171,8 @@ class NodeRef : NonCopyable, NonMovable {
Vector<InputSocketRef *> inputs_;
Vector<OutputSocketRef *> outputs_;
Vector<InternalLinkRef *> internal_links_;
SocketIndexByIdentifierMap *input_index_by_identifier_;
SocketIndexByIdentifierMap *output_index_by_identifier_;
friend NodeTreeRef;
@ -182,6 +186,9 @@ class NodeRef : NonCopyable, NonMovable {
const InputSocketRef &input(int index) const;
const OutputSocketRef &output(int index) const;
const InputSocketRef &input_by_identifier(StringRef identifier) const;
const OutputSocketRef &output_by_identifier(StringRef identifier) const;
bNode *bnode() const;
bNodeTree *btree() const;
@ -246,6 +253,7 @@ class NodeTreeRef : NonCopyable, NonMovable {
Vector<OutputSocketRef *> output_sockets_;
Vector<LinkRef *> links_;
MultiValueMap<const bNodeType *, NodeRef *> nodes_by_type_;
Vector<std::unique_ptr<SocketIndexByIdentifierMap>> owned_identifier_maps_;
public:
NodeTreeRef(bNodeTree *btree);
@ -279,6 +287,7 @@ class NodeTreeRef : NonCopyable, NonMovable {
bNodeSocket *bsocket);
void create_linked_socket_caches();
void create_socket_identifier_maps();
};
using NodeTreeRefMap = Map<bNodeTree *, std::unique_ptr<const NodeTreeRef>>;
@ -502,6 +511,18 @@ inline const OutputSocketRef &NodeRef::output(int index) const
return *outputs_[index];
}
inline const InputSocketRef &NodeRef::input_by_identifier(StringRef identifier) const
{
const int index = input_index_by_identifier_->lookup_as(identifier);
return this->input(index);
}
inline const OutputSocketRef &NodeRef::output_by_identifier(StringRef identifier) const
{
const int index = output_index_by_identifier_->lookup_as(identifier);
return this->output(index);
}
inline bNode *NodeRef::bnode() const
{
return bnode_;

View File

@ -14,6 +14,8 @@
* Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
*/
#include <mutex>
#include "NOD_node_tree_ref.hh"
#include "BLI_dot_export.hh"
@ -105,6 +107,7 @@ NodeTreeRef::NodeTreeRef(bNodeTree *btree) : btree_(btree)
}
}
this->create_socket_identifier_maps();
this->create_linked_socket_caches();
for (NodeRef *node : nodes_by_id_) {
@ -316,6 +319,89 @@ void OutputSocketRef::foreach_logical_target(
}
}
namespace {
struct SocketByIdentifierMap {
SocketIndexByIdentifierMap *map = nullptr;
std::unique_ptr<SocketIndexByIdentifierMap> owned_map;
};
} // namespace
static std::unique_ptr<SocketIndexByIdentifierMap> create_identifier_map(const ListBase &sockets)
{
std::unique_ptr<SocketIndexByIdentifierMap> map = std::make_unique<SocketIndexByIdentifierMap>();
int index;
LISTBASE_FOREACH_INDEX (bNodeSocket *, socket, &sockets, index) {
map->add_new(socket->identifier, index);
}
return map;
}
/* This function is not threadsafe. */
static SocketByIdentifierMap get_or_create_identifier_map(
const bNode &node, const ListBase &sockets, const bNodeSocketTemplate *sockets_template)
{
SocketByIdentifierMap map;
if (sockets_template == nullptr) {
if (BLI_listbase_is_empty(&sockets)) {
static SocketIndexByIdentifierMap empty_map;
map.map = &empty_map;
}
else if (node.type == NODE_REROUTE) {
if (&node.inputs == &sockets) {
static SocketIndexByIdentifierMap reroute_input_map = [] {
SocketIndexByIdentifierMap map;
map.add_new("Input", 0);
return map;
}();
map.map = &reroute_input_map;
}
else {
static SocketIndexByIdentifierMap reroute_output_map = [] {
SocketIndexByIdentifierMap map;
map.add_new("Output", 0);
return map;
}();
map.map = &reroute_output_map;
}
}
else {
/* The node has a dynamic amount of sockets. Therefore we need to create a new map. */
map.owned_map = create_identifier_map(sockets);
map.map = &*map.owned_map;
}
}
else {
/* Cache only one map for nodes that have the same sockets. */
static Map<const bNodeSocketTemplate *, std::unique_ptr<SocketIndexByIdentifierMap>> maps;
map.map = &*maps.lookup_or_add_cb(sockets_template,
[&]() { return create_identifier_map(sockets); });
}
return map;
}
void NodeTreeRef::create_socket_identifier_maps()
{
/* `get_or_create_identifier_map` is not threadsafe, therefore we have to hold a lock here. */
static std::mutex mutex;
std::lock_guard lock{mutex};
for (NodeRef *node : nodes_by_id_) {
bNode &bnode = *node->bnode_;
SocketByIdentifierMap inputs_map = get_or_create_identifier_map(
bnode, bnode.inputs, bnode.typeinfo->inputs);
SocketByIdentifierMap outputs_map = get_or_create_identifier_map(
bnode, bnode.outputs, bnode.typeinfo->outputs);
node->input_index_by_identifier_ = inputs_map.map;
node->output_index_by_identifier_ = outputs_map.map;
if (inputs_map.owned_map) {
owned_identifier_maps_.append(std::move(inputs_map.owned_map));
}
if (outputs_map.owned_map) {
owned_identifier_maps_.append(std::move(outputs_map.owned_map));
}
}
}
static bool has_link_cycles_recursive(const NodeRef &node,
MutableSpan<bool> visited,
MutableSpan<bool> is_in_stack)