Skip to content

[mlir][nvgpu] Use the strides of the memref descriptor to construct the TMA descriptor #85403

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

apaszke
Copy link
Member

@apaszke apaszke commented Mar 15, 2024

The previous version of the code assumed that the tensor was contiguous, which is not required and can cause surprising miscompiles.

The previous version of the code assumed that the tensor was contiguous,
which is not required and can cause surprising miscompiles.
@llvmbot
Copy link
Member

llvmbot commented Mar 15, 2024

@llvm/pr-subscribers-mlir-execution-engine

@llvm/pr-subscribers-mlir

Author: Adam Paszke (apaszke)

Changes

The previous version of the code assumed that the tensor was contiguous, which is not required and can cause surprising miscompiles.


Full diff: https://github.com/llvm/llvm-project/pull/85403.diff

1 Files Affected:

  • (modified) mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp (+19-13)
diff --git a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
index b9a3429e37b885..c76f8d77dff558 100644
--- a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
+++ b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
@@ -427,13 +427,21 @@ namespace {
 
 template <int rank>
 void mgpuGetMemRefDataAndShape(void *raw_descriptor, char **addr,
-                               uint64_t *globalDim) {
+                               uint64_t *globalDim, uint64_t *globalStrides,
+                               const CUtensorMapDataType tensorDataType) {
   auto descriptor =
       reinterpret_cast<StridedMemRefType<char, rank> *>(raw_descriptor);
   *addr = descriptor->data;
   for (int i = 0; i < rank; ++i) {
     globalDim[i] = static_cast<uint64_t>(descriptor->sizes[rank - i - 1]);
   }
+  static constexpr int elementSizeInBytes[] = {1, 2, 4, 4, 8, 8, 2,
+                                               4, 8, 2, 4, 4, 4};
+  // TODO(grypp): Check that the minormost stride is equal to the element size.
+  for (int i = 0; i < rank - 1; ++i) {
+    globalStrides[i] = static_cast<uint64_t>(
+        descriptor->strides[rank - i - 2] * elementSizeInBytes[tensorDataType]);
+  }
 }
 
 } // namespace
@@ -457,19 +465,24 @@ extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *mgpuTensorMapEncodeTiledMemref(
   char *globalAddress = nullptr;
   switch (tensorRank) {
   case 1:
-    mgpuGetMemRefDataAndShape<1>(ranked_descriptor, &globalAddress, globalDim);
+    mgpuGetMemRefDataAndShape<1>(ranked_descriptor, &globalAddress, globalDim,
+                                 globalStrides, tensorDataType);
     break;
   case 2:
-    mgpuGetMemRefDataAndShape<2>(ranked_descriptor, &globalAddress, globalDim);
+    mgpuGetMemRefDataAndShape<2>(ranked_descriptor, &globalAddress, globalDim,
+                                 globalStrides, tensorDataType);
     break;
   case 3:
-    mgpuGetMemRefDataAndShape<3>(ranked_descriptor, &globalAddress, globalDim);
+    mgpuGetMemRefDataAndShape<3>(ranked_descriptor, &globalAddress, globalDim,
+                                 globalStrides, tensorDataType);
     break;
   case 4:
-    mgpuGetMemRefDataAndShape<4>(ranked_descriptor, &globalAddress, globalDim);
+    mgpuGetMemRefDataAndShape<4>(ranked_descriptor, &globalAddress, globalDim,
+                                 globalStrides, tensorDataType);
     break;
   case 5:
-    mgpuGetMemRefDataAndShape<5>(ranked_descriptor, &globalAddress, globalDim);
+    mgpuGetMemRefDataAndShape<5>(ranked_descriptor, &globalAddress, globalDim,
+                                 globalStrides, tensorDataType);
     break;
   default:
     fprintf(
@@ -478,17 +491,10 @@ extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *mgpuTensorMapEncodeTiledMemref(
     return NULL;
   }
 
-  static const int elementSizeInBytes[] = {1, 2, 4, 4, 8, 8, 2,
-                                           4, 8, 2, 4, 4, 4};
   for (int64_t r = 0; r < tensorRank; ++r) {
-    elementStrides[r] = uint32_t(1);
     boxDim[r] = static_cast<uint32_t>(inputBoxDims[tensorRank - r - 1]);
   }
 
-  globalStrides[0] = globalDim[0] * elementSizeInBytes[tensorDataType];
-  for (int r = 1; r < tensorRank - 1; r++)
-    globalStrides[r] = globalStrides[r - 1] * globalDim[r];
-
   ScopedContext scopedContext;
   mgpuTensorMapEncodeTiled(&tensorMap, tensorDataType, tensorRank32,
                            globalAddress, globalDim, globalStrides, boxDim,

Copy link
Member

@grypp grypp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome thanks! We need a test for this code (llvm rule)

auto descriptor =
reinterpret_cast<StridedMemRefType<char, rank> *>(raw_descriptor);
*addr = descriptor->data;
for (int i = 0; i < rank; ++i) {
globalDim[i] = static_cast<uint64_t>(descriptor->sizes[rank - i - 1]);
}
static constexpr int elementSizeInBytes[] = {1, 2, 4, 4, 8, 8, 2,
4, 8, 2, 4, 4, 4};
// TODO(grypp): Check that the minormost stride is equal to the element size.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LLVM doesn't use TODO with name. Let's just keep this as TODO

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants