diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index a1f87a637a614..8c3391c8d9293 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -401,8 +401,29 @@ LogicalResult GPUDialect::verifyOperationAttribute(Operation *op, << expectedNumArguments; auto functionType = kernelGPUFunction.getFunctionType(); + auto typesMatch = [&](Type launchOpArgType, Type gpuFuncArgType) { + auto launchOpMemref = dyn_cast(launchOpArgType); + auto kernelMemref = dyn_cast(gpuFuncArgType); + // Allow address space incompatibility for OpenCL kernels: `gpu.launch`'s + // argument memref without address space attribute will match a kernel + // function's memref argument with address space `Global`. + if (launchOpMemref && kernelMemref) { + auto launchAS = llvm::dyn_cast_or_null( + launchOpMemref.getMemorySpace()); + auto kernelAS = llvm::dyn_cast_or_null( + kernelMemref.getMemorySpace()); + if (!launchAS && kernelAS && + kernelAS.getValue() == gpu::AddressSpace::Global) + return launchOpMemref.getShape() == kernelMemref.getShape() && + launchOpMemref.getLayout() == kernelMemref.getLayout() && + launchOpMemref.getElementType() == + kernelMemref.getElementType(); + } + return launchOpArgType == gpuFuncArgType; + }; for (unsigned i = 0; i < expectedNumArguments; ++i) { - if (launchOp.getKernelOperand(i).getType() != functionType.getInput(i)) { + if (!typesMatch(launchOp.getKernelOperand(i).getType(), + functionType.getInput(i))) { return launchOp.emitOpError("type of function argument ") << i << " does not match"; } diff --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir index ba7897f4e80cb..fdfd9fcc8b185 100644 --- a/mlir/test/Dialect/GPU/ops.mlir +++ b/mlir/test/Dialect/GPU/ops.mlir @@ -441,3 +441,17 @@ gpu.module @module_with_two_target [#nvvm.target, #rocdl.target gpu.module @module_with_offload_handler <#gpu.select_object<0>> [#nvvm.target] { } + +// Check kernel memref args are valid even if the address space differs +module attributes {gpu.container_module} { + func.func @foo(%mem : memref<5xf32>) { + %c0 = arith.constant 0 : i32 + gpu.launch_func @gpu_kernels::@kernel blocks in (%c0, %c0, %c0) threads in (%c0, %c0, %c0) : i32 args(%mem : memref<5xf32>) + return + } + gpu.module @gpu_kernels { + gpu.func @kernel(%arg0 : memref<5xf32, #gpu.address_space>) kernel { + gpu.return + } + } +}