Compositor: Merge equal operations
Some operations can take a lot of time to execute and any duplication should be avoided. This patch implements a compile step that detects operations with the same type, inputs and parameters that produce the same result and merge them. Now operations can generate a hash that represents their output result. They only need to implement `hash_output_params` and hash any parameter that affects the output result. Reviewed By: jbakker Differential Revision: https://developer.blender.org/D12341
This commit is contained in:
parent
d84c79a218
commit
b225a7c470
|
@ -647,6 +647,7 @@ if(WITH_GTESTS)
|
|||
tests/COM_BufferArea_test.cc
|
||||
tests/COM_BufferRange_test.cc
|
||||
tests/COM_BuffersIterator_test.cc
|
||||
tests/COM_NodeOperation_test.cc
|
||||
)
|
||||
set(TEST_INC
|
||||
)
|
||||
|
|
|
@ -425,7 +425,8 @@ bool DebugInfo::graphviz_system(const ExecutionSystem *system, char *str, int ma
|
|||
}
|
||||
|
||||
const bool has_execution_groups = system->getContext().get_execution_model() ==
|
||||
eExecutionModel::Tiled;
|
||||
eExecutionModel::Tiled &&
|
||||
system->m_groups.size() > 0;
|
||||
len += graphviz_legend(str + len, maxlen > len ? maxlen - len : 0, has_execution_groups);
|
||||
|
||||
len += snprintf(str + len, maxlen > len ? maxlen - len : 0, "}\r\n");
|
||||
|
|
|
@ -41,6 +41,49 @@ NodeOperation::NodeOperation()
|
|||
this->m_btree = nullptr;
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a hash that identifies the operation result in the current execution.
|
||||
* Requires `hash_output_params` to be implemented, otherwise `std::nullopt` is returned.
|
||||
* If the operation parameters or its linked inputs change, the hash must be re-generated.
|
||||
*/
|
||||
std::optional<NodeOperationHash> NodeOperation::generate_hash()
|
||||
{
|
||||
params_hash_ = get_default_hash_2(m_width, m_height);
|
||||
|
||||
/* Hash subclasses params. */
|
||||
is_hash_output_params_implemented_ = true;
|
||||
hash_output_params();
|
||||
if (!is_hash_output_params_implemented_) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
hash_param(getOutputSocket()->getDataType());
|
||||
NodeOperationHash hash;
|
||||
hash.params_hash_ = params_hash_;
|
||||
|
||||
hash.parents_hash_ = 0;
|
||||
for (NodeOperationInput &socket : m_inputs) {
|
||||
NodeOperation &input = socket.getLink()->getOperation();
|
||||
const bool is_constant = input.get_flags().is_constant_operation;
|
||||
combine_hashes(hash.parents_hash_, get_default_hash(is_constant));
|
||||
if (is_constant) {
|
||||
const float *elem = ((ConstantOperation *)&input)->get_constant_elem();
|
||||
const int num_channels = COM_data_type_num_channels(socket.getDataType());
|
||||
for (const int i : IndexRange(num_channels)) {
|
||||
combine_hashes(hash.parents_hash_, get_default_hash(elem[i]));
|
||||
}
|
||||
}
|
||||
else {
|
||||
combine_hashes(hash.parents_hash_, get_default_hash(input.get_id()));
|
||||
}
|
||||
}
|
||||
|
||||
hash.type_hash_ = typeid(*this).hash_code();
|
||||
hash.operation_ = this;
|
||||
|
||||
return hash;
|
||||
}
|
||||
|
||||
NodeOperationOutput *NodeOperation::getOutputSocket(unsigned int index)
|
||||
{
|
||||
return &m_outputs[index];
|
||||
|
|
|
@ -22,6 +22,8 @@
|
|||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
#include "BLI_ghash.h"
|
||||
#include "BLI_hash.hh"
|
||||
#include "BLI_math_color.h"
|
||||
#include "BLI_math_vector.h"
|
||||
#include "BLI_threads.h"
|
||||
|
@ -269,6 +271,42 @@ struct NodeOperationFlags {
|
|||
}
|
||||
};
|
||||
|
||||
/** Hash that identifies an operation output result in the current execution. */
|
||||
struct NodeOperationHash {
|
||||
private:
|
||||
NodeOperation *operation_;
|
||||
size_t type_hash_;
|
||||
size_t parents_hash_;
|
||||
size_t params_hash_;
|
||||
|
||||
friend class NodeOperation;
|
||||
|
||||
public:
|
||||
NodeOperation *get_operation() const
|
||||
{
|
||||
return operation_;
|
||||
}
|
||||
|
||||
bool operator==(const NodeOperationHash &other) const
|
||||
{
|
||||
return type_hash_ == other.type_hash_ && parents_hash_ == other.parents_hash_ &&
|
||||
params_hash_ == other.params_hash_;
|
||||
}
|
||||
|
||||
bool operator!=(const NodeOperationHash &other) const
|
||||
{
|
||||
return !(*this == other);
|
||||
}
|
||||
|
||||
bool operator<(const NodeOperationHash &other) const
|
||||
{
|
||||
return type_hash_ < other.type_hash_ ||
|
||||
(type_hash_ == other.type_hash_ && parents_hash_ < other.parents_hash_) ||
|
||||
(type_hash_ == other.type_hash_ && parents_hash_ == other.parents_hash_ &&
|
||||
params_hash_ < other.params_hash_);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* \brief NodeOperation contains calculation logic
|
||||
*
|
||||
|
@ -282,6 +320,9 @@ class NodeOperation {
|
|||
Vector<NodeOperationInput> m_inputs;
|
||||
Vector<NodeOperationOutput> m_outputs;
|
||||
|
||||
size_t params_hash_;
|
||||
bool is_hash_output_params_implemented_;
|
||||
|
||||
/**
|
||||
* \brief the index of the input socket that will be used to determine the resolution
|
||||
*/
|
||||
|
@ -363,6 +404,8 @@ class NodeOperation {
|
|||
return flags;
|
||||
}
|
||||
|
||||
std::optional<NodeOperationHash> generate_hash();
|
||||
|
||||
unsigned int getNumberOfInputSockets() const
|
||||
{
|
||||
return m_inputs.size();
|
||||
|
@ -624,6 +667,33 @@ class NodeOperation {
|
|||
protected:
|
||||
NodeOperation();
|
||||
|
||||
/* Overridden by subclasses to allow merging equal operations on compiling. Implementations must
|
||||
* hash any subclass parameter that affects the output result using `hash_params` methods. */
|
||||
virtual void hash_output_params()
|
||||
{
|
||||
is_hash_output_params_implemented_ = false;
|
||||
}
|
||||
|
||||
static void combine_hashes(size_t &combined, size_t other)
|
||||
{
|
||||
combined = BLI_ghashutil_combine_hash(combined, other);
|
||||
}
|
||||
|
||||
template<typename T> void hash_param(T param)
|
||||
{
|
||||
combine_hashes(params_hash_, get_default_hash(param));
|
||||
}
|
||||
|
||||
template<typename T1, typename T2> void hash_params(T1 param1, T2 param2)
|
||||
{
|
||||
combine_hashes(params_hash_, get_default_hash_2(param1, param2));
|
||||
}
|
||||
|
||||
template<typename T1, typename T2, typename T3> void hash_params(T1 param1, T2 param2, T3 param3)
|
||||
{
|
||||
combine_hashes(params_hash_, get_default_hash_3(param1, param2, param3));
|
||||
}
|
||||
|
||||
void addInputSocket(DataType datatype, ResizeMode resize_mode = ResizeMode::Center);
|
||||
void addOutputSocket(DataType datatype);
|
||||
|
||||
|
|
|
@ -101,16 +101,16 @@ void NodeOperationBuilder::convertToOperations(ExecutionSystem *system)
|
|||
add_datatype_conversions();
|
||||
|
||||
if (m_context->get_execution_model() == eExecutionModel::FullFrame) {
|
||||
/* Copy operations to system. Needed for graphviz. */
|
||||
system->set_operations(m_operations, {});
|
||||
|
||||
DebugInfo::graphviz(system, "compositor_prior_folding");
|
||||
save_graphviz("compositor_prior_folding");
|
||||
ConstantFolder folder(*this);
|
||||
folder.fold_operations();
|
||||
}
|
||||
|
||||
determineResolutions();
|
||||
|
||||
save_graphviz("compositor_prior_merging");
|
||||
merge_equal_operations();
|
||||
|
||||
if (m_context->get_execution_model() == eExecutionModel::Tiled) {
|
||||
/* surround complex ops with read/write buffer */
|
||||
add_complex_operation_buffers();
|
||||
|
@ -149,22 +149,28 @@ void NodeOperationBuilder::replace_operation_with_constant(NodeOperation *operat
|
|||
ConstantOperation *constant_operation)
|
||||
{
|
||||
BLI_assert(constant_operation->getNumberOfInputSockets() == 0);
|
||||
unlink_inputs_and_relink_outputs(operation, constant_operation);
|
||||
addOperation(constant_operation);
|
||||
}
|
||||
|
||||
void NodeOperationBuilder::unlink_inputs_and_relink_outputs(NodeOperation *unlinked_op,
|
||||
NodeOperation *linked_op)
|
||||
{
|
||||
int i = 0;
|
||||
while (i < m_links.size()) {
|
||||
Link &link = m_links[i];
|
||||
if (&link.to()->getOperation() == operation) {
|
||||
if (&link.to()->getOperation() == unlinked_op) {
|
||||
link.to()->setLink(nullptr);
|
||||
m_links.remove(i);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (&link.from()->getOperation() == operation) {
|
||||
link.to()->setLink(constant_operation->getOutputSocket());
|
||||
m_links[i] = Link(constant_operation->getOutputSocket(), link.to());
|
||||
if (&link.from()->getOperation() == unlinked_op) {
|
||||
link.to()->setLink(linked_op->getOutputSocket());
|
||||
m_links[i] = Link(linked_op->getOutputSocket(), link.to());
|
||||
}
|
||||
i++;
|
||||
}
|
||||
addOperation(constant_operation);
|
||||
}
|
||||
|
||||
void NodeOperationBuilder::mapInputSocket(NodeInput *node_socket,
|
||||
|
@ -456,6 +462,48 @@ void NodeOperationBuilder::determineResolutions()
|
|||
}
|
||||
}
|
||||
|
||||
static Vector<NodeOperationHash> generate_hashes(Span<NodeOperation *> operations)
|
||||
{
|
||||
Vector<NodeOperationHash> hashes;
|
||||
for (NodeOperation *op : operations) {
|
||||
std::optional<NodeOperationHash> hash = op->generate_hash();
|
||||
if (hash) {
|
||||
hashes.append(std::move(*hash));
|
||||
}
|
||||
}
|
||||
return hashes;
|
||||
}
|
||||
|
||||
/** Merge operations with same type, inputs and parameters that produce the same result. */
|
||||
void NodeOperationBuilder::merge_equal_operations()
|
||||
{
|
||||
bool any_merged = true;
|
||||
while (any_merged) {
|
||||
/* Re-generate hashes with any change. */
|
||||
Vector<NodeOperationHash> hashes = generate_hashes(m_operations);
|
||||
|
||||
/* Make hashes be consecutive when they are equal. */
|
||||
std::sort(hashes.begin(), hashes.end());
|
||||
|
||||
any_merged = false;
|
||||
const NodeOperationHash *prev_hash = nullptr;
|
||||
for (const NodeOperationHash &hash : hashes) {
|
||||
if (prev_hash && *prev_hash == hash) {
|
||||
merge_equal_operations(prev_hash->get_operation(), hash.get_operation());
|
||||
any_merged = true;
|
||||
}
|
||||
prev_hash = &hash;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void NodeOperationBuilder::merge_equal_operations(NodeOperation *from, NodeOperation *into)
|
||||
{
|
||||
unlink_inputs_and_relink_outputs(from, into);
|
||||
m_operations.remove_first_occurrence_and_reorder(from);
|
||||
delete from;
|
||||
}
|
||||
|
||||
Vector<NodeOperationInput *> NodeOperationBuilder::cache_output_links(
|
||||
NodeOperationOutput *output) const
|
||||
{
|
||||
|
@ -728,6 +776,14 @@ void NodeOperationBuilder::group_operations()
|
|||
}
|
||||
}
|
||||
|
||||
void NodeOperationBuilder::save_graphviz(StringRefNull name)
|
||||
{
|
||||
if (COM_EXPORT_GRAPHVIZ) {
|
||||
exec_system_->set_operations(m_operations, m_groups);
|
||||
DebugInfo::graphviz(exec_system_, name);
|
||||
}
|
||||
}
|
||||
|
||||
/** Create a graphviz representation of the NodeOperationBuilder. */
|
||||
std::ostream &operator<<(std::ostream &os, const NodeOperationBuilder &builder)
|
||||
{
|
||||
|
|
|
@ -169,7 +169,10 @@ class NodeOperationBuilder {
|
|||
|
||||
private:
|
||||
PreviewOperation *make_preview_operation() const;
|
||||
|
||||
void unlink_inputs_and_relink_outputs(NodeOperation *unlinked_op, NodeOperation *linked_op);
|
||||
void merge_equal_operations();
|
||||
void merge_equal_operations(NodeOperation *from, NodeOperation *into);
|
||||
void save_graphviz(StringRefNull name = "");
|
||||
#ifdef WITH_CXX_GUARDEDALLOC
|
||||
MEM_CXX_CLASS_ALLOC_FUNCS("COM:NodeCompilerImpl")
|
||||
#endif
|
||||
|
|
|
@ -40,6 +40,10 @@ void ConvertBaseOperation::deinitExecution()
|
|||
this->m_inputOperation = nullptr;
|
||||
}
|
||||
|
||||
void ConvertBaseOperation::hash_output_params()
|
||||
{
|
||||
}
|
||||
|
||||
void ConvertBaseOperation::update_memory_buffer_partial(MemoryBuffer *output,
|
||||
const rcti &area,
|
||||
Span<MemoryBuffer *> inputs)
|
||||
|
@ -269,6 +273,12 @@ void ConvertRGBToYCCOperation::executePixelSampled(float output[4],
|
|||
output[3] = inputColor[3];
|
||||
}
|
||||
|
||||
void ConvertRGBToYCCOperation::hash_output_params()
|
||||
{
|
||||
ConvertBaseOperation::hash_output_params();
|
||||
hash_param(m_mode);
|
||||
}
|
||||
|
||||
void ConvertRGBToYCCOperation::update_memory_buffer_partial(BuffersIterator<float> &it)
|
||||
{
|
||||
for (; !it.is_end(); ++it) {
|
||||
|
@ -327,6 +337,12 @@ void ConvertYCCToRGBOperation::executePixelSampled(float output[4],
|
|||
output[3] = inputColor[3];
|
||||
}
|
||||
|
||||
void ConvertYCCToRGBOperation::hash_output_params()
|
||||
{
|
||||
ConvertBaseOperation::hash_output_params();
|
||||
hash_param(m_mode);
|
||||
}
|
||||
|
||||
void ConvertYCCToRGBOperation::update_memory_buffer_partial(BuffersIterator<float> &it)
|
||||
{
|
||||
for (; !it.is_end(); ++it) {
|
||||
|
|
|
@ -37,6 +37,7 @@ class ConvertBaseOperation : public MultiThreadedOperation {
|
|||
Span<MemoryBuffer *> inputs) final;
|
||||
|
||||
protected:
|
||||
virtual void hash_output_params() override;
|
||||
virtual void update_memory_buffer_partial(BuffersIterator<float> &it) = 0;
|
||||
};
|
||||
|
||||
|
@ -124,6 +125,7 @@ class ConvertRGBToYCCOperation : public ConvertBaseOperation {
|
|||
void setMode(int mode);
|
||||
|
||||
protected:
|
||||
void hash_output_params() override;
|
||||
void update_memory_buffer_partial(BuffersIterator<float> &it) override;
|
||||
};
|
||||
|
||||
|
@ -141,6 +143,7 @@ class ConvertYCCToRGBOperation : public ConvertBaseOperation {
|
|||
void setMode(int mode);
|
||||
|
||||
protected:
|
||||
void hash_output_params() override;
|
||||
void update_memory_buffer_partial(BuffersIterator<float> &it) override;
|
||||
};
|
||||
|
||||
|
|
|
@ -0,0 +1,169 @@
|
|||
/*
|
||||
* This program is free software; you can redistribute it and/or
|
||||
* modify it under the terms of the GNU General Public License
|
||||
* as published by the Free Software Foundation; either version 2
|
||||
* of the License, or (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program; if not, write to the Free Software Foundation,
|
||||
* Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
|
||||
*
|
||||
* Copyright 2021, Blender Foundation.
|
||||
*/
|
||||
|
||||
#include "testing/testing.h"
|
||||
|
||||
#include "COM_ConstantOperation.h"
|
||||
|
||||
namespace blender::compositor::tests {
|
||||
|
||||
class NonHashedOperation : public NodeOperation {
|
||||
public:
|
||||
NonHashedOperation(int id)
|
||||
{
|
||||
set_id(id);
|
||||
addOutputSocket(DataType::Value);
|
||||
setWidth(2);
|
||||
setHeight(3);
|
||||
}
|
||||
};
|
||||
|
||||
class NonHashedConstantOperation : public ConstantOperation {
|
||||
float constant_;
|
||||
|
||||
public:
|
||||
NonHashedConstantOperation(int id)
|
||||
{
|
||||
set_id(id);
|
||||
addOutputSocket(DataType::Value);
|
||||
setWidth(2);
|
||||
setHeight(3);
|
||||
constant_ = 1.0f;
|
||||
}
|
||||
|
||||
const float *get_constant_elem() override
|
||||
{
|
||||
return &constant_;
|
||||
}
|
||||
|
||||
void set_constant(float value)
|
||||
{
|
||||
constant_ = value;
|
||||
}
|
||||
};
|
||||
|
||||
class HashedOperation : public NodeOperation {
|
||||
private:
|
||||
int param1;
|
||||
float param2;
|
||||
|
||||
public:
|
||||
HashedOperation(NodeOperation &input, int width, int height)
|
||||
{
|
||||
addInputSocket(DataType::Value);
|
||||
addOutputSocket(DataType::Color);
|
||||
setWidth(width);
|
||||
setHeight(height);
|
||||
param1 = 2;
|
||||
param2 = 7.0f;
|
||||
|
||||
getInputSocket(0)->setLink(input.getOutputSocket());
|
||||
}
|
||||
|
||||
void set_param1(int value)
|
||||
{
|
||||
param1 = value;
|
||||
}
|
||||
|
||||
void hash_output_params() override
|
||||
{
|
||||
hash_params(param1, param2);
|
||||
}
|
||||
};
|
||||
|
||||
static void test_non_equal_hashes_compare(NodeOperationHash &h1,
|
||||
NodeOperationHash &h2,
|
||||
NodeOperationHash &h3)
|
||||
{
|
||||
if (h1 < h2) {
|
||||
if (h3 < h1) {
|
||||
EXPECT_TRUE(h3 < h2);
|
||||
}
|
||||
else if (h3 < h2) {
|
||||
EXPECT_TRUE(h1 < h3);
|
||||
}
|
||||
else {
|
||||
EXPECT_TRUE(h1 < h3);
|
||||
EXPECT_TRUE(h2 < h3);
|
||||
}
|
||||
}
|
||||
else {
|
||||
EXPECT_TRUE(h2 < h1);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(NodeOperation, generate_hash)
|
||||
{
|
||||
/* Constant input. */
|
||||
{
|
||||
NonHashedConstantOperation input_op1(1);
|
||||
input_op1.set_constant(1.0f);
|
||||
EXPECT_EQ(input_op1.generate_hash(), std::nullopt);
|
||||
|
||||
HashedOperation op1(input_op1, 6, 4);
|
||||
std::optional<NodeOperationHash> hash1_opt = op1.generate_hash();
|
||||
EXPECT_NE(hash1_opt, std::nullopt);
|
||||
NodeOperationHash hash1 = *hash1_opt;
|
||||
|
||||
NonHashedConstantOperation input_op2(1);
|
||||
input_op2.set_constant(1.0f);
|
||||
HashedOperation op2(input_op2, 6, 4);
|
||||
NodeOperationHash hash2 = *op2.generate_hash();
|
||||
EXPECT_EQ(hash1, hash2);
|
||||
|
||||
input_op2.set_constant(3.0f);
|
||||
hash2 = *op2.generate_hash();
|
||||
EXPECT_NE(hash1, hash2);
|
||||
}
|
||||
|
||||
/* Non constant input. */
|
||||
{
|
||||
NonHashedOperation input_op(1);
|
||||
EXPECT_EQ(input_op.generate_hash(), std::nullopt);
|
||||
|
||||
HashedOperation op1(input_op, 6, 4);
|
||||
HashedOperation op2(input_op, 6, 4);
|
||||
NodeOperationHash hash1 = *op1.generate_hash();
|
||||
NodeOperationHash hash2 = *op2.generate_hash();
|
||||
EXPECT_EQ(hash1, hash2);
|
||||
op1.set_param1(-1);
|
||||
hash1 = *op1.generate_hash();
|
||||
EXPECT_NE(hash1, hash2);
|
||||
|
||||
HashedOperation op3(input_op, 11, 14);
|
||||
NodeOperationHash hash3 = *op3.generate_hash();
|
||||
EXPECT_NE(hash2, hash3);
|
||||
EXPECT_NE(hash1, hash3);
|
||||
|
||||
test_non_equal_hashes_compare(hash1, hash2, hash3);
|
||||
test_non_equal_hashes_compare(hash3, hash2, hash1);
|
||||
test_non_equal_hashes_compare(hash2, hash3, hash1);
|
||||
test_non_equal_hashes_compare(hash3, hash1, hash2);
|
||||
|
||||
NonHashedOperation input_op2(2);
|
||||
HashedOperation op4(input_op2, 11, 14);
|
||||
NodeOperationHash hash4 = *op4.generate_hash();
|
||||
EXPECT_NE(hash3, hash4);
|
||||
|
||||
input_op2.set_id(1);
|
||||
hash4 = *op4.generate_hash();
|
||||
EXPECT_EQ(hash3, hash4);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace blender::compositor::tests
|
Loading…
Reference in New Issue