From a1955d13791a6c69e8a531d301b3bd807563916b Mon Sep 17 00:00:00 2001 From: Valentin Clement Date: Mon, 26 Aug 2024 11:54:41 -0700 Subject: [PATCH] [flang][cuda] Simplify data transfer when possible --- flang/lib/Lower/Bridge.cpp | 45 ++++++++++++++------ flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp | 4 +- flang/test/Lower/CUDA/cuda-data-transfer.cuf | 19 ++++++++- 3 files changed, 52 insertions(+), 16 deletions(-) diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp index 24cd6b22b8925..a414cd5cb0b93 100644 --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -4251,15 +4251,37 @@ class FirConverter : public Fortran::lower::AbstractConverter { bool lhsIsDevice = Fortran::evaluate::HasCUDADeviceAttrs(assign.lhs); bool rhsIsDevice = Fortran::evaluate::HasCUDADeviceAttrs(assign.rhs); - auto getRefIfLoaded = [](mlir::Value val) -> mlir::Value { + auto getRefFromValue = [](mlir::Value val) -> mlir::Value { if (auto loadOp = mlir::dyn_cast_or_null(val.getDefiningOp())) return loadOp.getMemref(); + if (!mlir::isa(val.getType())) + return val; + if (auto declOp = + mlir::dyn_cast_or_null(val.getDefiningOp())) { + if (!declOp.getShape()) + return val; + if (mlir::isa(declOp.getMemref().getType())) + return declOp.getMemref(); + } return val; }; - mlir::Value rhsVal = getRefIfLoaded(rhs.getBase()); - mlir::Value lhsVal = getRefIfLoaded(lhs.getBase()); + auto getShapeFromDecl = [](mlir::Value val) -> mlir::Value { + if (!mlir::isa(val.getType())) + return {}; + if (auto declOp = + mlir::dyn_cast_or_null(val.getDefiningOp())) + return declOp.getShape(); + return {}; + }; + + mlir::Value rhsVal = getRefFromValue(rhs.getBase()); + mlir::Value lhsVal = getRefFromValue(lhs.getBase()); + // Get shape from the rhs if available otherwise get it from lhs. + mlir::Value shape = getShapeFromDecl(rhs.getBase()); + if (!shape) + shape = getShapeFromDecl(lhs.getBase()); // device = host if (lhsIsDevice && !rhsIsDevice) { @@ -4272,19 +4294,18 @@ class FirConverter : public Fortran::lower::AbstractConverter { base = convertOp.getValue(); // Special case if the rhs is a constant. if (matchPattern(base.getDefiningOp(), mlir::m_Constant())) { - builder.create( - loc, base, lhsVal, /*shape=*/mlir::Value{}, transferKindAttr); + builder.create(loc, base, lhsVal, shape, + transferKindAttr); } else { auto associate = hlfir::genAssociateExpr( loc, builder, rhs, rhs.getType(), ".cuf_host_tmp"); builder.create(loc, associate.getBase(), lhsVal, - /*shape=*/mlir::Value{}, - transferKindAttr); + shape, transferKindAttr); builder.create(loc, associate); } } else { - builder.create( - loc, rhsVal, lhsVal, /*shape=*/mlir::Value{}, transferKindAttr); + builder.create(loc, rhsVal, lhsVal, shape, + transferKindAttr); } return; } @@ -4293,8 +4314,7 @@ class FirConverter : public Fortran::lower::AbstractConverter { if (!lhsIsDevice && rhsIsDevice) { auto transferKindAttr = cuf::DataTransferKindAttr::get( builder.getContext(), cuf::DataTransferKind::DeviceHost); - builder.create(loc, rhsVal, lhsVal, - /*shape=*/mlir::Value{}, + builder.create(loc, rhsVal, lhsVal, shape, transferKindAttr); return; } @@ -4304,8 +4324,7 @@ class FirConverter : public Fortran::lower::AbstractConverter { assert(rhs.isVariable() && "CUDA Fortran assignment rhs is not legal"); auto transferKindAttr = cuf::DataTransferKindAttr::get( builder.getContext(), cuf::DataTransferKind::DeviceDevice); - builder.create(loc, rhsVal, lhsVal, - /*shape=*/mlir::Value{}, + builder.create(loc, rhsVal, lhsVal, shape, transferKindAttr); return; } diff --git a/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp b/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp index 3b4ad95cafe6b..7fb2dcf4af115 100644 --- a/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp +++ b/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp @@ -112,9 +112,11 @@ llvm::LogicalResult cuf::DataTransferOp::verify() { if (fir::isa_trivial(srcTy) && matchPattern(getSrc().getDefiningOp(), mlir::m_Constant())) return mlir::success(); + return emitOpError() << "expect src and dst to be references or descriptors or src to " - "be a constant"; + "be a constant: " + << srcTy << " - " << dstTy; } //===----------------------------------------------------------------------===// diff --git a/flang/test/Lower/CUDA/cuda-data-transfer.cuf b/flang/test/Lower/CUDA/cuda-data-transfer.cuf index 7eb74a4234f5a..b2ae8c9f82ebb 100644 --- a/flang/test/Lower/CUDA/cuda-data-transfer.cuf +++ b/flang/test/Lower/CUDA/cuda-data-transfer.cuf @@ -11,6 +11,7 @@ contains function dev1(a) integer, device :: a(:) integer :: dev1 + dev1 = 1 end function end @@ -198,8 +199,8 @@ end subroutine ! CHECK-SAME: %[[ARG0:.*]]: !fir.ref> {cuf.data_attr = #cuf.cuda, fir.bindc_name = "a"}, %[[ARG1:.*]]: !fir.ref> {fir.bindc_name = "b"}, %[[ARG2:.*]]: !fir.ref {fir.bindc_name = "n"}) ! CHECK: %[[B:.*]]:2 = hlfir.declare %[[ARG1]](%{{.*}}) dummy_scope %{{.*}} {uniq_name = "_QFsub8Eb"} : (!fir.ref>, !fir.shape<1>, !fir.dscope) -> (!fir.ref>, !fir.ref>) ! CHECK: %[[A:.*]]:2 = hlfir.declare %[[ARG0]](%{{.*}}) dummy_scope %{{.*}} {data_attr = #cuf.cuda, uniq_name = "_QFsub8Ea"} : (!fir.ref>, !fir.shape<1>, !fir.dscope) -> (!fir.box>, !fir.ref>) -! CHECK: cuf.data_transfer %[[A]]#0 to %[[B]]#0 {transfer_kind = #cuf.cuda_transfer} : !fir.box>, !fir.ref> -! CHECK: cuf.data_transfer %[[B]]#0 to %[[A]]#0 {transfer_kind = #cuf.cuda_transfer} : !fir.ref>, !fir.box> +! CHECK: cuf.data_transfer %[[ARG0]] to %[[B]]#0, %{{.*}} : !fir.shape<1> {transfer_kind = #cuf.cuda_transfer} : !fir.ref>, !fir.ref> +! CHECK: cuf.data_transfer %[[B]]#0 to %[[ARG0]], %{{.*}} : !fir.shape<1> {transfer_kind = #cuf.cuda_transfer} : !fir.ref>, !fir.ref> subroutine sub9(a) integer, pinned, allocatable :: a(:) @@ -274,3 +275,17 @@ end subroutine ! CHECK-LABEL: func.func @_QPsub14() ! CHECK: %[[TRUE:.*]] = arith.constant true ! CHECK: cuf.data_transfer %[[TRUE]] to %{{.*}}#0 {transfer_kind = #cuf.cuda_transfer} : i1, !fir.ref>> + +subroutine sub15(a_dev, a_host, n, m) + integer, intent(in) :: n, m + real, device :: a_dev(n*m) + real :: a_host(n*m) + + a_dev = a_host +end subroutine + +! CHECK-LABEL: func.func @_QPsub15( +! CHECK-SAME: %[[ARG0:.*]]: !fir.ref> {cuf.data_attr = #cuf.cuda, fir.bindc_name = "a_dev"}, %[[ARG1:.*]]: !fir.ref> {fir.bindc_name = "a_host"} +! CHECK: %{{.*}} = fir.shape %{{.*}} : (index) -> !fir.shape<1> +! CHECK: %[[SHAPE:.*]] = fir.shape %{{.*}} : (index) -> !fir.shape<1> +! CHECK: cuf.data_transfer %[[ARG1]] to %[[ARG0]], %[[SHAPE]] : !fir.shape<1> {transfer_kind = #cuf.cuda_transfer} : !fir.ref>, !fir.ref>