summaryrefslogtreecommitdiffstats
path: root/src/shader_recompiler/ir_opt
diff options
context:
space:
mode:
authorReinUsesLisp <reinuseslisp@airmail.cc>2021-02-11 20:39:06 +0100
committerameerj <52414509+ameerj@users.noreply.github.com>2021-07-23 03:51:22 +0200
commit9170200a11715d131645d1ffb92e86e6ef0d7e88 (patch)
tree6c6f84c38a9b59d023ecb09c0737ea56da166b64 /src/shader_recompiler/ir_opt
parentspirv: Initial SPIR-V support (diff)
downloadyuzu-9170200a11715d131645d1ffb92e86e6ef0d7e88.tar
yuzu-9170200a11715d131645d1ffb92e86e6ef0d7e88.tar.gz
yuzu-9170200a11715d131645d1ffb92e86e6ef0d7e88.tar.bz2
yuzu-9170200a11715d131645d1ffb92e86e6ef0d7e88.tar.lz
yuzu-9170200a11715d131645d1ffb92e86e6ef0d7e88.tar.xz
yuzu-9170200a11715d131645d1ffb92e86e6ef0d7e88.tar.zst
yuzu-9170200a11715d131645d1ffb92e86e6ef0d7e88.zip
Diffstat (limited to 'src/shader_recompiler/ir_opt')
-rw-r--r--src/shader_recompiler/ir_opt/constant_propagation_pass.cpp50
-rw-r--r--src/shader_recompiler/ir_opt/ssa_rewrite_pass.cpp24
-rw-r--r--src/shader_recompiler/ir_opt/verification_pass.cpp4
3 files changed, 75 insertions, 3 deletions
diff --git a/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp b/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp
index f1170c61e..9fba6ac23 100644
--- a/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp
+++ b/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp
@@ -132,6 +132,32 @@ void FoldLogicalAnd(IR::Inst& inst) {
}
}
+void FoldLogicalOr(IR::Inst& inst) {
+ if (!FoldCommutative(inst, [](bool a, bool b) { return a || b; })) {
+ return;
+ }
+ const IR::Value rhs{inst.Arg(1)};
+ if (rhs.IsImmediate()) {
+ if (rhs.U1()) {
+ inst.ReplaceUsesWith(IR::Value{true});
+ } else {
+ inst.ReplaceUsesWith(inst.Arg(0));
+ }
+ }
+}
+
+void FoldLogicalNot(IR::Inst& inst) {
+ const IR::U1 value{inst.Arg(0)};
+ if (value.IsImmediate()) {
+ inst.ReplaceUsesWith(IR::Value{!value.U1()});
+ return;
+ }
+ IR::Inst* const arg{value.InstRecursive()};
+ if (arg->Opcode() == IR::Opcode::LogicalNot) {
+ inst.ReplaceUsesWith(arg->Arg(0));
+ }
+}
+
template <typename Dest, typename Source>
void FoldBitCast(IR::Inst& inst, IR::Opcode reverse) {
const IR::Value value{inst.Arg(0)};
@@ -160,6 +186,24 @@ void FoldWhenAllImmediates(IR::Inst& inst, Func&& func) {
inst.ReplaceUsesWith(EvalImmediates(inst, func, Indices{}));
}
+void FoldBranchConditional(IR::Inst& inst) {
+ const IR::U1 cond{inst.Arg(0)};
+ if (cond.IsImmediate()) {
+ // TODO: Convert to Branch
+ return;
+ }
+ const IR::Inst* cond_inst{cond.InstRecursive()};
+ if (cond_inst->Opcode() == IR::Opcode::LogicalNot) {
+ const IR::Value true_label{inst.Arg(1)};
+ const IR::Value false_label{inst.Arg(2)};
+ // Remove negation on the conditional (take the parameter out of LogicalNot) and swap
+ // the branches
+ inst.SetArg(0, cond_inst->Arg(0));
+ inst.SetArg(1, false_label);
+ inst.SetArg(2, true_label);
+ }
+}
+
void ConstantPropagation(IR::Inst& inst) {
switch (inst.Opcode()) {
case IR::Opcode::GetRegister:
@@ -178,6 +222,10 @@ void ConstantPropagation(IR::Inst& inst) {
return FoldSelect<u32>(inst);
case IR::Opcode::LogicalAnd:
return FoldLogicalAnd(inst);
+ case IR::Opcode::LogicalOr:
+ return FoldLogicalOr(inst);
+ case IR::Opcode::LogicalNot:
+ return FoldLogicalNot(inst);
case IR::Opcode::ULessThan:
return FoldWhenAllImmediates(inst, [](u32 a, u32 b) { return a < b; });
case IR::Opcode::BitFieldUExtract:
@@ -188,6 +236,8 @@ void ConstantPropagation(IR::Inst& inst) {
}
return (base >> shift) & ((1U << count) - 1);
});
+ case IR::Opcode::BranchConditional:
+ return FoldBranchConditional(inst);
default:
break;
}
diff --git a/src/shader_recompiler/ir_opt/ssa_rewrite_pass.cpp b/src/shader_recompiler/ir_opt/ssa_rewrite_pass.cpp
index 15a9db90a..8ca996e93 100644
--- a/src/shader_recompiler/ir_opt/ssa_rewrite_pass.cpp
+++ b/src/shader_recompiler/ir_opt/ssa_rewrite_pass.cpp
@@ -34,6 +34,13 @@ struct SignFlagTag : FlagTag {};
struct CarryFlagTag : FlagTag {};
struct OverflowFlagTag : FlagTag {};
+struct GotoVariable : FlagTag {
+ GotoVariable() = default;
+ explicit GotoVariable(u32 index_) : index{index_} {}
+
+ u32 index;
+};
+
struct DefTable {
[[nodiscard]] ValueMap& operator[](IR::Reg variable) noexcept {
return regs[IR::RegIndex(variable)];
@@ -43,6 +50,10 @@ struct DefTable {
return preds[IR::PredIndex(variable)];
}
+ [[nodiscard]] ValueMap& operator[](GotoVariable goto_variable) {
+ return goto_vars[goto_variable.index];
+ }
+
[[nodiscard]] ValueMap& operator[](ZeroFlagTag) noexcept {
return zero_flag;
}
@@ -61,6 +72,7 @@ struct DefTable {
std::array<ValueMap, IR::NUM_USER_REGS> regs;
std::array<ValueMap, IR::NUM_USER_PREDS> preds;
+ boost::container::flat_map<u32, ValueMap> goto_vars;
ValueMap zero_flag;
ValueMap sign_flag;
ValueMap carry_flag;
@@ -68,15 +80,15 @@ struct DefTable {
};
IR::Opcode UndefOpcode(IR::Reg) noexcept {
- return IR::Opcode::Undef32;
+ return IR::Opcode::UndefU32;
}
IR::Opcode UndefOpcode(IR::Pred) noexcept {
- return IR::Opcode::Undef1;
+ return IR::Opcode::UndefU1;
}
IR::Opcode UndefOpcode(const FlagTag&) noexcept {
- return IR::Opcode::Undef1;
+ return IR::Opcode::UndefU1;
}
[[nodiscard]] bool IsPhi(const IR::Inst& inst) noexcept {
@@ -165,6 +177,9 @@ void SsaRewritePass(IR::Function& function) {
pass.WriteVariable(pred, block, inst.Arg(1));
}
break;
+ case IR::Opcode::SetGotoVariable:
+ pass.WriteVariable(GotoVariable{inst.Arg(0).U32()}, block, inst.Arg(1));
+ break;
case IR::Opcode::SetZFlag:
pass.WriteVariable(ZeroFlagTag{}, block, inst.Arg(0));
break;
@@ -187,6 +202,9 @@ void SsaRewritePass(IR::Function& function) {
inst.ReplaceUsesWith(pass.ReadVariable(pred, block));
}
break;
+ case IR::Opcode::GetGotoVariable:
+ inst.ReplaceUsesWith(pass.ReadVariable(GotoVariable{inst.Arg(0).U32()}, block));
+ break;
case IR::Opcode::GetZFlag:
inst.ReplaceUsesWith(pass.ReadVariable(ZeroFlagTag{}, block));
break;
diff --git a/src/shader_recompiler/ir_opt/verification_pass.cpp b/src/shader_recompiler/ir_opt/verification_pass.cpp
index 8a5adf5a2..32b56eb57 100644
--- a/src/shader_recompiler/ir_opt/verification_pass.cpp
+++ b/src/shader_recompiler/ir_opt/verification_pass.cpp
@@ -14,6 +14,10 @@ namespace Shader::Optimization {
static void ValidateTypes(const IR::Function& function) {
for (const auto& block : function.blocks) {
for (const IR::Inst& inst : *block) {
+ if (inst.Opcode() == IR::Opcode::Phi) {
+ // Skip validation on phi nodes
+ continue;
+ }
const size_t num_args{inst.NumArgs()};
for (size_t i = 0; i < num_args; ++i) {
const IR::Type t1{inst.Arg(i).Type()};