Skip to content

Importance weights

Use samesame.weights when you want a shift test to focus on the part of feature space that source and target actually share.

Choose an approach

Situation What to do
No weighting needed Omit weights
You already have sample weights Wrap them in ImportanceWeights(source=..., target=...)
You have domain-classifier probabilities Build weights with from_domain_probabilities(...)
import samesame as ss
from samesame.weights import ImportanceWeights, from_domain_probabilities

result = ss.shift.detect_shift(source_scores, target_scores)

result = ss.shift.detect_shift(
    source_scores,
    target_scores,
    weights=ImportanceWeights(
        source=source_weights,
        target=target_weights,
    ),
)

weights = from_domain_probabilities(
    source_prob=source_domain_probs,
    target_prob=target_domain_probs,
    mode="source",
)

result = ss.shift.detect_harm(
    source_scores,
    target_scores,
    direction="higher-is-worse",
    weights=weights,
)

When from_domain_probabilities(...) helps

Use it when source and target do not overlap well and you want the comparison to emphasize common support rather than low-overlap outliers.

It takes three main controls:

  • source_prob and target_prob, passed separately
  • mode, which decides whether to reweight source, target, or both
  • lambda_, which trades off correction strength against stability

Choosing a mode

Mode What it emphasizes
mode="source" overlap from the source side
mode="target" overlap from the target side
mode="both" common support from both sides

lambda_=0.5 is a practical default. Lower values correct more aggressively. Higher values move closer to uniform weights.

For a worked example, see Focus on shared support with importance weights. For the intuition behind the formulas, see When importance weights help.

API

Public importance-weight seam.

ImportanceWeights dataclass

Importance weights for Source and Target groups.

Attributes:

Name Type Description
source NDArray[float64]

Importance weights for source samples, normalized to sum to len(source).

target NDArray[float64]

Importance weights for target samples, normalized to sum to len(target).

Source code in src/samesame/weights.py
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
@dataclass(frozen=True)
class ImportanceWeights:
    """Importance weights for Source and Target groups.

    Attributes
    ----------
    source : NDArray[np.float64]
        Importance weights for source samples, normalized to sum to
        ``len(source)``.
    target : NDArray[np.float64]
        Importance weights for target samples, normalized to sum to
        ``len(target)``.
    """

    source: NDArray[np.float64]
    target: NDArray[np.float64]

from_domain_probabilities(*, source_prob, target_prob, mode='source', lambda_=0.5)

Build Importance weights from Domain probabilities.

Computes RIW weights from domain probabilities. The prior ratio is always inferred from the lengths of source_prob and target_prob.

Parameters:

Name Type Description Default
source_prob NDArray

Domain probabilities for source samples — probability, output by a domain classifier, that each source observation belongs to the target group. Must be in the open interval (0, 1).

required
target_prob NDArray

Domain probabilities for target samples — probability, output by a domain classifier, that each target observation belongs to the target group. Must be in the open interval (0, 1).

required
mode ('source', 'target', 'both')

Importance-weighting mode — controls which group's samples are reweighted:

  • 'source': reweight source samples only (default). Use when correcting the source distribution to match target.
  • 'target': reweight target samples only. Use when correcting the target distribution to match source.
  • 'both': reweight both groups simultaneously. Use when both groups contain low-overlap outliers.
'source'
lambda_ float

RIW blending coefficient in [0, 1] controlling the trade-off between correction strength and variance stability. 0.0 gives plain density-ratio weights (maximum correction, highest variance); 1.0 gives uniform weights (no correction). Default 0.5 is a balanced starting point for most applications.

0.5

Returns:

Type Description
ImportanceWeights

A frozen dataclass with .source and .target weight arrays. Weights for each active group are normalized so they sum to that group's sample size. Samples not targeted by mode receive weight 1.

Raises:

Type Description
ValueError

If any value in source_prob or target_prob is outside (0, 1).

ValueError

If lambda_ is outside [0, 1].

ValueError

If mode is not one of 'source', 'target', 'both'.

ValueError

If source_prob or target_prob is empty.

Examples:

