From da6cf2632cd4dc0d2b0278353fcaee0789b418c0 Mon Sep 17 00:00:00 2001 From: ReinUsesLisp Date: Sun, 4 Apr 2021 05:17:17 -0300 Subject: shader: Add subgroup masks --- .../backend/spirv/emit_context.cpp | 10 ++++- src/shader_recompiler/backend/spirv/emit_context.h | 5 +++ src/shader_recompiler/backend/spirv/emit_spirv.h | 5 +++ .../backend/spirv/emit_spirv_warp.cpp | 46 +++++++++++++++++----- 4 files changed, 56 insertions(+), 10 deletions(-) (limited to 'src/shader_recompiler/backend') diff --git a/src/shader_recompiler/backend/spirv/emit_context.cpp b/src/shader_recompiler/backend/spirv/emit_context.cpp index e70b78a28..5ef637fe7 100644 --- a/src/shader_recompiler/backend/spirv/emit_context.cpp +++ b/src/shader_recompiler/backend/spirv/emit_context.cpp @@ -390,8 +390,16 @@ void EmitContext::DefineInputs(const Info& info) { if (info.uses_local_invocation_id) { local_invocation_id = DefineInput(*this, U32[3], spv::BuiltIn::LocalInvocationId); } + if (info.uses_subgroup_mask) { + subgroup_mask_eq = DefineInput(*this, U32[4], spv::BuiltIn::SubgroupEqMaskKHR); + subgroup_mask_lt = DefineInput(*this, U32[4], spv::BuiltIn::SubgroupLtMaskKHR); + subgroup_mask_le = DefineInput(*this, U32[4], spv::BuiltIn::SubgroupLeMaskKHR); + subgroup_mask_gt = DefineInput(*this, U32[4], spv::BuiltIn::SubgroupGtMaskKHR); + subgroup_mask_ge = DefineInput(*this, U32[4], spv::BuiltIn::SubgroupGeMaskKHR); + } if (info.uses_subgroup_invocation_id || - (profile.warp_size_potentially_larger_than_guest && info.uses_subgroup_vote)) { + (profile.warp_size_potentially_larger_than_guest && + (info.uses_subgroup_vote || info.uses_subgroup_mask))) { subgroup_local_invocation_id = DefineInput(*this, U32[1], spv::BuiltIn::SubgroupLocalInvocationId); } diff --git a/src/shader_recompiler/backend/spirv/emit_context.h b/src/shader_recompiler/backend/spirv/emit_context.h index 3a686a78c..03c5a6aba 100644 --- a/src/shader_recompiler/backend/spirv/emit_context.h +++ b/src/shader_recompiler/backend/spirv/emit_context.h @@ -97,6 +97,11 @@ public: Id workgroup_id{}; Id local_invocation_id{}; Id subgroup_local_invocation_id{}; + Id subgroup_mask_eq{}; + Id subgroup_mask_lt{}; + Id subgroup_mask_le{}; + Id subgroup_mask_gt{}; + Id subgroup_mask_ge{}; Id instance_id{}; Id instance_index{}; Id base_instance{}; diff --git a/src/shader_recompiler/backend/spirv/emit_spirv.h b/src/shader_recompiler/backend/spirv/emit_spirv.h index 032b0b2f9..712c5e61f 100644 --- a/src/shader_recompiler/backend/spirv/emit_spirv.h +++ b/src/shader_recompiler/backend/spirv/emit_spirv.h @@ -401,6 +401,11 @@ Id EmitVoteAll(EmitContext& ctx, Id pred); Id EmitVoteAny(EmitContext& ctx, Id pred); Id EmitVoteEqual(EmitContext& ctx, Id pred); Id EmitSubgroupBallot(EmitContext& ctx, Id pred); +Id EmitSubgroupEqMask(EmitContext& ctx); +Id EmitSubgroupLtMask(EmitContext& ctx); +Id EmitSubgroupLeMask(EmitContext& ctx); +Id EmitSubgroupGtMask(EmitContext& ctx); +Id EmitSubgroupGeMask(EmitContext& ctx); Id EmitShuffleIndex(EmitContext& ctx, IR::Inst* inst, Id value, Id index, Id clamp, Id segmentation_mask); Id EmitShuffleUp(EmitContext& ctx, IR::Inst* inst, Id value, Id index, Id clamp, diff --git a/src/shader_recompiler/backend/spirv/emit_spirv_warp.cpp b/src/shader_recompiler/backend/spirv/emit_spirv_warp.cpp index cbc5b1c96..c57bd291d 100644 --- a/src/shader_recompiler/backend/spirv/emit_spirv_warp.cpp +++ b/src/shader_recompiler/backend/spirv/emit_spirv_warp.cpp @@ -6,10 +6,18 @@ namespace Shader::Backend::SPIRV { namespace { -Id LargeWarpBallot(EmitContext& ctx, Id ballot) { +Id WarpExtract(EmitContext& ctx, Id value) { const Id shift{ctx.Constant(ctx.U32[1], 5)}; const Id local_index{ctx.OpLoad(ctx.U32[1], ctx.subgroup_local_invocation_id)}; - return ctx.OpVectorExtractDynamic(ctx.U32[1], ballot, local_index); + return ctx.OpVectorExtractDynamic(ctx.U32[1], value, local_index); +} + +Id LoadMask(EmitContext& ctx, Id mask) { + const Id value{ctx.OpLoad(ctx.U32[4], mask)}; + if (!ctx.profile.warp_size_potentially_larger_than_guest) { + return ctx.OpCompositeExtract(ctx.U32[1], value, 0U); + } + return WarpExtract(ctx, value); } void SetInBoundsFlag(IR::Inst* inst, Id result) { @@ -47,8 +55,8 @@ Id EmitVoteAll(EmitContext& ctx, Id pred) { return ctx.OpSubgroupAllKHR(ctx.U1, pred); } const Id mask_ballot{ctx.OpSubgroupBallotKHR(ctx.U32[4], ctx.true_value)}; - const Id active_mask{LargeWarpBallot(ctx, mask_ballot)}; - const Id ballot{LargeWarpBallot(ctx, ctx.OpSubgroupBallotKHR(ctx.U32[4], pred))}; + const Id active_mask{WarpExtract(ctx, mask_ballot)}; + const Id ballot{WarpExtract(ctx, ctx.OpSubgroupBallotKHR(ctx.U32[4], pred))}; const Id lhs{ctx.OpBitwiseAnd(ctx.U32[1], ballot, active_mask)}; return ctx.OpIEqual(ctx.U1, lhs, active_mask); } @@ -58,8 +66,8 @@ Id EmitVoteAny(EmitContext& ctx, Id pred) { return ctx.OpSubgroupAnyKHR(ctx.U1, pred); } const Id mask_ballot{ctx.OpSubgroupBallotKHR(ctx.U32[4], ctx.true_value)}; - const Id active_mask{LargeWarpBallot(ctx, mask_ballot)}; - const Id ballot{LargeWarpBallot(ctx, ctx.OpSubgroupBallotKHR(ctx.U32[4], pred))}; + const Id active_mask{WarpExtract(ctx, mask_ballot)}; + const Id ballot{WarpExtract(ctx, ctx.OpSubgroupBallotKHR(ctx.U32[4], pred))}; const Id lhs{ctx.OpBitwiseAnd(ctx.U32[1], ballot, active_mask)}; return ctx.OpINotEqual(ctx.U1, lhs, ctx.u32_zero_value); } @@ -69,8 +77,8 @@ Id EmitVoteEqual(EmitContext& ctx, Id pred) { return ctx.OpSubgroupAllEqualKHR(ctx.U1, pred); } const Id mask_ballot{ctx.OpSubgroupBallotKHR(ctx.U32[4], ctx.true_value)}; - const Id active_mask{LargeWarpBallot(ctx, mask_ballot)}; - const Id ballot{LargeWarpBallot(ctx, ctx.OpSubgroupBallotKHR(ctx.U32[4], pred))}; + const Id active_mask{WarpExtract(ctx, mask_ballot)}; + const Id ballot{WarpExtract(ctx, ctx.OpSubgroupBallotKHR(ctx.U32[4], pred))}; const Id lhs{ctx.OpBitwiseXor(ctx.U32[1], ballot, active_mask)}; return ctx.OpLogicalOr(ctx.U1, ctx.OpIEqual(ctx.U1, lhs, ctx.u32_zero_value), ctx.OpIEqual(ctx.U1, lhs, active_mask)); @@ -81,7 +89,27 @@ Id EmitSubgroupBallot(EmitContext& ctx, Id pred) { if (!ctx.profile.warp_size_potentially_larger_than_guest) { return ctx.OpCompositeExtract(ctx.U32[1], ballot, 0U); } - return LargeWarpBallot(ctx, ballot); + return WarpExtract(ctx, ballot); +} + +Id EmitSubgroupEqMask(EmitContext& ctx) { + return LoadMask(ctx, ctx.subgroup_mask_eq); +} + +Id EmitSubgroupLtMask(EmitContext& ctx) { + return LoadMask(ctx, ctx.subgroup_mask_lt); +} + +Id EmitSubgroupLeMask(EmitContext& ctx) { + return LoadMask(ctx, ctx.subgroup_mask_le); +} + +Id EmitSubgroupGtMask(EmitContext& ctx) { + return LoadMask(ctx, ctx.subgroup_mask_gt); +} + +Id EmitSubgroupGeMask(EmitContext& ctx) { + return LoadMask(ctx, ctx.subgroup_mask_ge); } Id EmitShuffleIndex(EmitContext& ctx, IR::Inst* inst, Id value, Id index, Id clamp, -- cgit v1.2.3