Skip to content

Commit 1d1fd29

Browse files
committed
Implement a svd_only_u as well
1 parent d7ba3dd commit 1d1fd29

File tree

3 files changed

+168
-85
lines changed

3 files changed

+168
-85
lines changed

varipeps/utils/extensions/svd_ffi.cpp

Lines changed: 114 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,23 @@
1+
#include "svd_ffi.h"
12
#include "nanobind/nanobind.h"
23
#include "xla/ffi/api/ffi.h"
34

5+
namespace nb = nanobind;
6+
47
using lapack_int = int;
58

6-
namespace nb = nanobind;
7-
using namespace ::xla;
9+
namespace ffi = xla::ffi;
810

9-
inline constexpr auto LapackIntDtype = ffi::DataType::S32;
10-
static_assert(std::is_same_v<::xla::ffi::NativeType<LapackIntDtype>, lapack_int>);
11+
XLA_FFI_REGISTER_ENUM_ATTR_DECODING(UVtMode);
1112

1213
template <ffi::DataType dtype>
1314
static ffi::Error SvdOnlyVtImpl(
1415
ffi::Buffer<dtype> x,
1516
ffi::ResultBuffer<dtype> x_out,
1617
ffi::ResultBuffer<ffi::ToReal(dtype)> s,
17-
ffi::ResultBuffer<dtype> vt,
18-
ffi::ResultBuffer<LapackIntDtype> info) {
18+
ffi::ResultBuffer<dtype> u_or_vt,
19+
ffi::ResultBuffer<ffi::DataType::S32> info,
20+
UVtMode mode) {
1921

2022
using MachineType = ffi::NativeType<dtype>;
2123
using RealType = ffi::NativeType<ffi::ToReal(dtype)>;
@@ -76,48 +78,68 @@ static ffi::Error SvdOnlyVtImpl(
7678

7779
const auto lapack_int_max = std::numeric_limits<lapack_int>::max();
7880

79-
ffi::Span<const int64_t> dims = x.dimensions();
81+
const ffi::Span<const int64_t> dims = x.dimensions();
8082
if (dims.size() != 2) {
8183
return ffi::Error(ffi::ErrorCode::kInvalidArgument, "Only 2d arrays supported as input.");
8284
}
83-
int64_t x_rows = dims.front();
84-
int64_t x_cols = dims.back();
85+
const int64_t x_rows = dims.front();
86+
const int64_t x_cols = dims.back();
8587

86-
if (x_rows < x_cols) [[unlikely]] {
88+
if (mode == UVtMode::computeOnlyU && x_rows > x_cols) [[unlikely]] {
89+
return ffi::Error(ffi::ErrorCode::kInvalidArgument, "Only matrices with M <= N supported.");
90+
} else if (mode == UVtMode::computeOnlyVt && x_rows < x_cols) [[unlikely]] {
8791
return ffi::Error(ffi::ErrorCode::kInvalidArgument, "Only matrices with M >= N supported.");
8892
}
8993

9094
if (x_rows > lapack_int_max || x_cols > lapack_int_max) [[unlikely]] {
9195
return ffi::Error(ffi::ErrorCode::kOutOfRange, "Dimension of input out of range for lapack integer.");
9296
}
9397

94-
lapack_int x_rows_lapack = static_cast<lapack_int>(x_rows);
95-
lapack_int x_cols_lapack = static_cast<lapack_int>(x_cols);
98+
const lapack_int x_rows_lapack = static_cast<lapack_int>(x_rows);
99+
const lapack_int x_cols_lapack = static_cast<lapack_int>(x_cols);
96100

97101
auto* x_out_data = x_out->typed_data();
98102
auto* s_data = s->typed_data();
99-
auto* vt_data = vt->typed_data();
103+
auto* u_or_vt_data = u_or_vt->typed_data();
100104
auto* info_data = info->typed_data();
101105

106+
MachineType* u_data;
107+
MachineType* vt_data;
108+
if (mode == UVtMode::computeOnlyU && x_rows < x_cols) {
109+
u_data = u_or_vt_data;
110+
vt_data = nullptr;
111+
} else {
112+
u_data = nullptr;
113+
vt_data = u_or_vt_data;
114+
}
115+
102116
if (x.typed_data() != x_out_data) {
103117
std::copy_n(x.typed_data(), x.element_count(), x_out_data);
104118
}
105119

106120
ffi::NativeType<dtype> work_size = {};
107121
lapack_int lwork = -1;
108-
char jobz = 'O';
109-
lapack_int ldu = 1;
122+
const char jobz = 'O';
123+
lapack_int ldu;
124+
lapack_int ldvt;
125+
if (mode == UVtMode::computeOnlyU && x_rows < x_cols) {
126+
ldu = x_rows_lapack;
127+
ldvt = 1;
128+
} else {
129+
ldu = 1;
130+
ldvt = x_cols_lapack;
131+
}
110132

111133
if constexpr (ffi::IsComplexType<dtype>()) {
112134
fn(&jobz, &x_rows_lapack, &x_cols_lapack, nullptr,
113135
&x_rows_lapack, nullptr, nullptr,
114-
&ldu, nullptr, &x_cols_lapack, &work_size,
136+
&ldu, nullptr, &ldvt, &work_size,
115137
&lwork, nullptr, nullptr, info_data
116138
);
117139
} else {
118140
fn(&jobz, &x_rows_lapack, &x_cols_lapack, nullptr,
119141
&x_rows_lapack, nullptr, nullptr,
120-
&ldu, nullptr, &x_cols_lapack,
142+
&ldu, nullptr, &ldvt,
121143
&work_size, &lwork, nullptr, info_data
122144
);
123145
}
@@ -147,14 +169,14 @@ static ffi::Error SvdOnlyVtImpl(
147169

148170
if constexpr (ffi::IsComplexType<dtype>()) {
149171
fn(&jobz, &x_rows_lapack, &x_cols_lapack, x_out_data,
150-
&x_rows_lapack, s_data, nullptr,
151-
&ldu, vt_data, &x_cols_lapack, work.get(),
172+
&x_rows_lapack, s_data, u_data,
173+
&ldu, vt_data, &ldvt, work.get(),
152174
&lwork, rwork.get(), iwork.get(), info_data
153175
);
154176
} else {
155177
fn(&jobz, &x_rows_lapack, &x_cols_lapack, x_out_data,
156-
&x_rows_lapack, s_data, nullptr,
157-
&ldu, vt_data, &x_cols_lapack,
178+
&x_rows_lapack, s_data, u_data,
179+
&ldu, vt_data, &ldvt,
158180
work.get(), &lwork, iwork.get(), info_data
159181
);
160182
}
@@ -171,7 +193,8 @@ static ffi::Error SvdOnlyVtQRImpl(
171193
ffi::Buffer<dtype> x,
172194
ffi::ResultBuffer<dtype> x_out,
173195
ffi::ResultBuffer<ffi::ToReal(dtype)> s,
174-
ffi::ResultBuffer<LapackIntDtype> info) {
196+
ffi::ResultBuffer<ffi::DataType::S32> info,
197+
UVtMode mode) {
175198

176199
using MachineType = ffi::NativeType<dtype>;
177200
using RealType = ffi::NativeType<ffi::ToReal(dtype)>;
@@ -230,23 +253,25 @@ static ffi::Error SvdOnlyVtQRImpl(
230253

231254
const auto lapack_int_max = std::numeric_limits<lapack_int>::max();
232255

233-
ffi::Span<const int64_t> dims = x.dimensions();
256+
const ffi::Span<const int64_t> dims = x.dimensions();
234257
if (dims.size() != 2) {
235258
return ffi::Error(ffi::ErrorCode::kInvalidArgument, "Only 2d arrays supported as input.");
236259
}
237-
int64_t x_rows = dims.front();
238-
int64_t x_cols = dims.back();
260+
const int64_t x_rows = dims.front();
261+
const int64_t x_cols = dims.back();
239262

240-
if (x_rows < x_cols) [[unlikely]] {
263+
if (mode == UVtMode::computeOnlyU && x_rows > x_cols) [[unlikely]] {
264+
return ffi::Error(ffi::ErrorCode::kInvalidArgument, "Only matrices with M <= N supported.");
265+
} else if (mode == UVtMode::computeOnlyVt && x_rows < x_cols) [[unlikely]] {
241266
return ffi::Error(ffi::ErrorCode::kInvalidArgument, "Only matrices with M >= N supported.");
242267
}
243268

244269
if (x_rows > lapack_int_max || x_cols > lapack_int_max) [[unlikely]] {
245270
return ffi::Error(ffi::ErrorCode::kOutOfRange, "Dimension of input out of range for lapack integer.");
246271
}
247272

248-
lapack_int x_rows_lapack = static_cast<lapack_int>(x_rows);
249-
lapack_int x_cols_lapack = static_cast<lapack_int>(x_cols);
273+
const lapack_int x_rows_lapack = static_cast<lapack_int>(x_rows);
274+
const lapack_int x_cols_lapack = static_cast<lapack_int>(x_cols);
250275

251276
auto* x_out_data = x_out->typed_data();
252277
auto* s_data = s->typed_data();
@@ -259,20 +284,33 @@ static ffi::Error SvdOnlyVtQRImpl(
259284

260285
ffi::NativeType<dtype> work_size = {};
261286
lapack_int lwork = -1;
262-
char jobu = 'N';
263-
char jobvt = 'O';
264-
lapack_int ldu = 1;
287+
288+
char jobu;
289+
char jobvt;
290+
const lapack_int ldu = 1;
291+
const lapack_int ldvt = 1;
292+
if (mode == UVtMode::computeOnlyU) {
293+
jobu = 'O';
294+
jobvt = 'N';
295+
// ldu = 1;
296+
// ldvt = 1;
297+
} else {
298+
jobu = 'N';
299+
jobvt = 'O';
300+
// ldu = 1;
301+
// ldvt = 1;
302+
}
265303

266304
if constexpr (ffi::IsComplexType<dtype>()) {
267305
fn(&jobu, &jobvt, &x_rows_lapack, &x_cols_lapack, nullptr,
268306
&x_rows_lapack, nullptr, nullptr,
269-
&ldu, nullptr, &x_cols_lapack, &work_size,
307+
&ldu, nullptr, &ldvt, &work_size,
270308
&lwork, nullptr, info_data
271309
);
272310
} else {
273311
fn(&jobu, &jobvt, &x_rows_lapack, &x_cols_lapack, nullptr,
274312
&x_rows_lapack, nullptr, nullptr,
275-
&ldu, nullptr, &x_cols_lapack,
313+
&ldu, nullptr, &ldvt,
276314
&work_size, &lwork, info_data
277315
);
278316
}
@@ -300,13 +338,13 @@ static ffi::Error SvdOnlyVtQRImpl(
300338
if constexpr (ffi::IsComplexType<dtype>()) {
301339
fn(&jobu, &jobvt, &x_rows_lapack, &x_cols_lapack, x_out_data,
302340
&x_rows_lapack, s_data, nullptr,
303-
&ldu, nullptr, &x_cols_lapack, work.get(),
341+
&ldu, nullptr, &ldvt, work.get(),
304342
&lwork, rwork.get(), info_data
305343
);
306344
} else {
307345
fn(&jobu, &jobvt, &x_rows_lapack, &x_cols_lapack, x_out_data,
308346
&x_rows_lapack, s_data, nullptr,
309-
&ldu, nullptr, &x_cols_lapack,
347+
&ldu, nullptr, &ldvt,
310348
work.get(), &lwork, info_data
311349
);
312350
}
@@ -318,56 +356,60 @@ static ffi::Error SvdOnlyVtQRImpl(
318356
return ffi::Error::Success();
319357
}
320358

321-
#define DEFINE_REAL_SVD_ONLY_VT(fname, dtype) \
322-
XLA_FFI_DEFINE_HANDLER_SYMBOL( \
323-
fname, SvdOnlyVtImpl<dtype>, \
324-
ffi::Ffi::Bind() \
325-
.Arg<ffi::Buffer<dtype>>(/*x*/) \
326-
.Ret<ffi::Buffer<dtype>>(/*x_out*/) \
327-
.Ret<ffi::Buffer<dtype>>(/*s*/) \
328-
.Ret<ffi::Buffer<dtype>>(/*vt*/) \
329-
.Ret<ffi::Buffer<LapackIntDtype>>(/*info*/))
330-
331-
#define DEFINE_COMPLEX_SVD_ONLY_VT(fname, dtype) \
332-
XLA_FFI_DEFINE_HANDLER_SYMBOL( \
333-
fname, SvdOnlyVtImpl<dtype>, \
334-
ffi::Ffi::Bind() \
335-
.Arg<ffi::Buffer<dtype>>(/*x*/) \
336-
.Ret<ffi::Buffer<dtype>>(/*x_out*/) \
337-
.Ret<ffi::Buffer<ffi::ToReal(dtype)>>(/*s*/) \
338-
.Ret<ffi::Buffer<dtype>>(/*vt*/) \
339-
.Ret<ffi::Buffer<LapackIntDtype>>(/*info*/))
359+
#define DEFINE_REAL_SVD_ONLY_VT(fname, dtype) \
360+
XLA_FFI_DEFINE_HANDLER_SYMBOL( \
361+
fname, SvdOnlyVtImpl<dtype>, \
362+
ffi::Ffi::Bind() \
363+
.Arg<ffi::Buffer<dtype>>(/*x*/) \
364+
.Ret<ffi::Buffer<dtype>>(/*x_out*/) \
365+
.Ret<ffi::Buffer<dtype>>(/*s*/) \
366+
.Ret<ffi::Buffer<dtype>>(/*vt*/) \
367+
.Ret<ffi::Buffer<ffi::DataType::S32>>(/*info*/) \
368+
.Attr<UVtMode>("mode"))
369+
370+
#define DEFINE_COMPLEX_SVD_ONLY_VT(fname, dtype) \
371+
XLA_FFI_DEFINE_HANDLER_SYMBOL( \
372+
fname, SvdOnlyVtImpl<dtype>, \
373+
ffi::Ffi::Bind() \
374+
.Arg<ffi::Buffer<dtype>>(/*x*/) \
375+
.Ret<ffi::Buffer<dtype>>(/*x_out*/) \
376+
.Ret<ffi::Buffer<ffi::ToReal(dtype)>>(/*s*/) \
377+
.Ret<ffi::Buffer<dtype>>(/*vt*/) \
378+
.Ret<ffi::Buffer<ffi::DataType::S32>>(/*info*/) \
379+
.Attr<UVtMode>("mode"))
340380

341381
DEFINE_REAL_SVD_ONLY_VT(svd_only_vt_f32, ffi::DataType::F32);
342382
DEFINE_REAL_SVD_ONLY_VT(svd_only_vt_f64, ffi::DataType::F64);
343383
DEFINE_COMPLEX_SVD_ONLY_VT(svd_only_vt_c64, ffi::DataType::C64);
344384
DEFINE_COMPLEX_SVD_ONLY_VT(svd_only_vt_c128, ffi::DataType::C128);
345385

346-
#define DEFINE_REAL_SVD_ONLY_VT_QR(fname, dtype) \
347-
XLA_FFI_DEFINE_HANDLER_SYMBOL( \
348-
fname, SvdOnlyVtQRImpl<dtype>, \
349-
ffi::Ffi::Bind() \
350-
.Arg<ffi::Buffer<dtype>>(/*x*/) \
351-
.Ret<ffi::Buffer<dtype>>(/*x_out*/) \
352-
.Ret<ffi::Buffer<dtype>>(/*s*/) \
353-
.Ret<ffi::Buffer<LapackIntDtype>>(/*info*/))
354-
355-
#define DEFINE_COMPLEX_SVD_ONLY_VT_QR(fname, dtype) \
356-
XLA_FFI_DEFINE_HANDLER_SYMBOL( \
357-
fname, SvdOnlyVtQRImpl<dtype>, \
358-
ffi::Ffi::Bind() \
359-
.Arg<ffi::Buffer<dtype>>(/*x*/) \
360-
.Ret<ffi::Buffer<dtype>>(/*x_out*/) \
361-
.Ret<ffi::Buffer<ffi::ToReal(dtype)>>(/*s*/) \
362-
.Ret<ffi::Buffer<LapackIntDtype>>(/*info*/))
386+
#define DEFINE_REAL_SVD_ONLY_VT_QR(fname, dtype) \
387+
XLA_FFI_DEFINE_HANDLER_SYMBOL( \
388+
fname, SvdOnlyVtQRImpl<dtype>, \
389+
ffi::Ffi::Bind() \
390+
.Arg<ffi::Buffer<dtype>>(/*x*/) \
391+
.Ret<ffi::Buffer<dtype>>(/*x_out*/) \
392+
.Ret<ffi::Buffer<dtype>>(/*s*/) \
393+
.Ret<ffi::Buffer<ffi::DataType::S32>>(/*info*/) \
394+
.Attr<UVtMode>("mode"))
395+
396+
#define DEFINE_COMPLEX_SVD_ONLY_VT_QR(fname, dtype) \
397+
XLA_FFI_DEFINE_HANDLER_SYMBOL( \
398+
fname, SvdOnlyVtQRImpl<dtype>, \
399+
ffi::Ffi::Bind() \
400+
.Arg<ffi::Buffer<dtype>>(/*x*/) \
401+
.Ret<ffi::Buffer<dtype>>(/*x_out*/) \
402+
.Ret<ffi::Buffer<ffi::ToReal(dtype)>>(/*s*/) \
403+
.Ret<ffi::Buffer<ffi::DataType::S32>>(/*info*/) \
404+
.Attr<UVtMode>("mode"))
363405

364406
DEFINE_REAL_SVD_ONLY_VT_QR(svd_only_vt_qr_f32, ffi::DataType::F32);
365407
DEFINE_REAL_SVD_ONLY_VT_QR(svd_only_vt_qr_f64, ffi::DataType::F64);
366408
DEFINE_COMPLEX_SVD_ONLY_VT_QR(svd_only_vt_qr_c64, ffi::DataType::C64);
367409
DEFINE_COMPLEX_SVD_ONLY_VT_QR(svd_only_vt_qr_c128, ffi::DataType::C128);
368410

369411
template <typename T>
370-
nb::capsule EncapsulateFfiCall(T *fn) {
412+
static nb::capsule EncapsulateFfiCall(T *fn) {
371413
static_assert(std::is_invocable_r_v<XLA_FFI_Error *, T, XLA_FFI_CallFrame *>,
372414
"Encapsulated function must be and XLA FFI handler");
373415
return nb::capsule(reinterpret_cast<void *>(fn));

varipeps/utils/extensions/svd_ffi.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#ifndef VARIPEPS_SVD_FFI_H_
2+
#define VARIPEPS_SVD_FFI_H_
3+
4+
#include "xla/ffi/api/ffi.h"
5+
6+
enum class UVtMode : int8_t {
7+
computeOnlyU = 0, // Compute only U
8+
computeOnlyVt = 1, // Compute only Vt
9+
};
10+
11+
XLA_FFI_DECLARE_HANDLER_SYMBOL(svd_only_vt_f32);
12+
XLA_FFI_DECLARE_HANDLER_SYMBOL(svd_only_vt_f64);
13+
XLA_FFI_DECLARE_HANDLER_SYMBOL(svd_only_vt_c64);
14+
XLA_FFI_DECLARE_HANDLER_SYMBOL(svd_only_vt_c128);
15+
16+
XLA_FFI_DECLARE_HANDLER_SYMBOL(svd_only_vt_qr_f32);
17+
XLA_FFI_DECLARE_HANDLER_SYMBOL(svd_only_vt_qr_f64);
18+
XLA_FFI_DECLARE_HANDLER_SYMBOL(svd_only_vt_qr_c64);
19+
XLA_FFI_DECLARE_HANDLER_SYMBOL(svd_only_vt_qr_c128);
20+
21+
#endif // VARIPEPS_SVD_FFI_H

0 commit comments

Comments
 (0)