Geometry Nodes: avoid some overhead during field inferencing

Previously, the same `FieldInferencingInterface` was build for every node
multiple times. Now only once during the inferencing. Going forward,
it would be even better to store the inferencing interface as part of the
node declaration to avoid building it during inferencing at all.
This commit is contained in:
Jacques Lucke 2022-12-20 14:36:00 +01:00
parent dedca2c994
commit ddd24186d9
1 changed files with 49 additions and 22 deletions

View File

@ -5,6 +5,7 @@
#include "NOD_node_declaration.hh"
#include "BLI_resource_scope.hh"
#include "BLI_set.hh"
#include "BLI_stack.hh"
@ -91,9 +92,10 @@ static OutputFieldDependency get_interface_output_field_dependency(const bNode &
return socket_decl.output_field_dependency();
}
static FieldInferencingInterface get_dummy_field_inferencing_interface(const bNode &node)
static const FieldInferencingInterface &get_dummy_field_inferencing_interface(const bNode &node,
ResourceScope &scope)
{
FieldInferencingInterface inferencing_interface;
auto &inferencing_interface = scope.construct<FieldInferencingInterface>();
inferencing_interface.inputs.append_n_times(InputSocketFieldType::None,
node.input_sockets().size());
inferencing_interface.outputs.append_n_times(OutputFieldDependency::ForDataSource(),
@ -106,17 +108,19 @@ static FieldInferencingInterface get_dummy_field_inferencing_interface(const bNo
* In the future, this information can be stored in the node declaration. This would allow this
* function to return a reference, making it more efficient.
*/
static FieldInferencingInterface get_node_field_inferencing_interface(const bNode &node)
static const FieldInferencingInterface &get_node_field_inferencing_interface(const bNode &node,
ResourceScope &scope)
{
/* Node groups already reference all required information, so just return that. */
if (node.is_group()) {
bNodeTree *group = (bNodeTree *)node.id;
if (group == nullptr) {
return FieldInferencingInterface();
static const FieldInferencingInterface empty_interface;
return empty_interface;
}
if (!ntreeIsRegistered(group)) {
/* This can happen when there is a linked node group that was not found (see T92799). */
return get_dummy_field_inferencing_interface(node);
return get_dummy_field_inferencing_interface(node, scope);
}
if (!group->runtime->field_inferencing_interface) {
/* This shouldn't happen because referenced node groups should always be updated first. */
@ -125,7 +129,7 @@ static FieldInferencingInterface get_node_field_inferencing_interface(const bNod
return *group->runtime->field_inferencing_interface;
}
FieldInferencingInterface inferencing_interface;
auto &inferencing_interface = scope.construct<FieldInferencingInterface>();
for (const bNodeSocket *input_socket : node.input_sockets()) {
inferencing_interface.inputs.append(get_interface_input_field_type(node, *input_socket));
}
@ -185,7 +189,9 @@ static Vector<const bNodeSocket *> gather_input_socket_dependencies(
* to figure out if it is always a field or if it depends on any group inputs.
*/
static OutputFieldDependency find_group_output_dependencies(
const bNodeSocket &group_output_socket, const Span<SocketFieldState> field_state_by_socket_id)
const bNodeSocket &group_output_socket,
const Span<const FieldInferencingInterface *> interface_by_node,
const Span<SocketFieldState> field_state_by_socket_id)
{
if (!is_field_socket_type(group_output_socket)) {
return OutputFieldDependency::ForDataSource();
@ -227,8 +233,8 @@ static OutputFieldDependency find_group_output_dependencies(
}
}
else if (!origin_state.is_single) {
const FieldInferencingInterface inferencing_interface =
get_node_field_inferencing_interface(origin_node);
const FieldInferencingInterface &inferencing_interface =
*interface_by_node[origin_node.index()];
const OutputFieldDependency &field_dependency =
inferencing_interface.outputs[origin_socket->index()];
@ -251,13 +257,14 @@ static OutputFieldDependency find_group_output_dependencies(
}
static void propagate_data_requirements_from_right_to_left(
const bNodeTree &tree, const MutableSpan<SocketFieldState> field_state_by_socket_id)
const bNodeTree &tree,
const Span<const FieldInferencingInterface *> interface_by_node,
const MutableSpan<SocketFieldState> field_state_by_socket_id)
{
const Span<const bNode *> toposort_result = tree.toposort_right_to_left();
for (const bNode *node : toposort_result) {
const FieldInferencingInterface inferencing_interface = get_node_field_inferencing_interface(
*node);
const FieldInferencingInterface &inferencing_interface = *interface_by_node[node->index()];
for (const bNodeSocket *output_socket : node->output_sockets()) {
SocketFieldState &state = field_state_by_socket_id[output_socket->index_in_tree()];
@ -369,7 +376,9 @@ static void determine_group_input_states(
}
static void propagate_field_status_from_left_to_right(
const bNodeTree &tree, const MutableSpan<SocketFieldState> field_state_by_socket_id)
const bNodeTree &tree,
const Span<const FieldInferencingInterface *> interface_by_node,
const MutableSpan<SocketFieldState> field_state_by_socket_id)
{
const Span<const bNode *> toposort_result = tree.toposort_left_to_right();
@ -378,8 +387,7 @@ static void propagate_field_status_from_left_to_right(
continue;
}
const FieldInferencingInterface inferencing_interface = get_node_field_inferencing_interface(
*node);
const FieldInferencingInterface &inferencing_interface = *interface_by_node[node->index()];
/* Update field state of input sockets, also taking into account linked origin sockets. */
for (const bNodeSocket *input_socket : node->input_sockets()) {
@ -440,9 +448,11 @@ static void propagate_field_status_from_left_to_right(
}
}
static void determine_group_output_states(const bNodeTree &tree,
FieldInferencingInterface &new_inferencing_interface,
const Span<SocketFieldState> field_state_by_socket_id)
static void determine_group_output_states(
const bNodeTree &tree,
FieldInferencingInterface &new_inferencing_interface,
const Span<const FieldInferencingInterface *> interface_by_node,
const Span<SocketFieldState> field_state_by_socket_id)
{
const bNode *group_output_node = tree.group_output_node();
if (!group_output_node) {
@ -451,7 +461,7 @@ static void determine_group_output_states(const bNodeTree &tree,
for (const bNodeSocket *group_output_socket : group_output_node->input_sockets().drop_back(1)) {
OutputFieldDependency field_dependency = find_group_output_dependencies(
*group_output_socket, field_state_by_socket_id);
*group_output_socket, interface_by_node, field_state_by_socket_id);
new_inferencing_interface.outputs[group_output_socket->index()] = std::move(field_dependency);
}
}
@ -486,10 +496,25 @@ static void update_socket_shapes(const bNodeTree &tree,
}
}
static void prepare_inferencing_interfaces(
const Span<const bNode *> nodes,
MutableSpan<const FieldInferencingInterface *> interface_by_node,
ResourceScope &scope)
{
for (const int i : nodes.index_range()) {
interface_by_node[i] = &get_node_field_inferencing_interface(*nodes[i], scope);
}
}
bool update_field_inferencing(const bNodeTree &tree)
{
tree.ensure_topology_cache();
const Span<const bNode *> nodes = tree.all_nodes();
ResourceScope scope;
Array<const FieldInferencingInterface *> interface_by_node(nodes.size());
prepare_inferencing_interfaces(nodes, interface_by_node, scope);
/* Create new inferencing interface for this node group. */
std::unique_ptr<FieldInferencingInterface> new_inferencing_interface =
std::make_unique<FieldInferencingInterface>();
@ -501,10 +526,12 @@ bool update_field_inferencing(const bNodeTree &tree)
/* Keep track of the state of all sockets. The index into this array is #SocketRef::id(). */
Array<SocketFieldState> field_state_by_socket_id(tree.all_sockets().size());
propagate_data_requirements_from_right_to_left(tree, field_state_by_socket_id);
propagate_data_requirements_from_right_to_left(
tree, interface_by_node, field_state_by_socket_id);
determine_group_input_states(tree, *new_inferencing_interface, field_state_by_socket_id);
propagate_field_status_from_left_to_right(tree, field_state_by_socket_id);
determine_group_output_states(tree, *new_inferencing_interface, field_state_by_socket_id);
propagate_field_status_from_left_to_right(tree, interface_by_node, field_state_by_socket_id);
determine_group_output_states(
tree, *new_inferencing_interface, interface_by_node, field_state_by_socket_id);
update_socket_shapes(tree, field_state_by_socket_id);
/* Update the previous group interface. */