1
0
mirror of https://github.com/RPCS3/rpcs3.git synced 2025-01-31 12:31:45 +01:00

LLVM DSL: rewrite splat, fsplat, vsplat

Add llvm_const_float and llvm_splat templates.
This commit is contained in:
Nekotekina 2019-04-23 20:08:18 +03:00
parent c83e65f29e
commit b02503963e
2 changed files with 91 additions and 59 deletions

View File

@ -368,6 +368,8 @@ struct llvm_const_int
u64 val;
static constexpr bool is_ok = llvm_value_t<T>::is_int;
llvm::Value* eval(llvm::IRBuilder<>* ir) const
{
static_assert(llvm_value_t<T>::is_int, "llvm_const_int<>: invalid type");
@ -376,6 +378,23 @@ struct llvm_const_int
}
};
template <typename T>
struct llvm_const_float
{
using type = T;
f64 val;
static constexpr bool is_ok = llvm_value_t<T>::is_float;
llvm::Value* eval(llvm::IRBuilder<>* ir) const
{
static_assert(llvm_value_t<T>::is_float, "llvm_const_float<>: invalid type");
return llvm::ConstantFP::get(llvm_value_t<T>::get_type(ir->getContext()), val);
}
};
template <typename A1, typename A2, typename T = llvm_common_t<A1, A2>>
struct llvm_add
{
@ -1454,6 +1473,30 @@ struct llvm_insert
}
};
template <typename U, typename A1, typename T = llvm_common_t<A1>>
struct llvm_splat
{
using type = U;
llvm_expr_t<A1> a1;
static_assert(!llvm_value_t<T>::is_vector, "llvm_splat<>: invalid type");
static_assert(llvm_value_t<U>::is_vector, "llvm_splat<>: invalid result type");
static_assert(std::is_same_v<T, std::remove_extent_t<U>>, "llvm_splat<>: incompatible splat type");
static constexpr bool is_ok =
!llvm_value_t<T>::is_vector &&
llvm_value_t<U>::is_vector &&
std::is_same_v<T, std::remove_extent_t<U>>;
llvm::Value* eval(llvm::IRBuilder<>* ir) const
{
const auto v1 = a1.eval(ir);
return ir->CreateVectorSplat(llvm_value_t<U>::is_vector, v1);
}
};
class cpu_translator
{
protected:
@ -1632,6 +1675,24 @@ public:
return llvm_insert<T, llvm_const_int<u32>, V>{std::forward<T>(v), llvm_const_int<u32>{i}, std::forward<V>(e)};
}
template <typename T, typename = std::enable_if_t<llvm_const_int<T>::is_ok>>
static auto splat(u64 c)
{
return llvm_const_int<T>{c};
}
template <typename T, typename = std::enable_if_t<llvm_const_float<T>::is_ok>>
static auto fsplat(f64 c)
{
return llvm_const_float<T>{c};
}
template <typename T, typename U, typename = std::enable_if_t<llvm_splat<T, U>::is_ok>>
static auto vsplat(U&& v)
{
return llvm_splat<T, U>{std::forward<U>(v)};
}
// Average: (a + b + 1) >> 1
template <typename T>
inline auto avg(T a, T b)
@ -1653,31 +1714,6 @@ public:
return result;
}
template <typename T>
auto splat(u64 c)
{
value_t<T> result;
result.value = llvm::ConstantInt::get(result.get_type(m_context), c, result.is_sint);
return result;
}
template <typename T>
auto fsplat(f64 c)
{
value_t<T> result;
result.value = llvm::ConstantFP::get(result.get_type(m_context), c);
return result;
}
template <typename T, typename V>
auto vsplat(V v)
{
value_t<T> result;
static_assert(result.is_vector);
result.value = m_ir->CreateVectorSplat(result.is_vector, v.eval(m_ir));
return result;
}
// Shuffle single vector using all zeros second vector of the same size
template <typename T, typename T1, typename... Args>
auto zshuffle(T1 a, Args... args)

View File

