diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py index c3d829fde8cf..1a45a31cd750 100644 --- a/src/diffusers/quantizers/quantization_config.py +++ b/src/diffusers/quantizers/quantization_config.py @@ -476,8 +476,56 @@ def post_init(self): from torchao.quantization.quant_api import AOBaseConfig if not isinstance(self.quant_type, AOBaseConfig): + if isinstance(self.quant_type, str): + raise TypeError(self._build_string_quant_type_error(self.quant_type)) raise TypeError(f"quant_type must be an AOBaseConfig instance, got {type(self.quant_type).__name__}") + @staticmethod + def _build_string_quant_type_error(quant_type: str) -> str: + """Build a migration-guidance error for legacy string ``quant_type`` values. + + Older diffusers releases accepted lowercase strings such as ``"int8_weight_only"`` or ``"float8dq_e4m3_row"``. + That path was removed (see PR #13291) in favour of passing an ``AOBaseConfig`` subclass instance from + ``torchao.quantization`` directly. Users on torchao >= 0.16 also hit this because the legacy lowercase + factories were removed upstream. This helper surfaces the rename so users can self-migrate. + """ + # Map common legacy strings to their torchao Config-class replacements. We deliberately + # do not auto-instantiate the Config; the new API exposes options (granularity, dtype, + # version, ...) the legacy strings hard-coded, and silent defaults would be surprising. + legacy_to_config = { + "int4wo": "Int4WeightOnlyConfig", + "int4_weight_only": "Int4WeightOnlyConfig", + "int8wo": "Int8WeightOnlyConfig", + "int8_weight_only": "Int8WeightOnlyConfig", + "int8dq": "Int8DynamicActivationInt8WeightConfig", + "int8_dynamic_activation_int8_weight": "Int8DynamicActivationInt8WeightConfig", + "float8wo": "Float8WeightOnlyConfig", + "float8wo_e4m3": "Float8WeightOnlyConfig", + "float8wo_e5m2": "Float8WeightOnlyConfig", + "float8_weight_only": "Float8WeightOnlyConfig", + "float8dq": "Float8DynamicActivationFloat8WeightConfig", + "float8dq_e4m3": "Float8DynamicActivationFloat8WeightConfig", + "float8dq_e4m3_row": "Float8DynamicActivationFloat8WeightConfig", + "float8dq_e4m3_tensor": "Float8DynamicActivationFloat8WeightConfig", + "float8_dynamic_activation_float8_weight": "Float8DynamicActivationFloat8WeightConfig", + "float8_static_activation_float8_weight": "Float8StaticActivationFloat8WeightConfig", + } + suggestion = legacy_to_config.get(quant_type) + message = ( + f"TorchAoConfig no longer accepts string quant_type values (got {quant_type!r}); " + "pass an AOBaseConfig instance from torchao.quantization instead." + ) + if suggestion is not None: + message += ( + f" For {quant_type!r}, use " + f"`from torchao.quantization import {suggestion}; TorchAoConfig({suggestion}())`." + ) + message += ( + " See https://huggingface.co/docs/diffusers/main/en/quantization/torchao for " + "the full list of supported AOBaseConfig classes." + ) + return message + def to_dict(self): """Convert configuration to a dictionary.""" d = super().to_dict() diff --git a/tests/quantization/torchao/test_torchao.py b/tests/quantization/torchao/test_torchao.py index 8a811cfc1c73..842b3c9e10d7 100644 --- a/tests/quantization/torchao/test_torchao.py +++ b/tests/quantization/torchao/test_torchao.py @@ -112,6 +112,31 @@ def test_post_init_check(self): with self.assertRaises(TypeError): _ = TorchAoConfig(42) + def test_string_quant_type_error_includes_migration_hint(self): + """ + Passing a legacy string quant_type should raise TypeError and the message should name the + replacement AOBaseConfig class so users on torchao >= 0.16 (where the legacy lowercase + factories were removed) can self-migrate. See issues #13286 and #13266. + """ + legacy_to_config = { + "int8_weight_only": "Int8WeightOnlyConfig", + "int8wo": "Int8WeightOnlyConfig", + "float8_weight_only": "Float8WeightOnlyConfig", + "float8dq_e4m3_row": "Float8DynamicActivationFloat8WeightConfig", + "float8_dynamic_activation_float8_weight": "Float8DynamicActivationFloat8WeightConfig", + } + for legacy, config_name in legacy_to_config.items(): + with self.assertRaises(TypeError) as cm: + TorchAoConfig(legacy) + message = str(cm.exception) + self.assertIn(repr(legacy), message) + self.assertIn(config_name, message) + + # Strings without a known mapping should still raise TypeError and point at the docs. + with self.assertRaises(TypeError) as cm: + TorchAoConfig("not_a_real_quant_type") + self.assertIn("AOBaseConfig", str(cm.exception)) + def test_repr(self): """ Check that there is no error in the repr