Description
Context
Currently, the decomposition system in compile, soon to be integrated into export, is based on the core_aten_decompositions
set, as here:
This set of decompositions, while exhaustive, decomposes certain operations which we would like to preserve, including
aten.cudnn_batch_norm
and other batch norm implementations. Such implementations, when combined with utilities for freezing, such as those discussed in #2124, #2128, can greatly improve performance on networks with these operators.
As a bonus, the pre-aot decompositions, as here, may be more easily removed/reduced, since it is difficult to integrate these into export currently.
Proposal
Introduce a new decomposition dictionary which does one of two things:
1. Selectively remove operators from the core_aten_decompositions
dictionary, to ensure these operators are not decomposed and will stay in their initial state. This allows us to continue using the most up-to-date decompositions dictionary while selectively removing the operators which are preferable to stay as-is.
2. Selectively keep operators in the core_aten_decompositions
dictionary, to ensure the set of decompositions is fully defined and does not depend on the state of PyTorch's dictionary. This gives more control over the set of decompositions allowed, but requires that modifications are made to the decompositions library more often, since PyTorch may introduce new desirable decompositions often.