1
0
mirror of https://github.com/RPCS3/rpcs3.git synced 2024-11-26 12:42:41 +01:00

LLVM DSL: implement expression matching (preview)

Only literal match for binary ops implemented.
This commit is contained in:
Nekotekina 2019-04-24 23:53:49 +03:00
parent 8754bbd444
commit aca61fdcf9

View File

@ -60,6 +60,16 @@ struct llvm_value_t
return value;
}
std::tuple<> match(llvm::Value*& value) const
{
if (value != this->value)
{
value = nullptr;
}
return {};
}
llvm::Value* value;
// llvm_value_t() = default;
@ -361,6 +371,9 @@ struct is_llvm_expr_of<T, Of, std::void_t<typename is_llvm_expr<T>::type, typena
template <typename T, typename... Types>
using llvm_common_t = std::enable_if_t<(is_llvm_expr_of<T, Types>::ok && ...), typename is_llvm_expr<T>::type>;
template <typename... Args>
using llvm_match_tuple = decltype(std::tuple_cat(std::declval<llvm_expr_t<Args>&>().match(std::declval<llvm::Value*&>())...));
template <typename T, typename U = llvm_common_t<llvm_value_t<T>>>
struct llvm_match_t
{
@ -377,6 +390,38 @@ struct llvm_match_t
{
return value;
}
std::tuple<> match(llvm::Value*& value) const
{
if (value != this->value)
{
value = nullptr;
}
return {};
}
};
template <typename T, typename U = llvm_common_t<llvm_value_t<T>>>
struct llvm_placeholder_t
{
using type = T;
llvm::Value* eval(llvm::IRBuilder<>* ir) const
{
return nullptr;
}
std::tuple<llvm_match_t<T>> match(llvm::Value*& value) const
{
if (value && value->getType() == llvm_value_t<T>::get_type(value->getContext()))
{
return {value};
}
value = nullptr;
return {};
}
};
template <typename T, bool ForceSigned = false>
@ -394,6 +439,17 @@ struct llvm_const_int
return llvm::ConstantInt::get(llvm_value_t<T>::get_type(ir->getContext()), val, ForceSigned || llvm_value_t<T>::is_sint);
}
std::tuple<> match(llvm::Value*& value) const
{
if (value && value == llvm::ConstantInt::get(llvm_value_t<T>::get_type(value->getContext()), val, ForceSigned || llvm_value_t<T>::is_sint))
{
return {};
}
value = nullptr;
return {};
}
};
template <typename T>
@ -411,6 +467,17 @@ struct llvm_const_float
return llvm::ConstantFP::get(llvm_value_t<T>::get_type(ir->getContext()), val);
}
std::tuple<> match(llvm::Value*& value) const
{
if (value && value == llvm::ConstantFP::get(llvm_value_t<T>::get_type(value->getContext()), val))
{
return {};
}
value = nullptr;
return {};
}
};
template <uint N, typename T>
@ -428,6 +495,17 @@ struct llvm_const_vector
return llvm::ConstantDataVector::get(ir->getContext(), data);
}
std::tuple<> match(llvm::Value*& value) const
{
if (value && value == llvm::ConstantDataVector::get(value->getContext(), data))
{
return {};
}
value = nullptr;
return {};
}
};
template <typename A1, typename A2, typename T = llvm_common_t<A1, A2>>
@ -440,20 +518,36 @@ struct llvm_add
static_assert(llvm_value_t<T>::is_sint || llvm_value_t<T>::is_uint || llvm_value_t<T>::is_float, "llvm_add<>: invalid type");
static constexpr auto opc = llvm_value_t<T>::is_float ? llvm::Instruction::FAdd : llvm::Instruction::Add;
llvm::Value* eval(llvm::IRBuilder<>* ir) const
{
const auto v1 = a1.eval(ir);
const auto v2 = a2.eval(ir);
return ir->CreateBinOp(opc, v1, v2);
}
if constexpr (llvm_value_t<T>::is_int)
llvm_match_tuple<A1, A2> match(llvm::Value*& value) const
{
llvm::Value* v1 = {};
llvm::Value* v2 = {};
if (auto i = llvm::dyn_cast_or_null<llvm::BinaryOperator>(value); i && i->getOpcode() == opc)
{
return ir->CreateAdd(v1, v2);
v1 = i->getOperand(0);
v2 = i->getOperand(1);
if (auto r1 = a1.match(v1); v1)
{
if (auto r2 = a2.match(v2); v2)
{
return std::tuple_cat(r1, r2);
}
}
}
if constexpr (llvm_value_t<T>::is_float)
{
return ir->CreateFAdd(v1, v2);
}
value = nullptr;
return {};
}
};
@ -485,11 +579,13 @@ struct llvm_sum
const auto v1 = a1.eval(ir);
const auto v2 = a2.eval(ir);
const auto v3 = a3.eval(ir);
return ir->CreateAdd(ir->CreateAdd(v1, v2), v3);
}
if constexpr (llvm_value_t<T>::is_int)
{
return ir->CreateAdd(ir->CreateAdd(v1, v2), v3);
}
llvm_match_tuple<A1, A2, A3> match(llvm::Value*& value) const
{
value = nullptr;
return {};
}
};
@ -506,20 +602,36 @@ struct llvm_sub
static_assert(llvm_value_t<T>::is_sint || llvm_value_t<T>::is_uint || llvm_value_t<T>::is_float, "llvm_sub<>: invalid type");
static constexpr auto opc = llvm_value_t<T>::is_float ? llvm::Instruction::FSub : llvm::Instruction::Sub;
llvm::Value* eval(llvm::IRBuilder<>* ir) const
{
const auto v1 = a1.eval(ir);
const auto v2 = a2.eval(ir);
return ir->CreateBinOp(opc, v1, v2);
}
if constexpr (llvm_value_t<T>::is_int)
llvm_match_tuple<A1, A2> match(llvm::Value*& value) const
{
llvm::Value* v1 = {};
llvm::Value* v2 = {};
if (auto i = llvm::dyn_cast_or_null<llvm::BinaryOperator>(value); i && i->getOpcode() == opc)
{
return ir->CreateSub(v1, v2);
v1 = i->getOperand(0);
v2 = i->getOperand(1);
if (auto r1 = a1.match(v1); v1)
{
if (auto r2 = a2.match(v2); v2)
{
return std::tuple_cat(r1, r2);
}
}
}
if constexpr (llvm_value_t<T>::is_float)
{
return ir->CreateFSub(v1, v2);
}
value = nullptr;
return {};
}
};
@ -551,20 +663,36 @@ struct llvm_mul
static_assert(llvm_value_t<T>::is_sint || llvm_value_t<T>::is_uint || llvm_value_t<T>::is_float, "llvm_mul<>: invalid type");
static constexpr auto opc = llvm_value_t<T>::is_float ? llvm::Instruction::FMul : llvm::Instruction::Mul;
llvm::Value* eval(llvm::IRBuilder<>* ir) const
{
const auto v1 = a1.eval(ir);
const auto v2 = a2.eval(ir);
return ir->CreateBinOp(opc, v1, v2);
}
if constexpr (llvm_value_t<T>::is_int)
llvm_match_tuple<A1, A2> match(llvm::Value*& value) const
{
llvm::Value* v1 = {};
llvm::Value* v2 = {};
if (auto i = llvm::dyn_cast_or_null<llvm::BinaryOperator>(value); i && i->getOpcode() == opc)
{
return ir->CreateMul(v1, v2);
v1 = i->getOperand(0);
v2 = i->getOperand(1);
if (auto r1 = a1.match(v1); v1)
{
if (auto r2 = a2.match(v2); v2)
{
return std::tuple_cat(r1, r2);
}
}
}
if constexpr (llvm_value_t<T>::is_float)
{
return ir->CreateFMul(v1, v2);
}
value = nullptr;
return {};
}
};
@ -584,25 +712,38 @@ struct llvm_div
static_assert(llvm_value_t<T>::is_sint || llvm_value_t<T>::is_uint || llvm_value_t<T>::is_float, "llvm_div<>: invalid type");
static constexpr auto opc =
llvm_value_t<T>::is_float ? llvm::Instruction::FDiv :
llvm_value_t<T>::is_uint ? llvm::Instruction::UDiv : llvm::Instruction::SDiv;
llvm::Value* eval(llvm::IRBuilder<>* ir) const
{
const auto v1 = a1.eval(ir);
const auto v2 = a2.eval(ir);
return ir->CreateBinOp(opc, v1, v2);
}
if constexpr (llvm_value_t<T>::is_sint)
llvm_match_tuple<A1, A2> match(llvm::Value*& value) const
{
llvm::Value* v1 = {};
llvm::Value* v2 = {};
if (auto i = llvm::dyn_cast_or_null<llvm::BinaryOperator>(value); i && i->getOpcode() == opc)
{
return ir->CreateSDiv(v1, v2);
v1 = i->getOperand(0);
v2 = i->getOperand(1);
if (auto r1 = a1.match(v1); v1)
{
if (auto r2 = a2.match(v2); v2)
{
return std::tuple_cat(r1, r2);
}
}
}
if constexpr (llvm_value_t<T>::is_uint)
{
return ir->CreateUDiv(v1, v2);
}
if constexpr (llvm_value_t<T>::is_float)
{
return ir->CreateFDiv(v1, v2);
}
value = nullptr;
return {};
}
};
@ -621,6 +762,8 @@ struct llvm_neg
static_assert(llvm_value_t<T>::is_sint || llvm_value_t<T>::is_uint || llvm_value_t<T>::is_float, "llvm_neg<>: invalid type");
static constexpr auto opc = llvm_value_t<T>::is_float ? llvm::Instruction::FSub : llvm::Instruction::Sub;
llvm::Value* eval(llvm::IRBuilder<>* ir) const
{
const auto v1 = a1.eval(ir);
@ -635,6 +778,12 @@ struct llvm_neg
return ir->CreateFNeg(v1);
}
}
llvm_match_tuple<A1> match(llvm::Value*& value) const
{
value = nullptr;
return {};
}
};
template <typename T1>
@ -657,16 +806,30 @@ struct llvm_shl
{
const auto v1 = a1.eval(ir);
const auto v2 = a2.eval(ir);
return ir->CreateShl(v1, v2);
}
if constexpr (llvm_value_t<T>::is_sint)
llvm_match_tuple<A1, A2> match(llvm::Value*& value) const
{
llvm::Value* v1 = {};
llvm::Value* v2 = {};
if (auto i = llvm::dyn_cast_or_null<llvm::BinaryOperator>(value); i && i->getOpcode() == llvm::Instruction::Shl)
{
return ir->CreateShl(v1, v2);
v1 = i->getOperand(0);
v2 = i->getOperand(1);
if (auto r1 = a1.match(v1); v1)
{
if (auto r2 = a2.match(v2); v2)
{
return std::tuple_cat(r1, r2);
}
}
}
if constexpr (llvm_value_t<T>::is_uint)
{
return ir->CreateShl(v1, v2);
}
value = nullptr;
return {};
}
};
@ -692,20 +855,36 @@ struct llvm_shr
static_assert(llvm_value_t<T>::is_sint || llvm_value_t<T>::is_uint, "llvm_shr<>: invalid type");
static constexpr auto opc = llvm_value_t<T>::is_uint ? llvm::Instruction::LShr : llvm::Instruction::AShr;
llvm::Value* eval(llvm::IRBuilder<>* ir) const
{
const auto v1 = a1.eval(ir);
const auto v2 = a2.eval(ir);
return ir->CreateBinOp(opc, v1, v2);
}
if constexpr (llvm_value_t<T>::is_sint)
llvm_match_tuple<A1, A2> match(llvm::Value*& value) const
{
llvm::Value* v1 = {};
llvm::Value* v2 = {};
if (auto i = llvm::dyn_cast_or_null<llvm::BinaryOperator>(value); i && i->getOpcode() == opc)
{
return ir->CreateAShr(v1, v2);
v1 = i->getOperand(0);
v2 = i->getOperand(1);
if (auto r1 = a1.match(v1); v1)
{
if (auto r2 = a2.match(v2); v2)
{
return std::tuple_cat(r1, r2);
}
}
}
if constexpr (llvm_value_t<T>::is_uint)
{
return ir->CreateLShr(v1, v2);
}
value = nullptr;
return {};
}
};
@ -763,6 +942,12 @@ struct llvm_fshl
return ir->CreateCall(get_fshl(ir), {v1, v2, v3});
}
llvm_match_tuple<A1, A2, A3> match(llvm::Value*& value) const
{
value = nullptr;
return {};
}
};
template <typename A1, typename A2, typename A3, typename T = llvm_common_t<A1, A2, A3>>
@ -807,6 +992,12 @@ struct llvm_fshr
return ir->CreateCall(get_fshr(ir), {v1, v2, v3});
}
llvm_match_tuple<A1, A2, A3> match(llvm::Value*& value) const
{
value = nullptr;
return {};
}
};
template <typename A1, typename A2, typename T = llvm_common_t<A1, A2>>
@ -833,6 +1024,12 @@ struct llvm_rol
return ir->CreateCall(llvm_fshl<A1, A1, A2>::get_fshl(ir), {v1, v1, v2});
}
llvm_match_tuple<A1, A2> match(llvm::Value*& value) const
{
value = nullptr;
return {};
}
};
template <typename A1, typename A2, typename T = llvm_common_t<A1, A2>>
@ -849,11 +1046,30 @@ struct llvm_and
{
const auto v1 = a1.eval(ir);
const auto v2 = a2.eval(ir);
return ir->CreateAnd(v1, v2);
}
if constexpr (llvm_value_t<T>::is_int)
llvm_match_tuple<A1, A2> match(llvm::Value*& value) const
{
llvm::Value* v1 = {};
llvm::Value* v2 = {};
if (auto i = llvm::dyn_cast_or_null<llvm::BinaryOperator>(value); i && i->getOpcode() == llvm::Instruction::And)
{
return ir->CreateAnd(v1, v2);
v1 = i->getOperand(0);
v2 = i->getOperand(1);
if (auto r1 = a1.match(v1); v1)
{
if (auto r2 = a2.match(v2); v2)
{
return std::tuple_cat(r1, r2);
}
}
}
value = nullptr;
return {};
}
};
@ -883,11 +1099,30 @@ struct llvm_or
{
const auto v1 = a1.eval(ir);
const auto v2 = a2.eval(ir);
return ir->CreateOr(v1, v2);
}
if constexpr (llvm_value_t<T>::is_int)
llvm_match_tuple<A1, A2> match(llvm::Value*& value) const
{
llvm::Value* v1 = {};
llvm::Value* v2 = {};
if (auto i = llvm::dyn_cast_or_null<llvm::BinaryOperator>(value); i && i->getOpcode() == llvm::Instruction::Or)
{
return ir->CreateOr(v1, v2);
v1 = i->getOperand(0);
v2 = i->getOperand(1);
if (auto r1 = a1.match(v1); v1)
{
if (auto r2 = a2.match(v2); v2)
{
return std::tuple_cat(r1, r2);
}
}
}
value = nullptr;
return {};
}
};
@ -917,11 +1152,30 @@ struct llvm_xor
{
const auto v1 = a1.eval(ir);
const auto v2 = a2.eval(ir);
return ir->CreateXor(v1, v2);
}
if constexpr (llvm_value_t<T>::is_int)
llvm_match_tuple<A1, A2> match(llvm::Value*& value) const
{
llvm::Value* v1 = {};
llvm::Value* v2 = {};
if (auto i = llvm::dyn_cast_or_null<llvm::BinaryOperator>(value); i && i->getOpcode() == llvm::Instruction::Xor)
{
return ir->CreateXor(v1, v2);
v1 = i->getOperand(0);
v2 = i->getOperand(1);
if (auto r1 = a1.match(v1); v1)
{
if (auto r2 = a2.match(v2); v2)
{
return std::tuple_cat(r1, r2);
}
}
}
value = nullptr;
return {};
}
};
@ -970,11 +1224,13 @@ struct llvm_cmp
const auto v1 = a1.eval(ir);
const auto v2 = a2.eval(ir);
return ir->CreateICmp(pred, v1, v2);
}
if constexpr (llvm_value_t<T>::is_int)
{
return ir->CreateICmp(pred, v1, v2);
}
llvm_match_tuple<A1, A2> match(llvm::Value*& value) const
{
value = nullptr;
return {};
}
};
@ -1013,6 +1269,12 @@ struct llvm_ord
const auto v2 = cmp.a2.eval(ir);
return ir->CreateFCmp(pred, v1, v2);
}
llvm_match_tuple<Cmp> match(llvm::Value*& value) const
{
value = nullptr;
return {};
}
};
template <typename T>
@ -1043,6 +1305,12 @@ struct llvm_uno
const auto v2 = cmp.a2.eval(ir);
return ir->CreateFCmp(pred, v1, v2);
}
llvm_match_tuple<Cmp> match(llvm::Value*& value) const
{
value = nullptr;
return {};
}
};
template <typename T>
@ -1143,6 +1411,12 @@ struct llvm_noncast
// No operation required
return a1.eval(ir);
}
llvm_match_tuple<A1> match(llvm::Value*& value) const
{
value = nullptr;
return {};
}
};
template <typename U, typename A1, typename T = llvm_common_t<A1>>
@ -1188,6 +1462,12 @@ struct llvm_bitcast
return ir->CreateBitCast(v1, rt);
}
llvm_match_tuple<A1> match(llvm::Value*& value) const
{
value = nullptr;
return {};
}
};
template <typename U, typename A1, typename T = llvm_common_t<A1>>
@ -1212,6 +1492,12 @@ struct llvm_trunc
{
return ir->CreateTrunc(a1.eval(ir), llvm_value_t<U>::get_type(ir->getContext()));
}
llvm_match_tuple<A1> match(llvm::Value*& value) const
{
value = nullptr;
return {};
}
};
template <typename U, typename A1, typename T = llvm_common_t<A1>>
@ -1236,6 +1522,12 @@ struct llvm_sext
{
return ir->CreateSExt(a1.eval(ir), llvm_value_t<U>::get_type(ir->getContext()));
}
llvm_match_tuple<A1> match(llvm::Value*& value) const
{
value = nullptr;
return {};
}
};
template <typename U, typename A1, typename T = llvm_common_t<A1>>
@ -1260,6 +1552,12 @@ struct llvm_zext
{
return ir->CreateZExt(a1.eval(ir), llvm_value_t<U>::get_type(ir->getContext()));
}
llvm_match_tuple<A1> match(llvm::Value*& value) const
{
value = nullptr;
return {};
}
};
template <typename A1, typename A2, typename A3, typename T = llvm_common_t<A2, A3>, typename U = llvm_common_t<A1>>
@ -1282,6 +1580,12 @@ struct llvm_select
{
return ir->CreateSelect(cond.eval(ir), a2.eval(ir), a3.eval(ir));
}
llvm_match_tuple<A1, A2, A3> match(llvm::Value*& value) const
{
value = nullptr;
return {};
}
};
template <typename A1, typename A2, typename T = llvm_common_t<A1, A2>>
@ -1311,6 +1615,12 @@ struct llvm_min
return ir->CreateSelect(ir->CreateICmpULT(v1, v2), v1, v2);
}
}
llvm_match_tuple<A1, A2> match(llvm::Value*& value) const
{
value = nullptr;
return {};
}
};
template <typename A1, typename A2, typename T = llvm_common_t<A1, A2>>
@ -1338,6 +1648,12 @@ struct llvm_max
return ir->CreateSelect(ir->CreateICmpULT(v1, v2), v2, v1);
}
}
llvm_match_tuple<A1, A2> match(llvm::Value*& value) const
{
value = nullptr;
return {};
}
};
template <typename A1, typename A2, typename T = llvm_common_t<A1, A2>>
@ -1398,6 +1714,12 @@ struct llvm_add_sat
return ir->CreateCall(get_uadd_sat(ir), {v1, v2});
}
}
llvm_match_tuple<A1, A2> match(llvm::Value*& value) const
{
value = nullptr;
return {};
}
};
template <typename A1, typename A2, typename T = llvm_common_t<A1, A2>>
@ -1457,6 +1779,12 @@ struct llvm_sub_sat
return ir->CreateCall(get_usub_sat(ir), {v1, v2});
}
}
llvm_match_tuple<A1, A2> match(llvm::Value*& value) const
{
value = nullptr;
return {};
}
};
template <typename A1, typename I2, typename T = llvm_common_t<A1>, typename U = llvm_common_t<I2>>
@ -1480,6 +1808,12 @@ struct llvm_extract
return ir->CreateExtractElement(v1, v2);
}
llvm_match_tuple<A1, I2> match(llvm::Value*& value) const
{
value = nullptr;
return {};
}
};
template <typename A1, typename I2, typename A3, typename T = llvm_common_t<A1>, typename U = llvm_common_t<I2>, typename V = llvm_common_t<A3>>
@ -1506,6 +1840,12 @@ struct llvm_insert
return ir->CreateInsertElement(v1, v3, v2);
}
llvm_match_tuple<A1, I2, A3> match(llvm::Value*& value) const
{
value = nullptr;
return {};
}
};
template <typename U, typename A1, typename T = llvm_common_t<A1>>
@ -1530,6 +1870,12 @@ struct llvm_splat
return ir->CreateVectorSplat(llvm_value_t<U>::is_vector, v1);
}
llvm_match_tuple<A1> match(llvm::Value*& value) const
{
value = nullptr;
return {};
}
};
template <uint N, typename A1, typename T = llvm_common_t<A1>>
@ -1550,6 +1896,12 @@ struct llvm_zshuffle
return ir->CreateShuffleVector(v1, llvm::ConstantAggregateZero::get(v1->getType()), index_array);
}
llvm_match_tuple<A1> match(llvm::Value*& value) const
{
value = nullptr;
return {};
}
};
template <uint N, typename A1, typename A2, typename T = llvm_common_t<A1, A2>>
@ -1572,6 +1924,12 @@ struct llvm_shuffle2
return ir->CreateShuffleVector(v1, v2, index_array);
}
llvm_match_tuple<A1, A2> match(llvm::Value*& value) const
{
value = nullptr;
return {};
}
};
class cpu_translator