@@ -23,7 +23,7 @@ import TensorFlowAPI
23
23
public extension Array {
24
24
25
25
/// method that can 'completely' flatten a multi-dimensional array
26
- public static func Flat ( _ array: Array < Any > ) -> Array < Any > {
26
+ static func Flat ( _ array: Array < Any > ) -> Array < Any > {
27
27
if let a = array as? Array < Array < Any > > {
28
28
return Flat ( a. flatMap ( { $0} ) )
29
29
}
@@ -35,11 +35,11 @@ public extension Array {
35
35
} //end func
36
36
37
37
/// instance method
38
- public func flat( ) -> Array < Any > {
38
+ func flat( ) -> Array < Any > {
39
39
return Array . Flat ( self )
40
40
}
41
41
42
- public var shape : [ Int ] {
42
+ var shape : [ Int ] {
43
43
var _shape = [ Int] ( )
44
44
var a = self as Array < Any >
45
45
while a. count > 0 {
@@ -53,7 +53,7 @@ public extension Array {
53
53
return _shape
54
54
} //end var
55
55
56
- public func column( index: Int ) -> Array < Any > {
56
+ func column( index: Int ) -> Array < Any > {
57
57
var b = [ Any] ( )
58
58
let s = shape
59
59
guard s. count > 1 , index > - 1 , index < s [ 1 ] ,
@@ -72,17 +72,26 @@ public extension Array {
72
72
73
73
typealias SwiftArray < T> = Array < T >
74
74
public extension Data {
75
- public static func From( _ string: String ) -> Data {
75
+ static func From( _ string: String ) -> Data {
76
76
return string. withCString { p -> Data in
77
77
return Data ( bytes: p, count: string. utf8. count)
78
78
} //end return
79
79
} //end from
80
- public var string : String {
80
+ var string : String {
81
+ #if swift(>=5.0)
82
+ return self . withUnsafeBytes { ( buffered: UnsafeRawBufferPointer ) -> String in
83
+ let p = buffered. baseAddress? . assumingMemoryBound ( to: CChar . self)
84
+ var q = Array ( UnsafeBufferPointer ( start: p, count: self . count) )
85
+ q. append ( 0 )
86
+ return String ( cString: q)
87
+ }
88
+ #else
81
89
return self . withUnsafeBytes { ( p: UnsafePointer < CChar > ) -> String in
82
90
var q = Array ( UnsafeBufferPointer ( start: p, count: self . count) )
83
91
q. append ( 0 )
84
92
return String ( cString: q)
85
93
} //end return
94
+ #endif
86
95
} //end var
87
96
} //end extension
88
97
@@ -409,10 +418,18 @@ public class TensorFlow {
409
418
/// - parameters:
410
419
/// - data: data to copy with
411
420
public init ( data: Data ) throws {
421
+ #if swift(>=5.0)
422
+ guard let pData = ( data. withUnsafeBytes { ( buffered: UnsafeRawBufferPointer ) -> UnsafePointer < Int8 > ? in
423
+ return buffered. baseAddress? . assumingMemoryBound ( to: Int8 . self)
424
+ } ) else {
425
+ throw Panic . INVALID
426
+ }
427
+ #else
412
428
let pData = data. withUnsafeBytes {
413
429
( ptr: UnsafePointer < Int8 > ) -> UnsafePointer < Int8 > in
414
430
return ptr
415
431
} //end let
432
+ #endif
416
433
guard let _ = TFLib . libDLL,
417
434
let buf = TFLib . NewBufferFromString ( pData, data. count)
418
435
else { throw Panic . CALL }
@@ -792,10 +809,20 @@ public class TensorFlow {
792
809
// *NOTE* DON'T USE MAP
793
810
// UInt8(128) to Int8 will cause segment fault
794
811
795
- let s = strings [ i] . withUnsafeBytes { ( ptr: UnsafePointer < Int8 > ) -> [ Int8 ] in
812
+ #if swift(>=5.0)
813
+ let s = strings [ i] . withUnsafeBytes {
814
+ ( ptr: UnsafeRawBufferPointer ) -> [ Int8 ] in
815
+ let p = ptr. baseAddress? . bindMemory ( to: Int8 . self, capacity: strings [ i] . count)
816
+ let buffered = UnsafeBufferPointer ( start: p, count: strings [ i] . count)
817
+ return Array ( buffered)
818
+ }
819
+ #else
820
+ let s = strings [ i] . withUnsafeBytes {
821
+ ( ptr: UnsafePointer < Int8 > ) -> [ Int8 ] in
796
822
let buffered = UnsafeBufferPointer ( start: ptr, count: strings [ i] . count)
797
823
return Array ( buffered)
798
824
}
825
+ #endif
799
826
800
827
let encoded = try TensorFlow . Encode ( string: s)
801
828
size += UInt64 ( encoded. count)
@@ -943,7 +970,12 @@ public class TensorFlow {
943
970
public func `set`( config: Config ) throws -> SessionOptions {
944
971
let s = try Status ( )
945
972
let data = try config. serializedData ( )
973
+ #if swift(>=5.0)
974
+ guard let p = ( data. withUnsafeBytes { $0. baseAddress? . assumingMemoryBound ( to: CChar . self) } )
975
+ else { throw Panic . INVALID }
976
+ #else
946
977
let p = data. withUnsafeBytes { ( ptr: UnsafePointer < CChar > ) in return ptr }
978
+ #endif
947
979
TFLib . SetConfig ( options, p, data. count, s. status)
948
980
guard s. code == . OK else { throw Panic . FAULT ( reason: s. message) }
949
981
return self
@@ -1139,19 +1171,36 @@ public class TensorFlow {
1139
1171
} else if v is TensorProto , let p = v as? TensorProto {
1140
1172
let data = try p. serializedData ( )
1141
1173
let status = try Status ( )
1174
+ #if swift(>=5.0)
1175
+ data. withUnsafeBytes { ( buffered: UnsafeRawBufferPointer ) in
1176
+ if let ptr = buffered. baseAddress? . assumingMemoryBound ( to: CChar . self) {
1177
+ TFLib . SetAttrTensorShapeProto ( descriptor, k, ptr, data. count, status. status)
1178
+ total += 1
1179
+ }
1180
+ }
1181
+ #else
1142
1182
data. withUnsafeBytes { ( ptr: UnsafePointer < CChar > ) in
1143
1183
TFLib . SetAttrTensorShapeProto ( descriptor, k, ptr, data. count, status. status)
1144
1184
total += 1
1145
1185
} //end bytes
1186
+ #endif
1146
1187
guard status. code == . OK else { throw Panic . FAULT ( reason: status. message) }
1147
1188
} else if v is [ TensorProto ] , let pv = v as? [ TensorProto ] , pv. count > 0 {
1148
1189
let array = UnsafeMutablePointer< UnsafePointer< CChar>>. allocate( capacity: pv. count)
1149
1190
let lens = UnsafeMutablePointer< Int> . allocate( capacity: pv. count)
1150
1191
let data = try pv. map { try $0. serializedData ( ) }
1151
1192
for i in 0 ... pv. count - 1 {
1193
+ #if swift(>=5.0)
1194
+ data [ i] . withUnsafeBytes { ( buffered: UnsafeRawBufferPointer ) in
1195
+ if let ptr = buffered. baseAddress? . assumingMemoryBound ( to: CChar . self) {
1196
+ array. advanced ( by: i) . pointee = ptr
1197
+ }
1198
+ }
1199
+ #else
1152
1200
data [ i] . withUnsafeBytes { ( ptr: UnsafePointer < CChar > ) in
1153
1201
array. advanced ( by: i) . pointee = ptr
1154
1202
} //end bytes
1203
+ #endif
1155
1204
lens. advanced ( by: i) . pointee = data [ i] . count
1156
1205
} //next
1157
1206
let status = try Status ( )
@@ -1185,7 +1234,12 @@ public class TensorFlow {
1185
1234
guard status. code == . OK else { throw Panic . FAULT ( reason: status. message) }
1186
1235
} else if v is Data , let d = v as? Data , d. count > 0 {
1187
1236
let status = try Status ( )
1237
+ #if swift(>=5.0)
1238
+ guard let p = ( d. withUnsafeBytes { $0. baseAddress? . assumingMemoryBound ( to: Int8 . self) } )
1239
+ else { throw Panic . FAULT ( reason: " invalid pointer " ) }
1240
+ #else
1188
1241
let p = d. withUnsafeBytes { pointer -> UnsafePointer < Int8 > in return pointer }
1242
+ #endif
1189
1243
TFLib . SetAttrValueProto ( descriptor, k, p, d. count, status. status)
1190
1244
total += 1
1191
1245
guard status. code == . OK else { throw Panic . FAULT ( reason: status. message) }
@@ -2038,6 +2092,12 @@ public class TensorFlow {
2038
2092
}
2039
2093
}
2040
2094
2095
+ #if swift(>=5.0)
2096
+ private func assignOpaque( pointer: UnsafeMutablePointer < OpaquePointer ? > , count: Int ) -> Array < OpaquePointer ? > {
2097
+ return Array ( UnsafeBufferPointer ( start: pointer, count: count) )
2098
+ }
2099
+ #endif
2100
+
2041
2101
public func getFunctions( ) throws -> [ Function ? ] {
2042
2102
let count = TFLib . GraphNumFunctions ( self . graph)
2043
2103
guard count > 0 else {
@@ -2046,11 +2106,15 @@ public class TensorFlow {
2046
2106
var funcs : OpaquePointer ? = nil
2047
2107
let status = try Status ( )
2048
2108
let num = TFLib . GraphGetFunctions ( self . graph, & funcs, count, status. status)
2049
- guard num > 0 && num <= count else {
2109
+ guard let _ = funcs , num > 0 && num <= count else {
2050
2110
throw Panic . FAULT ( reason: " no function exisis in this graph " )
2051
2111
}
2112
+ #if swift(>=5.0)
2113
+ let array = assignOpaque ( pointer: & funcs, count: Int ( num) )
2114
+ #else
2052
2115
let pointers = UnsafeBufferPointer ( start: & funcs, count: Int ( num) )
2053
2116
let array = Array ( pointers)
2117
+ #endif
2054
2118
return array. map { if let f = $0 { return Function ( f) } else { return nil } }
2055
2119
}
2056
2120
@@ -2072,6 +2136,21 @@ public class TensorFlow {
2072
2136
/// - throws: Panic.Fault(reason: status.message)
2073
2137
public init ( importDefinition: FunctionDef ) throws {
2074
2138
let proto = try importDefinition. serializedData ( )
2139
+ #if swift(>=5.0)
2140
+ ref = try proto. withUnsafeBytes {
2141
+ ( buffered: UnsafeRawBufferPointer ) throws -> OpaquePointer in
2142
+ guard let p = buffered. baseAddress? . assumingMemoryBound ( to: UInt8 . self) else {
2143
+ throw Panic . INVALID
2144
+ }
2145
+ let status = try Status ( )
2146
+ guard let function = TFLib . FunctionImportFunctionDef (
2147
+ p, Int32 ( proto. count) , status. status) ,
2148
+ status. code == . OK else {
2149
+ throw Panic . FAULT ( reason: status. message)
2150
+ }
2151
+ return function
2152
+ }
2153
+ #else
2075
2154
ref = try proto. withUnsafeBytes {
2076
2155
( p: UnsafePointer < UInt8 > ) throws -> OpaquePointer in
2077
2156
let status = try Status ( )
@@ -2082,6 +2161,7 @@ public class TensorFlow {
2082
2161
}
2083
2162
return function
2084
2163
}
2164
+ #endif
2085
2165
}
2086
2166
2087
2167
/// Sets function attribute named `name` to value.
@@ -2092,6 +2172,19 @@ public class TensorFlow {
2092
2172
/// - throws: Panic.Fault(reason: status.message)
2093
2173
public func setAttributeFor( _ name: String , value: AttrValue ) throws {
2094
2174
let proto = try value. serializedData ( )
2175
+ #if swift(>=5.0)
2176
+ try proto. withUnsafeBytes {
2177
+ ( buffered: UnsafeRawBufferPointer ) throws in
2178
+ guard let p = buffered. baseAddress? . assumingMemoryBound ( to: UInt8 . self)
2179
+ else { throw Panic . INVALID }
2180
+ let status = try Status ( )
2181
+ TFLib . FunctionSetAttrValueProto (
2182
+ ref, name, p, Int32 ( proto. count) , status. status)
2183
+ guard status. code == . OK else {
2184
+ throw Panic . FAULT ( reason: status. message)
2185
+ }
2186
+ }
2187
+ #else
2095
2188
try proto. withUnsafeBytes {
2096
2189
( p: UnsafePointer < UInt8 > ) throws in
2097
2190
let status = try Status ( )
@@ -2101,6 +2194,7 @@ public class TensorFlow {
2101
2194
throw Panic . FAULT ( reason: status. message)
2102
2195
}
2103
2196
}
2197
+ #endif
2104
2198
}
2105
2199
2106
2200
/// get an attribute value by its name
0 commit comments