Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Accurate Tensor.device for TFEager backends #1077

Closed
wants to merge 12 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 35 additions & 1 deletion Sources/TensorFlow/Core/Tensor.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

import CTensorFlow
import Foundation
import _Differentiation

infix operator .==: ComparisonPrecedence
Expand Down Expand Up @@ -768,7 +769,40 @@ extension Tensor: Differentiable & EuclideanDifferentiable where Scalar: TensorF
case .XLA:
return xlaTensor.device
case .TF_EAGER:
return Device.defaultTFEager
var kind: Device.Kind = .CPU
var ordinal = 0
let status = _ExecutionContext.global.status

// Find out what the underlying libraries think the default is.
if let cString = TFE_TensorHandleDeviceName(handle._cTensorHandle, status) {
checkOk(status)
let tfDeviceName = String(cString: cString)

// Parse type and ordinal from a string with the expected syntax:
// /job:localhost/replica:0/task:0/device:CPU:0
let pattern = ".+device:(.+):(\\d+)$"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe break this String -> Device out as a separate function?
Also, I'm concerned that the string parsing will be expensive. This function is called a lot. (whenever there is a scalar constant or anything like that). I think it would be best to see if you can add your own TFE_TensorHandleDevice_Type and TFE_TensorHandleDevice_Id. Some benchmarking results might work instead.

let regex = try! NSRegularExpression(pattern: pattern)
let nsrange = NSRange(tfDeviceName.startIndex..., in: tfDeviceName)
if let match = regex.firstMatch(in: tfDeviceName, range: nsrange) {
if let kindRange = Range(match.range(at: 1), in: tfDeviceName) {
switch String(tfDeviceName[kindRange]).uppercased() {
case "CPU":
kind = .CPU
case "GPU":
kind = .GPU
case "TPU":
kind = .TPU
default:
kind = .CPU
}
}
if let ordinalRange = Range(match.range(at: 2), in: tfDeviceName) {
ordinal = Int(tfDeviceName[ordinalRange]) ?? 0
}
}
}

return Device(kind: kind, ordinal: ordinal, backend: .TF_EAGER)
}
}
}
Expand Down