Functions: implement common subnetwork elimination optimization
This was the last of the three network optimizations I developed in the functions branch. Common subnetwork elimination and constant folding together can get rid of most unnecessary nodes.
This commit is contained in:
parent
e3e42c00cb
commit
d1f4546a59
|
@ -25,6 +25,7 @@ namespace blender::fn::mf_network_optimization {
|
|||
|
||||
void dead_node_removal(MFNetwork &network);
|
||||
void constant_folding(MFNetwork &network, ResourceCollector &resources);
|
||||
void common_subnetwork_elimination(MFNetwork &network);
|
||||
|
||||
} // namespace blender::fn::mf_network_optimization
|
||||
|
||||
|
|
|
@ -18,10 +18,17 @@
|
|||
* \ingroup fn
|
||||
*/
|
||||
|
||||
/* Used to check if two multi-functions have the exact same type. */
|
||||
#include <typeinfo>
|
||||
|
||||
#include "FN_multi_function_builder.hh"
|
||||
#include "FN_multi_function_network_evaluation.hh"
|
||||
#include "FN_multi_function_network_optimization.hh"
|
||||
|
||||
#include "BLI_disjoint_set.hh"
|
||||
#include "BLI_ghash.h"
|
||||
#include "BLI_map.hh"
|
||||
#include "BLI_rand.h"
|
||||
#include "BLI_stack.hh"
|
||||
|
||||
namespace blender::fn::mf_network_optimization {
|
||||
|
@ -292,4 +299,179 @@ void constant_folding(MFNetwork &network, ResourceCollector &resources)
|
|||
|
||||
/** \} */
|
||||
|
||||
/* -------------------------------------------------------------------- */
|
||||
/** \name Common Subnetwork Elimination
|
||||
*
|
||||
* \{ */
|
||||
|
||||
static uint32_t compute_node_hash(MFFunctionNode &node, RNG *rng, Span<uint32_t> node_hashes)
|
||||
{
|
||||
uint32_t combined_inputs_hash = 394659347u;
|
||||
for (MFInputSocket *input_socket : node.inputs()) {
|
||||
MFOutputSocket *origin_socket = input_socket->origin();
|
||||
uint32_t input_hash;
|
||||
if (origin_socket == nullptr) {
|
||||
input_hash = BLI_rng_get_uint(rng);
|
||||
}
|
||||
else {
|
||||
input_hash = BLI_ghashutil_combine_hash(node_hashes[origin_socket->node().id()],
|
||||
origin_socket->index());
|
||||
}
|
||||
combined_inputs_hash = BLI_ghashutil_combine_hash(combined_inputs_hash, input_hash);
|
||||
}
|
||||
|
||||
uint32_t function_hash = node.function().hash();
|
||||
uint32_t node_hash = BLI_ghashutil_combine_hash(combined_inputs_hash, function_hash);
|
||||
return node_hash;
|
||||
}
|
||||
|
||||
/**
|
||||
* Produces a hash for every node. Two nodes with the same hash should have a high probability of
|
||||
* outputting the same values.
|
||||
*/
|
||||
static Array<uint32_t> compute_node_hashes(MFNetwork &network)
|
||||
{
|
||||
RNG *rng = BLI_rng_new(0);
|
||||
Array<uint32_t> node_hashes(network.node_id_amount());
|
||||
Array<bool> node_is_hashed(network.node_id_amount(), false);
|
||||
|
||||
/* No dummy nodes are not assumed to output the same values. */
|
||||
for (MFDummyNode *node : network.dummy_nodes()) {
|
||||
uint32_t node_hash = BLI_rng_get_uint(rng);
|
||||
node_hashes[node->id()] = node_hash;
|
||||
node_is_hashed[node->id()] = true;
|
||||
}
|
||||
|
||||
Stack<MFFunctionNode *> nodes_to_check;
|
||||
nodes_to_check.push_multiple(network.function_nodes());
|
||||
|
||||
while (!nodes_to_check.is_empty()) {
|
||||
MFFunctionNode &node = *nodes_to_check.peek();
|
||||
if (node_is_hashed[node.id()]) {
|
||||
nodes_to_check.pop();
|
||||
continue;
|
||||
}
|
||||
|
||||
/* Make sure that origin nodes are hashed first. */
|
||||
bool all_dependencies_ready = true;
|
||||
for (MFInputSocket *input_socket : node.inputs()) {
|
||||
MFOutputSocket *origin_socket = input_socket->origin();
|
||||
if (origin_socket != nullptr) {
|
||||
MFNode &origin_node = origin_socket->node();
|
||||
if (!node_is_hashed[origin_node.id()]) {
|
||||
all_dependencies_ready = false;
|
||||
nodes_to_check.push(&origin_node.as_function());
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!all_dependencies_ready) {
|
||||
continue;
|
||||
}
|
||||
|
||||
uint32_t node_hash = compute_node_hash(node, rng, node_hashes);
|
||||
node_hashes[node.id()] = node_hash;
|
||||
node_is_hashed[node.id()] = true;
|
||||
nodes_to_check.pop();
|
||||
}
|
||||
|
||||
BLI_rng_free(rng);
|
||||
return node_hashes;
|
||||
}
|
||||
|
||||
static Map<uint32_t, Vector<MFNode *, 1>> group_nodes_by_hash(MFNetwork &network,
|
||||
Span<uint32_t> node_hashes)
|
||||
{
|
||||
Map<uint32_t, Vector<MFNode *, 1>> nodes_by_hash;
|
||||
for (uint id : IndexRange(network.node_id_amount())) {
|
||||
MFNode *node = network.node_or_null_by_id(id);
|
||||
if (node != nullptr) {
|
||||
uint32_t node_hash = node_hashes[id];
|
||||
nodes_by_hash.lookup_or_add_default(node_hash).append(node);
|
||||
}
|
||||
}
|
||||
return nodes_by_hash;
|
||||
}
|
||||
|
||||
static bool functions_are_equal(const MultiFunction &a, const MultiFunction &b)
|
||||
{
|
||||
if (&a == &b) {
|
||||
return true;
|
||||
}
|
||||
if (typeid(a) == typeid(b)) {
|
||||
return a.equals(b);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool nodes_output_same_values(DisjointSet &cache, const MFNode &a, const MFNode &b)
|
||||
{
|
||||
if (cache.in_same_set(a.id(), b.id())) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (a.is_dummy() || b.is_dummy()) {
|
||||
return false;
|
||||
}
|
||||
if (!functions_are_equal(a.as_function().function(), b.as_function().function())) {
|
||||
return false;
|
||||
}
|
||||
for (uint i : a.inputs().index_range()) {
|
||||
const MFOutputSocket *origin_a = a.input(i).origin();
|
||||
const MFOutputSocket *origin_b = b.input(i).origin();
|
||||
if (origin_a == nullptr || origin_b == nullptr) {
|
||||
return false;
|
||||
}
|
||||
if (!nodes_output_same_values(cache, origin_a->node(), origin_b->node())) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
cache.join(a.id(), b.id());
|
||||
return true;
|
||||
}
|
||||
|
||||
static void relink_duplicate_nodes(MFNetwork &network,
|
||||
Map<uint32_t, Vector<MFNode *, 1>> &nodes_by_hash)
|
||||
{
|
||||
DisjointSet same_node_cache{network.node_id_amount()};
|
||||
|
||||
for (Span<MFNode *> nodes_with_same_hash : nodes_by_hash.values()) {
|
||||
if (nodes_with_same_hash.size() <= 1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
Vector<MFNode *, 16> nodes_to_check = nodes_with_same_hash;
|
||||
Vector<MFNode *, 16> remaining_nodes;
|
||||
while (nodes_to_check.size() >= 2) {
|
||||
MFNode &deduplicated_node = *nodes_to_check[0];
|
||||
for (MFNode *node : nodes_to_check.as_span().drop_front(1)) {
|
||||
/* This is true with fairly high probability, but hash collisions can happen. So we have to
|
||||
* check if the node actually output the same values. */
|
||||
if (nodes_output_same_values(same_node_cache, deduplicated_node, *node)) {
|
||||
for (uint i : deduplicated_node.outputs().index_range()) {
|
||||
network.relink(node->output(i), deduplicated_node.output(i));
|
||||
}
|
||||
}
|
||||
else {
|
||||
remaining_nodes.append(node);
|
||||
}
|
||||
}
|
||||
nodes_to_check = std::move(remaining_nodes);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Tries to detect duplicate subnetworks and eliminates them. This can help quite a lot when node
|
||||
* groups were used to create the network.
|
||||
*/
|
||||
void common_subnetwork_elimination(MFNetwork &network)
|
||||
{
|
||||
Array<uint32_t> node_hashes = compute_node_hashes(network);
|
||||
Map<uint32_t, Vector<MFNode *, 1>> nodes_by_hash = group_nodes_by_hash(network, node_hashes);
|
||||
relink_duplicate_nodes(network, nodes_by_hash);
|
||||
}
|
||||
|
||||
/** \} */
|
||||
|
||||
} // namespace blender::fn::mf_network_optimization
|
||||
|
|
Loading…
Reference in New Issue