1
0
mirror of https://github.com/RPCS3/rpcs3.git synced 2025-01-31 12:31:45 +01:00

LLVM DSL: rewrite add_sat and sub_sat

Simplify constant folding logic
This commit is contained in:
Nekotekina 2019-04-22 15:32:52 +03:00
parent ac473eb400
commit 2eac59f59a

View File

@ -1286,6 +1286,125 @@ struct llvm_max
}
};
template <typename A1, typename A2, typename T = llvm_common_t<A1, A2>>
struct llvm_add_sat
{
using type = T;
llvm_expr_t<A1> a1;
llvm_expr_t<A2> a2;
static_assert(llvm_value_t<T>::is_sint || llvm_value_t<T>::is_uint, "llvm_add_sat<>: invalid type");
static constexpr bool is_ok = llvm_value_t<T>::is_sint || llvm_value_t<T>::is_uint;
static llvm::Function* get_sadd_sat(llvm::IRBuilder<>* ir)
{
const auto module = ir->GetInsertBlock()->getParent()->getParent();
return llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::sadd_sat, {llvm_value_t<T>::get_type(ir->getContext())});
}
static llvm::Function* get_uadd_sat(llvm::IRBuilder<>* ir)
{
const auto module = ir->GetInsertBlock()->getParent()->getParent();
return llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::uadd_sat, {llvm_value_t<T>::get_type(ir->getContext())});
}
llvm::Value* eval(llvm::IRBuilder<>* ir) const
{
const auto v1 = a1.eval(ir);
const auto v2 = a2.eval(ir);
if (llvm::isa<llvm::Constant>(v1) && llvm::isa<llvm::Constant>(v2))
{
const auto sum = ir->CreateAdd(v1, v2);
if constexpr (llvm_value_t<T>::is_sint)
{
const auto max = llvm::ConstantInt::get(v1->getType(), llvm::APInt::getSignedMaxValue(llvm_value_t<T>::esize));
const auto sat = ir->CreateXor(ir->CreateAShr(v1, llvm_value_t<T>::esize - 1), max); // Max -> min if v1 < 0
const auto ovf = ir->CreateAnd(ir->CreateXor(v2, sum), ir->CreateNot(ir->CreateXor(v1, v2))); // Get overflow
return ir->CreateSelect(ir->CreateICmpSLT(ovf, llvm::ConstantInt::get(v1->getType(), 0)), sat, sum);
}
if constexpr (llvm_value_t<T>::is_uint)
{
const auto max = llvm::ConstantInt::get(v1->getType(), llvm::APInt::getMaxValue(llvm_value_t<T>::esize));
return ir->CreateSelect(ir->CreateICmpULT(sum, v1), max, sum);
}
}
if constexpr (llvm_value_t<T>::is_sint)
{
return ir->CreateCall(get_sadd_sat(ir), {v1, v2});
}
if constexpr (llvm_value_t<T>::is_uint)
{
return ir->CreateCall(get_uadd_sat(ir), {v1, v2});
}
}
};
template <typename A1, typename A2, typename T = llvm_common_t<A1, A2>>
struct llvm_sub_sat
{
using type = T;
llvm_expr_t<A1> a1;
llvm_expr_t<A2> a2;
static_assert(llvm_value_t<T>::is_sint || llvm_value_t<T>::is_uint, "llvm_sub_sat<>: invalid type");
static constexpr bool is_ok = llvm_value_t<T>::is_sint || llvm_value_t<T>::is_uint;
static llvm::Function* get_ssub_sat(llvm::IRBuilder<>* ir)
{
const auto module = ir->GetInsertBlock()->getParent()->getParent();
return llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::ssub_sat, {llvm_value_t<T>::get_type(ir->getContext())});
}
static llvm::Function* get_usub_sat(llvm::IRBuilder<>* ir)
{
const auto module = ir->GetInsertBlock()->getParent()->getParent();
return llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::usub_sat, {llvm_value_t<T>::get_type(ir->getContext())});
}
llvm::Value* eval(llvm::IRBuilder<>* ir) const
{
const auto v1 = a1.eval(ir);
const auto v2 = a2.eval(ir);
if (llvm::isa<llvm::Constant>(v1) && llvm::isa<llvm::Constant>(v2))
{
const auto dif = ir->CreateSub(v1, v2);
if constexpr (llvm_value_t<T>::is_sint)
{
const auto max = llvm::ConstantInt::get(v1->getType(), llvm::APInt::getSignedMaxValue(llvm_value_t<T>::esize));
const auto sat = ir->CreateXor(ir->CreateAShr(v1, llvm_value_t<T>::esize - 1), max); // Max -> min if v1 < 0
const auto ovf = ir->CreateAnd(ir->CreateXor(v1, dif), ir->CreateXor(v1, v2)); // Get overflow (subtraction)
return ir->CreateSelect(ir->CreateICmpSLT(ovf, llvm::ConstantInt::get(v1->getType(), 0)), sat, dif);
}
if constexpr (llvm_value_t<T>::is_uint)
{
return ir->CreateSelect(ir->CreateICmpULT(v1, v2), llvm::ConstantInt::get(v1->getType(), 0), dif);
}
}
if constexpr (llvm_value_t<T>::is_sint)
{
return ir->CreateCall(get_ssub_sat(ir), {v1, v2});
}
if constexpr (llvm_value_t<T>::is_uint)
{
return ir->CreateCall(get_usub_sat(ir), {v1, v2});
}
}
};
class cpu_translator
{
protected:
@ -1428,96 +1547,16 @@ public:
return llvm_rol<T, U>{std::forward<T>(a), std::forward<U>(b)};
}
// Add with saturation
template <typename T>
inline auto add_sat(T a, T b)
template <typename T, typename U, typename = std::enable_if_t<llvm_add_sat<T, U>::is_ok>>
static auto add_sat(T&& a, U&& b)
{
value_t<typename T::type> result;
const auto eva = a.eval(m_ir);
const auto evb = b.eval(m_ir);
// Compute constant result immediately if possible
if (llvm::isa<llvm::Constant>(eva) && llvm::isa<llvm::Constant>(evb))
{
static_assert(result.is_sint || result.is_uint);
if constexpr (result.is_sint)
{
llvm::Type* cast_to = m_ir->getIntNTy(result.esize * 2);
if constexpr (result.is_vector != 0)
cast_to = llvm::VectorType::get(cast_to, result.is_vector);
const auto axt = m_ir->CreateSExt(eva, cast_to);
const auto bxt = m_ir->CreateSExt(evb, cast_to);
result.value = m_ir->CreateAdd(axt, bxt);
const auto _max = m_ir->getInt(llvm::APInt::getSignedMaxValue(result.esize * 2).ashr(result.esize));
const auto _min = m_ir->getInt(llvm::APInt::getSignedMinValue(result.esize * 2).ashr(result.esize));
const auto smax = result.is_vector != 0 ? llvm::ConstantVector::getSplat(result.is_vector, _max) : _max;
const auto smin = result.is_vector != 0 ? llvm::ConstantVector::getSplat(result.is_vector, _min) : _min;
result.value = m_ir->CreateSelect(m_ir->CreateICmpSGT(result.value, smax), smax, result.value);
result.value = m_ir->CreateSelect(m_ir->CreateICmpSLT(result.value, smin), smin, result.value);
result.value = m_ir->CreateTrunc(result.value, result.get_type(m_context));
}
else
{
const auto _max = m_ir->getInt(llvm::APInt::getMaxValue(result.esize));
const auto ones = result.is_vector != 0 ? llvm::ConstantVector::getSplat(result.is_vector, _max) : _max;
result.value = m_ir->CreateAdd(eva, evb);
result.value = m_ir->CreateSelect(m_ir->CreateICmpULT(result.value, eva), ones, result.value);
}
}
else
{
result.value = m_ir->CreateCall(get_intrinsic<typename T::type>(result.is_sint ? llvm::Intrinsic::sadd_sat : llvm::Intrinsic::uadd_sat), {eva, evb});
}
return result;
return llvm_add_sat<T, U>{std::forward<T>(a), std::forward<U>(b)};
}
// Subtract with saturation
template <typename T>
inline auto sub_sat(T a, T b)
template <typename T, typename U, typename = std::enable_if_t<llvm_sub_sat<T, U>::is_ok>>
static auto sub_sat(T&& a, U&& b)
{
value_t<typename T::type> result;
const auto eva = a.eval(m_ir);
const auto evb = b.eval(m_ir);
// Compute constant result immediately if possible
if (llvm::isa<llvm::Constant>(eva) && llvm::isa<llvm::Constant>(evb))
{
static_assert(result.is_sint || result.is_uint);
if constexpr (result.is_sint)
{
llvm::Type* cast_to = m_ir->getIntNTy(result.esize * 2);
if constexpr (result.is_vector != 0)
cast_to = llvm::VectorType::get(cast_to, result.is_vector);
const auto axt = m_ir->CreateSExt(eva, cast_to);
const auto bxt = m_ir->CreateSExt(evb, cast_to);
result.value = m_ir->CreateSub(axt, bxt);
const auto _max = m_ir->getInt(llvm::APInt::getSignedMaxValue(result.esize * 2).ashr(result.esize));
const auto _min = m_ir->getInt(llvm::APInt::getSignedMinValue(result.esize * 2).ashr(result.esize));
const auto smax = result.is_vector != 0 ? llvm::ConstantVector::getSplat(result.is_vector, _max) : _max;
const auto smin = result.is_vector != 0 ? llvm::ConstantVector::getSplat(result.is_vector, _min) : _min;
result.value = m_ir->CreateSelect(m_ir->CreateICmpSGT(result.value, smax), smax, result.value);
result.value = m_ir->CreateSelect(m_ir->CreateICmpSLT(result.value, smin), smin, result.value);
result.value = m_ir->CreateTrunc(result.value, result.get_type(m_context));
}
else
{
const auto _min = m_ir->getInt(llvm::APInt::getMinValue(result.esize));
const auto zero = result.is_vector != 0 ? llvm::ConstantVector::getSplat(result.is_vector, _min) : _min;
result.value = m_ir->CreateSub(eva, evb);
result.value = m_ir->CreateSelect(m_ir->CreateICmpULT(eva, evb), zero, result.value);
}
}
else
{
result.value = m_ir->CreateCall(get_intrinsic<typename T::type>(result.is_sint ? llvm::Intrinsic::ssub_sat : llvm::Intrinsic::usub_sat), {eva, evb});
}
return result;
return llvm_sub_sat<T, U>{std::forward<T>(a), std::forward<U>(b)};
}
// Average: (a + b + 1) >> 1