Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ changelog does not include internal changes that do not affect the user.

### Added

- Added `STCH` from [Smooth Tchebycheff Scalarization for Multi-Objective
Optimization](https://openreview.net/pdf?id=m4dO5L6eCp), a `Scalarizer` that combines the input
tensor of values into a smooth approximation of their (weighted, shifted) maximum.
- Added `MoDoWeighting` from [Three-Way Trade-Off in Multi-Objective Learning: Optimization, Generalization and Conflict-Avoidance](https://www.jmlr.org/papers/volume25/23-1287/23-1287.pdf) (JMLR 2024). It is a stateful `Weighting` that maintains task weights across calls via a simplex-projected gradient step on a cross-batch matrix `G = J_1 @ J_2.T`, computed from two independent mini-batches using `autojac.jac`.
- Added `GeometricMean` (also known as GLS) studied in [MultiNet++: Multi-Stream Feature
Aggregation and Geometric Loss Strategy for Multi-Task
Expand Down
1 change: 1 addition & 0 deletions docs/source/docs/scalarization/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,5 @@ Abstract base class
geometric_mean.rst
mean.rst
random.rst
stch.rst
sum.rst
7 changes: 7 additions & 0 deletions docs/source/docs/scalarization/stch.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
:hide-toc:

STCH
====

.. autoclass:: torchjd.scalarization.STCH
:members: __call__
3 changes: 2 additions & 1 deletion src/torchjd/scalarization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from ._mean import Mean
from ._random import Random
from ._scalarizer_base import Scalarizer
from ._stch import STCH
from ._sum import Sum

__all__ = ["Constant", "GeometricMean", "Mean", "Random", "Scalarizer", "Sum"]
__all__ = ["Constant", "GeometricMean", "Mean", "Random", "Scalarizer", "STCH", "Sum"]
88 changes: 88 additions & 0 deletions src/torchjd/scalarization/_stch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import torch
from torch import Tensor

from ._scalarizer_base import Scalarizer


class STCH(Scalarizer):
r"""
:class:`~torchjd.scalarization.Scalarizer` that combines the input tensor of values using smooth
Tchebycheff scalarization, as defined in `Smooth Tchebycheff Scalarization for Multi-Objective
Optimization <https://openreview.net/pdf?id=m4dO5L6eCp>`_.

It returns

.. math::
\mu \log \sum_{i=1}^m \exp\left(\frac{\lambda_i (f_i - z_i^*)}{\mu}\right),

a smooth approximation of the (non-differentiable) weighted maximum
:math:`\max_i \lambda_i (f_i - z_i^*)` that becomes tighter as ``mu`` decreases.

Following the paper's notation:

- :math:`f_i` is the :math:`i`-th input value (the :math:`i`-th objective),
- :math:`m` is the number of objectives (the number of elements of the input),
- :math:`\lambda_i` is its preference weight (the ``weights`` parameter),
- :math:`z_i^*` is the :math:`i`-th component of the ideal point (the ``reference`` parameter),
- :math:`\mu` is the smoothing parameter (the ``mu`` parameter).

:param mu: The smoothing parameter :math:`\mu`. Must be strictly positive. Smaller values make
the scalarization closer to the maximum. The paper evaluates :math:`\mu \in \{0.01, 0.1,
0.5, 1\}` and reports that a small :math:`\mu` works reasonably well, while no single value
is best across all problems.
:param weights: The preference vector :math:`\lambda` applied to the values (in the paper, on
the probability simplex). If ``None``, a uniform preference summing to one is used. If
provided, it must have the same shape as the values passed at call time.
:param reference: The ideal point :math:`z^*` subtracted from the values. If ``None``, no shift
is applied. If provided, it must have the same shape as the values passed at call time.
"""

def __init__(
self,
mu: float,
weights: Tensor | None = None,
reference: Tensor | None = None,
) -> None:
if mu <= 0.0:
raise ValueError(f"Parameter `mu` should be strictly positive. Found `mu = {mu}`.")

super().__init__()
self.mu = mu
self.weights = weights
self.reference = reference

def forward(self, values: Tensor, /) -> Tensor:
if self.weights is not None and self.weights.shape != values.shape:
raise ValueError(
f"Parameter `weights` should have the same shape as `values`. Found "
f"`weights.shape = {tuple(self.weights.shape)}` and `values.shape = "
f"{tuple(values.shape)}`."
)
if self.reference is not None and self.reference.shape != values.shape:
raise ValueError(
f"Parameter `reference` should have the same shape as `values`. Found "
f"`reference.shape = {tuple(self.reference.shape)}` and `values.shape = "
f"{tuple(values.shape)}`."
)

if self.weights is None:
weights = torch.full_like(values, 1.0 / values.numel())
else:
weights = self.weights

shifted = values if self.reference is None else values - self.reference

# Center the weighted values before dividing by mu (Appendix B.1 of the paper). This keeps
# the largest exponent at 0 so the `/ mu` step never overflows for large values and small
# mu. Adding `max_y` back makes it value-preserving: the result and its gradient are
# mathematically identical to `mu * logsumexp(weights * shifted / mu)`.
y = weights * shifted
max_y = y.max()
exponents = (y - max_y) / self.mu
return self.mu * torch.logsumexp(exponents.flatten(), dim=-1) + max_y

def __repr__(self) -> str:
return (
f"{self.__class__.__name__}(mu={self.mu}, weights={self.weights!r}, "
f"reference={self.reference!r})"
)
82 changes: 82 additions & 0 deletions tests/unit/scalarization/test_stch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import torch
from pytest import mark, raises
from torch import Tensor
from utils.tensors import tensor_

from torchjd.scalarization import STCH

from ._asserts import (
assert_grad_flow,
assert_permutation_invariant,
assert_returns_scalar,
)
from ._inputs import all_inputs


def test_value_default() -> None:
# Uniform weights, no reference: mu * logsumexp([0, 0]) = log(2).
out = STCH(mu=1.0)(tensor_([0.0, 0.0]))
torch.testing.assert_close(out, torch.log(tensor_(2.0)))


def test_value_with_weights() -> None:
# weights = [1, 1] on values [1, 1]: mu * logsumexp([1, 1]) = 1 + log(2).
out = STCH(mu=1.0, weights=tensor_([1.0, 1.0]))(tensor_([1.0, 1.0]))
torch.testing.assert_close(out, 1.0 + torch.log(tensor_(2.0)))


def test_value_with_reference() -> None:
# reference shifts values to [0, 0], so the result collapses back to log(2).
out = STCH(mu=1.0, weights=tensor_([1.0, 1.0]), reference=tensor_([1.0, 1.0]))(
tensor_([1.0, 1.0])
)
torch.testing.assert_close(out, torch.log(tensor_(2.0)))


@mark.parametrize("losses", all_inputs)
def test_expected_structure(losses: Tensor) -> None:
assert_returns_scalar(STCH(mu=1.0), losses)


@mark.parametrize("losses", all_inputs)
def test_grad_flow(losses: Tensor) -> None:
assert_grad_flow(STCH(mu=1.0), losses)


@mark.parametrize("losses", all_inputs)
def test_permutation_invariant(losses: Tensor) -> None:
# With uniform weights and no reference, STCH is symmetric in its inputs.
assert_permutation_invariant(STCH(mu=1.0), losses)


def test_does_not_overflow_for_large_values_and_small_mu() -> None:
# `weights * values / mu` would overflow to inf before logsumexp can stabilize it. The
# value-preserving centering keeps the result finite and equal to the dominant (max) term.
values = tensor_([1e30, 2e30, 3e30])
out = STCH(mu=1e-10)(values)
assert out.isfinite()
torch.testing.assert_close(out, tensor_(1e30)) # 3e30 weighted by the uniform 1/3.


@mark.parametrize("mu", [0.0, -1.0])
def test_raises_on_non_positive_mu(mu: float) -> None:
with raises(ValueError):
STCH(mu=mu)


def test_raises_on_weights_shape_mismatch() -> None:
scalarizer = STCH(mu=1.0, weights=tensor_([1.0, 1.0, 1.0]))
with raises(ValueError):
scalarizer(tensor_([1.0, 1.0]))


def test_raises_on_reference_shape_mismatch() -> None:
scalarizer = STCH(mu=1.0, reference=tensor_([1.0, 1.0, 1.0]))
with raises(ValueError):
scalarizer(tensor_([1.0, 1.0]))


def test_representations() -> None:
s = STCH(mu=0.5)
assert repr(s) == "STCH(mu=0.5, weights=None, reference=None)"
assert str(s) == "STCH"
Loading