diff --git a/utils/TableGen/DAGISelMatcher.h b/utils/TableGen/DAGISelMatcher.h index ec61fcd1dab..9af98f77f35 100644 --- a/utils/TableGen/DAGISelMatcher.h +++ b/utils/TableGen/DAGISelMatcher.h @@ -104,6 +104,12 @@ public: return ((getHashImpl() << 4) ^ getKind()) & (~0U>>1); } + /// isSafeToReorderWithPatternPredicate - Return true if it is safe to sink a + /// PatternPredicate node past this one. + virtual bool isSafeToReorderWithPatternPredicate() const { + return false; + } + void print(raw_ostream &OS, unsigned indent = 0) const; void dump() const; protected: @@ -173,6 +179,7 @@ public: return N->getKind() == RecordNode; } + virtual bool isSafeToReorderWithPatternPredicate() const { return true; } private: virtual void printImpl(raw_ostream &OS, unsigned indent) const; virtual bool isEqualImpl(const Matcher *M) const { return true; } @@ -199,6 +206,8 @@ public: return N->getKind() == RecordChild; } + virtual bool isSafeToReorderWithPatternPredicate() const { return true; } + private: virtual void printImpl(raw_ostream &OS, unsigned indent) const; virtual bool isEqualImpl(const Matcher *M) const { @@ -216,6 +225,8 @@ public: return N->getKind() == RecordMemRef; } + virtual bool isSafeToReorderWithPatternPredicate() const { return true; } + private: virtual void printImpl(raw_ostream &OS, unsigned indent) const; virtual bool isEqualImpl(const Matcher *M) const { return true; } @@ -233,6 +244,8 @@ public: return N->getKind() == CaptureFlagInput; } + virtual bool isSafeToReorderWithPatternPredicate() const { return true; } + private: virtual void printImpl(raw_ostream &OS, unsigned indent) const; virtual bool isEqualImpl(const Matcher *M) const { return true; } @@ -252,6 +265,8 @@ public: return N->getKind() == MoveChild; } + virtual bool isSafeToReorderWithPatternPredicate() const { return true; } + private: virtual void printImpl(raw_ostream &OS, unsigned indent) const; virtual bool isEqualImpl(const Matcher *M) const { @@ -270,6 +285,8 @@ public: return N->getKind() == MoveParent; } + virtual bool isSafeToReorderWithPatternPredicate() const { return true; } + private: virtual void printImpl(raw_ostream &OS, unsigned indent) const; virtual bool isEqualImpl(const Matcher *M) const { return true; } @@ -291,6 +308,8 @@ public: return N->getKind() == CheckSame; } + virtual bool isSafeToReorderWithPatternPredicate() const { return true; } + private: virtual void printImpl(raw_ostream &OS, unsigned indent) const; virtual bool isEqualImpl(const Matcher *M) const { @@ -314,6 +333,8 @@ public: return N->getKind() == CheckPatternPredicate; } + virtual bool isSafeToReorderWithPatternPredicate() const { return true; } + private: virtual void printImpl(raw_ostream &OS, unsigned indent) const; virtual bool isEqualImpl(const Matcher *M) const { @@ -336,6 +357,9 @@ public: return N->getKind() == CheckPredicate; } + // TODO: Ok? + //virtual bool isSafeToReorderWithPatternPredicate() const { return true; } + private: virtual void printImpl(raw_ostream &OS, unsigned indent) const; virtual bool isEqualImpl(const Matcher *M) const { @@ -359,6 +383,8 @@ public: return N->getKind() == CheckOpcode; } + virtual bool isSafeToReorderWithPatternPredicate() const { return true; } + private: virtual void printImpl(raw_ostream &OS, unsigned indent) const; virtual bool isEqualImpl(const Matcher *M) const { @@ -382,6 +408,8 @@ public: return N->getKind() == CheckMultiOpcode; } + virtual bool isSafeToReorderWithPatternPredicate() const { return true; } + private: virtual void printImpl(raw_ostream &OS, unsigned indent) const; virtual bool isEqualImpl(const Matcher *M) const { @@ -406,6 +434,8 @@ public: return N->getKind() == CheckType; } + virtual bool isSafeToReorderWithPatternPredicate() const { return true; } + private: virtual void printImpl(raw_ostream &OS, unsigned indent) const; virtual bool isEqualImpl(const Matcher *M) const { @@ -430,6 +460,8 @@ public: return N->getKind() == CheckChildType; } + virtual bool isSafeToReorderWithPatternPredicate() const { return true; } + private: virtual void printImpl(raw_ostream &OS, unsigned indent) const; virtual bool isEqualImpl(const Matcher *M) const { @@ -454,6 +486,8 @@ public: return N->getKind() == CheckInteger; } + virtual bool isSafeToReorderWithPatternPredicate() const { return true; } + private: virtual void printImpl(raw_ostream &OS, unsigned indent) const; virtual bool isEqualImpl(const Matcher *M) const { @@ -476,6 +510,8 @@ public: return N->getKind() == CheckCondCode; } + virtual bool isSafeToReorderWithPatternPredicate() const { return true; } + private: virtual void printImpl(raw_ostream &OS, unsigned indent) const; virtual bool isEqualImpl(const Matcher *M) const { @@ -498,6 +534,8 @@ public: return N->getKind() == CheckValueType; } + virtual bool isSafeToReorderWithPatternPredicate() const { return true; } + private: virtual void printImpl(raw_ostream &OS, unsigned indent) const; virtual bool isEqualImpl(const Matcher *M) const { @@ -522,6 +560,9 @@ public: return N->getKind() == CheckComplexPat; } + // Not safe to move a pattern predicate past a complex pattern. + virtual bool isSafeToReorderWithPatternPredicate() const { return false; } + private: virtual void printImpl(raw_ostream &OS, unsigned indent) const; virtual bool isEqualImpl(const Matcher *M) const { @@ -546,6 +587,8 @@ public: return N->getKind() == CheckAndImm; } + virtual bool isSafeToReorderWithPatternPredicate() const { return true; } + private: virtual void printImpl(raw_ostream &OS, unsigned indent) const; virtual bool isEqualImpl(const Matcher *M) const { @@ -568,6 +611,8 @@ public: return N->getKind() == CheckOrImm; } + virtual bool isSafeToReorderWithPatternPredicate() const { return true; } + private: virtual void printImpl(raw_ostream &OS, unsigned indent) const; virtual bool isEqualImpl(const Matcher *M) const { @@ -587,6 +632,8 @@ public: return N->getKind() == CheckFoldableChainNode; } + virtual bool isSafeToReorderWithPatternPredicate() const { return true; } + private: virtual void printImpl(raw_ostream &OS, unsigned indent) const; virtual bool isEqualImpl(const Matcher *M) const { return true; } @@ -607,6 +654,8 @@ public: return N->getKind() == CheckChainCompatible; } + virtual bool isSafeToReorderWithPatternPredicate() const { return true; } + private: virtual void printImpl(raw_ostream &OS, unsigned indent) const; virtual bool isEqualImpl(const Matcher *M) const { diff --git a/utils/TableGen/DAGISelMatcherOpt.cpp b/utils/TableGen/DAGISelMatcherOpt.cpp index 5aaa51f97cf..55c25389336 100644 --- a/utils/TableGen/DAGISelMatcherOpt.cpp +++ b/utils/TableGen/DAGISelMatcherOpt.cpp @@ -16,6 +16,8 @@ #include using namespace llvm; +/// ContractNodes - Turn multiple matcher node patterns like 'MoveChild+Record' +/// into single compound nodes like RecordChild. static void ContractNodes(OwningPtr &MatcherPtr) { // If we reached the end of the chain, we're done. Matcher *N = MatcherPtr.get(); @@ -61,6 +63,71 @@ static void ContractNodes(OwningPtr &MatcherPtr) { ContractNodes(N->getNextPtr()); } +/// SinkPatternPredicates - Pattern predicates can be checked at any level of +/// the matching tree. The generator dumps them at the top level of the pattern +/// though, which prevents factoring from being able to see past them. This +/// optimization sinks them as far down into the pattern as possible. +/// +/// Conceptually, we'd like to sink these predicates all the way to the last +/// matcher predicate in the series. However, it turns out that some +/// ComplexPatterns have side effects on the graph, so we really don't want to +/// run a the complex pattern if the pattern predicate will fail. For this +/// reason, we refuse to sink the pattern predicate past a ComplexPattern. +/// +static void SinkPatternPredicates(OwningPtr &MatcherPtr) { + // Recursively scan for a PatternPredicate. + // If we reached the end of the chain, we're done. + Matcher *N = MatcherPtr.get(); + if (N == 0) return; + + // Walk down all members of a scope node. + if (ScopeMatcher *Scope = dyn_cast(N)) { + for (unsigned i = 0, e = Scope->getNumChildren(); i != e; ++i) { + OwningPtr Child(Scope->takeChild(i)); + SinkPatternPredicates(Child); + Scope->resetChild(i, Child.take()); + } + return; + } + + // If this node isn't a CheckPatternPredicateMatcher we keep scanning until + // we find one. + CheckPatternPredicateMatcher *CPPM =dyn_cast(N); + if (CPPM == 0) + return SinkPatternPredicates(N->getNextPtr()); + + // Ok, we found one, lets try to sink it. Check if we can sink it past the + // next node in the chain. If not, we won't be able to change anything and + // might as well bail. + if (!CPPM->getNext()->isSafeToReorderWithPatternPredicate()) + return; + + // Okay, we know we can sink it past at least one node. Unlink it from the + // chain and scan for the new insertion point. + MatcherPtr.take(); // Don't delete CPPM. + MatcherPtr.reset(CPPM->takeNext()); + + N = MatcherPtr.get(); + while (N->getNext()->isSafeToReorderWithPatternPredicate()) + N = N->getNext(); + + // At this point, we want to insert CPPM after N. + CPPM->setNext(N->takeNext()); + N->setNext(CPPM); +} + +/// FactorNodes - Turn matches like this: +/// Scope +/// OPC_CheckType i32 +/// ABC +/// OPC_CheckType i32 +/// XYZ +/// into: +/// OPC_CheckType i32 +/// Scope +/// ABC +/// XYZ +/// static void FactorNodes(OwningPtr &MatcherPtr) { // If we reached the end of the chain, we're done. Matcher *N = MatcherPtr.get(); @@ -145,6 +212,7 @@ static void FactorNodes(OwningPtr &MatcherPtr) { Matcher *llvm::OptimizeMatcher(Matcher *TheMatcher) { OwningPtr MatcherPtr(TheMatcher); ContractNodes(MatcherPtr); + SinkPatternPredicates(MatcherPtr); FactorNodes(MatcherPtr); return MatcherPtr.take(); }