Geometry Nodes: add field node type for constants

It is common to have fields that contain a constant value. Before this
commit, such constants were represented by operation nodes which
don't have inputs. Having a special node type for constants makes
working with them a bit cheaper.

It also allows skipping some unnecessary processing when evaluating
fields, because constant fields can be detected more easily.

This commit also generalizes the concept of field node types a bit.
This commit is contained in:
Jacques Lucke 2022-01-02 14:27:16 +01:00
parent 4c46203cb5
commit 8be217ada5
3 changed files with 181 additions and 85 deletions

View File

@ -59,12 +59,22 @@ namespace blender::fn {
class FieldInput;
struct FieldInputs;
/**
* Have a fixed set of base node types, because all code that works with field nodes has to
* understand those.
*/
enum class FieldNodeType {
Input,
Operation,
Constant,
};
/**
* A node in a field-tree. It has at least one output that can be referenced by fields.
*/
class FieldNode {
private:
bool is_input_;
FieldNodeType node_type_;
protected:
/**
@ -76,14 +86,13 @@ class FieldNode {
std::shared_ptr<const FieldInputs> field_inputs_;
public:
FieldNode(bool is_input);
FieldNode(FieldNodeType node_type);
virtual ~FieldNode() = default;
virtual const CPPType &output_cpp_type(int output_index) const = 0;
bool is_input() const;
bool is_operation() const;
FieldNodeType node_type() const;
bool depends_on_input() const;
const std::shared_ptr<const FieldInputs> &field_inputs() const;
@ -267,6 +276,20 @@ class FieldInput : public FieldNode {
const CPPType &output_cpp_type(int output_index) const override;
};
class FieldConstant : public FieldNode {
private:
const CPPType &type_;
void *value_;
public:
FieldConstant(const CPPType &type, const void *value);
~FieldConstant();
const CPPType &output_cpp_type(int output_index) const override;
const CPPType &type() const;
const GPointer value() const;
};
/**
* Keeps track of the inputs of a field.
*/
@ -468,9 +491,7 @@ template<typename T> T evaluate_constant_field(const Field<T> &field)
template<typename T> Field<T> make_constant_field(T value)
{
auto constant_fn = std::make_unique<fn::CustomMF_Constant<T>>(std::forward<T>(value));
auto operation = std::make_shared<FieldOperation>(std::move(constant_fn));
return Field<T>{GField{std::move(operation), 0}};
return make_constant_field(CPPType::get<T>(), &value);
}
GField make_constant_field(const CPPType &type, const void *value);
@ -552,18 +573,13 @@ template<typename T> struct ValueOrField {
/** \name #FieldNode Inline Methods
* \{ */
inline FieldNode::FieldNode(bool is_input) : is_input_(is_input)
inline FieldNode::FieldNode(const FieldNodeType node_type) : node_type_(node_type)
{
}
inline bool FieldNode::is_input() const
inline FieldNodeType FieldNode::node_type() const
{
return is_input_;
}
inline bool FieldNode::is_operation() const
{
return !is_input_;
return node_type_;
}
inline bool FieldNode::depends_on_input() const

View File

@ -268,6 +268,7 @@ class MFProcedure : NonCopyable, NonMovable {
Vector<MFReturnInstruction *> return_instructions_;
Vector<MFVariable *> variables_;
Vector<MFParameter> params_;
Vector<destruct_ptr<MultiFunction>> owned_functions_;
MFInstruction *entry_ = nullptr;
friend class MFProcedureDotExport;
@ -284,9 +285,10 @@ class MFProcedure : NonCopyable, NonMovable {
MFReturnInstruction &new_return_instruction();
void add_parameter(MFParamType::InterfaceType interface_type, MFVariable &variable);
Span<ConstMFParameter> params() const;
template<typename T, typename... Args> const MultiFunction &construct_function(Args &&...args);
MFInstruction *entry();
const MFInstruction *entry() const;
void set_entry(MFInstruction &entry);
@ -550,6 +552,15 @@ inline Span<const MFVariable *> MFProcedure::variables() const
return variables_;
}
template<typename T, typename... Args>
inline const MultiFunction &MFProcedure::construct_function(Args &&...args)
{
destruct_ptr<T> fn = allocator_.construct<T>(std::forward<Args>(args)...);
const MultiFunction &fn_ref = *fn;
owned_functions_.append(std::move(fn));
return fn_ref;
}
/** \} */
} // namespace blender::fn

View File

@ -64,17 +64,26 @@ static FieldTreeInfo preprocess_field_tree(Span<GFieldRef> entry_fields)
while (!fields_to_check.is_empty()) {
GFieldRef field = fields_to_check.pop();
if (field.node().is_input()) {
const FieldInput &field_input = static_cast<const FieldInput &>(field.node());
field_tree_info.deduplicated_field_inputs.add(field_input);
continue;
}
BLI_assert(field.node().is_operation());
const FieldOperation &operation = static_cast<const FieldOperation &>(field.node());
for (const GFieldRef operation_input : operation.inputs()) {
field_tree_info.field_users.add(operation_input, field);
if (handled_fields.add(operation_input)) {
fields_to_check.push(operation_input);
const FieldNode &field_node = field.node();
switch (field_node.node_type()) {
case FieldNodeType::Input: {
const FieldInput &field_input = static_cast<const FieldInput &>(field_node);
field_tree_info.deduplicated_field_inputs.add(field_input);
break;
}
case FieldNodeType::Operation: {
const FieldOperation &operation = static_cast<const FieldOperation &>(field_node);
for (const GFieldRef operation_input : operation.inputs()) {
field_tree_info.field_users.add(operation_input, field);
if (handled_fields.add(operation_input)) {
fields_to_check.push(operation_input);
}
}
break;
}
case FieldNodeType::Constant: {
/* Nothing to do. */
break;
}
}
}
@ -179,56 +188,71 @@ static void build_multi_function_procedure_for_fields(MFProcedure &procedure,
fields_to_check.pop();
continue;
}
/* Field inputs should already be handled above. */
BLI_assert(field.node().is_operation());
const FieldNode &field_node = field.node();
switch (field_node.node_type()) {
case FieldNodeType::Input: {
/* Field inputs should already be handled above. */
break;
}
case FieldNodeType::Operation: {
const FieldOperation &operation_node = static_cast<const FieldOperation &>(field.node());
const Span<GField> operation_inputs = operation_node.inputs();
const FieldOperation &operation = static_cast<const FieldOperation &>(field.node());
const Span<GField> operation_inputs = operation.inputs();
if (field_with_index.current_input_index < operation_inputs.size()) {
/* Not all inputs are handled yet. Push the next input field to the stack and increment the
* input index. */
fields_to_check.push({operation_inputs[field_with_index.current_input_index]});
field_with_index.current_input_index++;
}
else {
/* All inputs variables are ready, now gather all variables that are used by the function
* and call it. */
const MultiFunction &multi_function = operation.multi_function();
Vector<MFVariable *> variables(multi_function.param_amount());
int param_input_index = 0;
int param_output_index = 0;
for (const int param_index : multi_function.param_indices()) {
const MFParamType param_type = multi_function.param_type(param_index);
const MFParamType::InterfaceType interface_type = param_type.interface_type();
if (interface_type == MFParamType::Input) {
const GField &input_field = operation_inputs[param_input_index];
variables[param_index] = variable_by_field.lookup(input_field);
param_input_index++;
}
else if (interface_type == MFParamType::Output) {
const GFieldRef output_field{operation, param_output_index};
const bool output_is_ignored =
field_tree_info.field_users.lookup(output_field).is_empty() &&
!output_fields.contains(output_field);
if (output_is_ignored) {
/* Ignored outputs don't need a variable. */
variables[param_index] = nullptr;
}
else {
/* Create a new variable for used outputs. */
MFVariable &new_variable = procedure.new_variable(param_type.data_type());
variables[param_index] = &new_variable;
variable_by_field.add_new(output_field, &new_variable);
}
param_output_index++;
if (field_with_index.current_input_index < operation_inputs.size()) {
/* Not all inputs are handled yet. Push the next input field to the stack and increment
* the input index. */
fields_to_check.push({operation_inputs[field_with_index.current_input_index]});
field_with_index.current_input_index++;
}
else {
BLI_assert_unreachable();
/* All inputs variables are ready, now gather all variables that are used by the
* function and call it. */
const MultiFunction &multi_function = operation_node.multi_function();
Vector<MFVariable *> variables(multi_function.param_amount());
int param_input_index = 0;
int param_output_index = 0;
for (const int param_index : multi_function.param_indices()) {
const MFParamType param_type = multi_function.param_type(param_index);
const MFParamType::InterfaceType interface_type = param_type.interface_type();
if (interface_type == MFParamType::Input) {
const GField &input_field = operation_inputs[param_input_index];
variables[param_index] = variable_by_field.lookup(input_field);
param_input_index++;
}
else if (interface_type == MFParamType::Output) {
const GFieldRef output_field{operation_node, param_output_index};
const bool output_is_ignored =
field_tree_info.field_users.lookup(output_field).is_empty() &&
!output_fields.contains(output_field);
if (output_is_ignored) {
/* Ignored outputs don't need a variable. */
variables[param_index] = nullptr;
}
else {
/* Create a new variable for used outputs. */
MFVariable &new_variable = procedure.new_variable(param_type.data_type());
variables[param_index] = &new_variable;
variable_by_field.add_new(output_field, &new_variable);
}
param_output_index++;
}
else {
BLI_assert_unreachable();
}
}
builder.add_call_with_all_variables(multi_function, variables);
}
break;
}
case FieldNodeType::Constant: {
const FieldConstant &constant_node = static_cast<const FieldConstant &>(field_node);
const MultiFunction &fn = procedure.construct_function<CustomMF_GenericConstant>(
constant_node.type(), constant_node.value().get(), false);
MFVariable &new_variable = *builder.add_call<1>(fn)[0];
variable_by_field.add_new(field, &new_variable);
break;
}
builder.add_call_with_all_variables(multi_function, variables);
}
}
}
@ -301,17 +325,29 @@ Vector<GVArray> evaluate_fields(ResourceScope &scope,
Vector<GVArray> field_context_inputs = get_field_context_inputs(
scope, mask, context, field_tree_info.deduplicated_field_inputs);
/* Finish fields that output an input varray directly. For those we don't have to do any further
* processing. */
/* Finish fields that don't need any processing directly. */
for (const int out_index : fields_to_evaluate.index_range()) {
const GFieldRef &field = fields_to_evaluate[out_index];
if (!field.node().is_input()) {
continue;
const FieldNode &field_node = field.node();
switch (field_node.node_type()) {
case FieldNodeType::Input: {
const FieldInput &field_input = static_cast<const FieldInput &>(field.node());
const int field_input_index = field_tree_info.deduplicated_field_inputs.index_of(
field_input);
const GVArray &varray = field_context_inputs[field_input_index];
r_varrays[out_index] = varray;
break;
}
case FieldNodeType::Constant: {
const FieldConstant &field_constant = static_cast<const FieldConstant &>(field.node());
r_varrays[out_index] = GVArray::ForSingleRef(
field_constant.type(), mask.min_array_size(), field_constant.value().get());
break;
}
case FieldNodeType::Operation: {
break;
}
}
const FieldInput &field_input = static_cast<const FieldInput &>(field.node());
const int field_input_index = field_tree_info.deduplicated_field_inputs.index_of(field_input);
const GVArray &varray = field_context_inputs[field_input_index];
r_varrays[out_index] = varray;
}
Set<GFieldRef> varying_fields = find_varying_fields(field_tree_info, field_context_inputs);
@ -491,9 +527,8 @@ GField make_field_constant_if_possible(GField field)
GField make_constant_field(const CPPType &type, const void *value)
{
auto constant_fn = std::make_unique<CustomMF_GenericConstant>(type, value, true);
auto operation = std::make_shared<FieldOperation>(std::move(constant_fn));
return GField{std::move(operation), 0};
auto constant_node = std::make_shared<FieldConstant>(type, value);
return GField{std::move(constant_node)};
}
GVArray FieldContext::get_varray_for_input(const FieldInput &field_input,
@ -602,7 +637,7 @@ static std::shared_ptr<const FieldInputs> combine_field_inputs(Span<GField> fiel
}
FieldOperation::FieldOperation(const MultiFunction &function, Vector<GField> inputs)
: FieldNode(false), function_(&function), inputs_(std::move(inputs))
: FieldNode(FieldNodeType::Operation), function_(&function), inputs_(std::move(inputs))
{
field_inputs_ = combine_field_inputs(inputs_);
}
@ -612,7 +647,7 @@ FieldOperation::FieldOperation(const MultiFunction &function, Vector<GField> inp
*/
FieldInput::FieldInput(const CPPType &type, std::string debug_name)
: FieldNode(true), type_(&type), debug_name_(std::move(debug_name))
: FieldNode(FieldNodeType::Input), type_(&type), debug_name_(std::move(debug_name))
{
std::shared_ptr<FieldInputs> field_inputs = std::make_shared<FieldInputs>();
field_inputs->nodes.add_new(this);
@ -620,6 +655,40 @@ FieldInput::FieldInput(const CPPType &type, std::string debug_name)
field_inputs_ = std::move(field_inputs);
}
/* --------------------------------------------------------------------
* FieldConstant.
*/
FieldConstant::FieldConstant(const CPPType &type, const void *value)
: FieldNode(FieldNodeType::Constant), type_(type)
{
value_ = MEM_mallocN_aligned(type.size(), type.alignment(), __func__);
type.copy_construct(value, value_);
}
FieldConstant::~FieldConstant()
{
type_.destruct(value_);
MEM_freeN(value_);
}
const CPPType &FieldConstant::output_cpp_type(int output_index) const
{
BLI_assert(output_index == 0);
UNUSED_VARS_NDEBUG(output_index);
return type_;
}
const CPPType &FieldConstant::type() const
{
return type_;
}
const GPointer FieldConstant::value() const
{
return {type_, value_};
}
/* --------------------------------------------------------------------
* FieldEvaluator.
*/