Skip to content

[mlir][AMDGPU] Enable emulating vector buffer_atomic_fadd for bf16 on gfx942 #129029

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,15 @@ void mlir::amdgpu::populateAmdgpuEmulateAtomicsPatterns(
} else {
target.addIllegalOp<RawBufferAtomicFmaxOp>();
}
// TODO(https://github.com/llvm/llvm-project/issues/129206): Refactor
// this to avoid hardcoding ISA version: gfx950 has bf16 atomics.
if (chipset < Chipset(9, 5, 0)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you go based off the subtarget feature instead of isa versions. Really no code at all should ever be checking the isa version

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also is there something special about this emulation? Why can't you use the ordinary backend AtomicExpand handling?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pass is for handling the MLIR ops that correspond to buffer intrinsics. ptr addrspace(7) will go through backend atomic handling.

(This pass is useful for code that doesn't/can't (ex. Triton setting cache modifiers) use buffer fat pointers).

Also ... we don't have access to subtarget features

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't have access to subtarget features

If you are producing IR you must have a system for managing subtarget features. ISA version checks are not acceptable

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This type of check is done in this cpp already (for Chipset(9, 0, 8)). I agree this is not the best way to handle this, but that seems like a refactor that could be done in a future PR?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And re version checks not being acceptable, there's a bunch of them in the HIP headers (see https://github.com/ROCm/clr/blob/amd-staging/hipamd/include/hip/amd_detail/amd_hip_unsafe_atomics.h#L211 ) for example.

The procedural history of this pass is that initially started as a way to implement the gfx941 workarounds for buffer intrinsics.

Copy link
Contributor

@arsenm arsenm Mar 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The example of the hip headers is a perfect demonstration of why this is a terrible idea you should not be emulating. Those headers are a superfund site that should be purged of all target specific implementation details, particularly of the listing isa version variety

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, I'm going to call for not letting the perfect be the enemy of the good here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In terms of this patch, sure. But this system needs to be fixed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Yeah, especially since every downstream inventing their own thing here's a bit awkward
  2. The keywords here are DLTI - Data Layout and Target Information - which is getting a decent amount of setup but doesn't seem to see much use
  3. @rengolin 's a relevant person
  4. In all honesty ... patches welcome on this

target.addDynamicallyLegalOp<RawBufferAtomicFaddOp>(
[](RawBufferAtomicFaddOp op) -> bool {
Type elemType = getElementTypeOrSelf(op.getValue().getType());
return !isa<BFloat16Type>(elemType);
});
}
}
patterns.add<
RawBufferAtomicByCasPattern<RawBufferAtomicFaddOp, arith::AddFOp>,
Expand Down
77 changes: 63 additions & 14 deletions mlir/test/Dialect/AMDGPU/amdgpu-emulate-atomics.mlir
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
// RUN: mlir-opt -split-input-file -amdgpu-emulate-atomics=chipset=gfx90a %s | FileCheck %s --check-prefixes=CHECK,GFX9
// RUN: mlir-opt -split-input-file -amdgpu-emulate-atomics=chipset=gfx90a %s | FileCheck %s --check-prefixes=CHECK,GFX90A
// RUN: mlir-opt -split-input-file -amdgpu-emulate-atomics=chipset=gfx1030 %s | FileCheck %s --check-prefixes=CHECK,GFX10
// RUN: mlir-opt -split-input-file -amdgpu-emulate-atomics=chipset=gfx1100 %s | FileCheck %s --check-prefixes=CHECK,GFX11
// RUN: mlir-opt -split-input-file -amdgpu-emulate-atomics=chipset=gfx1200 %s | FileCheck %s --check-prefixes=CHECK,GFX12
// RUN: mlir-opt -split-input-file -amdgpu-emulate-atomics=chipset=gfx942 %s | FileCheck %s --check-prefixes=CHECK,GFX942
// RUN: mlir-opt -split-input-file -amdgpu-emulate-atomics=chipset=gfx950 %s | FileCheck %s --check-prefixes=CHECK,GFX950

// -----

