Metal: Add support for all matrix types

* Add missing constructor for squared matrix types.
* Use using instead of define for typenames.
* Add `==, !=, unary-` operators
* Catch all functional style constructors inside the regex
This commit is contained in:
Clément Foucault 2023-01-03 17:54:01 +01:00
parent ecd4533615
commit a6355a4542
Notes: blender-bot 2024-04-11 14:26:06 +02:00
Referenced by issue #103637, Nodes can no longer be connected in node editors
2 changed files with 224 additions and 82 deletions

View File

@ -519,8 +519,8 @@ char *MSLGeneratorInterface::msl_patch_default_get()
}
std::stringstream ss_patch;
ss_patch << datatoc_mtl_shader_shared_h << std::endl;
ss_patch << datatoc_mtl_shader_defines_msl << std::endl;
ss_patch << datatoc_mtl_shader_shared_h << std::endl;
size_t len = strlen(ss_patch.str().c_str());
msl_patch_default = (char *)malloc(len * sizeof(char));
@ -607,7 +607,7 @@ bool MTLShader::generate_msl_from_glsl(const shader::ShaderCreateInfo *info)
/*** Regex Commands ***/
/* Source cleanup and syntax replacement. */
static std::regex remove_excess_newlines("\\n+");
static std::regex replace_mat3("mat3\\s*\\(");
static std::regex replace_matrix_construct("mat([234](x[234])?)\\s*\\(");
/* Special condition - mat3 and array constructor replacement.
* Also replace excessive new lines to ensure cases are not missed.
@ -615,14 +615,14 @@ bool MTLShader::generate_msl_from_glsl(const shader::ShaderCreateInfo *info)
shd_builder_->glsl_vertex_source_ = std::regex_replace(
shd_builder_->glsl_vertex_source_, remove_excess_newlines, "\n");
shd_builder_->glsl_vertex_source_ = std::regex_replace(
shd_builder_->glsl_vertex_source_, replace_mat3, "MAT3(");
shd_builder_->glsl_vertex_source_, replace_matrix_construct, "MAT$1(");
replace_array_initializers_func(shd_builder_->glsl_vertex_source_);
if (!msl_iface.uses_transform_feedback) {
shd_builder_->glsl_fragment_source_ = std::regex_replace(
shd_builder_->glsl_fragment_source_, remove_excess_newlines, "\n");
shd_builder_->glsl_fragment_source_ = std::regex_replace(
shd_builder_->glsl_fragment_source_, replace_mat3, "MAT3(");
shd_builder_->glsl_fragment_source_, replace_matrix_construct, "MAT$1(");
replace_array_initializers_func(shd_builder_->glsl_fragment_source_);
}

View File

@ -3,7 +3,7 @@
/** Special header for mapping commonly defined tokens to API-specific variations.
* Where possible, this will adhere closely to base GLSL, where semantics are the same.
* However, host code shader code may need modifying to support types where necessary variations
* exist between APIs but are not expressed through the source. (e.g. distinctio between depth2d
* exist between APIs but are not expressed through the source. (e.g. distinction between depth2d
* and texture2d types in metal).
*/
@ -16,19 +16,27 @@
#define DFDY_SIGN 1.0
/* Type definitions. */
#define vec2 float2
#define vec3 float3
#define vec4 float4
#define mat2 float2x2
#define mat2x2 float2x2
#define mat3 float3x3
#define mat4 float4x4
#define ivec2 int2
#define ivec3 int3
#define ivec4 int4
#define uvec2 uint2
#define uvec3 uint3
#define uvec4 uint4
using vec2 = float2;
using vec3 = float3;
using vec4 = float4;
using mat2x2 = float2x2;
using mat2x3 = float2x3;
using mat2x4 = float2x4;
using mat3x2 = float3x2;
using mat3x3 = float3x3;
using mat3x4 = float3x4;
using mat4x2 = float4x2;
using mat4x3 = float4x3;
using mat4x4 = float4x4;
using mat2 = float2x2;
using mat3 = float3x3;
using mat4 = float4x4;
using ivec2 = int2;
using ivec3 = int3;
using ivec4 = int4;
using uvec2 = uint2;
using uvec3 = uint3;
using uvec4 = uint4;
/* MTLBOOL is used for native boolean's generated by the Metal backend, to avoid type-emulation
* for GLSL bools, which are treated as integers. */
#define MTLBOOL bool
@ -687,6 +695,76 @@ inline void _texture_write_internal(thread _mtl_combined_image_sampler_3d<S, A>
}
}
/* Matrix compare operators. */
/** TODO(fclem): Template. */
inline bool operator==(float4x4 a, float4x4 b)
{
for (int i = 0; i < 4; i++) {
if (any(a[i] != b[i])) {
return false;
}
}
return true;
}
inline bool operator==(float3x3 a, float3x3 b)
{
for (int i = 0; i < 3; i++) {
if (any(a[i] != b[i])) {
return false;
}
}
return true;
}
inline bool operator==(float2x2 a, float2x2 b)
{
for (int i = 0; i < 2; i++) {
if (any(a[i] != b[i])) {
return false;
}
}
return true;
}
inline bool operator!=(float4x4 a, float4x4 b)
{
return !(a == b);
}
inline bool operator!=(float3x3 a, float3x3 b)
{
return !(a == b);
}
inline bool operator!=(float2x2 a, float2x2 b)
{
return !(a == b);
}
/* Matrix unary minus operator. */
inline float4x4 operator-(float4x4 a)
{
float4x4 b;
for (int i = 0; i < 4; i++) {
b[i] = -a[i];
}
return b;
}
inline float3x3 operator-(float3x3 a)
{
float3x3 b;
for (int i = 0; i < 3; i++) {
b[i] = -a[i];
}
return b;
}
inline float2x2 operator-(float2x2 a)
{
float2x2 b;
for (int i = 0; i < 2; i++) {
b[i] = -a[i];
}
return b;
}
/* SSBO Vertex Fetch Mode. */
#ifdef MTL_SSBO_VERTEX_FETCH
/* Enabled when geometry is passed via raw buffer bindings, rather than using
@ -997,47 +1075,59 @@ float4x4 inverse(float4x4 a)
float b10 = a[2][1] * a[3][3] - a[2][3] * a[3][1];
float b11 = a[2][2] * a[3][3] - a[2][3] * a[3][2];
float invdet = 1.0 / (b00 * b11 - b01 * b10 + b02 * b09 + b03 * b08 - b04 * b07 + b05 * b06);
float inv_det = 1.0 / (b00 * b11 - b01 * b10 + b02 * b09 + b03 * b08 - b04 * b07 + b05 * b06);
return float4x4(a[1][1] * b11 - a[1][2] * b10 + a[1][3] * b09,
a[0][2] * b10 - a[0][1] * b11 - a[0][3] * b09,
a[3][1] * b05 - a[3][2] * b04 + a[3][3] * b03,
a[2][2] * b04 - a[2][1] * b05 - a[2][3] * b03,
a[1][2] * b08 - a[1][0] * b11 - a[1][3] * b07,
a[0][0] * b11 - a[0][2] * b08 + a[0][3] * b07,
a[3][2] * b02 - a[3][0] * b05 - a[3][3] * b01,
a[2][0] * b05 - a[2][2] * b02 + a[2][3] * b01,
a[1][0] * b10 - a[1][1] * b08 + a[1][3] * b06,
a[0][1] * b08 - a[0][0] * b10 - a[0][3] * b06,
a[3][0] * b04 - a[3][1] * b02 + a[3][3] * b00,
a[2][1] * b02 - a[2][0] * b04 - a[2][3] * b00,
a[1][1] * b07 - a[1][0] * b09 - a[1][2] * b06,
a[0][0] * b09 - a[0][1] * b07 + a[0][2] * b06,
a[3][1] * b01 - a[3][0] * b03 - a[3][2] * b00,
a[2][0] * b03 - a[2][1] * b01 + a[2][2] * b00) *
invdet;
float4x4 adjoint{};
adjoint[0][0] = a[1][1] * b11 - a[1][2] * b10 + a[1][3] * b09;
adjoint[0][1] = a[0][2] * b10 - a[0][1] * b11 - a[0][3] * b09;
adjoint[0][2] = a[3][1] * b05 - a[3][2] * b04 + a[3][3] * b03;
adjoint[0][3] = a[2][2] * b04 - a[2][1] * b05 - a[2][3] * b03;
adjoint[1][0] = a[1][2] * b08 - a[1][0] * b11 - a[1][3] * b07;
adjoint[1][1] = a[0][0] * b11 - a[0][2] * b08 + a[0][3] * b07;
adjoint[1][2] = a[3][2] * b02 - a[3][0] * b05 - a[3][3] * b01;
adjoint[1][3] = a[2][0] * b05 - a[2][2] * b02 + a[2][3] * b01;
adjoint[2][0] = a[1][0] * b10 - a[1][1] * b08 + a[1][3] * b06;
adjoint[2][1] = a[0][1] * b08 - a[0][0] * b10 - a[0][3] * b06;
adjoint[2][2] = a[3][0] * b04 - a[3][1] * b02 + a[3][3] * b00;
adjoint[2][3] = a[2][1] * b02 - a[2][0] * b04 - a[2][3] * b00;
adjoint[3][0] = a[1][1] * b07 - a[1][0] * b09 - a[1][2] * b06;
adjoint[3][1] = a[0][0] * b09 - a[0][1] * b07 + a[0][2] * b06;
adjoint[3][2] = a[3][1] * b01 - a[3][0] * b03 - a[3][2] * b00;
adjoint[3][3] = a[2][0] * b03 - a[2][1] * b01 + a[2][2] * b00;
return adjoint * inv_det;
}
float3x3 inverse(float3x3 m)
{
float b00 = m[1][1] * m[2][2] - m[2][1] * m[1][2];
float b01 = m[0][1] * m[2][2] - m[2][1] * m[0][2];
float b02 = m[0][1] * m[1][2] - m[1][1] * m[0][2];
float invdet = 1.0 / (m[0][0] * (m[1][1] * m[2][2] - m[2][1] * m[1][2]) -
m[1][0] * (m[0][1] * m[2][2] - m[2][1] * m[0][2]) +
m[2][0] * (m[0][1] * m[1][2] - m[1][1] * m[0][2]));
float inv_det = 1.0 / (m[0][0] * b00 - m[1][0] * b01 + m[2][0] * b02);
float3x3 inverse(0);
inverse[0][0] = +(m[1][1] * m[2][2] - m[2][1] * m[1][2]);
inverse[1][0] = -(m[1][0] * m[2][2] - m[2][0] * m[1][2]);
inverse[2][0] = +(m[1][0] * m[2][1] - m[2][0] * m[1][1]);
inverse[0][1] = -(m[0][1] * m[2][2] - m[2][1] * m[0][2]);
inverse[1][1] = +(m[0][0] * m[2][2] - m[2][0] * m[0][2]);
inverse[2][1] = -(m[0][0] * m[2][1] - m[2][0] * m[0][1]);
inverse[0][2] = +(m[0][1] * m[1][2] - m[1][1] * m[0][2]);
inverse[1][2] = -(m[0][0] * m[1][2] - m[1][0] * m[0][2]);
inverse[2][2] = +(m[0][0] * m[1][1] - m[1][0] * m[0][1]);
inverse = inverse * invdet;
float3x3 adjoint{};
adjoint[0][0] = +b00;
adjoint[0][1] = -b01;
adjoint[0][2] = +b02;
adjoint[1][0] = -(m[1][0] * m[2][2] - m[2][0] * m[1][2]);
adjoint[1][1] = +(m[0][0] * m[2][2] - m[2][0] * m[0][2]);
adjoint[1][2] = -(m[0][0] * m[1][2] - m[1][0] * m[0][2]);
adjoint[2][0] = +(m[1][0] * m[2][1] - m[2][0] * m[1][1]);
adjoint[2][1] = -(m[0][0] * m[2][1] - m[2][0] * m[0][1]);
adjoint[2][2] = +(m[0][0] * m[1][1] - m[1][0] * m[0][1]);
return adjoint * inv_det;
}
return inverse;
float2x2 inverse(float2x2 m)
{
float inv_det = 1.0 / (m[0][0] * m[1][1] - m[1][0] * m[0][1]);
float2x2 adjoint{};
adjoint[0][0] = +m[1][1];
adjoint[1][0] = -m[1][0];
adjoint[0][1] = -m[0][1];
adjoint[1][1] = +m[0][0];
return adjoint * inv_det;
}
/* Additional overloads for builtin functions. */
@ -1110,44 +1200,96 @@ template<typename T, unsigned int Size> bool is_zero(vec<T, Size> a)
return true;
}
/* Matrix conversion fallback. */
mat3 MAT3(vec3 a, vec3 b, vec3 c)
/**
* Matrix conversion fallback for functional style casting & constructors.
* To avoid name collision with the types, they are replaced with uppercase version
* before compilation.
*/
mat2 MAT2x2(vec2 a, vec2 b)
{
return mat2(a, b);
}
mat2 MAT2x2(float a1, float a2, float b1, float b2)
{
return mat2(vec2(a1, a2), vec2(b1, b2));
}
mat2 MAT2x2(float f)
{
return mat2(f);
}
mat2 MAT2x2(mat3 m)
{
return mat2(m[0].xy, m[1].xy);
}
mat2 MAT2x2(mat4 m)
{
return mat2(m[0].xy, m[1].xy);
}
mat3 MAT3x3(vec3 a, vec3 b, vec3 c)
{
return mat3(a, b, c);
}
mat3 MAT3(vec3 a, vec3 b, float c1, float c2, float c3)
{
return mat3(a, b, vec3(c1, c2, c3));
}
mat3 MAT3(vec3 a, float b1, float b2, float b3, vec3 c)
{
return mat3(a, vec3(b1, b2, b3), c);
}
mat3 MAT3(vec3 a, float b1, float b2, float b3, float c1, float c2, float c3)
{
return mat3(a, vec3(b1, b2, b3), vec3(c1, c2, c3));
}
mat3 MAT3(float a1, float a2, float a3, float b1, float b2, float b3, float c1, float c2, float c3)
mat3 MAT3x3(
float a1, float a2, float a3, float b1, float b2, float b3, float c1, float c2, float c3)
{
return mat3(vec3(a1, a2, a3), vec3(b1, b2, b3), vec3(c1, c2, c3));
}
mat3 MAT3(float a1, float a2, float a3, vec3 b, vec3 c)
{
return mat3(vec3(a1, a2, a3), b, c);
}
mat3 MAT3(float a1, float a2, float a3, vec3 b, float c1, float c2, float c3)
{
return mat3(vec3(a1, a2, a3), b, vec3(c1, c2, c3));
}
mat3 MAT3(float a1, float a2, float a3, float b1, float b2, float b3, vec3 c)
{
return mat3(vec3(a1, a2, a3), vec3(b1, b2, b3), c);
}
mat3 MAT3(float f)
mat3 MAT3x3(float f)
{
return mat3(f);
}
mat3 MAT3(mat4 m)
mat3 MAT3x3(mat4 m)
{
return mat4_to_mat3(m);
}
return mat3(m[0].xyz, m[1].xyz, m[2].xyz);
}
mat3 MAT3x3(mat2 m)
{
return mat3(vec3(m[0].xy, 0.0), vec3(m[1].xy, 0.0), vec3(0.0, 0.0, 1.0));
}
mat4 MAT4x4(vec4 a, vec4 b, vec4 c, vec4 d)
{
return mat4(a, b, c, d);
}
mat4 MAT4x4(float a1,
float a2,
float a3,
float a4,
float b1,
float b2,
float b3,
float b4,
float c1,
float c2,
float c3,
float c4,
float d1,
float d2,
float d3,
float d4)
{
return mat4(
vec4(a1, a2, a3, a4), vec4(b1, b2, b3, b4), vec4(c1, c2, c3, c4), vec4(d1, d2, d3, d4));
}
mat4 MAT4x4(float f)
{
return mat4(f);
}
mat4 MAT4x4(mat3 m)
{
return mat4(
vec4(m[0].xyz, 0.0), vec4(m[1].xyz, 0.0), vec4(m[2].xyz, 0.0), vec4(0.0, 0.0, 0.0, 1.0));
}
mat4 MAT4x4(mat2 m)
{
return mat4(vec4(m[0].xy, 0.0, 0.0),
vec4(m[1].xy, 0.0, 0.0),
vec4(0.0, 0.0, 1.0, 0.0),
vec4(0.0, 0.0, 0.0, 1.0));
}
#define MAT2 MAT2x2
#define MAT3 MAT3x3
#define MAT4 MAT4x4