>>> import numpy as np
>>> from samesame.weights import from_domain_probabilities
>>> source_prob = np.array([0.25, 0.4])
>>> target_prob = np.array([0.6, 0.75])
>>> w = from_domain_probabilities(source_prob=source_prob, target_prob=target_prob)
>>> np.round(w.source, 4)
array([0.7692, 1.2308])
>>> np.round(w.target, 4)
array([1., 1.])
>>> w2 = from_domain_probabilities(source_prob=source_prob, target_prob=target_prob, mode="both")
>>> np.round(w2.source, 4)
array([0.7692, 1.2308])
>>> np.round(w2.target, 4)
array([1.2308, 0.7692])
Source code in src/samesame/weights.py
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
def from_domain_probabilities(
    *,
    source_prob: NDArray,
    target_prob: NDArray,
    mode: WeightingMode = "source",
    lambda_: float = 0.5,
) -> ImportanceWeights:
    """Build Importance weights from Domain probabilities.

    Computes RIW weights from domain probabilities.
    The prior ratio is always inferred from the lengths of ``source_prob``
    and ``target_prob``.

    Parameters
    ----------
    source_prob : NDArray
        Domain probabilities for source samples — probability, output by a
        domain classifier, that each source observation belongs to the target
        group. Must be in the open interval (0, 1).
    target_prob : NDArray
        Domain probabilities for target samples — probability, output by a
        domain classifier, that each target observation belongs to the target
        group. Must be in the open interval (0, 1).
    mode : {'source', 'target', 'both'}, optional
        Importance-weighting mode — controls which group's samples are
        reweighted:

        - ``'source'``: reweight source samples only (default). Use when
          correcting the source distribution to match target.
        - ``'target'``: reweight target samples only. Use when correcting
          the target distribution to match source.
        - ``'both'``: reweight both groups simultaneously. Use when both
          groups contain low-overlap outliers.

    lambda_ : float, optional
        RIW blending coefficient in [0, 1] controlling the trade-off between
        correction strength and variance stability. ``0.0`` gives plain
        density-ratio weights (maximum correction, highest variance); ``1.0``
        gives uniform weights (no correction). Default ``0.5`` is a balanced
        starting point for most applications.

    Returns
    -------
    ImportanceWeights
        A frozen dataclass with ``.source`` and ``.target`` weight arrays.
        Weights for each active group are normalized so they sum to that
        group's sample size. Samples not targeted by ``mode`` receive weight 1.

    Raises
    ------
    ValueError
        If any value in ``source_prob`` or ``target_prob`` is outside (0, 1).
    ValueError
        If ``lambda_`` is outside [0, 1].
    ValueError
        If ``mode`` is not one of ``'source'``, ``'target'``, ``'both'``.
    ValueError
        If ``source_prob`` or ``target_prob`` is empty.

    Examples
    --------
    >>> import numpy as np
    >>> from samesame.weights import from_domain_probabilities
    >>> source_prob = np.array([0.25, 0.4])
    >>> target_prob = np.array([0.6, 0.75])
    >>> w = from_domain_probabilities(source_prob=source_prob, target_prob=target_prob)
    >>> np.round(w.source, 4)
    array([0.7692, 1.2308])
    >>> np.round(w.target, 4)
    array([1., 1.])
    >>> w2 = from_domain_probabilities(source_prob=source_prob, target_prob=target_prob, mode="both")
    >>> np.round(w2.source, 4)
    array([0.7692, 1.2308])
    >>> np.round(w2.target, 4)
    array([1.2308, 0.7692])
    """
    source_prob = np.asarray(source_prob, dtype=np.float64)
    target_prob = np.asarray(target_prob, dtype=np.float64)
    n_source = len(source_prob)
    n_target = len(target_prob)
    if n_source == 0 or n_target == 0:
        raise ValueError("source_prob and target_prob must both be non-empty.")
    if lambda_ < 0.0 or lambda_ > 1.0:
        raise ValueError("lambda_ must be in [0, 1].")
    validate_mode(mode)

    # Prior ratio: how much more likely a random draw is from source vs target.
    # Inferred from sample sizes rather than supplied explicitly.
    group_balance = n_source / n_target

    # Density ratio r(x) = p(target|x) / p(source|x), derived from the domain
    # classifier probability via Bayes' theorem with the inferred prior ratio.
    source_dr = density_ratio(source_prob, group_balance=group_balance)
    target_dr = density_ratio(target_prob, group_balance=group_balance)

    # Default: leave each group with unit weights (no reweighting).
    out_source = np.ones(n_source, dtype=np.float64)
    out_target = np.ones(n_target, dtype=np.float64)

    if mode in ("source", "both"):
        # RIW formula: r / ((1-λ) + λ·r) blends toward uniform as λ→1.
        out_source = riw(source_dr, lam=lambda_)
        # Normalize so source weights sum to n_source (preserves expected value).
        out_source = out_source * (n_source / out_source.sum())

    if mode in ("target", "both"):
        # Inverse RIW: 1 / (λ + (1-λ)·r) — maps target back to source density.
        out_target = inverse_riw(target_dr, lam=lambda_)
        # Normalize so target weights sum to n_target.
        out_target = out_target * (n_target / out_target.sum())

    return ImportanceWeights(source=out_source, target=out_target)