Compositor: Merge equal operations

Some operations can take a lot of time to execute and
any duplication should be avoided.

This patch implements a compile step that detects
operations with the same type, inputs and parameters that
produce the same result and merge them. Now operations
can generate a hash that represents their output result. They only
need to implement `hash_output_params` and hash any parameter
that affects the output result.

Reviewed By: jbakker

Differential Revision: https://developer.blender.org/D12341
This commit is contained in:
Manuel Castilla 2021-09-04 16:59:02 +02:00
parent d84c79a218
commit b225a7c470
9 changed files with 373 additions and 11 deletions

View File

@ -647,6 +647,7 @@ if(WITH_GTESTS)
tests/COM_BufferArea_test.cc
tests/COM_BufferRange_test.cc
tests/COM_BuffersIterator_test.cc
tests/COM_NodeOperation_test.cc
)
set(TEST_INC
)

View File

@ -425,7 +425,8 @@ bool DebugInfo::graphviz_system(const ExecutionSystem *system, char *str, int ma
}
const bool has_execution_groups = system->getContext().get_execution_model() ==
eExecutionModel::Tiled;
eExecutionModel::Tiled &&
system->m_groups.size() > 0;
len += graphviz_legend(str + len, maxlen > len ? maxlen - len : 0, has_execution_groups);
len += snprintf(str + len, maxlen > len ? maxlen - len : 0, "}\r\n");

View File

