Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions src/diffusers/quantizers/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,8 +476,56 @@ def post_init(self):
from torchao.quantization.quant_api import AOBaseConfig

if not isinstance(self.quant_type, AOBaseConfig):
if isinstance(self.quant_type, str):
raise TypeError(self._build_string_quant_type_error(self.quant_type))
raise TypeError(f"quant_type must be an AOBaseConfig instance, got {type(self.quant_type).__name__}")

@staticmethod
def _build_string_quant_type_error(quant_type: str) -> str:
"""Build a migration-guidance error for legacy string ``quant_type`` values.

Older diffusers releases accepted lowercase strings such as ``"int8_weight_only"`` or ``"float8dq_e4m3_row"``.
That path was removed (see PR #13291) in favour of passing an ``AOBaseConfig`` subclass instance from
``torchao.quantization`` directly. Users on torchao >= 0.16 also hit this because the legacy lowercase
factories were removed upstream. This helper surfaces the rename so users can self-migrate.
"""
# Map common legacy strings to their torchao Config-class replacements. We deliberately
# do not auto-instantiate the Config; the new API exposes options (granularity, dtype,
# version, ...) the legacy strings hard-coded, and silent defaults would be surprising.
legacy_to_config = {
"int4wo": "Int4WeightOnlyConfig",
"int4_weight_only": "Int4WeightOnlyConfig",
"int8wo": "Int8WeightOnlyConfig",
"int8_weight_only": "Int8WeightOnlyConfig",
"int8dq": "Int8DynamicActivationInt8WeightConfig",
"int8_dynamic_activation_int8_weight": "Int8DynamicActivationInt8WeightConfig",
"float8wo": "Float8WeightOnlyConfig",
"float8wo_e4m3": "Float8WeightOnlyConfig",
"float8wo_e5m2": "Float8WeightOnlyConfig",
"float8_weight_only": "Float8WeightOnlyConfig",
"float8dq": "Float8DynamicActivationFloat8WeightConfig",
"float8dq_e4m3": "Float8DynamicActivationFloat8WeightConfig",
"float8dq_e4m3_row": "Float8DynamicActivationFloat8WeightConfig",
"float8dq_e4m3_tensor": "Float8DynamicActivationFloat8WeightConfig",
"float8_dynamic_activation_float8_weight": "Float8DynamicActivationFloat8WeightConfig",
"float8_static_activation_float8_weight": "Float8StaticActivationFloat8WeightConfig",
}
suggestion = legacy_to_config.get(quant_type)
message = (
f"TorchAoConfig no longer accepts string quant_type values (got {quant_type!r}); "
"pass an AOBaseConfig instance from torchao.quantization instead."
)
if suggestion is not None:
message += (
f" For {quant_type!r}, use "
f"`from torchao.quantization import {suggestion}; TorchAoConfig({suggestion}())`."
)
message += (
" See https://huggingface.co/docs/diffusers/main/en/quantization/torchao for "
"the full list of supported AOBaseConfig classes."
)
return message

def to_dict(self):
"""Convert configuration to a dictionary."""
d = super().to_dict()
Expand Down
25 changes: 25 additions & 0 deletions tests/quantization/torchao/test_torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,31 @@ def test_post_init_check(self):
with self.assertRaises(TypeError):
_ = TorchAoConfig(42)

def test_string_quant_type_error_includes_migration_hint(self):
"""
Passing a legacy string quant_type should raise TypeError and the message should name the
replacement AOBaseConfig class so users on torchao >= 0.16 (where the legacy lowercase
factories were removed) can self-migrate. See issues #13286 and #13266.
"""
legacy_to_config = {
"int8_weight_only": "Int8WeightOnlyConfig",
"int8wo": "Int8WeightOnlyConfig",
"float8_weight_only": "Float8WeightOnlyConfig",
"float8dq_e4m3_row": "Float8DynamicActivationFloat8WeightConfig",
"float8_dynamic_activation_float8_weight": "Float8DynamicActivationFloat8WeightConfig",
}
for legacy, config_name in legacy_to_config.items():
with self.assertRaises(TypeError) as cm:
TorchAoConfig(legacy)
message = str(cm.exception)
self.assertIn(repr(legacy), message)
self.assertIn(config_name, message)

# Strings without a known mapping should still raise TypeError and point at the docs.
with self.assertRaises(TypeError) as cm:
TorchAoConfig("not_a_real_quant_type")
self.assertIn("AOBaseConfig", str(cm.exception))

def test_repr(self):
"""
Check that there is no error in the repr
Expand Down
Loading