BLI: add atomic disjoint set data structure
The existing `DisjointSet` data structure only supports single threaded access, which limits performance severely in some cases. This patch implements `AtomicDisjointSet` based on "Wait-free Parallel Algorithms for the Union-Find Problem" by Richard J. Anderson and Heather Woll. The Mesh Island node also got updated to make use of the new data structure. In my tests it got 2-5 times faster. More details are in 16653. Differential Revision: https://developer.blender.org/D16653
This commit is contained in:
parent
5f0120cd35
commit
39615cd3b7
|
@ -0,0 +1,146 @@
|
|||
/* SPDX-License-Identifier: GPL-2.0-or-later */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <atomic>
|
||||
|
||||
#include "BLI_array.hh"
|
||||
|
||||
namespace blender {
|
||||
|
||||
/**
|
||||
* Same as `DisjointSet` but is thread safe (at slightly higher cost for the single threaded case).
|
||||
*
|
||||
* The implementation is based on the following paper:
|
||||
* "Wait-free Parallel Algorithms for the Union-Find Problem"
|
||||
* by Richard J. Anderson and Heather Woll.
|
||||
*
|
||||
* It's also inspired by this implementation: https://github.com/wjakob/dset.
|
||||
*/
|
||||
class AtomicDisjointSet {
|
||||
private:
|
||||
/* Can generally used relaxed memory order with this algorithm. */
|
||||
static constexpr auto relaxed = std::memory_order_relaxed;
|
||||
|
||||
struct Item {
|
||||
int parent;
|
||||
int rank;
|
||||
};
|
||||
|
||||
/**
|
||||
* An #Item per element. It's important that the entire item is in a single atomic, so that it
|
||||
* can be updated atomically. */
|
||||
mutable Array<std::atomic<Item>> items_;
|
||||
|
||||
public:
|
||||
/**
|
||||
* Create a new disjoing set with the given set. Initially, every element is in a separate set.
|
||||
*/
|
||||
AtomicDisjointSet(const int size);
|
||||
|
||||
/**
|
||||
* Join the sets containing elements x and y. Nothing happens when they were in the same set
|
||||
* before.
|
||||
*/
|
||||
void join(int x, int y)
|
||||
{
|
||||
while (true) {
|
||||
x = this->find_root(x);
|
||||
y = this->find_root(y);
|
||||
|
||||
if (x == y) {
|
||||
/* They are in the same set already. */
|
||||
return;
|
||||
}
|
||||
|
||||
Item x_item = items_[x].load(relaxed);
|
||||
Item y_item = items_[y].load(relaxed);
|
||||
|
||||
if (
|
||||
/* Implement union by rank heuristic. */
|
||||
x_item.rank > y_item.rank
|
||||
/* If the rank is the same, make a consistent decision. */
|
||||
|| (x_item.rank == y_item.rank && x < y)) {
|
||||
std::swap(x_item, y_item);
|
||||
std::swap(x, y);
|
||||
}
|
||||
|
||||
/* Update parent of item x. */
|
||||
const Item x_item_new{y, x_item.rank};
|
||||
if (!items_[x].compare_exchange_strong(x_item, x_item_new, relaxed)) {
|
||||
/* Another thread has updated item x, start again. */
|
||||
continue;
|
||||
}
|
||||
|
||||
if (x_item.rank == y_item.rank) {
|
||||
/* Increase rank of item y. This may fail when another thread has updated item y in the
|
||||
* meantime. That may lead to worse behavior with the union by rank heurist, but seems to
|
||||
* be ok in practice. */
|
||||
const Item y_item_new{y, y_item.rank + 1};
|
||||
items_[y].compare_exchange_weak(y_item, y_item_new, relaxed);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Return true when x and y are in the same set.
|
||||
*/
|
||||
bool in_same_set(int x, int y) const
|
||||
{
|
||||
while (true) {
|
||||
x = this->find_root(x);
|
||||
y = this->find_root(y);
|
||||
if (x == y) {
|
||||
return true;
|
||||
}
|
||||
if (items_[x].load(relaxed).parent == x) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Find the element that represents the set containing x currently.
|
||||
*/
|
||||
int find_root(int x) const
|
||||
{
|
||||
while (true) {
|
||||
const Item item = items_[x].load(relaxed);
|
||||
if (x == item.parent) {
|
||||
return x;
|
||||
}
|
||||
const int new_parent = items_[item.parent].load(relaxed).parent;
|
||||
if (item.parent != new_parent) {
|
||||
/* This halves the path for faster future lookups. That fail but that does not change
|
||||
* correctness. */
|
||||
Item expected = item;
|
||||
const Item desired{new_parent, item.rank};
|
||||
items_[x].compare_exchange_weak(expected, desired, relaxed);
|
||||
}
|
||||
x = new_parent;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* True when x represents a set.
|
||||
*/
|
||||
bool is_root(const int x) const
|
||||
{
|
||||
const Item item = items_[x].load(relaxed);
|
||||
return item.parent == x;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get an identifier for each id. This is deterministic and does not depend on the order of
|
||||
* joins. The ids are ordered by their first occurence. Consequently, `result[0]` is always zero
|
||||
* (unless there are no elements).
|
||||
*/
|
||||
void calc_reduced_ids(MutableSpan<int> result) const;
|
||||
|
||||
/**
|
||||
* Count the number of disjoint sets.
|
||||
*/
|
||||
int count_sets() const;
|
||||
};
|
||||
|
||||
} // namespace blender
|
|
@ -50,6 +50,7 @@ set(SRC
|
|||
intern/array_utils.c
|
||||
intern/array_utils.cc
|
||||
intern/astar.c
|
||||
intern/atomic_disjoint_set.cc
|
||||
intern/bitmap.c
|
||||
intern/bitmap_draw_2d.c
|
||||
intern/boxpack_2d.c
|
||||
|
@ -172,6 +173,7 @@ set(SRC
|
|||
BLI_asan.h
|
||||
BLI_assert.h
|
||||
BLI_astar.h
|
||||
BLI_atomic_disjoint_set.hh
|
||||
BLI_bit_vector.hh
|
||||
BLI_bitmap.h
|
||||
BLI_bitmap_draw_2d.h
|
||||
|
|
|
@ -0,0 +1,108 @@
|
|||
/* SPDX-License-Identifier: GPL-2.0-or-later */
|
||||
|
||||
#include "BLI_atomic_disjoint_set.hh"
|
||||
#include "BLI_enumerable_thread_specific.hh"
|
||||
#include "BLI_map.hh"
|
||||
#include "BLI_sort.hh"
|
||||
#include "BLI_task.hh"
|
||||
|
||||
namespace blender {
|
||||
|
||||
AtomicDisjointSet::AtomicDisjointSet(const int size) : items_(size)
|
||||
{
|
||||
threading::parallel_for(IndexRange(size), 4096, [&](const IndexRange range) {
|
||||
for (const int i : range) {
|
||||
items_[i].store(Item{i, 0}, relaxed);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
static void update_first_occurence(Map<int, int> &map, const int root, const int index)
|
||||
{
|
||||
map.add_or_modify(
|
||||
root,
|
||||
[&](int *first_occurence) { *first_occurence = index; },
|
||||
[&](int *first_occurence) {
|
||||
if (index < *first_occurence) {
|
||||
*first_occurence = index;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void AtomicDisjointSet::calc_reduced_ids(MutableSpan<int> result) const
|
||||
{
|
||||
BLI_assert(result.size() == items_.size());
|
||||
|
||||
const int size = result.size();
|
||||
|
||||
/* Find the root for element. With multi-threading, this root is not deterministic. So
|
||||
* some postprocessing has to be done to make it deterministic. */
|
||||
threading::EnumerableThreadSpecific<Map<int, int>> first_occurence_by_root_per_thread;
|
||||
threading::parallel_for(IndexRange(size), 1024, [&](const IndexRange range) {
|
||||
Map<int, int> &first_occurence_by_root = first_occurence_by_root_per_thread.local();
|
||||
for (const int i : range) {
|
||||
const int root = this->find_root(i);
|
||||
result[i] = root;
|
||||
update_first_occurence(first_occurence_by_root, root, i);
|
||||
}
|
||||
});
|
||||
|
||||
/* Build a map that contains the first element index that has a certain root. */
|
||||
Map<int, int> &combined_map = first_occurence_by_root_per_thread.local();
|
||||
for (const Map<int, int> &other_map : first_occurence_by_root_per_thread) {
|
||||
if (&combined_map == &other_map) {
|
||||
continue;
|
||||
}
|
||||
for (const auto item : other_map.items()) {
|
||||
update_first_occurence(combined_map, item.key, item.value);
|
||||
}
|
||||
}
|
||||
|
||||
struct RootOccurence {
|
||||
int root;
|
||||
int first_occurence;
|
||||
};
|
||||
|
||||
/* Sort roots by first occurence. This removes the non-determinism above. */
|
||||
Vector<RootOccurence, 16> root_occurences;
|
||||
root_occurences.reserve(combined_map.size());
|
||||
for (const auto item : combined_map.items()) {
|
||||
root_occurences.append({item.key, item.value});
|
||||
}
|
||||
parallel_sort(root_occurences.begin(),
|
||||
root_occurences.end(),
|
||||
[](const RootOccurence &a, const RootOccurence &b) {
|
||||
return a.first_occurence < b.first_occurence;
|
||||
});
|
||||
|
||||
/* Remap original root values with deterministic values. */
|
||||
Map<int, int> id_by_root;
|
||||
id_by_root.reserve(root_occurences.size());
|
||||
for (const int i : root_occurences.index_range()) {
|
||||
id_by_root.add_new(root_occurences[i].root, i);
|
||||
}
|
||||
threading::parallel_for(IndexRange(size), 1024, [&](const IndexRange range) {
|
||||
for (const int i : range) {
|
||||
result[i] = id_by_root.lookup(result[i]);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
int AtomicDisjointSet::count_sets() const
|
||||
{
|
||||
return threading::parallel_reduce<int>(
|
||||
items_.index_range(),
|
||||
1024,
|
||||
0,
|
||||
[&](const IndexRange range, int count) {
|
||||
for (const int i : range) {
|
||||
if (this->is_root(i)) {
|
||||
count++;
|
||||
}
|
||||
}
|
||||
return count;
|
||||
},
|
||||
[](const int a, const int b) { return a + b; });
|
||||
}
|
||||
|
||||
} // namespace blender
|
|
@ -5,7 +5,8 @@
|
|||
|
||||
#include "BKE_mesh.h"
|
||||
|
||||
#include "BLI_disjoint_set.hh"
|
||||
#include "BLI_atomic_disjoint_set.hh"
|
||||
#include "BLI_task.hh"
|
||||
|
||||
#include "node_geometry_util.hh"
|
||||
|
||||
|
@ -35,17 +36,15 @@ class IslandFieldInput final : public bke::MeshFieldInput {
|
|||
{
|
||||
const Span<MEdge> edges = mesh.edges();
|
||||
|
||||
DisjointSet<int> islands(mesh.totvert);
|
||||
for (const int i : edges.index_range()) {
|
||||
islands.join(edges[i].v1, edges[i].v2);
|
||||
}
|
||||
AtomicDisjointSet islands(mesh.totvert);
|
||||
threading::parallel_for(edges.index_range(), 1024, [&](const IndexRange range) {
|
||||
for (const MEdge &edge : edges.slice(range)) {
|
||||
islands.join(edge.v1, edge.v2);
|
||||
}
|
||||
});
|
||||
|
||||
Array<int> output(mesh.totvert);
|
||||
VectorSet<int> ordered_roots;
|
||||
for (const int i : IndexRange(mesh.totvert)) {
|
||||
const int root = islands.find_root(i);
|
||||
output[i] = ordered_roots.index_of_or_add(root);
|
||||
}
|
||||
islands.calc_reduced_ids(output);
|
||||
|
||||
return mesh.attributes().adapt_domain<int>(
|
||||
VArray<int>::ForContainer(std::move(output)), ATTR_DOMAIN_POINT, domain);
|
||||
|
@ -81,18 +80,15 @@ class IslandCountFieldInput final : public bke::MeshFieldInput {
|
|||
{
|
||||
const Span<MEdge> edges = mesh.edges();
|
||||
|
||||
DisjointSet<int> islands(mesh.totvert);
|
||||
for (const int i : edges.index_range()) {
|
||||
islands.join(edges[i].v1, edges[i].v2);
|
||||
}
|
||||
AtomicDisjointSet islands(mesh.totvert);
|
||||
threading::parallel_for(edges.index_range(), 1024, [&](const IndexRange range) {
|
||||
for (const MEdge &edge : edges.slice(range)) {
|
||||
islands.join(edge.v1, edge.v2);
|
||||
}
|
||||
});
|
||||
|
||||
Set<int> island_list;
|
||||
for (const int i_vert : IndexRange(mesh.totvert)) {
|
||||
const int root = islands.find_root(i_vert);
|
||||
island_list.add(root);
|
||||
}
|
||||
|
||||
return VArray<int>::ForSingle(island_list.size(), mesh.attributes().domain_size(domain));
|
||||
const int islands_num = islands.count_sets();
|
||||
return VArray<int>::ForSingle(islands_num, mesh.attributes().domain_size(domain));
|
||||
}
|
||||
|
||||
uint64_t hash() const override
|
||||
|
|
Loading…
Reference in New Issue