@ -2422,7 +2422,7 @@ class spu_llvm_recompiler : public spu_recompiler_base, public cpu_translator
{
if (llvm::isa<llvm::ConstantAggregateZero>(val))
{
return splat<u64[4]>(0).value;
return splat<u64[4]>(0).eval(m_ir);
}
if (auto cv = llvm::dyn_cast<llvm::ConstantDataVector>(val))
@ -2450,7 +2450,7 @@ class spu_llvm_recompiler : public spu_recompiler_base, public cpu_translator
{
if (llvm::isa<llvm::ConstantAggregateZero>(val))
{
return fsplat<f64[4]>(0.).value;
return fsplat<f64[4]>(0.).eval(m_ir);
}
if (auto cv = llvm::dyn_cast<llvm::ConstantDataVector>(val))
@ -2503,7 +2503,7 @@ class spu_llvm_recompiler : public spu_recompiler_base, public cpu_translator
const auto s = m_ir->CreateAnd(m_ir->CreateLShr(d, 32), 0x80000000);
const auto m = m_ir->CreateXor(m_ir->CreateLShr(d, 29), 0x40000000);
const auto r = m_ir->CreateOr(m_ir->CreateAnd(m, 0x7fffffff), s);
return m_ir->CreateTrunc(m_ir->CreateSelect(m_ir->CreateIsNotNull(d), r, splat<u64[4]>(0).value), get_type<u32[4]>());
return m_ir->CreateTrunc(m_ir->CreateSelect(m_ir->CreateIsNotNull(d), r, splat<u64[4]>(0).eval(m_ir)), get_type<u32[4]>());
}
llvm::Value* xfloat_to_double(llvm::Value* val)
@ -2513,8 +2513,8 @@ class spu_llvm_recompiler : public spu_recompiler_base, public cpu_translator
const auto x = m_ir->CreateZExt(val, get_type<u64[4]>());
const auto s = m_ir->CreateShl(m_ir->CreateAnd(x, 0x80000000), 32);
const auto a = m_ir->CreateAnd(x, 0x7fffffff);
const auto m = m_ir->CreateShl(m_ir->CreateAdd(a, splat<u64[4]>(0x1c0000000).value), 29);
const auto r = m_ir->CreateSelect(m_ir->CreateICmpSGT(a, splat<u64[4]>(0x7fffff).value), m, splat<u64[4]>(0).value);
const auto m = m_ir->CreateShl(m_ir->CreateAdd(a, splat<u64[4]>(0x1c0000000).eval(m_ir)), 29);
const auto r = m_ir->CreateSelect(m_ir->CreateICmpSGT(a, splat<u64[4]>(0x7fffff).eval(m_ir)), m, splat<u64[4]>(0).eval(m_ir));
const auto f = m_ir->CreateOr(s, r);
return uint64_as_double(f);
}
@ -2524,8 +2524,8 @@ class spu_llvm_recompiler : public spu_recompiler_base, public cpu_translator
{
verify("xfloat_in_double" HERE), val, val->getType() == get_type<f64[4]>();
const auto smax = uint64_as_double(splat<u64[4]>(0x47ffffffe0000000).value);
const auto smin = uint64_as_double(splat<u64[4]>(0x3810000000000000).value);
const auto smax = uint64_as_double(splat<u64[4]>(0x47ffffffe0000000).eval(m_ir));
const auto smin = uint64_as_double(splat<u64[4]>(0x3810000000000000).eval(m_ir));
const auto d = double_as_uint64(val);
const auto s = m_ir->CreateAnd(d, 0x8000000000000000);
@ -2533,7 +2533,7 @@ class spu_llvm_recompiler : public spu_recompiler_base, public cpu_translator
const auto n = m_ir->CreateFCmpOLT(a, smax);
const auto z = m_ir->CreateFCmpOLT(a, smin);
const auto c = double_as_uint64(m_ir->CreateSelect(n, a, smax));
return m_ir->CreateSelect(z, fsplat<f64[4]>(0.).value, uint64_as_double(m_ir->CreateOr(c, s)));
return m_ir->CreateSelect(z, fsplat<f64[4]>(0.).eval(m_ir), uint64_as_double(m_ir->CreateOr(c, s)));
}
// Expand 32-bit mask for xfloat values to 64-bit, 29 least significant bits are always zero
@ -2773,7 +2773,7 @@ class spu_llvm_recompiler : public spu_recompiler_base, public cpu_translator
return r;
}
return splat<T>(imm);
return eval(splat<T>(imm));
}
template <typename T = u32[4], uint I, uint N>
@ -2807,7 +2807,7 @@ class spu_llvm_recompiler : public spu_recompiler_base, public cpu_translator
return r;
}
return splat<T>(imm);
return eval(splat<T>(imm));
}
// Return either basic block addr with single dominating value, or negative number of PHI entries
@ -4854,7 +4854,7 @@ public:
if constexpr (!by.is_vector)
sh.value = m_ir->CreateVectorSplat(sh.is_vector, sh.value);
value_t<R> max_sh = splat<R>(by.esize - 1);
value_t<R> max_sh = eval(splat<R>(by.esize - 1));
sh.value = m_ir->CreateSelect(m_ir->CreateICmpUGT(max_sh.value, sh.value), sh.value, max_sh.value);
set_vr(op.rt, get_vr<R>(op.ra) >> sh);
}
@ -6063,9 +6063,9 @@ public:
value_t<f64[4]> a = get_vr<f64[4]>(op.ra);
value_t<f64[4]> s;
if (m_interp_magn)
s = vsplat<f64[4]>(bitcast<f64>(((1023 + 173) - get_imm<u64>(op.i8)) << 52));
s = eval(vsplat<f64[4]>(bitcast<f64>(((1023 + 173) - get_imm<u64>(op.i8)) << 52)));
else
s = fsplat<f64[4]>(std::exp2(static_cast<int>(173 - op.i8)));
s = eval(fsplat<f64[4]>(std::exp2(static_cast<int>(173 - op.i8))));
if (op.i8 != 173 || m_interp_magn)
a = eval(a * s);
@ -6118,9 +6118,9 @@ public:
value_t<f32[4]> a = get_vr<f32[4]>(op.ra);
value_t<f32[4]> s;
if (m_interp_magn)
s = vsplat<f32[4]>(load_const<f32>(m_scale_float_to, get_imm<u8>(op.i8)));
s = eval(vsplat<f32[4]>(load_const<f32>(m_scale_float_to, get_imm<u8>(op.i8))));
else
s = fsplat<f32[4]>(std::exp2(static_cast<float>(static_cast<s16>(173 - op.i8))));
s = eval(fsplat<f32[4]>(std::exp2(static_cast<float>(static_cast<s16>(173 - op.i8)))));
if (op.i8 != 173 || m_interp_magn)
a = eval(a * s);
@ -6137,9 +6137,9 @@ public:
value_t<f64[4]> a = get_vr<f64[4]>(op.ra);
value_t<f64[4]> s;
if (m_interp_magn)
s = vsplat<f64[4]>(bitcast<f64>(((1023 + 173) - get_imm<u64>(op.i8)) << 52));
s = eval(vsplat<f64[4]>(bitcast<f64>(((1023 + 173) - get_imm<u64>(op.i8)) << 52)));
else
s = fsplat<f64[4]>(std::exp2(static_cast<int>(173 - op.i8)));
s = eval(fsplat<f64[4]>(std::exp2(static_cast<int>(173 - op.i8))));
if (op.i8 != 173 || m_interp_magn)
a = eval(a * s);
@ -6184,27 +6184,23 @@ public:
return;
}
const auto _max = fsplat<f64[4]>(std::exp2(32.f));
r.value = m_ir->CreateFPToUI(a.value, get_type<s32[4]>());
r.value = m_ir->CreateSelect(m_ir->CreateFCmpUGE(a.value, _max.value), splat<s32[4]>(-1).eval(m_ir), (r & sext<s32[4]>(fcmp_ord(a >= fsplat<f64[4]>(0.)))).eval(m_ir));
set_vr(op.rt, r);
set_vr(op.rt, select(fcmp_uno(a >= fsplat<f64[4]>(std::exp2(32.f))), splat<s32[4]>(-1), r & sext<s32[4]>(fcmp_ord(a >= fsplat<f64[4]>(0.)))));
}
else
{
value_t<f32[4]> a = get_vr<f32[4]>(op.ra);
value_t<f32[4]> s;
if (m_interp_magn)
s = vsplat<f32[4]>(load_const<f32>(m_scale_float_to, get_imm<u8>(op.i8)));
s = eval(vsplat<f32[4]>(load_const<f32>(m_scale_float_to, get_imm<u8>(op.i8))));
else
s = fsplat<f32[4]>(std::exp2(static_cast<float>(static_cast<s16>(173 - op.i8))));
s = eval(fsplat<f32[4]>(std::exp2(static_cast<float>(static_cast<s16>(173 - op.i8)))));
if (op.i8 != 173 || m_interp_magn)
a = eval(a * s);
value_t<s32[4]> r;
const auto _max = fsplat<f32[4]>(std::exp2(32.f));
r.value = m_ir->CreateFPToUI(a.value, get_type<s32[4]>());
r.value = m_ir->CreateSelect(m_ir->CreateFCmpUGE(a.value, _max.value), splat<s32[4]>(-1).eval(m_ir), (r & ~(bitcast<s32[4]>(a) >> 31)).eval(m_ir));
set_vr(op.rt, r);
set_vr(op.rt, select(fcmp_uno(a >= fsplat<f32[4]>(std::exp2(32.f))), splat<s32[4]>(-1), r & ~(bitcast<s32[4]>(a) >> 31)));
}
}
@ -6227,9 +6223,9 @@ public:
value_t<f64[4]> s;
if (m_interp_magn)
s = vsplat<f64[4]>(bitcast<f64>((get_imm<u64>(op.i8) + (1023 - 155)) << 52));
s = eval(vsplat<f64[4]>(bitcast<f64>((get_imm<u64>(op.i8) + (1023 - 155)) << 52)));
else
s = fsplat<f64[4]>(std::exp2(static_cast<int>(op.i8 - 155)));
s = eval(fsplat<f64[4]>(std::exp2(static_cast<int>(op.i8 - 155))));
if (op.i8 != 155 || m_interp_magn)
r = eval(r * s);
set_vr(op.rt, r);
@ -6240,9 +6236,9 @@ public:
r.value = m_ir->CreateSIToFP(get_vr<s32[4]>(op.ra).value, get_type<f32[4]>());
value_t<f32[4]> s;
if (m_interp_magn)
s = vsplat<f32[4]>(load_const<f32>(m_scale_to_float, get_imm<u8>(op.i8)));
s = eval(vsplat<f32[4]>(load_const<f32>(m_scale_to_float, get_imm<u8>(op.i8))));
else
s = fsplat<f32[4]>(std::exp2(static_cast<float>(static_cast<s16>(op.i8 - 155))));
s = eval(fsplat<f32[4]>(std::exp2(static_cast<float>(static_cast<s16>(op.i8 - 155)))));
if (op.i8 != 155 || m_interp_magn)
r = eval(r * s);
set_vr(op.rt, r);
@ -6268,9 +6264,9 @@ public:
value_t<f64[4]> s;
if (m_interp_magn)
s = vsplat<f64[4]>(bitcast<f64>((get_imm<u64>(op.i8) + (1023 - 155)) << 52));
s = eval(vsplat<f64[4]>(bitcast<f64>((get_imm<u64>(op.i8) + (1023 - 155)) << 52)));
else
s = fsplat<f64[4]>(std::exp2(static_cast<int>(op.i8 - 155)));
s = eval(fsplat<f64[4]>(std::exp2(static_cast<int>(op.i8 - 155))));
if (op.i8 != 155 || m_interp_magn)
r = eval(r * s);
set_vr(op.rt, r);
@ -6281,9 +6277,9 @@ public:
r.value = m_ir->CreateUIToFP(get_vr<s32[4]>(op.ra).value, get_type<f32[4]>());
value_t<f32[4]> s;
if (m_interp_magn)
s = vsplat<f32[4]>(load_const<f32>(m_scale_to_float, get_imm<u8>(op.i8)));
s = eval(vsplat<f32[4]>(load_const<f32>(m_scale_to_float, get_imm<u8>(op.i8))));
else
s = fsplat<f32[4]>(std::exp2(static_cast<float>(static_cast<s16>(op.i8 - 155))));
s = eval(fsplat<f32[4]>(std::exp2(static_cast<float>(static_cast<s16>(op.i8 - 155)))));
if (op.i8 != 155 || m_interp_magn)
r = eval(r * s);
set_vr(op.rt, r);
@ -6555,7 +6551,7 @@ public:
m_ir->SetInsertPoint(done);
// Clear stack mirror and return by tail call to the provided return address
m_ir->CreateStore(splat<u64[2]>(-1).value, m_ir->CreateBitCast(m_ir->CreateGEP(m_thread, stack0.value), get_type<u64(*)[2]>()));
m_ir->CreateStore(splat<u64[2]>(-1).eval(m_ir), m_ir->CreateBitCast(m_ir->CreateGEP(m_thread, stack0.value), get_type<u64(*)[2]>()));
tail(_ret);
m_ir->SetInsertPoint(fail);
}