|
| 1 | +from typing import Optional, Union |
| 2 | + |
| 3 | +import tensorrt as trt |
| 4 | +import torch |
| 5 | +from torch.fx.node import Target |
| 6 | +from torch_tensorrt.fx.converters.converter_utils import ( |
| 7 | + SourceIR, |
| 8 | + get_trt_tensor, |
| 9 | + set_layer_name, |
| 10 | + to_numpy, |
| 11 | +) |
| 12 | +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor |
| 13 | + |
| 14 | + |
| 15 | +def linear( |
| 16 | + network: TRTNetwork, |
| 17 | + target: Union[Target, str], |
| 18 | + source_ir: Optional[SourceIR], |
| 19 | + name: str, |
| 20 | + input: TRTTensor, |
| 21 | + weight: Union[TRTTensor, torch.Tensor], |
| 22 | + bias: Optional[Union[TRTTensor, torch.Tensor]], |
| 23 | +) -> TRTTensor: |
| 24 | + """ |
| 25 | + TensorRT fully connected layer implicitly flatten last three dimensions at |
| 26 | + the start and implicitly reshape the result to (K, 1, 1) at the end. |
| 27 | +
|
| 28 | + e.g. If input is (N, C, H, W), first it gets flatten to (N, C*H*W). Then after |
| 29 | + going through fully connected operation, it becomes (N, K). Before sending it |
| 30 | + out, it gets reshaped into (N, K, 1, 1) and this is the final output. |
| 31 | +
|
| 32 | + TODO: We can optimize this to get rid of unneccesary transformation. |
| 33 | + """ |
| 34 | + |
| 35 | + if not isinstance(input, trt.tensorrt.ITensor): |
| 36 | + raise RuntimeError( |
| 37 | + f"Linear received input {input} that is not part " "of the TensorRT region!" |
| 38 | + ) |
| 39 | + |
| 40 | + # reshape the input to (*, X, 1, 1) |
| 41 | + pre_shuffle_layer = network.add_shuffle(input) |
| 42 | + pre_shuffle_layer.reshape_dims = tuple(input.shape) + (1, 1) |
| 43 | + set_layer_name(pre_shuffle_layer, target, f"{name}_pre_shuffle", source_ir) |
| 44 | + |
| 45 | + # Process bias terms |
| 46 | + if isinstance(bias, torch.Tensor): |
| 47 | + # Transform the bias constant into a Numpy array |
| 48 | + bias = to_numpy(bias) |
| 49 | + |
| 50 | + elif isinstance(bias, TRTTensor): |
| 51 | + bias = get_trt_tensor(network, bias, f"{name}_bias") |
| 52 | + |
| 53 | + elif bias is not None: |
| 54 | + raise RuntimeError( |
| 55 | + f"Linear layer {name} has bias of type {type(bias)}, Expected Torch Tensor or TRT Tensor" |
| 56 | + ) |
| 57 | + |
| 58 | + # Process weight terms |
| 59 | + if network.has_explicit_precision or isinstance(weight, TRTTensor): |
| 60 | + weight = get_trt_tensor(network, weight, f"{name}_weight") |
| 61 | + |
| 62 | + elif isinstance(weight, torch.Tensor): |
| 63 | + # Transform the weight constant into a Numpy array |
| 64 | + weight = to_numpy(weight) |
| 65 | + |
| 66 | + else: |
| 67 | + raise RuntimeError( |
| 68 | + f"Linear layer {name} has weight of type {type(weight)}, Expect Optional[Tensor]" |
| 69 | + ) |
| 70 | + |
| 71 | + # add fully connected layer |
| 72 | + fully_connected_layer = network.add_fully_connected( |
| 73 | + input=pre_shuffle_layer.get_output(0), |
| 74 | + num_outputs=weight.shape[0], |
| 75 | + kernel=trt.Weights() if isinstance(weight, TRTTensor) else weight, |
| 76 | + bias=trt.Weights() if isinstance(bias, TRTTensor) else bias, |
| 77 | + ) |
| 78 | + set_layer_name(fully_connected_layer, target, f"{name}_linear", source_ir) |
| 79 | + |
| 80 | + # reshape the output from (*, K, 1, 1) to (*, K) |
| 81 | + post_shuffle_layer = network.add_shuffle(fully_connected_layer.get_output(0)) |
| 82 | + post_shuffle_layer.reshape_dims = tuple(input.shape[:-1]) + (weight.shape[0],) |
| 83 | + set_layer_name(post_shuffle_layer, target, f"{name}_post_shuffle", source_ir) |
| 84 | + |
| 85 | + return post_shuffle_layer.get_output(0) |
0 commit comments