diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index da11ee9960e1f..b507fa656d601 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -391,6 +391,8 @@ static llvm::Expected convertOmpOpRegions( Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl *continuationBlockPHIs = nullptr) { + bool isLoopWrapper = isa(region.getParentOp()); + llvm::BasicBlock *continuationBlock = splitBB(builder, true, "omp.region.cont"); llvm::BasicBlock *sourceBlock = builder.GetInsertBlock(); @@ -407,30 +409,34 @@ static llvm::Expected convertOmpOpRegions( // Terminators (namely YieldOp) may be forwarding values to the region that // need to be available in the continuation block. Collect the types of these - // operands in preparation of creating PHI nodes. + // operands in preparation of creating PHI nodes. This is skipped for loop + // wrapper operations, for which we know in advance they have no terminators. SmallVector continuationBlockPHITypes; - bool operandsProcessed = false; unsigned numYields = 0; - for (Block &bb : region.getBlocks()) { - if (omp::YieldOp yield = dyn_cast(bb.getTerminator())) { - if (!operandsProcessed) { - for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) { - continuationBlockPHITypes.push_back( - moduleTranslation.convertType(yield->getOperand(i).getType())); - } - operandsProcessed = true; - } else { - assert(continuationBlockPHITypes.size() == yield->getNumOperands() && - "mismatching number of values yielded from the region"); - for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) { - llvm::Type *operandType = - moduleTranslation.convertType(yield->getOperand(i).getType()); - (void)operandType; - assert(continuationBlockPHITypes[i] == operandType && - "values of mismatching types yielded from the region"); + + if (!isLoopWrapper) { + bool operandsProcessed = false; + for (Block &bb : region.getBlocks()) { + if (omp::YieldOp yield = dyn_cast(bb.getTerminator())) { + if (!operandsProcessed) { + for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) { + continuationBlockPHITypes.push_back( + moduleTranslation.convertType(yield->getOperand(i).getType())); + } + operandsProcessed = true; + } else { + assert(continuationBlockPHITypes.size() == yield->getNumOperands() && + "mismatching number of values yielded from the region"); + for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) { + llvm::Type *operandType = + moduleTranslation.convertType(yield->getOperand(i).getType()); + (void)operandType; + assert(continuationBlockPHITypes[i] == operandType && + "values of mismatching types yielded from the region"); + } } + numYields++; } - numYields++; } } @@ -468,6 +474,13 @@ static llvm::Expected convertOmpOpRegions( moduleTranslation.convertBlock(*bb, bb->isEntryBlock(), builder))) return llvm::make_error(); + // Create a direct branch here for loop wrappers to prevent their lack of a + // terminator from causing a crash below. + if (isLoopWrapper) { + builder.CreateBr(continuationBlock); + continue; + } + // Special handling for `omp.yield` and `omp.terminator` (we may have more // than one): they return the control to the parent OpenMP dialect operation // so replace them with the branch to the continuation block. We handle this