Skip to content

torch.Tensor.view(dtype) is not supported #2212

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
crcrpar opened this issue Jun 10, 2025 · 0 comments · May be fixed by #2213
Open

torch.Tensor.view(dtype) is not supported #2212

crcrpar opened this issue Jun 10, 2025 · 0 comments · May be fixed by #2213

Comments

@crcrpar
Copy link
Collaborator

crcrpar commented Jun 10, 2025

🐛 Bug

torch.Tensor.view with dtype is not supported while it's seen in e.g. torchao's to_mx function here -- https://github.com/pytorch/ao/blob/v0.10.0/torchao/prototype/mx_formats/mx_tensor.py#L146-L329.

To Reproduce

Code sample

import torch
import thunder


def f(t: torch.Tensor) -> torch.Tensor:
    return t.view(torch.float8_e4m3fn)


if __name__ == "__main__":
    x = torch.testing.make_tensor(
        (4, 4),
        dtype=torch.int8,
        device=torch.device("cuda"),
        requires_grad=False,
    )
    jitted = thunder.jit(f)
    out = jitted(x)

This fails as follows:

Traceback (most recent call last):
  File "/path/to/c.py", line 18, in <module>
    out = jitted(x)
          ^^^^^^^^^
  File "/path/to/thunder/__init__.py", line 830, in wrapped
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/path/to/thunder/__init__.py", line 870, in fn_
    cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/path/to/thunder/__init__.py", line 809, in wrapped
    cache_entry, inps, pro_to_epi = get_computation_and_inputs_fn(*args, **kwargs)
                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/path/to/thunder/core/langctxs.py", line 136, in _fn
    result = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
  File "/path/to/thunder/__init__.py", line 236, in cache_info_wrapper
    res = fn(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^
  File "/path/to/thunder/__init__.py", line 774, in get_computation_and_inputs
    prologue_trc, computation_trc, epilogue_trc = acquire_initial_trace(fn, args, kwargs, cd, cs, ad_hoc_executor)
                                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/path/to/thunder/__init__.py", line 434, in acquire_initial_trace
    jit_results: TraceResults = thunder_general_jit(
                                ^^^^^^^^^^^^^^^^^^^^
  File "/path/to/thunder/core/jit_ext.py", line 2133, in thunder_general_jit
    result = jfn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^
  File "/path/to/thunder/core/interpreter.py", line 7567, in fn_
    raise e
  File "/path/to/thunder/core/interpreter.py", line 7526, in fn_2
    return fn(*args, **kwargs)

  File "/path/to/c.py", line 7, in f
    return t.view(torch.float8_e8m0fnu)
^^^^^^^^^^^^^^^^^^
  File "/path/to/thunder/core/interpreter.py", line 6952, in partial_call_impl
    return partial_function.func(*(partial_function.args + args), **(partial_function.keywords | kwargs))
^^^^^^^^^^
  File "/path/to/thunder/core/interpreter.py", line 1302, in wrapping_wrapper
    res = ufn(*uargs, **ukwargs)
          ^^^^^^^^^^^^^^^^^^^^^^
  File "/path/to/thunder/core/jit_ext.py", line 388, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/path/to/thunder/core/symbol.py", line 320, in __call__
    result = self.meta(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/path/to/thunder/core/langctxs.py", line 136, in _fn
    result = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
  File "/path/to/thunder/torch/__init__.py", line 1398, in view
    return reshape(a, shape)
           ^^^^^^^^^^^^^^^^^
  File "/path/to/thunder/core/symbol.py", line 320, in __call__
    result = self.meta(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/path/to/thunder/core/langctxs.py", line 136, in _fn
    result = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
  File "/path/to/thunder/torch/__init__.py", line 1138, in reshape
    return clang.reshape(a, shape)
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/path/to/thunder/core/langctxs.py", line 136, in _fn
    result = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
  File "/path/to/thunder/clang/__init__.py", line 1069, in reshape
    if l >= 0:
       ^^^^^^
TypeError: '>=' not supported between instances of 'torch.dtype' and 'int'

With thunderfx, thanks to its splitter, it works but it gets fallen back. The module after the split is as follows:

class GraphModule(torch.nn.Module):
    def forward(self, l_t_: "i8[4, 4]"):
        # No stacktrace found for following nodes
        inductor_0 = self.inductor_0(l_t_);  l_t_ = None
        return (inductor_0,)

    class inductor_0(torch.nn.Module):
        def forward(self, l_t_: "i8[4, 4]"):
            view: "f8e4m3fn[4, 4]" = l_t_.view(torch.float8_e4m3fn);  l_t_ = None
            return view

        class _orig_mod(torch.nn.Module):
            def forward(self, l_t_: "i8[4, 4]"):
                view: "f8e4m3fn[4, 4]" = l_t_.view(torch.float8_e4m3fn);  l_t_ = None
                return view

Expected behavior

just works.

@crcrpar crcrpar linked a pull request Jun 10, 2025 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant