1
0
mirror of https://github.com/RPCS3/rpcs3.git synced 2024-11-22 18:53:28 +01:00

SPU LLVM: implement spu_re, spu_rsqrte

Improve matching with peek_through_bitcasts() helper.
Implement erase_stores() helper.
This commit is contained in:
Nekotekina 2021-09-07 19:42:05 +03:00
parent aba332d4c4
commit d28b0ba2fa
3 changed files with 96 additions and 7 deletions

View File

@ -7,6 +7,18 @@
llvm::LLVMContext g_llvm_ctx;
llvm::Value* peek_through_bitcasts(llvm::Value* arg)
{
llvm::CastInst* i;
while ((i = llvm::dyn_cast_or_null<llvm::CastInst>(arg)) && i->getOpcode() == llvm::Instruction::BitCast)
{
arg = i->getOperand(0);
}
return arg;
}
cpu_translator::cpu_translator(llvm::Module* _module, bool is_be)
: m_context(g_llvm_ctx)
, m_module(_module)
@ -312,4 +324,27 @@ void cpu_translator::replace_intrinsics(llvm::Function& f)
}
}
void cpu_translator::erase_stores(llvm::ArrayRef<llvm::Value*> args)
{
for (auto v : args)
{
for (auto it = v->use_begin(); it != v->use_end(); ++it)
{
llvm::Value* i = *it;
llvm::CastInst* bci = nullptr;
// Walk through bitcasts
while (i && (bci = llvm::dyn_cast<llvm::CastInst>(i)) && bci->getOpcode() == llvm::Instruction::BitCast)
{
i = *bci->use_begin();
}
if (auto si = llvm::dyn_cast_or_null<llvm::StoreInst>(i))
{
si->eraseFromParent();
}
}
}
}
#endif

View File

@ -427,6 +427,9 @@ using llvm_common_t = std::enable_if_t<(is_llvm_expr_of<T, Types>::ok && ...), t
template <typename... Args>
using llvm_match_tuple = decltype(std::tuple_cat(std::declval<llvm_expr_t<Args>&>().match(std::declval<llvm::Value*&>(), nullptr)...));
// Helper function
llvm::Value* peek_through_bitcasts(llvm::Value*);
template <typename T, typename U = llvm_common_t<llvm_value_t<T>>>
struct llvm_match_t
{
@ -442,7 +445,8 @@ struct llvm_match_t
template <typename... Args>
bool eq(const Args&... args) const
{
return value && ((value == args.value) && ...);
llvm::Value* lhs = nullptr;
return value && (lhs = peek_through_bitcasts(value)) && ((lhs == peek_through_bitcasts(args.value)) && ...);
}
llvm::Value* eval(llvm::IRBuilder<>*) const
@ -3491,6 +3495,15 @@ public:
// Finalize processing custom intrinsics
void replace_intrinsics(llvm::Function&);
// Erase store instructions of provided
void erase_stores(llvm::ArrayRef<llvm::Value*> args);
template <typename... Args>
void erase_stores(Args... args)
{
erase_stores({args.value...});
}
template <typename T, typename U>
static auto pshufb(T&& a, U&& b)
{

View File

@ -7778,12 +7778,9 @@ public:
bool is_input_positive(value_t<f32[4]> a)
{
if (auto [ok, v0, v1] = match_expr(a, match<f32[4]>() * match<f32[4]>()); ok)
if (auto [ok, v0, v1] = match_expr(a, match<f32[4]>() * match<f32[4]>()); ok && v0.eq(v1))
{
if (v0.value == v1.value)
{
return true;
}
return true;
}
return false;
@ -8496,6 +8493,18 @@ public:
return {"spu_fi", {std::forward<T>(a), std::forward<U>(b)}};
}
template <typename T>
static llvm_calli<f32[4], T> spu_re(T&& a)
{
return {"spu_re", {std::forward<T>(a)}};
}
template <typename T>
static llvm_calli<f32[4], T> spu_rsqrte(T&& a)
{
return {"spu_rsqrte", {std::forward<T>(a)}};
}
void FI(spu_opcode_t op)
{
// TODO
@ -8527,7 +8536,39 @@ public:
return bitcast<f32[4]>((b & 0xff800000u) | (bitcast<u32[4]>(fpcast<f32[4]>(bnew)) & ~0xff800000u)); // Inject old sign and exponent
});
set_vr(op.rt, fi(get_vr<f32[4]>(op.ra), get_vr<f32[4]>(op.rb)));
register_intrinsic("spu_re", [&](llvm::CallInst* ci)
{
const auto a = value<f32[4]>(ci->getOperand(0));
return fre(a);
});
register_intrinsic("spu_rsqrte", [&](llvm::CallInst* ci)
{
const auto a = value<f32[4]>(ci->getOperand(0));
return frsqe(fabs(a));
});
const auto [a, b] = get_vrs<f32[4]>(op.ra, op.rb);
if (const auto [ok, mb] = match_expr(b, frest(match<f32[4]>())); ok && mb.eq(a))
{
erase_stores(b);
set_vr(op.rt, spu_re(a));
return;
}
if (const auto [ok, mb] = match_expr(b, frsqest(match<f32[4]>())); ok && mb.eq(a))
{
erase_stores(b);
set_vr(op.rt, spu_rsqrte(a));
return;
}
const auto r = eval(fi(a, b));
if (!m_interp_magn)
spu_log.todo("[%s:0x%05x] Unmatched spu_fi found", m_hash, m_pos);
set_vr(op.rt, r);
}
void CFLTS(spu_opcode_t op)