0x0. 前言
更多的深度學(xué)習(xí)編譯器知識可以在 https://github.com/BBuf/tvm_mlir_learn 找到。同時也維護了一個cuda學(xué)習(xí)倉庫 https://github.com/BBuf/how-to-optim-algorithm-in-cuda 以及一個如何學(xué)習(xí)深度學(xué)習(xí)框架(PyTorch和OneFlow)的學(xué)習(xí)倉庫,https://github.com/BBuf/how-to-learn-deep-learning-framework , 有需要的小伙伴可以點一點star 。在https://github.com/BBuf/how-to-optim-algorithm-in-cuda/tree/master/large-language-model-note 這個目錄下收集了一系列和LLM訓(xùn)練,推理相關(guān)的文章。
【省流】上次介紹了深度學(xué)習(xí)編譯器之Layerout Transform優(yōu)化 ,在這篇文章中提到還會介紹常量折疊優(yōu)化Pass的實現(xiàn),但在介紹常量折疊Pass之前我想再介紹一個類似的優(yōu)化方法也就是公共子表達(dá)式消除實現(xiàn)(CSE)。仍然是以O(shè)neFlow中基于MLIR進行實現(xiàn)的CSE Pass為例子來講解。在解析代碼實現(xiàn)的過程中,我發(fā)現(xiàn)基于MLIR來做公共子表達(dá)式消除的時候還順帶做了死代碼消除的功能。另外,在考慮公共子表達(dá)式消除的時候需要保證兩個重復(fù)的操作處于同一個基本塊中以及兩個重復(fù)操作之間沒有其它具有副作用的操作才可以消除。在OneFlow的實現(xiàn)中只是對OneFlow的UserOp的特殊屬性即OpName和SymbolID進行了擦除,用一個魔法屬性來代替,這是因為這兩個屬性不應(yīng)該去影響公共子表達(dá)式的消除。這個優(yōu)化還是比較有用的,在OneFlow的Stable Diffusion優(yōu)化中發(fā)揮了不小的作用。
0x1. 效果
公共子表達(dá)式消除的作用很簡單,就是把公共的表達(dá)式折疊為1個表達(dá)式來避免重復(fù)的計算開銷。我們以O(shè)neFlow針對CSE Pass寫的2個測試為例子來進行說明。這兩個例子在 https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/ir/test/OneFlow/cse.mlir ,這里提供了一個 MLIR Module,包含兩個函數(shù):@Cast_1__FUSE__ScalarMulByTensor_2 和 @f2。
其中,第一個函數(shù) @Cast_1__FUSE__ScalarMulByTensor_2 接受一個形狀為 96x96xi64 的張量作為輸入,并執(zhí)行兩個類型轉(zhuǎn)換操作,將輸入轉(zhuǎn)換為 96x96xf32 張量。然后,它使用 oneflow.add_n 操作將兩個結(jié)果張量相加,并返回結(jié)果 96x96xf32 張量。FileCheck 命令驗證了具有 "ScalarMulByTensor_2" op_name 屬性的 "oneflow.cast" 和 "oneflow.add_n2" 操作的存在。這里再解釋一下 CHECK 指定,比如CHECK: %[[OUT:[a-zA-Z0-9_]+]] = "oneflow.cast" 是一個 FileCheck 指令,用于驗證生成的代碼是否符合預(yù)期。FileCheck 是 LLVM 項目的一部分,用于為編譯器測試提供模式匹配功能。%[[OUT:[a-zA-Z0-9_]+]] 是一個正則表達(dá)式捕獲組,用于捕獲一個以 % 開頭、后跟一系列字母、數(shù)字或下劃線的字符串。這個字符串對應(yīng)于 MLIR 中的一個值名稱。"oneflow.cast" 表示我們希望找到一個名為 "oneflow.cast" 的操作。
第二個函數(shù) @f2 接受三個輸入張量:一個形狀為 2x64x64x320xf16 的張量,一個形狀為 320x320x3x3xf16 的張量,和一個形狀為 320xf16 的張量。它將第二個輸入張量轉(zhuǎn)置兩次,并使用轉(zhuǎn)置后的張量、第一個輸入張量和第三個輸入張量執(zhí)行兩個 conv2d 操作。該函數(shù)返回兩個形狀為 2x64x64x320xf16 的結(jié)果張量。FileCheck 命令驗證了具有等于 163 的 scope_symbol_id 屬性的 "oneflow.conv2d" 操作的存在,并檢查輸出的兩個結(jié)果張量。
這兩個函數(shù)有一個共同點,那就是它們都存在一個完全相同的公共Op,我們可以編譯oneflow之后使用下面的命令將CSE Pass添加到opt pass pipline里面來運行這個mlir表達(dá)式做變換,我們可以關(guān)注變換后的表達(dá)式。命令如下:
oneflow/build/oneflow/ir/bin/oneflow-optoneflow/oneflow/ir/test/OneFlow/cse.mlir-cse-with-attributes-ignored-cse-cse-put-attributes-canonicalize
解釋一下這里的幾個選項:
cse-with-attributes-ignored: 此參數(shù)告訴優(yōu)化器在執(zhí)行公共子表達(dá)式消除(CSE)時忽略O(shè)neFlow IR特有的會影響CSE的屬性(這里是OpName和SymbolID)。
cse: 這個參數(shù)開啟公共子表達(dá)式消除(CSE)優(yōu)化。CSE 是一種編譯器優(yōu)化技術(shù),用于刪除冗余的子表達(dá)式,從而減少計算量和提高程序運行速度。
cse-put-attributes: 此參數(shù)指示優(yōu)化器在執(zhí)行 CSE 之后,將原始屬性放回原始操作。這有助于確保在優(yōu)化過程中保留操作的屬性信息。(也暗示我們必須把原始的屬性保存下來)
canonicalize: 這個參數(shù)開啟規(guī)范化優(yōu)化。規(guī)范化優(yōu)化會將程序中的操作和表達(dá)式轉(zhuǎn)換為一種統(tǒng)一的標(biāo)準(zhǔn)形式,從而簡化后續(xù)優(yōu)化的實現(xiàn)和提高效率。(這兩個給定的例子里,不開啟canonicalize也不會影響輸出IR的表達(dá))
接下來是運行上述命令后輸出的MLIR Module。
module{
func.func@Cast_1__FUSE__ScalarMulByTensor_2(%arg0:tensor<96x96xi64>)->tensor<96x96xf32>{
%0="oneflow.cast"(%arg0){device_name=["0:0"],device_tag="cpu",dtype=2:i32,hierarchy=[1],op_name="Cast_1",op_type_name="cast",pin_memory=false,scope_symbol_id=4611686018427416574:i64}:(tensor<96x96xi64>)->tensor<96x96xf32>
%1="oneflow.add_n2"(%0,%0){device_name=["0:0"],device_tag="cpu",hierarchy=[1],op_name="ScalarMulByTensor_2",op_type_name="add_n",scope_symbol_id=4611686018427416574:i64}:(tensor<96x96xf32>,tensor<96x96xf32>)->tensor<96x96xf32>
return%1:tensor<96x96xf32>
}
func.func@f2(%arg0:tensor<2x64x64x320xf16>,%arg1:tensor<320x320x3x3xf16>,%arg2:tensor<320xf16>)->(tensor<2x64x64x320xf16>,tensor<2x64x64x320xf16>){
%0="oneflow.transpose"(%arg1){device_name=["@0:0"],device_tag="cuda",hierarchy=[1],op_name="unet.down_blocks.0.resnets.0.conv1-conv2d-31_transpose_input_1",perm=[0:si32,2:si32,3:si32,1:si32],scope_symbol_id=163:i64}:(tensor<320x320x3x3xf16>)->tensor<320x3x3x320xf16>
%1="oneflow.conv2d"(%arg0,%0,%arg2){data_format="channels_last",device_name=["@0:0"],device_tag="cuda",dilation_rate=[1:si32,1:si32],filters=320:si32,groups=1:si32,hierarchy=[1],kernel_size=[3:si32,3:si32],op_name="unet.down_blocks.0.resnets.0.conv1-conv2d-31",operand_segment_sizes=array,padding_before=[1:si32,1:si32],scope_symbol_id=163:i64,strides=[1:si32,1:si32],tuning_cache=""}:(tensor<2x64x64x320xf16>,tensor<320x3x3x320xf16>,tensor<320xf16>)->tensor<2x64x64x320xf16>
return%1,%1:tensor<2x64x64x320xf16>,tensor<2x64x64x320xf16>
}
}
和原始的MLIR ModuleOp對比,我們發(fā)現(xiàn)這兩個函數(shù)里面的公共子表達(dá)式(cast和transpose)都只保留了一個,實現(xiàn)了公共子表達(dá)式消除的目的。在OneFlow編譯器中,這個優(yōu)化率先在OneFlow的Stable Diffusion引人,加速了模型的推理速度。
0x2. 原理&代碼實現(xiàn)
基于 OneFlow 實現(xiàn) CSE 的原理是,我們需要先消除 OneFlow 的 UserOp 的 OpName 和 SymbolID 這兩個屬性,這兩個屬性對 CSE 來說是沒影響的,但是是由 OneFlow 系統(tǒng)添加的,所以我們需要做個預(yù)處理忽略掉這兩個不一致。然后調(diào)用MLIR系統(tǒng)的 CSE Pass 之后我們需要把這個忽略的屬性加回來。這樣才可以保證優(yōu)化后的IR可以轉(zhuǎn)回OneFlow的圖并正確執(zhí)行。
首先基于ODS在https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/ir/include/OneFlow/OneFlowPasses.td#L156-L172 定義了兩個CSE相關(guān)的Pass類,MLIR會自動生成這兩個Pass的定義。我們詳細(xì)看一下細(xì)節(jié):
defCSEWithAttributesIgnored:Pass<"cse-with-attributes-ignored",?"ModuleOp">{//定義了一個名為"cse-with-attributes-ignored"的Pass,它作用在MLIR中的模塊操作(ModuleOp)上。
letsummary="ignoreoneflowattributestohavecsework";//summary和description:提供了有關(guān)Pass功能的簡短描述和詳細(xì)說明。這個Pass的目的是執(zhí)行CSE優(yōu)化,同時忽略O(shè)neFlow屬性(如操作名、符號ID等)。
letdescription=[{
cseandignoreoneflowattributeslikeopname,symbolid,etc.
}];
letconstructor="mlir::createCSEWithAttributesIgnored()";//指定用于創(chuàng)建這個Pass的函數(shù),即mlir::createCSEWithAttributesIgnored()。
letdependentDialects=[];//列出這個Pass依賴的其他方言。在這種情況下,它是空的,表示沒有依賴關(guān)系。
}
defCSEPutAttributes:Pass<"cse-put-attributes",?"ModuleOp">{
letsummary="cseandignoreoneflowattributes";
letdescription=[{
putbackoneflowattributeslikeopname,symbolid,etc.
}];
letconstructor="mlir::createCSEPutAttributes()";
letdependentDialects=[];
}
可以看到 CSE 的預(yù)處理和后處理 Pass 主要就是實現(xiàn) createCSEWithAttributesIgnored 和 createCSEPutAttributes 這兩個函數(shù)。它們的定義在:https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/ir/include/OneFlow/Transform/CSEWithAttributesIgnored.h#L25-L33
//CSEState結(jié)構(gòu)體包含兩個成員:
//scopeSymbolIDs:一個llvm::DenseMap,將Operation*類型的指針映射到IntegerAttr類型的屬性。這個映射可能用于存儲操作的范圍符號ID。
//opNames:一個llvm::DenseMap,將Operation*類型的指針映射到StringAttr類型的屬性。這個映射可能用于存儲操作的名稱。
structCSEState{
llvm::DenseMapscopeSymbolIDs;
llvm::DenseMapopNames;
};
//這個函數(shù)返回一個std::unique_ptr類型的對象。根據(jù)函數(shù)名稱,這個函數(shù)創(chuàng)建一個CSEPass,其中忽略了屬性。
std::unique_ptrcreateCSEWithAttributesIgnored();
//這個函數(shù)也返回一個std::unique_ptr類型的對象。根據(jù)函數(shù)名稱,這個函數(shù)創(chuàng)建一個CSEPass,會處理或放置屬性。
std::unique_ptrcreateCSEPutAttributes();
//這個函數(shù)接受一個std::shared_ptr類型的參數(shù),并返回一個std::pair,其中包含兩個std::unique_ptr類型的對象。這個函數(shù)創(chuàng)建一對CSEPass,它們共享給定的CSEState。
std::pair,std::unique_ptr>createCSEPasses(
std::shared_ptrstate);
//這個函數(shù)接受一個std::shared_ptr類型的參數(shù)。根據(jù)函數(shù)名稱,這個函數(shù)可能會注冊一組CSEPass,它們共享給定的CSEState。
voidregisterCSEPasses(std::shared_ptrstate);
接下來看下這幾個 Pass 的具體實現(xiàn)。代碼在 https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/ir/lib/OneFlow/Transform/CSEWithAttributesIgnored.cpp
首先來看createCSEWithAttributesIgnored:
structEraseAttributes:publicmlir::OpInterfaceRewritePattern{ explicitEraseAttributes(mlir::MLIRContext*context,std::shared_ptr state) :OpInterfaceRewritePattern (context,/*benefit=*/1),state_{state}{} mlir::LogicalResultmatchAndRewrite(UserOpCompatibleop, mlir::PatternRewriter&rewriter)constoverride{ if(op->getAttrOfType (OpTrait::IsOpConfCompatible ::getOpNameAttr()) .getValue() .str() !=MAGIC_OP_NAME){ if(state_){ state_->opNames[op]= op->getAttrOfType (OpTrait::IsOpConfCompatible ::getOpNameAttr()); state_->scopeSymbolIDs[op]=op->getAttrOfType ( OpTrait::IsOpConfCompatible ::getScopeSymbolIDAttr()); } op->setAttr(OpTrait::IsOpConfCompatible ::getOpNameAttr(), rewriter.getStringAttr(MAGIC_OP_NAME)); op->setAttr(OpTrait::IsOpConfCompatible ::getScopeSymbolIDAttr(), rewriter.getI64IntegerAttr(MAGIC_SCOPE_SYMBOL_ID)); returnsuccess(); }else{ returnfailure(); } } private: std::shared_ptr state_; }; classCSEWithAttributesIgnored:publicCSEWithAttributesIgnoredBase { public: explicitCSEWithAttributesIgnored(){} explicitCSEWithAttributesIgnored(std::shared_ptr state):state_(state){} voidrunOnOperation()override{ Operation*op=getOperation(); RewritePatternSetpatterns(op->getContext()); patterns.add (op->getContext(),state_); (void)applyPatternsAndFoldGreedily(op,std::move(patterns)); } private: std::shared_ptr state_; }; std::unique_ptr createCSEWithAttributesIgnored(){ returnstd::make_unique (); }
這段代碼定義了一個 EraseAttributes 重寫類, 它會移除 op 中的某些屬性。它繼承自 OpInterfaceRewritePattern, 意味著它可以匹配實現(xiàn)了 UserOpCompatible 這個 OpInterface 的 op。然后 EraseAttributes 構(gòu)造函數(shù)接受一個 MLIRContext* 和一個shared_ptr。CSEState 用于跟蹤已重寫的 op 的屬性。matchAndRewrite 方法檢查 op 是否有名為 OpNameAttr 的 StringAttr 屬性, 如果有, 并且其值不等于 MAGIC_OP_NAME, 則該方法會:
將 op 的 OpNameAttr 和 ScopeSymbolIDAttr 屬性記錄在 CSEState 中。
將 OpNameAttr 設(shè)置為 MAGIC_OP_NAME, 將 ScopeSymbolIDAttr 設(shè)置為 MAGIC_SCOPE_SYMBOL_ID。
然后,CSEWithAttributesIgnored 繼承自 CSEWithAttributesIgnoredBase, 重寫了其 runOnOperation 方法。該方法會實例化一個 RewritePatternSet, 添加 EraseAttributes 這個匹配重寫模板, 然后應(yīng)用該模板, 從而移除user op 中的屬性。它還保存一個指向CSEState 的 shared_ptr , 可以在 EraseAttributes 中使用。注意這里的 CSEWithAttributesIgnoredBase 是通過ODS自動生成的 Pass 類定義。createCSEWithAttributesIgnored 函數(shù)會創(chuàng)建一個 CSEWithAttributesIgnored pass 并返回。
接著看一下 createCSEPutAttributes 的實現(xiàn),
structPutAttributes:publicmlir::OpInterfaceRewritePattern{ explicitPutAttributes(mlir::MLIRContext*context,std::shared_ptr state) :OpInterfaceRewritePattern (context,/*benefit=*/1),state_{state}{} mlir::LogicalResultmatchAndRewrite(UserOpCompatibleop, mlir::PatternRewriter&rewriter)constoverride{ if(op->getAttrOfType (OpTrait::IsOpConfCompatible ::getOpNameAttr()) .getValue() .str() ==MAGIC_OP_NAME){ if(state_){ op->setAttr(OpTrait::IsOpConfCompatible ::getOpNameAttr(),state_->opNames[op]); op->setAttr(OpTrait::IsOpConfCompatible ::getScopeSymbolIDAttr(), state_->scopeSymbolIDs[op]); } returnsuccess(); }else{ returnfailure(); } } private: std::shared_ptr state_; }; classCSEPutAttributes:publicCSEPutAttributesBase { public: explicitCSEPutAttributes(){} explicitCSEPutAttributes(std::shared_ptr state){state_=state;} voidrunOnOperation()override{ Operation*op=getOperation(); RewritePatternSetpatterns(op->getContext()); patterns.add (op->getContext(),state_); (void)applyPatternsAndFoldGreedily(op,std::move(patterns)); } private: std::shared_ptr state_; }; std::unique_ptr createCSEPutAttributes(){returnstd::make_unique ();}
這個 PutAttributes 重寫模板與 EraseAttributes 相反, 它會將先前刪除的屬性恢復(fù)回 op。PutAttributes 構(gòu)造函數(shù)也接受一個 MLIRContext* 和一個 shared_ptr。它使用 CSEState 來查找先前刪除的屬性值。matchAndRewrite 方法檢查 op 是否有一個名為 OpNameAttr 的 StringAttr 屬性,其值等 于 MAGIC_OP_NAME 。如果是,它會從 CSEState 中查找原先的 OpNameAttr 和 ScopeSymbolIDAttr 屬性值。將 OpNameAttr 設(shè)置為原先的值,將 ScopeSymbolIDAttr 設(shè)置為原先的值。
上面的2個Pass都是OneFlow中的預(yù)處理和后處理,而真的CSE Pass則是MLIR自帶的CSE Pass(oneflow/build/oneflow/ir/llvm_monorepo-src/mlir/lib/Transforms/CSE.cpp), 我們來解析一下。
structSimpleOperationInfo:publicllvm::DenseMapInfo{ staticunsignedgetHashValue(constOperation*opC){ returnOperationEquivalence::computeHash( const_cast (opC), /*hashOperands=*/OperationEquivalence::directHashValue, /*hashResults=*/OperationEquivalence::ignoreHashValue, OperationEquivalence::IgnoreLocations); } staticboolisEqual(constOperation*lhsC,constOperation*rhsC){ auto*lhs=const_cast (lhsC); auto*rhs=const_cast (rhsC); if(lhs==rhs) returntrue; if(lhs==getTombstoneKey()||lhs==getEmptyKey()|| rhs==getTombstoneKey()||rhs==getEmptyKey()) returnfalse; returnOperationEquivalence::isEquivalentTo( const_cast (lhsC),const_cast (rhsC), OperationEquivalence::IgnoreLocations); } };
SimpleOperationInfo 這個結(jié)構(gòu)體繼承自 llvm::DenseMapInfo
getHashValue: 為 Operation* 計算哈希值。它使用 OperationEquivalence::computeHash 來計算哈希值,并傳遞 hashOperands=directHashValue 和 hashResults=ignoreHashValue。這意味著它會直接對 op 的操作數(shù)計算哈希值,但會忽略結(jié)果。
isEqual: 檢查兩個 Operation* 是否相等。它首先檢查是否是相同的 op , 如果是,則返回 true。否則,它使用OperationEquivalence::isEquivalentTo 檢查兩個 op 是否等價。同樣,它傳遞了 IgnoreLocations, 意味著它會忽略 op 的位置信息。
所以, 這個 DenseMapInfo 允許以忽略結(jié)果和位置的方式將 Operation* 用作 DenseMap 的鍵。操作數(shù)用于等價性檢查和哈希值計算。
///Simplecommonsub-expressionelimination. //這是一個名為CSE(CommonSub-expressionElimination,公共子表達(dá)式消除)的結(jié)構(gòu)體定義,用于執(zhí)行簡單的公共子表達(dá)式消除。CSE是一種編譯器優(yōu)化技術(shù),用于消除程序中的重復(fù)計算,提高執(zhí)行效率。 structCSE:publicimpl::CSEBase{ ///Sharedimplementationofoperationeliminationandscopedmapdefinitions. //使用AllocatorTy和ScopedMapTy來定義分配器和作用域映射。ScopedMapTy是一個散列表,用于存儲操作之間的映射關(guān)系。 usingAllocatorTy=llvm::RecyclingAllocator< ??????llvm::BumpPtrAllocator, ??????llvm::ScopedHashTableVal >; usingScopedMapTy=llvm::ScopedHashTable ; ///CacheholdingMemoryEffectsinformationbetweentwooperations.Thefirst ///operationisstoredhasthekey.Thesecondoperationisstoredinsidea ///pairinthevalue.ThepairalsoholdtheMemoryEffectsbetweenthose ///twooperations.IftheMemoryEffectsisnullptrthenweassumethereis ///nooperationwithMemoryEffects::Writebetweenthetwooperations. //MemEffectsCache用于在兩個操作之間緩存MemoryEffects信息。MemoryEffects表示某個操作對內(nèi)存的影響。 usingMemEffectsCache= DenseMap >; ///RepresentsasingleentryinthedepthfirsttraversalofaCFG. //CFGStackNode結(jié)構(gòu)體表示控制流圖(CFG)深度優(yōu)先遍歷中的一個節(jié)點。包括作用域、節(jié)點、子節(jié)點迭代器等信息。 structCFGStackNode{ CFGStackNode(ScopedMapTy&knownValues,DominanceInfoNode*node) :scope(knownValues),node(node),childIterator(node->begin()){} ///Scopefortheknownvalues. ScopedMapTy::ScopeTyscope; DominanceInfoNode*node; DominanceInfoNode::const_iteratorchildIterator; ///Ifthisnodehasbeenfullyprocessedyetornot. boolprocessed=false; }; ///Attempttoeliminatearedundantoperation.Returnssuccessifthe ///operationwasmarkedforremoval,failureotherwise. //simplifyOperation函數(shù)嘗試消除冗余操作。如果操作被標(biāo)記為移除,則返回成功,否則返回失敗。 LogicalResultsimplifyOperation(ScopedMapTy&knownValues,Operation*op, boolhasSSADominance); //simplifyBlock函數(shù)簡化指定的基本塊(Block)。 voidsimplifyBlock(ScopedMapTy&knownValues,Block*bb,boolhasSSADominance); //simplifyRegion函數(shù)簡化指定的區(qū)域(Region)。 voidsimplifyRegion(ScopedMapTy&knownValues,Region®ion); //runOnOperation函數(shù)是重寫的基類方法,用于執(zhí)行CSE優(yōu)化。 voidrunOnOperation()override; private: //replaceUsesAndDelete函數(shù)用于替換操作的使用和刪除操作。 voidreplaceUsesAndDelete(ScopedMapTy&knownValues,Operation*op, Operation*existing,boolhasSSADominance); ///Checkifthereisside-effectingoperationsotherthanthegiveneffect ///betweenthetwooperations. //hasOtherSideEffectingOpInBetween函數(shù)檢查給定操作之間是否存在其他具有副作用的操作。 boolhasOtherSideEffectingOpInBetween(Operation*fromOp,Operation*toOp); ///Operationsmarkedasdeadandtobeerased. //opsToErase是一個用于存儲將要刪除的操作的向量。 std::vector opsToErase; //domInfo是一個指向支配信息(DominanceInfo)的指針。 DominanceInfo*domInfo=nullptr; //memEffectsCache是一個緩存,用于存儲操作之間的內(nèi)存效果信息。 MemEffectsCachememEffectsCache; }; }//namespace
我們先看一下核心的runOperation方法。
voidCSE::runOnOperation(){ ///Ascopedhashtableofdefiningoperationswithinaregion. //定義一個名為knownValues的局部變量。它是一個作用域內(nèi)的哈希表,用于存儲在一個區(qū)域內(nèi)定義的操作。 ScopedMapTyknownValues; //從DominanceInfo分析中獲取支配關(guān)系信息,并將其存儲在名為domInfo的變量中。 domInfo=&getAnalysis(); //獲取當(dāng)前操作(rootOp),并遍歷其所有區(qū)域。對每個區(qū)域執(zhí)行簡化操作(simplifyRegion)。 Operation*rootOp=getOperation(); for(auto®ion:rootOp->getRegions()) simplifyRegion(knownValues,region); //如果opsToErase(要刪除的操作)為空,說明沒有操作被刪除,因此保留所有分析。 //Ifnooperationswereerased,thenwemarkallanalysesaspreserved. if(opsToErase.empty()) returnmarkAllAnalysesPreserved(); ///Eraseanyoperationsthatweremarkedasdeadduringsimplification. //如果opsToErase中有操作,遍歷opsToErase并刪除其中的操作。然后清空opsToErase。 for(auto*op:opsToErase) op->erase(); opsToErase.clear(); //Wecurrentlydon'tremoveregionoperations,somarkdominanceas //preserved. //由于當(dāng)前代碼不會刪除區(qū)域操作,因此將支配關(guān)系信息(DominanceInfo)和后支配關(guān)系信息(PostDominanceInfo)標(biāo)記為已保留。將domInfo設(shè)置為nullptr。 markAnalysesPreserved (); domInfo=nullptr; }
這里首先會獲取當(dāng)前 ModuleOp 中 Region 里的支配關(guān)系,以便后續(xù)執(zhí)行完 CSE 之后刪除 Op 后可以更新支配信息。這里的重點是 simplifyRegion 函數(shù),這是執(zhí)行 CSE 的具體細(xì)節(jié)。這個函數(shù)主要使用支配樹遍歷區(qū)域中的基本塊,并調(diào)用 simplifyBlock() 函數(shù)對每個基本塊進行簡化。
//函數(shù)接受一個類型為ScopedMapTy的引用knownValues和一個類型為Region的引用region作為參數(shù)。
voidCSE::simplifyRegion(ScopedMapTy&knownValues,Region®ion){
//Iftheregionisemptythereisnothingtodo.
if(region.empty())
return;
//判斷區(qū)域是否具有SSA支配關(guān)系(StaticSingleAssignmentDominance),并將結(jié)果存儲在變量hasSSADominance中。
boolhasSSADominance=domInfo->hasSSADominance(®ion);
//Iftheregiononlycontainsoneblock,thensimplifyitdirectly.
//如果區(qū)域只包含一個基本塊,那么直接對其進行簡化。創(chuàng)建一個名為scope的ScopedMapTy::ScopeTy對象,然后調(diào)用simplifyBlock()函數(shù)對該基本塊進行簡化。
if(region.hasOneBlock()){
ScopedMapTy::ScopeTyscope(knownValues);
simplifyBlock(knownValues,®ion.front(),hasSSADominance);
return;
}
//IftheregiondoesnothavedominanceInfo,thenskipit.
//TODO:RegionswithoutSSAdominanceshoulddefineadifferent
//traversalorderwhichisappropriateandcanbeusedhere.
//如果區(qū)域沒有支配關(guān)系信息(hasSSADominance為false),則跳過它。此處提到了一個TODO:對于沒有SSA支配關(guān)系的區(qū)域,應(yīng)該定義一個不同的遍歷順序。
if(!hasSSADominance)
return;
//Note,dequeisbeingusedherebecausetherewassignificantperformance
//gainsovervectorwhenthecontainerbecomesverylargeduetothe
//specificaccesspatterns.If/whentheseperformanceissuesareno
//longeraproblemwecanchangethistovector.Formoreinformationsee
//thellvmmailinglistdiscussiononthis:
//http://lists.llvm.org/pipermail/llvm-commits/Week-of-Mon-20120116/135228.html
//定義一個名為stack的std::deque容器,用于存儲CFGStackNode的std::unique_ptr。這里使用deque是因為它在容器變大時具有更好的性能表現(xiàn)。
std::deque>stack;
//Processthenodesofthedomtreeforthisregion.
//處理這個區(qū)域的支配樹節(jié)點。將區(qū)域的根節(jié)點壓入棧中。
stack.emplace_back(std::make_unique(
knownValues,domInfo->getRootNode(®ion)));
//當(dāng)棧不為空時,執(zhí)行以下循環(huán)操作:
while(!stack.empty()){
//獲取棧頂?shù)漠?dāng)前節(jié)點(currentNode)。
auto¤tNode=stack.back();
//Checktoseeifweneedtoprocessthisnode.
//檢查當(dāng)前節(jié)點是否需要被處理。如果未處理,則將其標(biāo)記為已處理,并調(diào)用simplifyBlock()函數(shù)對當(dāng)前節(jié)點所在的基本塊進行簡化。
if(!currentNode->processed){
currentNode->processed=true;
simplifyBlock(knownValues,currentNode->node->getBlock(),
hasSSADominance);
}
//Otherwise,checktoseeifweneedtoprocessachildnode.
//檢查是否需要處理子節(jié)點。如果當(dāng)前節(jié)點的子節(jié)點迭代器未到達(dá)末尾,將子節(jié)點壓入棧中。
if(currentNode->childIterator!=currentNode->node->end()){
auto*childNode=*(currentNode->childIterator++);
stack.emplace_back(
std::make_unique(knownValues,childNode));
}else{
//Finally,ifthenodeandallofitschildrenhavebeenprocessed
//thenwedeletethenode.
//如果當(dāng)前節(jié)點及其所有子節(jié)點都已處理完畢,則將節(jié)點從棧中彈出。
stack.pop_back();
}
}
}
函數(shù)的執(zhí)行流程請看注釋,到這一步之后CSE的具體實現(xiàn)實際上就在 simplifyBlock 函數(shù)了,我們繼續(xù)追蹤。函數(shù)接受一個類型為 ScopedMapTy 的引用 knownValues,一個類型為 Block 的指針 bb,以及一個布爾值 hasSSADominance 作為參數(shù)。從代碼中可以推測,該函數(shù)的目的是簡化一個給定的基本塊。
voidCSE::simplifyBlock(ScopedMapTy&knownValues,Block*bb,
boolhasSSADominance){
//遍歷基本塊bb中的所有操作(op)
for(auto&op:*bb){
//Mostoperationsdon'thaveregions,sofastpaththatcase.
//檢查操作是否包含區(qū)域。如果操作包含區(qū)域,執(zhí)行以下操作:
if(op.getNumRegions()!=0){
//Ifthisoperationisisolatedabove,wecan'tprocessnestedregions
//withthegiven'knownValues'map.Thiswouldcausetheinsertionof
//implicitcapturesinexplicitcaptureonlyregions.
//如果操作具有IsIsolatedFromAbove特性,那么我們不能使用給定的knownValues映射來處理嵌套區(qū)域,
//因為這可能導(dǎo)致在僅顯式捕獲的區(qū)域中插入隱式捕獲。在這種情況下,創(chuàng)建一個新的nestedKnownValues映射,
//并對操作的每個區(qū)域調(diào)用simplifyRegion()函數(shù)。
if(op.mightHaveTrait()){
ScopedMapTynestedKnownValues;
for(auto®ion:op.getRegions())
simplifyRegion(nestedKnownValues,region);
}else{
//Otherwise,processnestedregionsnormally.
//如果操作沒有IsIsolatedFromAbove特性,那么正常處理嵌套區(qū)域。
//對操作的每個區(qū)域調(diào)用simplifyRegion()函數(shù),傳入knownValues映射。
for(auto®ion:op.getRegions())
simplifyRegion(knownValues,region);
}
}
//如果操作被簡化(調(diào)用simplifyOperation()函數(shù)并檢查其返回值),則不處理操作包含的任何區(qū)域,繼續(xù)處理下一個操作。
//Iftheoperationissimplified,wedon'tprocessanyheldregions.
if(succeeded(simplifyOperation(knownValues,&op,hasSSADominance)))
continue;
}
//CleartheMemoryEffectscachesinceitsusageisbyblockonly.
//在處理完所有操作后,清空memEffectsCache,因為它的使用僅限于單個基本塊。
memEffectsCache.clear();
}
在 simplifyBlock 中會進一步調(diào)用到 simplifyOperation 來對 Operation 做優(yōu)化。我們最后跟進這個函數(shù)看一下。函數(shù)的參數(shù)和 simplifyBlock 一樣,接受一個類型為 ScopedMapTy 的引用 knownValues,一個類型為 Operation 的指針op,以及一個布爾值 hasSSADominance 作為參數(shù)。
///Attempttoeliminatearedundantoperation.
LogicalResultCSE::simplifyOperation(ScopedMapTy&knownValues,Operation*op,
boolhasSSADominance){
//Don'tsimplifyterminatoroperations.
//如果操作是終止操作(具有IsTerminator特性),則不對其進行簡化。
if(op->hasTrait())
returnfailure();
//Iftheoperationisalreadytriviallydeadjustaddittotheeraselist.
//如果操作已經(jīng)是無關(guān)緊要的死代碼,將其添加到待擦除操作列表opsToErase中,增加死代碼消除計數(shù),然后返回成功。
if(isOpTriviallyDead(op)){
opsToErase.push_back(op);
++numDCE;
returnsuccess();
}
//Don'tsimplifyoperationswithregionsthathavemultipleblocks.
//TODO:WeneedadditionalteststoverifythatwehandlesuchIRcorrectly.
//不簡化具有多個基本塊的區(qū)域中的操作。這里提到了一個TODO:需要額外的測試來驗證處理此類IR的正確性。
if(!llvm::all_of(op->getRegions(),[](Region&r){
returnr.getBlocks().empty()||llvm::hasSingleElement(r.getBlocks());
}))
returnfailure();
//Somesimpleusecaseofoperationwithmemoryside-effectaredealtwith
//here.Operationswithnoside-effectaredoneafter.
//首先處理具有內(nèi)存副作用的簡單操作。沒有副作用的操作會在后面處理。
if(!isMemoryEffectFree(op)){
automemEffects=dyn_cast(op);
//TODO:OnlybasicusecaseforoperationswithMemoryEffects::Readcanbe
//eleminatednow.Moreworkneedstobedoneformorecomplicatedpatterns
//andotherside-effects.
//如果操作不是無內(nèi)存副作用的,嘗試獲取其MemoryEffectOpInterface。
//如果操作沒有MemoryEffectOpInterface,或者它不僅僅具有MemoryEffects::Read副作用,則返回失敗。
if(!memEffects||!memEffects.onlyHasEffect())
returnfailure();
//Lookforanexistingdefinitionfortheoperation.
//查找操作的現(xiàn)有定義。如果找到現(xiàn)有定義,并且操作在同一個基本塊中,并且兩者之間沒有其它具有副作用的操作,
//則可以刪除冗余操作。調(diào)用replaceUsesAndDelete()函數(shù)替換使用并刪除操作。
if(auto*existing=knownValues.lookup(op)){
if(existing->getBlock()==op->getBlock()&&
!hasOtherSideEffectingOpInBetween(existing,op)){
//Theoperationthatcanbedeletedhasbeenreachwithno
//side-effectingoperationsinbetweentheexistingoperationand
//thisonesowecanremovetheduplicate.
replaceUsesAndDelete(knownValues,op,existing,hasSSADominance);
returnsuccess();
}
}
//將操作插入knownValues映射中,并返回失敗。
knownValues.insert(op,op);
returnfailure();
}
//Lookforanexistingdefinitionfortheoperation.
//查找操作的現(xiàn)有定義。如果找到現(xiàn)有定義,調(diào)用replaceUsesAndDelete()函數(shù)替換使用并刪除操作,
//增加公共子表達(dá)式消除計數(shù),并返回成功。
if(auto*existing=knownValues.lookup(op)){
replaceUsesAndDelete(knownValues,op,existing,hasSSADominance);
++numCSE;
returnsuccess();
}
//Otherwise,weaddthisoperationtotheknownvaluesmap.
//否則,將此操作添加到knownValues映射中,并返回失敗。
knownValues.insert(op,op);
returnfailure();
}
我們可以看到在 simplifyOperation 中,不僅僅包含公共子表達(dá)式消除(CSE),而且包含了死代碼消除(DCE)。此外,在處理 Operation 時,它會考慮 Operation 的內(nèi)存副作用以及 Operation 是否在具有多個基本塊的區(qū)域中。
0x3. 總結(jié)
在閱讀代碼實現(xiàn)的過程中,我發(fā)現(xiàn)基于MLIR來做公共子表達(dá)式消除的時候還順帶做了死代碼消除的功能。另外,在考慮公共子表達(dá)式消除的時候需要保證兩個重復(fù)的操作處于同一個基本塊中以及兩個重復(fù)操作之間沒有其它具有副作用的操作才可以消除。在OneFlow的實現(xiàn)中只是對OneFlow的UserOp的特殊屬性即OpName和SymbolID進行了擦除,用一個魔法屬性來代替,這是因為這兩個屬性不應(yīng)該去影響公共子表達(dá)式的消除。這個優(yōu)化還是比較有用的,在OneFlow的Stable Diffusion優(yōu)化中發(fā)揮了不小的作用。
-
代碼
+關(guān)注
關(guān)注
30文章
4979瀏覽量
74455 -
編譯器
+關(guān)注
關(guān)注
1文章
1673瀏覽量
52016 -
深度學(xué)習(xí)
+關(guān)注
關(guān)注
73文章
5612瀏覽量
124685
原文標(biāo)題:0x4. 相關(guān)鏈接
文章出處:【微信號:GiantPandaCV,微信公眾號:GiantPandaCV】歡迎添加關(guān)注!文章轉(zhuǎn)載請注明出處。
發(fā)布評論請先 登錄
國產(chǎn)深度學(xué)習(xí)框架的挑戰(zhàn)和機會
Nanopi深度學(xué)習(xí)之路(1)深度學(xué)習(xí)框架分析
深度學(xué)習(xí)框架你了解多少
如何學(xué)習(xí)深度學(xué)習(xí)框架
評論