Expand All @@ -10,16 +13,37 @@ func.func @atomic_fmax(%val: f32, %buffer: memref<?xf32>, %idx: i32) {
// CHECK: gpu.printf "Begin\0A"
// GFX10: amdgpu.raw_buffer_atomic_fmax {foo, indexOffset = 4 : i32} [[val]] -> [[buffer]][[[idx]]]
// GFX11: amdgpu.raw_buffer_atomic_fmax {foo, indexOffset = 4 : i32} [[val]] -> [[buffer]][[[idx]]]
// GFX9: [[ld:%.+]] = amdgpu.raw_buffer_load {foo, indexOffset = 4 : i32} [[buffer]][[[idx]]]
// GFX9: cf.br [[loop:\^.+]]([[ld]] : f32)
// GFX9: [[loop]]([[arg:%.+]]: f32):
// GFX9: [[operated:%.+]] = arith.maximumf [[val]], [[arg]]
// GFX9: [[atomicRes:%.+]] = amdgpu.raw_buffer_atomic_cmpswap {foo, indexOffset = 4 : i32} [[operated]], [[arg]] -> [[buffer]][[[idx]]]
// GFX9: [[argCast:%.+]] = arith.bitcast [[arg]] : f32 to i32
// GFX9: [[resCast:%.+]] = arith.bitcast [[atomicRes]] : f32 to i32
// GFX9: [[test:%.+]] = arith.cmpi eq, [[resCast]], [[argCast]]
// GFX9: cf.cond_br [[test]], [[post:\^.+]], [[loop]]([[atomicRes]] : f32)
// GFX9: [[post]]:
// GFX12: amdgpu.raw_buffer_atomic_fmax {foo, indexOffset = 4 : i32} [[val]] -> [[buffer]][[[idx]]]
// GFX90A: [[ld:%.+]] = amdgpu.raw_buffer_load {foo, indexOffset = 4 : i32} [[buffer]][[[idx]]]
// GFX90A: cf.br [[loop:\^.+]]([[ld]] : f32)
// GFX90A: [[loop]]([[arg:%.+]]: f32):
// GFX90A: [[operated:%.+]] = arith.maximumf [[val]], [[arg]]
// GFX90A: [[atomicRes:%.+]] = amdgpu.raw_buffer_atomic_cmpswap {foo, indexOffset = 4 : i32} [[operated]], [[arg]] -> [[buffer]][[[idx]]]
// GFX90A: [[argCast:%.+]] = arith.bitcast [[arg]] : f32 to i32
// GFX90A: [[resCast:%.+]] = arith.bitcast [[atomicRes]] : f32 to i32
// GFX90A: [[test:%.+]] = arith.cmpi eq, [[resCast]], [[argCast]]
// GFX90A: cf.cond_br [[test]], [[post:\^.+]], [[loop]]([[atomicRes]] : f32)
// GFX90A: [[post]]:
// GFX942: [[ld:%.+]] = amdgpu.raw_buffer_load {foo, indexOffset = 4 : i32} [[buffer]][[[idx]]]
// GFX942: cf.br [[loop:\^.+]]([[ld]] : f32)
// GFX942: [[loop]]([[arg:%.+]]: f32):
// GFX942: [[operated:%.+]] = arith.maximumf [[val]], [[arg]]
// GFX942: [[atomicRes:%.+]] = amdgpu.raw_buffer_atomic_cmpswap {foo, indexOffset = 4 : i32} [[operated]], [[arg]] -> [[buffer]][[[idx]]]
// GFX942: [[argCast:%.+]] = arith.bitcast [[arg]] : f32 to i32
// GFX942: [[resCast:%.+]] = arith.bitcast [[atomicRes]] : f32 to i32
// GFX942: [[test:%.+]] = arith.cmpi eq, [[resCast]], [[argCast]]
// GFX942: cf.cond_br [[test]], [[post:\^.+]], [[loop]]([[atomicRes]] : f32)
// GFX942: [[post]]:
// GFX950: [[ld:%.+]] = amdgpu.raw_buffer_load {foo, indexOffset = 4 : i32} [[buffer]][[[idx]]]
// GFX950: cf.br [[loop:\^.+]]([[ld]] : f32)
// GFX950: [[loop]]([[arg:%.+]]: f32):
// GFX950: [[operated:%.+]] = arith.maximumf [[val]], [[arg]]
// GFX950: [[atomicRes:%.+]] = amdgpu.raw_buffer_atomic_cmpswap {foo, indexOffset = 4 : i32} [[operated]], [[arg]] -> [[buffer]][[[idx]]]
// GFX950: [[argCast:%.+]] = arith.bitcast [[arg]] : f32 to i32
// GFX950: [[resCast:%.+]] = arith.bitcast [[atomicRes]] : f32 to i32
// GFX950: [[test:%.+]] = arith.cmpi eq, [[resCast]], [[argCast]]
// GFX950: cf.cond_br [[test]], [[post:\^.+]], [[loop]]([[atomicRes]] : f32)
// GFX950: [[post]]:
// CHECK-NEXT: gpu.printf "End\0A"
gpu.printf "Begin\n"
amdgpu.raw_buffer_atomic_fmax {foo, indexOffset = 4 : i32} %val -> %buffer[%idx] : f32 -> memref<?xf32>, i32
Expand All @@ -33,9 +57,12 @@ func.func @atomic_fmax_f64(%val: f64, %buffer: memref<?xf64>, %idx: i32) {
// CHECK: func @atomic_fmax_f64
// CHECK-SAME: ([[val:%.+]]: f64, [[buffer:%.+]]: memref<?xf64>, [[idx:%.+]]: i32)
// CHECK: gpu.printf "Begin\0A"
// GFX9: amdgpu.raw_buffer_atomic_fmax [[val]] -> [[buffer]][[[idx]]]
// GFX90A: amdgpu.raw_buffer_atomic_fmax [[val]] -> [[buffer]][[[idx]]]
// GFX10: amdgpu.raw_buffer_atomic_fmax [[val]] -> [[buffer]][[[idx]]]
// GFX11: amdgpu.raw_buffer_atomic_fmax [[val]] -> [[buffer]][[[idx]]]
// GFX12: amdgpu.raw_buffer_atomic_fmax [[val]] -> [[buffer]][[[idx]]]
// GFX942: amdgpu.raw_buffer_atomic_fmax [[val]] -> [[buffer]][[[idx]]]
// GFX950: amdgpu.raw_buffer_atomic_fmax [[val]] -> [[buffer]][[[idx]]]
// CHECK-NEXT: gpu.printf "End\0A"
gpu.printf "Begin\n"
amdgpu.raw_buffer_atomic_fmax %val -> %buffer[%idx] : f64 -> memref<?xf64>, i32
Expand All @@ -47,17 +74,20 @@ func.func @atomic_fmax_f64(%val: f64, %buffer: memref<?xf64>, %idx: i32) {

func.func @atomic_fadd(%val: f32, %buffer: memref<?xf32>, %idx: i32) {
// CHECK: func @atomic_fadd
// GFX9: amdgpu.raw_buffer_atomic_fadd
// GFX90A: amdgpu.raw_buffer_atomic_fadd
// GFX10: amdgpu.raw_buffer_load
// GFX10: amdgpu.raw_buffer_atomic_cmpswap
// GFX11: amdgpu.raw_buffer_atomic_fadd
// GFX12: amdgpu.raw_buffer_atomic_fadd
// GFX942: amdgpu.raw_buffer_atomic_fadd
// GFX950: amdgpu.raw_buffer_atomic_fadd
amdgpu.raw_buffer_atomic_fadd %val -> %buffer[%idx] : f32 -> memref<?xf32>, i32
func.return
}

// CHECK: func @atomic_fadd_v2f16
func.func @atomic_fadd_v2f16(%val: vector<2xf16>, %buffer: memref<?xf16>, %idx: i32) {
// GFX9: amdgpu.raw_buffer_atomic_fadd
// GFX90A: amdgpu.raw_buffer_atomic_fadd
// GFX10: amdgpu.raw_buffer_load
// GFX10: amdgpu.raw_buffer_atomic_cmpswap
// Note: the atomic operation itself will be done over i32, and then we use bitcasts
Expand All @@ -69,6 +99,25 @@ func.func @atomic_fadd_v2f16(%val: vector<2xf16>, %buffer: memref<?xf16>, %idx:
// GFX11: %[[vecCastOld:.+]] = vector.bitcast %[[old]] : vector<2xf16> to vector<1xi32>
// GFX11: %[[scalarOld:.+]] = vector.extract %[[vecCastOld]][0]
// GFX11: arith.cmpi eq, %[[scalarOld]], %[[scalarExpected]]
// GFX942: amdgpu.raw_buffer_atomic_fadd
// GFX12: amdgpu.raw_buffer_atomic_fadd
// GFX950: amdgpu.raw_buffer_atomic_fadd
amdgpu.raw_buffer_atomic_fadd %val -> %buffer[%idx] : vector<2xf16> -> memref<?xf16>, i32
func.return
}

// CHECK: func @atomic_fadd_v2bf16
func.func @atomic_fadd_v2bf16(%val: vector<2xbf16>, %buffer: memref<?xbf16>, %idx: i32) {
// GFX90A: amdgpu.raw_buffer_load
// GFX90A: amdgpu.raw_buffer_atomic_cmpswap
// GFX10: amdgpu.raw_buffer_load
// GFX10: amdgpu.raw_buffer_atomic_cmpswap
// GFX11: amdgpu.raw_buffer_load
// GFX11: amdgpu.raw_buffer_atomic_cmpswap
// GFX942: amdgpu.raw_buffer_load
// GFX942: amdgpu.raw_buffer_atomic_cmpswap
// GFX12: amdgpu.raw_buffer_atomic_fadd
// GFX950: amdgpu.raw_buffer_atomic_fadd
amdgpu.raw_buffer_atomic_fadd %val -> %buffer[%idx] : vector<2xbf16> -> memref<?xbf16>, i32
func.return
}