@ -41,6 +41,49 @@ NodeOperation::NodeOperation()
this->m_btree = nullptr;
}
/**
* Generate a hash that identifies the operation result in the current execution.
* Requires `hash_output_params` to be implemented, otherwise `std::nullopt` is returned.
* If the operation parameters or its linked inputs change, the hash must be re-generated.
*/
std::optional<NodeOperationHash> NodeOperation::generate_hash()
{
params_hash_ = get_default_hash_2(m_width, m_height);
/* Hash subclasses params. */
is_hash_output_params_implemented_ = true;
hash_output_params();
if (!is_hash_output_params_implemented_) {
return std::nullopt;
}
hash_param(getOutputSocket()->getDataType());
NodeOperationHash hash;
hash.params_hash_ = params_hash_;
hash.parents_hash_ = 0;
for (NodeOperationInput &socket : m_inputs) {
NodeOperation &input = socket.getLink()->getOperation();
const bool is_constant = input.get_flags().is_constant_operation;
combine_hashes(hash.parents_hash_, get_default_hash(is_constant));
if (is_constant) {
const float *elem = ((ConstantOperation *)&input)->get_constant_elem();
const int num_channels = COM_data_type_num_channels(socket.getDataType());
for (const int i : IndexRange(num_channels)) {
combine_hashes(hash.parents_hash_, get_default_hash(elem[i]));
}
}
else {
combine_hashes(hash.parents_hash_, get_default_hash(input.get_id()));
}
}
hash.type_hash_ = typeid(*this).hash_code();
hash.operation_ = this;
return hash;
}
NodeOperationOutput *NodeOperation::getOutputSocket(unsigned int index)
{
return &m_outputs[index];

View File

@ -22,6 +22,8 @@
#include <sstream>
#include <string>
#include "BLI_ghash.h"
#include "BLI_hash.hh"
#include "BLI_math_color.h"
#include "BLI_math_vector.h"
#include "BLI_threads.h"
@ -269,6 +271,42 @@ struct NodeOperationFlags {
}
};
/** Hash that identifies an operation output result in the current execution. */
struct NodeOperationHash {
private:
NodeOperation *operation_;
size_t type_hash_;
size_t parents_hash_;
size_t params_hash_;
friend class NodeOperation;
public:
NodeOperation *get_operation() const
{
return operation_;
}
bool operator==(const NodeOperationHash &other) const
{
return type_hash_ == other.type_hash_ && parents_hash_ == other.parents_hash_ &&
params_hash_ == other.params_hash_;
}
bool operator!=(const NodeOperationHash &other) const
{
return !(*this == other);
}
bool operator<(const NodeOperationHash &other) const
{
return type_hash_ < other.type_hash_ ||
(type_hash_ == other.type_hash_ && parents_hash_ < other.parents_hash_) ||
(type_hash_ == other.type_hash_ && parents_hash_ == other.parents_hash_ &&
params_hash_ < other.params_hash_);
}
};
/**
* \brief NodeOperation contains calculation logic
*
@ -282,6 +320,9 @@ class NodeOperation {
Vector<NodeOperationInput> m_inputs;
Vector<NodeOperationOutput> m_outputs;
size_t params_hash_;
bool is_hash_output_params_implemented_;
/**
* \brief the index of the input socket that will be used to determine the resolution
*/
@ -363,6 +404,8 @@ class NodeOperation {
return flags;
}
std::optional<NodeOperationHash> generate_hash();
unsigned int getNumberOfInputSockets() const
{
return m_inputs.size();
@ -624,6 +667,33 @@ class NodeOperation {
protected:
NodeOperation();
/* Overridden by subclasses to allow merging equal operations on compiling. Implementations must
* hash any subclass parameter that affects the output result using `hash_params` methods. */
virtual void hash_output_params()
{
is_hash_output_params_implemented_ = false;
}
static void combine_hashes(size_t &combined, size_t other)
{
combined = BLI_ghashutil_combine_hash(combined, other);
}
template<typename T> void hash_param(T param)
{
combine_hashes(params_hash_, get_default_hash(param));
}
template<typename T1, typename T2> void hash_params(T1 param1, T2 param2)
{
combine_hashes(params_hash_, get_default_hash_2(param1, param2));
}
template<typename T1, typename T2, typename T3> void hash_params(T1 param1, T2 param2, T3 param3)
{
combine_hashes(params_hash_, get_default_hash_3(param1, param2, param3));
}
void addInputSocket(DataType datatype, ResizeMode resize_mode = ResizeMode::Center);
void addOutputSocket(DataType datatype);

View File

@ -101,16 +101,16 @@ void NodeOperationBuilder::convertToOperations(ExecutionSystem *system)
add_datatype_conversions();
if (m_context->get_execution_model() == eExecutionModel::FullFrame) {
/* Copy operations to system. Needed for graphviz. */
system->set_operations(m_operations, {});
DebugInfo::graphviz(system, "compositor_prior_folding");
save_graphviz("compositor_prior_folding");
ConstantFolder folder(*this);
folder.fold_operations();
}
determineResolutions();
save_graphviz("compositor_prior_merging");
merge_equal_operations();
if (m_context->get_execution_model() == eExecutionModel::Tiled) {
/* surround complex ops with read/write buffer */
add_complex_operation_buffers();
@ -149,22 +149,28 @@ void NodeOperationBuilder::replace_operation_with_constant(NodeOperation *operat
ConstantOperation *constant_operation)
{
BLI_assert(constant_operation->getNumberOfInputSockets() == 0);
unlink_inputs_and_relink_outputs(operation, constant_operation);
addOperation(constant_operation);
}
void NodeOperationBuilder::unlink_inputs_and_relink_outputs(NodeOperation *unlinked_op,
NodeOperation *linked_op)
{
int i = 0;
while (i < m_links.size()) {
Link &link = m_links[i];
if (&link.to()->getOperation() == operation) {
if (&link.to()->getOperation() == unlinked_op) {
link.to()->setLink(nullptr);
m_links.remove(i);
continue;
}
if (&link.from()->getOperation() == operation) {
link.to()->setLink(constant_operation->getOutputSocket());
m_links[i] = Link(constant_operation->getOutputSocket(), link.to());
if (&link.from()->getOperation() == unlinked_op) {
link.to()->setLink(linked_op->getOutputSocket());
m_links[i] = Link(linked_op->getOutputSocket(), link.to());
}
i++;
}
addOperation(constant_operation);
}
void NodeOperationBuilder::mapInputSocket(NodeInput *node_socket,
@ -456,6 +462,48 @@ void NodeOperationBuilder::determineResolutions()
}
}
static Vector<NodeOperationHash> generate_hashes(Span<NodeOperation *> operations)
{
Vector<NodeOperationHash> hashes;
for (NodeOperation *op : operations) {
std::optional<NodeOperationHash> hash = op->generate_hash();
if (hash) {
hashes.append(std::move(*hash));
}
}
return hashes;
}
/** Merge operations with same type, inputs and parameters that produce the same result. */
void NodeOperationBuilder::merge_equal_operations()
{
bool any_merged = true;
while (any_merged) {
/* Re-generate hashes with any change. */
Vector<NodeOperationHash> hashes = generate_hashes(m_operations);
/* Make hashes be consecutive when they are equal. */
std::sort(hashes.begin(), hashes.end());
any_merged = false;
const NodeOperationHash *prev_hash = nullptr;
for (const NodeOperationHash &hash : hashes) {
if (prev_hash && *prev_hash == hash) {
merge_equal_operations(prev_hash->get_operation(), hash.get_operation());
any_merged = true;
}
prev_hash = &hash;
}
}
}
void NodeOperationBuilder::merge_equal_operations(NodeOperation *from, NodeOperation *into)
{
unlink_inputs_and_relink_outputs(from, into);
m_operations.remove_first_occurrence_and_reorder(from);
delete from;
}
Vector<NodeOperationInput *> NodeOperationBuilder::cache_output_links(
NodeOperationOutput *output) const
{
@ -728,6 +776,14 @@ void NodeOperationBuilder::group_operations()
}
}
void NodeOperationBuilder::save_graphviz(StringRefNull name)
{
if (COM_EXPORT_GRAPHVIZ) {
exec_system_->set_operations(m_operations, m_groups);
DebugInfo::graphviz(exec_system_, name);
}
}
/** Create a graphviz representation of the NodeOperationBuilder. */
std::ostream &operator<<(std::ostream &os, const NodeOperationBuilder &builder)
{

View File

@ -169,7 +169,10 @@ class NodeOperationBuilder {
private:
PreviewOperation *make_preview_operation() const;
void unlink_inputs_and_relink_outputs(NodeOperation *unlinked_op, NodeOperation *linked_op);
void merge_equal_operations();
void merge_equal_operations(NodeOperation *from, NodeOperation *into);
void save_graphviz(StringRefNull name = "");
#ifdef WITH_CXX_GUARDEDALLOC
MEM_CXX_CLASS_ALLOC_FUNCS("COM:NodeCompilerImpl")
#endif

View File

@ -40,6 +40,10 @@ void ConvertBaseOperation::deinitExecution()
this->m_inputOperation = nullptr;
}
void ConvertBaseOperation::hash_output_params()
{
}
void ConvertBaseOperation::update_memory_buffer_partial(MemoryBuffer *output,
const rcti &area,
Span<MemoryBuffer *> inputs)
@ -269,6 +273,12 @@ void ConvertRGBToYCCOperation::executePixelSampled(float output[4],
output[3] = inputColor[3];
}
void ConvertRGBToYCCOperation::hash_output_params()
{
ConvertBaseOperation::hash_output_params();
hash_param(m_mode);
}
void ConvertRGBToYCCOperation::update_memory_buffer_partial(BuffersIterator<float> &it)
{
for (; !it.is_end(); ++it) {
@ -327,6 +337,12 @@ void ConvertYCCToRGBOperation::executePixelSampled(float output[4],
output[3] = inputColor[3];
}
void ConvertYCCToRGBOperation::hash_output_params()
{
ConvertBaseOperation::hash_output_params();
hash_param(m_mode);
}
void ConvertYCCToRGBOperation::update_memory_buffer_partial(BuffersIterator<float> &it)
{
for (; !it.is_end(); ++it) {

View File

@ -37,6 +37,7 @@ class ConvertBaseOperation : public MultiThreadedOperation {
Span<MemoryBuffer *> inputs) final;
protected:
virtual void hash_output_params() override;
virtual void update_memory_buffer_partial(BuffersIterator<float> &it) = 0;
};
@ -124,6 +125,7 @@ class ConvertRGBToYCCOperation : public ConvertBaseOperation {
void setMode(int mode);
protected:
void hash_output_params() override;
void update_memory_buffer_partial(BuffersIterator<float> &it) override;
};
@ -141,6 +143,7 @@ class ConvertYCCToRGBOperation : public ConvertBaseOperation {
void setMode(int mode);
protected:
void hash_output_params() override;
void update_memory_buffer_partial(BuffersIterator<float> &it) override;
};

View File

@ -0,0 +1,169 @@
/*
* This program is free software; you can redistribute it and/or
* modify it under the terms of the GNU General Public License
* as published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software Foundation,
* Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
*
* Copyright 2021, Blender Foundation.
*/
#include "testing/testing.h"
#include "COM_ConstantOperation.h"
namespace blender::compositor::tests {
class NonHashedOperation : public NodeOperation {
public:
NonHashedOperation(int id)
{
set_id(id);
addOutputSocket(DataType::Value);
setWidth(2);
setHeight(3);
}
};
class NonHashedConstantOperation : public ConstantOperation {
float constant_;
public:
NonHashedConstantOperation(int id)
{
set_id(id);
addOutputSocket(DataType::Value);
setWidth(2);
setHeight(3);
constant_ = 1.0f;
}
const float *get_constant_elem() override
{
return &constant_;
}
void set_constant(float value)
{
constant_ = value;
}
};
class HashedOperation : public NodeOperation {
private:
int param1;
float param2;
public:
HashedOperation(NodeOperation &input, int width, int height)
{
addInputSocket(DataType::Value);
addOutputSocket(DataType::Color);
setWidth(width);
setHeight(height);
param1 = 2;
param2 = 7.0f;
getInputSocket(0)->setLink(input.getOutputSocket());
}
void set_param1(int value)
{
param1 = value;
}
void hash_output_params() override
{
hash_params(param1, param2);
}
};
static void test_non_equal_hashes_compare(NodeOperationHash &h1,
NodeOperationHash &h2,
NodeOperationHash &h3)
{
if (h1 < h2) {
if (h3 < h1) {
EXPECT_TRUE(h3 < h2);
}
else if (h3 < h2) {
EXPECT_TRUE(h1 < h3);
}
else {
EXPECT_TRUE(h1 < h3);
EXPECT_TRUE(h2 < h3);
}
}
else {
EXPECT_TRUE(h2 < h1);
}
}
TEST(NodeOperation, generate_hash)
{
/* Constant input. */
{
NonHashedConstantOperation input_op1(1);
input_op1.set_constant(1.0f);
EXPECT_EQ(input_op1.generate_hash(), std::nullopt);
HashedOperation op1(input_op1, 6, 4);
std::optional<NodeOperationHash> hash1_opt = op1.generate_hash();
EXPECT_NE(hash1_opt, std::nullopt);
NodeOperationHash hash1 = *hash1_opt;
NonHashedConstantOperation input_op2(1);
input_op2.set_constant(1.0f);
HashedOperation op2(input_op2, 6, 4);
NodeOperationHash hash2 = *op2.generate_hash();
EXPECT_EQ(hash1, hash2);
input_op2.set_constant(3.0f);
hash2 = *op2.generate_hash();
EXPECT_NE(hash1, hash2);
}
/* Non constant input. */
{
NonHashedOperation input_op(1);
EXPECT_EQ(input_op.generate_hash(), std::nullopt);
HashedOperation op1(input_op, 6, 4);
HashedOperation op2(input_op, 6, 4);
NodeOperationHash hash1 = *op1.generate_hash();
NodeOperationHash hash2 = *op2.generate_hash();
EXPECT_EQ(hash1, hash2);
op1.set_param1(-1);
hash1 = *op1.generate_hash();
EXPECT_NE(hash1, hash2);
HashedOperation op3(input_op, 11, 14);
NodeOperationHash hash3 = *op3.generate_hash();
EXPECT_NE(hash2, hash3);
EXPECT_NE(hash1, hash3);
test_non_equal_hashes_compare(hash1, hash2, hash3);
test_non_equal_hashes_compare(hash3, hash2, hash1);
test_non_equal_hashes_compare(hash2, hash3, hash1);
test_non_equal_hashes_compare(hash3, hash1, hash2);
NonHashedOperation input_op2(2);
HashedOperation op4(input_op2, 11, 14);
NodeOperationHash hash4 = *op4.generate_hash();
EXPECT_NE(hash3, hash4);
input_op2.set_id(1);
hash4 = *op4.generate_hash();
EXPECT_EQ(hash3, hash4);
}
}
} // namespace blender::compositor::tests