1
+ #include " svd_ffi.h"
1
2
#include " nanobind/nanobind.h"
2
3
#include " xla/ffi/api/ffi.h"
3
4
5
+ namespace nb = nanobind;
6
+
4
7
using lapack_int = int ;
5
8
6
- namespace nb = nanobind;
7
- using namespace ::xla;
9
+ namespace ffi = xla::ffi;
8
10
9
- inline constexpr auto LapackIntDtype = ffi::DataType::S32;
10
- static_assert (std::is_same_v<::xla::ffi::NativeType<LapackIntDtype>, lapack_int>);
11
+ XLA_FFI_REGISTER_ENUM_ATTR_DECODING (UVtMode);
11
12
12
13
template <ffi::DataType dtype>
13
14
static ffi::Error SvdOnlyVtImpl (
14
15
ffi::Buffer<dtype> x,
15
16
ffi::ResultBuffer<dtype> x_out,
16
17
ffi::ResultBuffer<ffi::ToReal(dtype)> s,
17
- ffi::ResultBuffer<dtype> vt,
18
- ffi::ResultBuffer<LapackIntDtype> info) {
18
+ ffi::ResultBuffer<dtype> u_or_vt,
19
+ ffi::ResultBuffer<ffi::DataType::S32> info,
20
+ UVtMode mode) {
19
21
20
22
using MachineType = ffi::NativeType<dtype>;
21
23
using RealType = ffi::NativeType<ffi::ToReal (dtype)>;
@@ -76,48 +78,68 @@ static ffi::Error SvdOnlyVtImpl(
76
78
77
79
const auto lapack_int_max = std::numeric_limits<lapack_int>::max ();
78
80
79
- ffi::Span<const int64_t > dims = x.dimensions ();
81
+ const ffi::Span<const int64_t > dims = x.dimensions ();
80
82
if (dims.size () != 2 ) {
81
83
return ffi::Error (ffi::ErrorCode::kInvalidArgument , " Only 2d arrays supported as input." );
82
84
}
83
- int64_t x_rows = dims.front ();
84
- int64_t x_cols = dims.back ();
85
+ const int64_t x_rows = dims.front ();
86
+ const int64_t x_cols = dims.back ();
85
87
86
- if (x_rows < x_cols) [[unlikely]] {
88
+ if (mode == UVtMode::computeOnlyU && x_rows > x_cols) [[unlikely]] {
89
+ return ffi::Error (ffi::ErrorCode::kInvalidArgument , " Only matrices with M <= N supported." );
90
+ } else if (mode == UVtMode::computeOnlyVt && x_rows < x_cols) [[unlikely]] {
87
91
return ffi::Error (ffi::ErrorCode::kInvalidArgument , " Only matrices with M >= N supported." );
88
92
}
89
93
90
94
if (x_rows > lapack_int_max || x_cols > lapack_int_max) [[unlikely]] {
91
95
return ffi::Error (ffi::ErrorCode::kOutOfRange , " Dimension of input out of range for lapack integer." );
92
96
}
93
97
94
- lapack_int x_rows_lapack = static_cast <lapack_int>(x_rows);
95
- lapack_int x_cols_lapack = static_cast <lapack_int>(x_cols);
98
+ const lapack_int x_rows_lapack = static_cast <lapack_int>(x_rows);
99
+ const lapack_int x_cols_lapack = static_cast <lapack_int>(x_cols);
96
100
97
101
auto * x_out_data = x_out->typed_data ();
98
102
auto * s_data = s->typed_data ();
99
- auto * vt_data = vt ->typed_data ();
103
+ auto * u_or_vt_data = u_or_vt ->typed_data ();
100
104
auto * info_data = info->typed_data ();
101
105
106
+ MachineType* u_data;
107
+ MachineType* vt_data;
108
+ if (mode == UVtMode::computeOnlyU && x_rows < x_cols) {
109
+ u_data = u_or_vt_data;
110
+ vt_data = nullptr ;
111
+ } else {
112
+ u_data = nullptr ;
113
+ vt_data = u_or_vt_data;
114
+ }
115
+
102
116
if (x.typed_data () != x_out_data) {
103
117
std::copy_n (x.typed_data (), x.element_count (), x_out_data);
104
118
}
105
119
106
120
ffi::NativeType<dtype> work_size = {};
107
121
lapack_int lwork = -1 ;
108
- char jobz = ' O' ;
109
- lapack_int ldu = 1 ;
122
+ const char jobz = ' O' ;
123
+ lapack_int ldu;
124
+ lapack_int ldvt;
125
+ if (mode == UVtMode::computeOnlyU && x_rows < x_cols) {
126
+ ldu = x_rows_lapack;
127
+ ldvt = 1 ;
128
+ } else {
129
+ ldu = 1 ;
130
+ ldvt = x_cols_lapack;
131
+ }
110
132
111
133
if constexpr (ffi::IsComplexType<dtype>()) {
112
134
fn (&jobz, &x_rows_lapack, &x_cols_lapack, nullptr ,
113
135
&x_rows_lapack, nullptr , nullptr ,
114
- &ldu, nullptr , &x_cols_lapack , &work_size,
136
+ &ldu, nullptr , &ldvt , &work_size,
115
137
&lwork, nullptr , nullptr , info_data
116
138
);
117
139
} else {
118
140
fn (&jobz, &x_rows_lapack, &x_cols_lapack, nullptr ,
119
141
&x_rows_lapack, nullptr , nullptr ,
120
- &ldu, nullptr , &x_cols_lapack ,
142
+ &ldu, nullptr , &ldvt ,
121
143
&work_size, &lwork, nullptr , info_data
122
144
);
123
145
}
@@ -147,14 +169,14 @@ static ffi::Error SvdOnlyVtImpl(
147
169
148
170
if constexpr (ffi::IsComplexType<dtype>()) {
149
171
fn (&jobz, &x_rows_lapack, &x_cols_lapack, x_out_data,
150
- &x_rows_lapack, s_data, nullptr ,
151
- &ldu, vt_data, &x_cols_lapack , work.get (),
172
+ &x_rows_lapack, s_data, u_data ,
173
+ &ldu, vt_data, &ldvt , work.get (),
152
174
&lwork, rwork.get (), iwork.get (), info_data
153
175
);
154
176
} else {
155
177
fn (&jobz, &x_rows_lapack, &x_cols_lapack, x_out_data,
156
- &x_rows_lapack, s_data, nullptr ,
157
- &ldu, vt_data, &x_cols_lapack ,
178
+ &x_rows_lapack, s_data, u_data ,
179
+ &ldu, vt_data, &ldvt ,
158
180
work.get (), &lwork, iwork.get (), info_data
159
181
);
160
182
}
@@ -171,7 +193,8 @@ static ffi::Error SvdOnlyVtQRImpl(
171
193
ffi::Buffer<dtype> x,
172
194
ffi::ResultBuffer<dtype> x_out,
173
195
ffi::ResultBuffer<ffi::ToReal(dtype)> s,
174
- ffi::ResultBuffer<LapackIntDtype> info) {
196
+ ffi::ResultBuffer<ffi::DataType::S32> info,
197
+ UVtMode mode) {
175
198
176
199
using MachineType = ffi::NativeType<dtype>;
177
200
using RealType = ffi::NativeType<ffi::ToReal (dtype)>;
@@ -230,23 +253,25 @@ static ffi::Error SvdOnlyVtQRImpl(
230
253
231
254
const auto lapack_int_max = std::numeric_limits<lapack_int>::max ();
232
255
233
- ffi::Span<const int64_t > dims = x.dimensions ();
256
+ const ffi::Span<const int64_t > dims = x.dimensions ();
234
257
if (dims.size () != 2 ) {
235
258
return ffi::Error (ffi::ErrorCode::kInvalidArgument , " Only 2d arrays supported as input." );
236
259
}
237
- int64_t x_rows = dims.front ();
238
- int64_t x_cols = dims.back ();
260
+ const int64_t x_rows = dims.front ();
261
+ const int64_t x_cols = dims.back ();
239
262
240
- if (x_rows < x_cols) [[unlikely]] {
263
+ if (mode == UVtMode::computeOnlyU && x_rows > x_cols) [[unlikely]] {
264
+ return ffi::Error (ffi::ErrorCode::kInvalidArgument , " Only matrices with M <= N supported." );
265
+ } else if (mode == UVtMode::computeOnlyVt && x_rows < x_cols) [[unlikely]] {
241
266
return ffi::Error (ffi::ErrorCode::kInvalidArgument , " Only matrices with M >= N supported." );
242
267
}
243
268
244
269
if (x_rows > lapack_int_max || x_cols > lapack_int_max) [[unlikely]] {
245
270
return ffi::Error (ffi::ErrorCode::kOutOfRange , " Dimension of input out of range for lapack integer." );
246
271
}
247
272
248
- lapack_int x_rows_lapack = static_cast <lapack_int>(x_rows);
249
- lapack_int x_cols_lapack = static_cast <lapack_int>(x_cols);
273
+ const lapack_int x_rows_lapack = static_cast <lapack_int>(x_rows);
274
+ const lapack_int x_cols_lapack = static_cast <lapack_int>(x_cols);
250
275
251
276
auto * x_out_data = x_out->typed_data ();
252
277
auto * s_data = s->typed_data ();
@@ -259,20 +284,33 @@ static ffi::Error SvdOnlyVtQRImpl(
259
284
260
285
ffi::NativeType<dtype> work_size = {};
261
286
lapack_int lwork = -1 ;
262
- char jobu = ' N' ;
263
- char jobvt = ' O' ;
264
- lapack_int ldu = 1 ;
287
+
288
+ char jobu;
289
+ char jobvt;
290
+ const lapack_int ldu = 1 ;
291
+ const lapack_int ldvt = 1 ;
292
+ if (mode == UVtMode::computeOnlyU) {
293
+ jobu = ' O' ;
294
+ jobvt = ' N' ;
295
+ // ldu = 1;
296
+ // ldvt = 1;
297
+ } else {
298
+ jobu = ' N' ;
299
+ jobvt = ' O' ;
300
+ // ldu = 1;
301
+ // ldvt = 1;
302
+ }
265
303
266
304
if constexpr (ffi::IsComplexType<dtype>()) {
267
305
fn (&jobu, &jobvt, &x_rows_lapack, &x_cols_lapack, nullptr ,
268
306
&x_rows_lapack, nullptr , nullptr ,
269
- &ldu, nullptr , &x_cols_lapack , &work_size,
307
+ &ldu, nullptr , &ldvt , &work_size,
270
308
&lwork, nullptr , info_data
271
309
);
272
310
} else {
273
311
fn (&jobu, &jobvt, &x_rows_lapack, &x_cols_lapack, nullptr ,
274
312
&x_rows_lapack, nullptr , nullptr ,
275
- &ldu, nullptr , &x_cols_lapack ,
313
+ &ldu, nullptr , &ldvt ,
276
314
&work_size, &lwork, info_data
277
315
);
278
316
}
@@ -300,13 +338,13 @@ static ffi::Error SvdOnlyVtQRImpl(
300
338
if constexpr (ffi::IsComplexType<dtype>()) {
301
339
fn (&jobu, &jobvt, &x_rows_lapack, &x_cols_lapack, x_out_data,
302
340
&x_rows_lapack, s_data, nullptr ,
303
- &ldu, nullptr , &x_cols_lapack , work.get (),
341
+ &ldu, nullptr , &ldvt , work.get (),
304
342
&lwork, rwork.get (), info_data
305
343
);
306
344
} else {
307
345
fn (&jobu, &jobvt, &x_rows_lapack, &x_cols_lapack, x_out_data,
308
346
&x_rows_lapack, s_data, nullptr ,
309
- &ldu, nullptr , &x_cols_lapack ,
347
+ &ldu, nullptr , &ldvt ,
310
348
work.get (), &lwork, info_data
311
349
);
312
350
}
@@ -318,56 +356,60 @@ static ffi::Error SvdOnlyVtQRImpl(
318
356
return ffi::Error::Success ();
319
357
}
320
358
321
- #define DEFINE_REAL_SVD_ONLY_VT (fname, dtype ) \
322
- XLA_FFI_DEFINE_HANDLER_SYMBOL ( \
323
- fname, SvdOnlyVtImpl<dtype>, \
324
- ffi::Ffi::Bind () \
325
- .Arg<ffi::Buffer<dtype>>(/* x*/ ) \
326
- .Ret<ffi::Buffer<dtype>>(/* x_out*/ ) \
327
- .Ret<ffi::Buffer<dtype>>(/* s*/ ) \
328
- .Ret<ffi::Buffer<dtype>>(/* vt*/ ) \
329
- .Ret<ffi::Buffer<LapackIntDtype>>(/* info*/ ))
330
-
331
- #define DEFINE_COMPLEX_SVD_ONLY_VT (fname, dtype ) \
332
- XLA_FFI_DEFINE_HANDLER_SYMBOL ( \
333
- fname, SvdOnlyVtImpl<dtype>, \
334
- ffi::Ffi::Bind () \
335
- .Arg<ffi::Buffer<dtype>>(/* x*/ ) \
336
- .Ret<ffi::Buffer<dtype>>(/* x_out*/ ) \
337
- .Ret<ffi::Buffer<ffi::ToReal(dtype)>>(/* s*/ ) \
338
- .Ret<ffi::Buffer<dtype>>(/* vt*/ ) \
339
- .Ret<ffi::Buffer<LapackIntDtype>>(/* info*/ ))
359
+ #define DEFINE_REAL_SVD_ONLY_VT (fname, dtype ) \
360
+ XLA_FFI_DEFINE_HANDLER_SYMBOL ( \
361
+ fname, SvdOnlyVtImpl<dtype>, \
362
+ ffi::Ffi::Bind () \
363
+ .Arg<ffi::Buffer<dtype>>(/* x*/ ) \
364
+ .Ret<ffi::Buffer<dtype>>(/* x_out*/ ) \
365
+ .Ret<ffi::Buffer<dtype>>(/* s*/ ) \
366
+ .Ret<ffi::Buffer<dtype>>(/* vt*/ ) \
367
+ .Ret<ffi::Buffer<ffi::DataType::S32>>(/* info*/ ) \
368
+ .Attr<UVtMode>(" mode" ))
369
+
370
+ #define DEFINE_COMPLEX_SVD_ONLY_VT (fname, dtype ) \
371
+ XLA_FFI_DEFINE_HANDLER_SYMBOL ( \
372
+ fname, SvdOnlyVtImpl<dtype>, \
373
+ ffi::Ffi::Bind () \
374
+ .Arg<ffi::Buffer<dtype>>(/* x*/ ) \
375
+ .Ret<ffi::Buffer<dtype>>(/* x_out*/ ) \
376
+ .Ret<ffi::Buffer<ffi::ToReal(dtype)>>(/* s*/ ) \
377
+ .Ret<ffi::Buffer<dtype>>(/* vt*/ ) \
378
+ .Ret<ffi::Buffer<ffi::DataType::S32>>(/* info*/ ) \
379
+ .Attr<UVtMode>(" mode" ))
340
380
341
381
DEFINE_REAL_SVD_ONLY_VT(svd_only_vt_f32, ffi::DataType::F32);
342
382
DEFINE_REAL_SVD_ONLY_VT (svd_only_vt_f64, ffi::DataType::F64);
343
383
DEFINE_COMPLEX_SVD_ONLY_VT (svd_only_vt_c64, ffi::DataType::C64);
344
384
DEFINE_COMPLEX_SVD_ONLY_VT (svd_only_vt_c128, ffi::DataType::C128);
345
385
346
- #define DEFINE_REAL_SVD_ONLY_VT_QR (fname, dtype ) \
347
- XLA_FFI_DEFINE_HANDLER_SYMBOL ( \
348
- fname, SvdOnlyVtQRImpl<dtype>, \
349
- ffi::Ffi::Bind () \
350
- .Arg<ffi::Buffer<dtype>>(/* x*/ ) \
351
- .Ret<ffi::Buffer<dtype>>(/* x_out*/ ) \
352
- .Ret<ffi::Buffer<dtype>>(/* s*/ ) \
353
- .Ret<ffi::Buffer<LapackIntDtype>>(/* info*/ ))
354
-
355
- #define DEFINE_COMPLEX_SVD_ONLY_VT_QR (fname, dtype ) \
356
- XLA_FFI_DEFINE_HANDLER_SYMBOL ( \
357
- fname, SvdOnlyVtQRImpl<dtype>, \
358
- ffi::Ffi::Bind () \
359
- .Arg<ffi::Buffer<dtype>>(/* x*/ ) \
360
- .Ret<ffi::Buffer<dtype>>(/* x_out*/ ) \
361
- .Ret<ffi::Buffer<ffi::ToReal(dtype)>>(/* s*/ ) \
362
- .Ret<ffi::Buffer<LapackIntDtype>>(/* info*/ ))
386
+ #define DEFINE_REAL_SVD_ONLY_VT_QR (fname, dtype ) \
387
+ XLA_FFI_DEFINE_HANDLER_SYMBOL ( \
388
+ fname, SvdOnlyVtQRImpl<dtype>, \
389
+ ffi::Ffi::Bind () \
390
+ .Arg<ffi::Buffer<dtype>>(/* x*/ ) \
391
+ .Ret<ffi::Buffer<dtype>>(/* x_out*/ ) \
392
+ .Ret<ffi::Buffer<dtype>>(/* s*/ ) \
393
+ .Ret<ffi::Buffer<ffi::DataType::S32>>(/* info*/ ) \
394
+ .Attr<UVtMode>(" mode" ))
395
+
396
+ #define DEFINE_COMPLEX_SVD_ONLY_VT_QR (fname, dtype ) \
397
+ XLA_FFI_DEFINE_HANDLER_SYMBOL ( \
398
+ fname, SvdOnlyVtQRImpl<dtype>, \
399
+ ffi::Ffi::Bind () \
400
+ .Arg<ffi::Buffer<dtype>>(/* x*/ ) \
401
+ .Ret<ffi::Buffer<dtype>>(/* x_out*/ ) \
402
+ .Ret<ffi::Buffer<ffi::ToReal(dtype)>>(/* s*/ ) \
403
+ .Ret<ffi::Buffer<ffi::DataType::S32>>(/* info*/ ) \
404
+ .Attr<UVtMode>(" mode" ))
363
405
364
406
DEFINE_REAL_SVD_ONLY_VT_QR(svd_only_vt_qr_f32, ffi::DataType::F32);
365
407
DEFINE_REAL_SVD_ONLY_VT_QR (svd_only_vt_qr_f64, ffi::DataType::F64);
366
408
DEFINE_COMPLEX_SVD_ONLY_VT_QR (svd_only_vt_qr_c64, ffi::DataType::C64);
367
409
DEFINE_COMPLEX_SVD_ONLY_VT_QR (svd_only_vt_qr_c128, ffi::DataType::C128);
368
410
369
411
template <typename T>
370
- nb::capsule EncapsulateFfiCall (T *fn) {
412
+ static nb::capsule EncapsulateFfiCall (T *fn) {
371
413
static_assert (std::is_invocable_r_v<XLA_FFI_Error *, T, XLA_FFI_CallFrame *>,
372
414
" Encapsulated function must be and XLA FFI handler" );
373
415
return nb::capsule (reinterpret_cast <void *>(fn));
0 commit comments