From 42a82dd83bee68f184f85b39e1ae1a6860a50aea Mon Sep 17 00:00:00 2001 From: Bruce Forstall Date: Fri, 6 Jan 2023 11:55:04 -0800 Subject: [PATCH] Convert JitHashTable iteration to range-based `for` (#80265) --- src/coreclr/jit/assertionprop.cpp | 6 +- src/coreclr/jit/compiler.cpp | 7 +- src/coreclr/jit/copyprop.cpp | 20 +- src/coreclr/jit/fgehopt.cpp | 10 +- src/coreclr/jit/jithashtable.h | 445 +++++++++++++++++------------- src/coreclr/jit/optimizer.cpp | 10 +- src/coreclr/jit/ssabuilder.cpp | 3 +- src/coreclr/jit/valuenum.cpp | 16 +- 8 files changed, 276 insertions(+), 241 deletions(-) diff --git a/src/coreclr/jit/assertionprop.cpp b/src/coreclr/jit/assertionprop.cpp index 1e45bed9a215c..cc7fc272a8d65 100644 --- a/src/coreclr/jit/assertionprop.cpp +++ b/src/coreclr/jit/assertionprop.cpp @@ -1926,11 +1926,9 @@ void Compiler::optPrintVnAssertionMapping() { printf("\nVN Assertion Mapping\n"); printf("---------------------\n"); - for (ValueNumToAssertsMap::KeyIterator ki = optValueNumToAsserts->Begin(); !ki.Equal(optValueNumToAsserts->End()); - ++ki) + for (ValueNumToAssertsMap::Node* const iter : ValueNumToAssertsMap::KeyValueIteration(optValueNumToAsserts)) { - printf("(%d => ", ki.Get()); - printf("%s)\n", BitVecOps::ToString(apTraits, ki.GetValue())); + printf("(%d => %s)\n", iter->GetKey(), BitVecOps::ToString(apTraits, iter->GetValue())); } } #endif diff --git a/src/coreclr/jit/compiler.cpp b/src/coreclr/jit/compiler.cpp index bed028e7c2d00..3a124f87faa9e 100644 --- a/src/coreclr/jit/compiler.cpp +++ b/src/coreclr/jit/compiler.cpp @@ -1226,11 +1226,10 @@ void DisplayNowayAssertMap() NowayAssertCountMap* nacp = new NowayAssertCountMap[count]; unsigned i = 0; - for (FileLineToCountMap::KeyIterator iter = NowayAssertMap->Begin(), end = NowayAssertMap->End(); - !iter.Equal(end); ++iter) + for (FileLineToCountMap::Node* const iter : FileLineToCountMap::KeyValueIteration(NowayAssertMap)) { - nacp[i].count = iter.GetValue(); - nacp[i].fl = iter.Get(); + nacp[i].count = iter->GetValue(); + nacp[i].fl = iter->GetKey(); ++i; } diff --git a/src/coreclr/jit/copyprop.cpp b/src/coreclr/jit/copyprop.cpp index 38314f586e4a4..1714c8ab0ba24 100644 --- a/src/coreclr/jit/copyprop.cpp +++ b/src/coreclr/jit/copyprop.cpp @@ -75,11 +75,11 @@ void Compiler::optBlockCopyPropPopStacks(BasicBlock* block, LclNumToLiveDefsMap* void Compiler::optDumpCopyPropStack(LclNumToLiveDefsMap* curSsaName) { JITDUMP("{ "); - for (LclNumToLiveDefsMap::KeyIterator iter = curSsaName->Begin(); !iter.Equal(curSsaName->End()); ++iter) + for (LclNumToLiveDefsMap::Node* const iter : LclNumToLiveDefsMap::KeyValueIteration(curSsaName)) { - unsigned defLclNum = iter.Get(); - GenTreeLclVarCommon* lclDefNode = iter.GetValue()->Top().GetDefNode()->AsLclVarCommon(); - LclSsaVarDsc* ssaDef = iter.GetValue()->Top().GetSsaDef(); + unsigned defLclNum = iter->GetKey(); + GenTreeLclVarCommon* lclDefNode = iter->GetValue()->Top().GetDefNode()->AsLclVarCommon(); + LclSsaVarDsc* ssaDef = iter->GetValue()->Top().GetSsaDef(); if (ssaDef != nullptr) { @@ -165,9 +165,9 @@ bool Compiler::optCopyProp( ValueNum lclDefVN = varDsc->GetPerSsaData(tree->GetSsaNum())->m_vnPair.GetConservative(); assert(lclDefVN != ValueNumStore::NoVN); - for (LclNumToLiveDefsMap::KeyIterator iter = curSsaName->Begin(); !iter.Equal(curSsaName->End()); ++iter) + for (LclNumToLiveDefsMap::Node* const iter : LclNumToLiveDefsMap::KeyValueIteration(curSsaName)) { - unsigned newLclNum = iter.Get(); + unsigned newLclNum = iter->GetKey(); // Nothing to do if same. if (lclNum == newLclNum) @@ -175,7 +175,7 @@ bool Compiler::optCopyProp( continue; } - CopyPropSsaDef newLclDef = iter.GetValue()->Top(); + CopyPropSsaDef newLclDef = iter->GetValue()->Top(); LclSsaVarDsc* const newLclSsaDef = newLclDef.GetSsaDef(); // Likewise, nothing to do if the most recent def is not available. @@ -494,12 +494,12 @@ PhaseStatus Compiler::optVnCopyProp() #ifdef DEBUG // Verify the definitions remaining are only those we pushed for parameters. - for (LclNumToLiveDefsMap::KeyIterator iter = m_curSsaName.Begin(); !iter.Equal(m_curSsaName.End()); ++iter) + for (LclNumToLiveDefsMap::Node* const iter : LclNumToLiveDefsMap::KeyValueIteration(&m_curSsaName)) { - unsigned lclNum = iter.Get(); + unsigned lclNum = iter->GetKey(); assert(m_compiler->lvaGetDesc(lclNum)->lvIsParam || (lclNum == m_compiler->info.compThisArg)); - CopyPropSsaDefStack* defStack = iter.GetValue(); + CopyPropSsaDefStack* defStack = iter->GetValue(); assert(defStack->Height() == 1); } #endif // DEBUG diff --git a/src/coreclr/jit/fgehopt.cpp b/src/coreclr/jit/fgehopt.cpp index 3b7b5f67525ee..c0bcdda777fb1 100644 --- a/src/coreclr/jit/fgehopt.cpp +++ b/src/coreclr/jit/fgehopt.cpp @@ -2121,14 +2121,12 @@ PhaseStatus Compiler::fgTailMergeThrows() // Second pass. // // We walk the map rather than the block list, to save a bit of time. - BlockToBlockMap::KeyIterator iter(blockMap.Begin()); - BlockToBlockMap::KeyIterator end(blockMap.End()); - unsigned updateCount = 0; + unsigned updateCount = 0; - for (; !iter.Equal(end); iter++) + for (BlockToBlockMap::Node* const iter : BlockToBlockMap::KeyValueIteration(&blockMap)) { - BasicBlock* const nonCanonicalBlock = iter.Get(); - BasicBlock* const canonicalBlock = iter.GetValue(); + BasicBlock* const nonCanonicalBlock = iter->GetKey(); + BasicBlock* const canonicalBlock = iter->GetValue(); flowList* nextPredEdge = nullptr; bool updated = false; diff --git a/src/coreclr/jit/jithashtable.h b/src/coreclr/jit/jithashtable.h index 238928ca0dc98..9ad73dbf2f7d5 100644 --- a/src/coreclr/jit/jithashtable.h +++ b/src/coreclr/jit/jithashtable.h @@ -94,7 +94,21 @@ class JitPrimeInfo extern const JitPrimeInfo jitPrimeInfo[27]; // Hash table class definition - +// +// Several iterators are defined that work with range-based `for`: +// KeyIteration: yields just the hash table keys +// ValueIteration: yields just the hash table values +// KeyValueIteration: yields just the hash table pairs +// +// For example: +// +// for (const unsigned int lclNum : LclVarRefCounts::KeyIteration(&defsInBlock)) +// +// for (ValueNumToAssertsMap::Node* const iter : ValueNumToAssertsMap::KeyValueIteration(optValueNumToAsserts)) +// { +// // use iter->GetKey(), iter->GetValue() +// } +// template + Node(Node* next, Key k, Args&&... args) : m_next(next), m_key(k), m_val(std::forward(args)...) + { + } + + void* operator new(size_t sz, Allocator alloc) + { + return alloc.template allocate(sz); + } + + void operator delete(void* p, Allocator alloc) + { + alloc.deallocate(p); + } + + public: + Key GetKey() const + { + return m_key; + } + + Value GetValue() const + { + return m_val; + } + }; //------------------------------------------------------------------------ // JitHashTable: Construct an empty JitHashTable object. @@ -344,22 +395,205 @@ class JitHashTable m_tableSizeInfo = JitPrimeInfo(); m_tableCount = 0; m_tableMax = 0; - - return; } - // Get an iterator to the first key in the table. - KeyIterator Begin() const + // + // Iteration support + // + + class NodeIterator { - KeyIterator i(this, true); - return i; - } + protected: + Node** m_table; + Node* m_node; + unsigned m_tableSize; + unsigned m_index; + + //------------------------------------------------------------------------ + // NodeIterator: Construct an iterator for the specified JitHashTable. + // + // Arguments: + // hash - the hashtable + // begin - `true` to construct an "begin" iterator, + // `false` to construct an "end" iterator + // + NodeIterator(const JitHashTable* hash, bool begin) + : m_table(hash->m_table) + , m_node(nullptr) + , m_tableSize(hash->m_tableSizeInfo.prime) + , m_index(begin ? 0 : m_tableSize) + { + if (begin && (hash->m_tableCount > 0)) + { + assert(m_table != nullptr); + while ((m_index < m_tableSize) && (m_table[m_index] == nullptr)) + { + m_index++; + } + + if (m_index < m_tableSize) + { + m_node = m_table[m_index]; + assert(m_node != nullptr); + } + } + } - // Get an iterator following the last key in the table. - KeyIterator End() const + //------------------------------------------------------------------------ + // Next: Advance the iterator to the next node. + // + // Notes: + // Advancing the end iterator has no effect. + // + void Next() + { + if (m_node != nullptr) + { + m_node = m_node->m_next; + if (m_node != nullptr) + { + return; + } + + // Otherwise... + m_index++; + } + while ((m_index < m_tableSize) && (m_table[m_index] == nullptr)) + { + m_index++; + } + + if (m_index < m_tableSize) + { + m_node = m_table[m_index]; + assert(m_node != nullptr); + } + else + { + m_node = nullptr; + } + } + + public: + // Advance the iterator to the next node + NodeIterator& operator++() + { + Next(); + return *this; + } + + bool operator!=(const NodeIterator& i) const + { + return i.m_node != m_node; + } + }; + + // KeyIterator: an iterator which yields only the hash table keys. + class KeyIterator : public NodeIterator { - return KeyIterator(this, false); - } + public: + KeyIterator(const JitHashTable* hash, bool begin) : NodeIterator(hash, begin) + { + } + + Key operator*() const + { + return this->m_node->GetKey(); + } + }; + + // ValueIterator: an iterator which yields only the hash table values. + class ValueIterator : public NodeIterator + { + public: + ValueIterator(const JitHashTable* hash, bool begin) : NodeIterator(hash, begin) + { + } + + Value operator*() const + { + return this->m_node->GetValue(); + } + }; + + // KeyValueIterator: an iterator which yields the hash table pairs. It exposes a bit of the + // hash table implementation by returning a `Node*` that contains the data. + class KeyValueIterator : public NodeIterator + { + public: + KeyValueIterator(const JitHashTable* hash, bool begin) : NodeIterator(hash, begin) + { + } + + // We could return a new struct, but why bother copying data? + Node* operator*() const + { + return this->m_node; + } + }; + + // KeyIteration: an adaptor to use for range-based `for` iteration over the hash table keys. + class KeyIteration + { + const JitHashTable* const m_hash; + + public: + KeyIteration(const JitHashTable* hash) : m_hash(hash) + { + } + + KeyIterator begin() const + { + return KeyIterator(m_hash, true); + } + + KeyIterator end() const + { + return KeyIterator(m_hash, false); + } + }; + + // ValueIteration: an adaptor to use for range-based `for` iteration over the hash table values. + class ValueIteration + { + const JitHashTable* const m_hash; + + public: + ValueIteration(const JitHashTable* hash) : m_hash(hash) + { + } + + ValueIterator begin() const + { + return ValueIterator(m_hash, true); + } + + ValueIterator end() const + { + return ValueIterator(m_hash, false); + } + }; + + // KeyValueIteration: an adaptor to use for range-based `for` iteration over the hash table pairs. + class KeyValueIteration + { + const JitHashTable* const m_hash; + + public: + KeyValueIteration(const JitHashTable* hash) : m_hash(hash) + { + } + + KeyValueIterator begin() const + { + return KeyValueIterator(m_hash, true); + } + + KeyValueIterator end() const + { + return KeyValueIterator(m_hash, false); + } + }; // Get the number of keys currently stored in the table. unsigned GetCount() const @@ -374,8 +608,6 @@ class JitHashTable } private: - struct Node; - //------------------------------------------------------------------------ // GetIndexForKey: Get the bucket index for the specified key. // @@ -525,165 +757,6 @@ class JitHashTable (unsigned)(newTableSize * Behavior::s_density_factor_numerator / Behavior::s_density_factor_denominator); } - // For iteration, we use a pattern similar to the STL "forward - // iterator" pattern. It basically consists of wrapping an - // "iteration variable" in an object, and providing pointer-like - // operators on the iterator. Example usage: - // - // for (JitHashTable::KeyIterator iter = foo->Begin(), end = foo->End(); !iter.Equal(end); iter++) - // { - // // use foo, iter. - // } - // iter.Get() will yield (a reference to) the - // current key. It will assert the equivalent of "iter != end." - class KeyIterator - { - private: - friend class JitHashTable; - - Node** m_table; - Node* m_node; - unsigned m_tableSize; - unsigned m_index; - - public: - //------------------------------------------------------------------------ - // KeyIterator: Construct an iterator for the specified JitHashTable. - // - // Arguments: - // hash - the hashtable - // begin - `true` to construct an "begin" iterator, - // `false` to construct an "end" iterator - // - KeyIterator(const JitHashTable* hash, bool begin) - : m_table(hash->m_table) - , m_node(nullptr) - , m_tableSize(hash->m_tableSizeInfo.prime) - , m_index(begin ? 0 : m_tableSize) - { - if (begin && (hash->m_tableCount > 0)) - { - assert(m_table != nullptr); - while ((m_index < m_tableSize) && (m_table[m_index] == nullptr)) - { - m_index++; - } - - if (m_index >= m_tableSize) - { - return; - } - else - { - m_node = m_table[m_index]; - } - assert(m_node != nullptr); - } - } - - //------------------------------------------------------------------------ - // Get: Get a reference to this iterator's key. - // - // Return Value: - // A reference to this iterator's key. - // - // Assumptions: - // This must not be the "end" iterator. - // - const Key& Get() const - { - assert(m_node != nullptr); - - return m_node->m_key; - } - - //------------------------------------------------------------------------ - // GetValue: Get a reference to this iterator's value. - // - // Return Value: - // A reference to this iterator's value. - // - // Assumptions: - // This must not be the "end" iterator. - // - Value& GetValue() const - { - assert(m_node != nullptr); - - return m_node->m_val; - } - - //------------------------------------------------------------------------ - // SetValue: Assign a new value to this iterator's key - // - // Arguments: - // value - the value to assign - // - // Assumptions: - // This must not be the "end" iterator. - // - void SetValue(const Value& value) const - { - assert(m_node != nullptr); - - m_node->m_val = value; - } - - //------------------------------------------------------------------------ - // Next: Advance the iterator to the next node. - // - // Notes: - // Advancing the end iterator has no effect. - // - void Next() - { - if (m_node != nullptr) - { - m_node = m_node->m_next; - if (m_node != nullptr) - { - return; - } - - // Otherwise... - m_index++; - } - while ((m_index < m_tableSize) && (m_table[m_index] == nullptr)) - { - m_index++; - } - - if (m_index >= m_tableSize) - { - m_node = nullptr; - return; - } - else - { - m_node = m_table[m_index]; - } - assert(m_node != nullptr); - } - - // Return `true` if the specified iterator has the same position as this iterator - bool Equal(const KeyIterator& i) const - { - return i.m_node == m_node; - } - - // Advance the iterator to the next node - void operator++() - { - Next(); - } - - // Advance the iterator to the next node - void operator++(int) - { - Next(); - } - }; - //------------------------------------------------------------------------ // operator[]: Get a reference to the value associated with the specified key. // @@ -727,30 +800,6 @@ class JitHashTable Behavior::NoMemory(); } - // The node type. - struct Node - { - Node* m_next; // Assume that the alignment requirement of Key and Value are no greater than Node*, - // so put m_next first to avoid unnecessary padding. - Key m_key; - Value m_val; - - template - Node(Node* next, Key k, Args&&... args) : m_next(next), m_key(k), m_val(std::forward(args)...) - { - } - - void* operator new(size_t sz, Allocator alloc) - { - return alloc.template allocate(sz); - } - - void operator delete(void* p, Allocator alloc) - { - alloc.deallocate(p); - } - }; - // Instance members Allocator m_alloc; // Allocator to use in this table. Node** m_table; // pointer to table diff --git a/src/coreclr/jit/optimizer.cpp b/src/coreclr/jit/optimizer.cpp index 1b0d6cc450296..4bc5be6982aaa 100644 --- a/src/coreclr/jit/optimizer.cpp +++ b/src/coreclr/jit/optimizer.cpp @@ -6603,11 +6603,10 @@ PhaseStatus Compiler::optHoistLoopCode() if (m_nodeTestData == nullptr) { NodeToTestDataMap* testData = GetNodeTestData(); - for (NodeToTestDataMap::KeyIterator ki = testData->Begin(); !ki.Equal(testData->End()); ++ki) + for (GenTree* const node : NodeToTestDataMap::KeyIteration(testData)) { TestLabelAndNum tlAndN; - GenTree* node = ki.Get(); - bool b = testData->Lookup(node, &tlAndN); + bool b = testData->Lookup(node, &tlAndN); assert(b); if (tlAndN.m_tl != TL_LoopHoist) { @@ -10309,11 +10308,8 @@ void Compiler::optRemoveRedundantZeroInits() if (removedTrackedDefs) { - LclVarRefCounts::KeyIterator iter(defsInBlock.Begin()); - LclVarRefCounts::KeyIterator end(defsInBlock.End()); - for (; !iter.Equal(end); iter++) + for (const unsigned int lclNum : LclVarRefCounts::KeyIteration(&defsInBlock)) { - unsigned int lclNum = iter.Get(); if (defsInBlock[lclNum] == 0) { VarSetOps::RemoveElemD(this, block->bbVarDef, lvaGetDesc(lclNum)->lvVarIndex); diff --git a/src/coreclr/jit/ssabuilder.cpp b/src/coreclr/jit/ssabuilder.cpp index e936b3e472049..3e8c0559f7f28 100644 --- a/src/coreclr/jit/ssabuilder.cpp +++ b/src/coreclr/jit/ssabuilder.cpp @@ -1723,10 +1723,9 @@ void Compiler::JitTestCheckSSA() { printf("\nJit Testing: SSA names.\n"); } - for (NodeToTestDataMap::KeyIterator ki = testData->Begin(); !ki.Equal(testData->End()); ++ki) + for (GenTree* const node : NodeToTestDataMap::KeyIteration(testData)) { TestLabelAndNum tlAndN; - GenTree* node = ki.Get(); bool nodeExists = testData->Lookup(node, &tlAndN); assert(nodeExists); if (tlAndN.m_tl == TL_SsaName) diff --git a/src/coreclr/jit/valuenum.cpp b/src/coreclr/jit/valuenum.cpp index 1bc05151623d1..8169a73dd814d 100644 --- a/src/coreclr/jit/valuenum.cpp +++ b/src/coreclr/jit/valuenum.cpp @@ -8063,11 +8063,11 @@ ValueNum Compiler::fgMemoryVNForLoopSideEffects(MemoryKind memoryKind, Compiler::LoopDsc::FieldHandleSet* fieldsMod = optLoopTable[loopNum].lpFieldsModified; if (fieldsMod != nullptr) { - for (Compiler::LoopDsc::FieldHandleSet::KeyIterator ki = fieldsMod->Begin(); !ki.Equal(fieldsMod->End()); - ++ki) + for (Compiler::LoopDsc::FieldHandleSet::Node* const ki : + Compiler::LoopDsc::FieldHandleSet::KeyValueIteration(fieldsMod)) { - CORINFO_FIELD_HANDLE fldHnd = ki.Get(); - FieldKindForVN fieldKind = ki.GetValue(); + CORINFO_FIELD_HANDLE fldHnd = ki->GetKey(); + FieldKindForVN fieldKind = ki->GetValue(); ValueNum fldHndVN = vnStore->VNForHandle(ssize_t(fldHnd), GTF_ICON_FIELD_HDL); #ifdef DEBUG @@ -8090,11 +8090,8 @@ ValueNum Compiler::fgMemoryVNForLoopSideEffects(MemoryKind memoryKind, Compiler::LoopDsc::ClassHandleSet* elemTypesMod = optLoopTable[loopNum].lpArrayElemTypesModified; if (elemTypesMod != nullptr) { - for (Compiler::LoopDsc::ClassHandleSet::KeyIterator ki = elemTypesMod->Begin(); - !ki.Equal(elemTypesMod->End()); ++ki) + for (const CORINFO_CLASS_HANDLE elemClsHnd : Compiler::LoopDsc::ClassHandleSet::KeyIteration(elemTypesMod)) { - CORINFO_CLASS_HANDLE elemClsHnd = ki.Get(); - #ifdef DEBUG if (verbose) { @@ -11240,10 +11237,9 @@ void Compiler::JitTestCheckVN() { printf("\nJit Testing: Value numbering.\n"); } - for (NodeToTestDataMap::KeyIterator ki = testData->Begin(); !ki.Equal(testData->End()); ++ki) + for (GenTree* const node : NodeToTestDataMap::KeyIteration(testData)) { TestLabelAndNum tlAndN; - GenTree* node = ki.Get(); ValueNum nodeVN = node->GetVN(VNK_Liberal); bool b = testData->Lookup(node, &tlAndN);