Skip to content

Commit 7567ee2

Browse files
authored
Add pytorch_cuda_alloc_conf config to tune VRAM memory allocation (#7673)
## Summary This PR adds a `pytorch_cuda_alloc_conf` config flag to control the torch memory allocator behavior. - `pytorch_cuda_alloc_conf` defaults to `None`, preserving the current behavior. - The configuration options are explained here: https://pytorch.org/docs/stable/notes/cuda.html#optimizing-memory-usage-with-pytorch-cuda-alloc-conf. Tuning this configuration can reduce peak reserved VRAM and improve performance. - Setting `pytorch_cuda_alloc_conf: "backend:cudaMallocAsync"` in `invokeai.yaml` is expected to work well on many systems. This is a good first step for those looking to tune this config. (We may make this the default in the future.) - The optimal configuration seems to be dependent on a number of factors such as device version, VRAM, CUDA kernel version, etc. For now, users will have to experiment with this config to see if it hurts or helps on their systems. In most cases, I expect it to help. ### Memory Tests ``` VAE decode memory usage comparison: - SDXL, fp16, 1024x1024: - `cudaMallocAsync`: allocated=2593 MB, reserved=3200 MB - `native`: allocated=2595 MB, reserved=4418 MB - SDXL, fp32, 1024x1024: - `cudaMallocAsync`: allocated=3982 MB, reserved=5536 MB - `native`: allocated=3982 MB, reserved=7276 MB - SDXL, fp32, 1536x1536: - `cudaMallocAsync`: allocated=8643 MB, reserved=12032 MB - `native`: allocated=8643 MB, reserved=15900 MB ``` ## Related Issues / Discussions N/A ## QA Instructions - [x] Performance tests with `pytorch_cuda_alloc_conf` unset. - [x] Performance tests with `pytorch_cuda_alloc_conf: "backend:cudaMallocAsync"`. ## Merge Plan - [x] Merge #7668 first and change target branch to `main` ## Checklist - [x] _The PR has a short but descriptive title, suitable for a changelog_ - [x] _Tests added / updated (if applicable)_ - [x] _Documentation added / updated (if applicable)_ - [ ] _Updated `What's New` copy (if doing a release after this PR)_
2 parents 7feae5e + 0e632db commit 7567ee2

File tree

5 files changed

+85
-7
lines changed

5 files changed

+85
-7
lines changed

docs/features/low-vram.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ It is possible to fine-tune the settings for best performance or if you still ge
3131
Low-VRAM mode involves 4 features, each of which can be configured or fine-tuned:
3232

3333
- Partial model loading (`enable_partial_loading`)
34+
- PyTorch CUDA allocator config (`pytorch_cuda_alloc_conf`)
3435
- Dynamic RAM and VRAM cache sizes (`max_cache_ram_gb`, `max_cache_vram_gb`)
3536
- Working memory (`device_working_mem_gb`)
3637
- Keeping a RAM weight copy (`keep_ram_copy_of_weights`)
@@ -51,6 +52,16 @@ As described above, you can enable partial model loading by adding this line to
5152
enable_partial_loading: true
5253
```
5354

55+
### PyTorch CUDA allocator config
56+
57+
The PyTorch CUDA allocator's behavior can be configured using the `pytorch_cuda_alloc_conf` config. Tuning the allocator configuration can help to reduce the peak reserved VRAM. The optimal configuration is dependent on many factors (e.g. device type, VRAM, CUDA driver version, etc.), but switching from PyTorch's native allocator to using CUDA's built-in allocator works well on many systems. To try this, add the following line to your `invokeai.yaml` file:
58+
59+
```yaml
60+
pytorch_cuda_alloc_conf: "backend:cudaMallocAsync"
61+
```
62+
63+
A more complete explanation of the available configuration options is [here](https://pytorch.org/docs/stable/notes/cuda.html#optimizing-memory-usage-with-pytorch-cuda-alloc-conf).
64+
5465
### Dynamic RAM and VRAM cache sizes
5566

5667
Loading models from disk is slow and can be a major bottleneck for performance. Invoke uses two model caches - RAM and VRAM - to reduce loading from disk to a minimum.

invokeai/app/run_app.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,7 @@
22

33
from invokeai.app.invocations.load_custom_nodes import load_custom_nodes
44
from invokeai.app.services.config.config_default import get_config
5-
from invokeai.app.util.startup_utils import (
6-
apply_monkeypatches,
7-
check_cudnn,
8-
enable_dev_reload,
9-
find_open_port,
10-
register_mime_types,
11-
)
5+
from invokeai.app.util.torch_cuda_allocator import configure_torch_cuda_allocator
126
from invokeai.backend.util.logging import InvokeAILogger
137
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
148

@@ -32,6 +26,20 @@ def run_app() -> None:
3226

3327
logger = InvokeAILogger.get_logger(config=app_config)
3428

29+
# Configure the torch CUDA memory allocator.
30+
# NOTE: It is important that this happens before torch is imported.
31+
if app_config.pytorch_cuda_alloc_conf:
32+
configure_torch_cuda_allocator(app_config.pytorch_cuda_alloc_conf, logger)
33+
34+
# Import from startup_utils here to avoid importing torch before configure_torch_cuda_allocator() is called.
35+
from invokeai.app.util.startup_utils import (
36+
apply_monkeypatches,
37+
check_cudnn,
38+
enable_dev_reload,
39+
find_open_port,
40+
register_mime_types,
41+
)
42+
3543
# Find an open port, and modify the config accordingly.
3644
orig_config_port = app_config.port
3745
app_config.port = find_open_port(app_config.port)

invokeai/app/services/config/config_default.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ class InvokeAIAppConfig(BaseSettings):
9191
ram: DEPRECATED: This setting is no longer used. It has been replaced by `max_cache_ram_gb`, but most users will not need to use this config since automatic cache size limits should work well in most cases. This config setting will be removed once the new model cache behavior is stable.
9292
vram: DEPRECATED: This setting is no longer used. It has been replaced by `max_cache_vram_gb`, but most users will not need to use this config since automatic cache size limits should work well in most cases. This config setting will be removed once the new model cache behavior is stable.
9393
lazy_offload: DEPRECATED: This setting is no longer used. Lazy-offloading is enabled by default. This config setting will be removed once the new model cache behavior is stable.
94+
pytorch_cuda_alloc_conf: Configure the Torch CUDA memory allocator. This will impact peak reserved VRAM usage and performance. Setting to "backend:cudaMallocAsync" works well on many systems. The optimal configuration is highly dependent on the system configuration (device type, VRAM, CUDA driver version, etc.), so must be tuned experimentally.
9495
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps`
9596
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`
9697
sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.
@@ -169,6 +170,9 @@ class InvokeAIAppConfig(BaseSettings):
169170
vram: Optional[float] = Field(default=None, ge=0, description="DEPRECATED: This setting is no longer used. It has been replaced by `max_cache_vram_gb`, but most users will not need to use this config since automatic cache size limits should work well in most cases. This config setting will be removed once the new model cache behavior is stable.")
170171
lazy_offload: bool = Field(default=True, description="DEPRECATED: This setting is no longer used. Lazy-offloading is enabled by default. This config setting will be removed once the new model cache behavior is stable.")
171172

173+
# PyTorch Memory Allocator
174+
pytorch_cuda_alloc_conf: Optional[str] = Field(default=None, description="Configure the Torch CUDA memory allocator. This will impact peak reserved VRAM usage and performance. Setting to \"backend:cudaMallocAsync\" works well on many systems. The optimal configuration is highly dependent on the system configuration (device type, VRAM, CUDA driver version, etc.), so must be tuned experimentally.")
175+
172176
# DEVICE
173177
device: DEVICE = Field(default="auto", description="Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.")
174178
precision: PRECISION = Field(default="auto", description="Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.")
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import logging
2+
import os
3+
4+
5+
def configure_torch_cuda_allocator(pytorch_cuda_alloc_conf: str, logger: logging.Logger | None = None):
6+
"""Configure the PyTorch CUDA memory allocator. See
7+
https://pytorch.org/docs/stable/notes/cuda.html#optimizing-memory-usage-with-pytorch-cuda-alloc-conf for supported
8+
configurations.
9+
"""
10+
11+
# Raise if the PYTORCH_CUDA_ALLOC_CONF environment variable is already set.
12+
prev_cuda_alloc_conf = os.environ.get("PYTORCH_CUDA_ALLOC_CONF", None)
13+
if prev_cuda_alloc_conf is not None:
14+
raise RuntimeError(
15+
f"Attempted to configure the PyTorch CUDA memory allocator, but PYTORCH_CUDA_ALLOC_CONF is already set to "
16+
f"'{prev_cuda_alloc_conf}'."
17+
)
18+
19+
# Configure the PyTorch CUDA memory allocator.
20+
# NOTE: It is important that this happens before torch is imported.
21+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = pytorch_cuda_alloc_conf
22+
23+
import torch
24+
25+
# Relevant docs: https://pytorch.org/docs/stable/notes/cuda.html#optimizing-memory-usage-with-pytorch-cuda-alloc-conf
26+
if not torch.cuda.is_available():
27+
raise RuntimeError(
28+
"Attempted to configure the PyTorch CUDA memory allocator, but no CUDA devices are available."
29+
)
30+
31+
# Verify that the torch allocator was properly configured.
32+
allocator_backend = torch.cuda.get_allocator_backend()
33+
expected_backend = "cudaMallocAsync" if "cudaMallocAsync" in pytorch_cuda_alloc_conf else "native"
34+
if allocator_backend != expected_backend:
35+
raise RuntimeError(
36+
f"Failed to configure the PyTorch CUDA memory allocator. Expected backend: '{expected_backend}', but got "
37+
f"'{allocator_backend}'. Verify that 1) the pytorch_cuda_alloc_conf is set correctly, and 2) that torch is "
38+
"not imported before calling configure_torch_cuda_allocator()."
39+
)
40+
41+
if logger is not None:
42+
logger.info(f"PyTorch CUDA memory allocator: {torch.cuda.get_allocator_backend()}")
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import pytest
2+
import torch
3+
4+
from invokeai.app.util.torch_cuda_allocator import configure_torch_cuda_allocator
5+
6+
7+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA device.")
8+
def test_configure_torch_cuda_allocator_raises_if_torch_is_already_imported():
9+
"""Test that configure_torch_cuda_allocator() raises a RuntimeError if torch is already imported."""
10+
import torch # noqa: F401
11+
12+
with pytest.raises(RuntimeError, match="Failed to configure the PyTorch CUDA memory allocator."):
13+
configure_torch_cuda_allocator("backend:cudaMallocAsync")

0 commit comments

Comments
 (0)