Skip to content

Importance weights

When to use importance weights

Use importance weights when you know that source and target have different feature distributions — covariate shift — and you want the shift test to focus on the region where both groups overlap rather than penalising samples that are simply foreign to the other group. If you have no prior knowledge of covariate shift, omit weights.

Choosing a mode

Mode What it does
mode="source" Down-weights source samples foreign to target. Target samples keep unit weight.
mode="target" Down-weights target samples foreign to source. Source samples keep unit weight.
mode="both" Down-weights outliers in both groups; focuses the test on common support.

lambda_ controls numerical stability: 0.0 is the plain density ratio (IWERM); 1.0 is uniform weights (no correction). The default 0.5 is a safe starting point.

Weights for each active group are automatically normalized to sum to that group's sample size. In mode="both", source and target are normalized independently. Non-active groups always receive unit weights.

For guidance on which mode fits your scenario, see Why importance weights stabilise shift detection.

Connecting weights to a shift test

Call contextual_weights to build per-sample weights, then pass the result as weights= to test_shift or test_adverse_shift:

import samesame
from samesame.weights import contextual_weights

weights = contextual_weights(
    source_prob=source_domain_probs,
    target_prob=target_domain_probs,
    mode="source",
)
result = samesame.test_shift(
    source=source_scores,
    target=target_scores,
    weights=weights,
)

source_prob and target_prob are the domain probabilities for source and target samples separately. The prior ratio is inferred automatically from their lengths. See Weighting strategies for a quick-reference comparison of all three approaches.

For a worked end-to-end example, see the tutorial Adjust for covariate shift with importance weights. For the conceptual background — why density ratios can become extreme and how lambda_ tames them — see Why importance weights stabilise shift detection.


Sample weight builders for covariate shift adaptation.

ContextualWeights dataclass

Importance weights for source and target groups, used to correct for covariate shift between source and target during a shift test.

Attributes:

Name Type Description
source NDArray[float64]

Importance weights for source samples, normalized to sum to len(source).

target NDArray[float64]

Importance weights for target samples, normalized to sum to len(target).

Source code in src/samesame/weights.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
@dataclass(frozen=True)
class ContextualWeights:
    """Importance weights for source and target groups, used to
    correct for covariate shift between source and target during a shift test.

    Attributes
    ----------
    source : NDArray[np.float64]
        Importance weights for source samples, normalized to sum to
        ``len(source)``.
    target : NDArray[np.float64]
        Importance weights for target samples, normalized to sum to
        ``len(target)``.
    """

    source: NDArray[np.float64]
    target: NDArray[np.float64]

contextual_weights(*, source_prob, target_prob, mode='source', lambda_=0.5)

Build context-aware sample weights for shift testing.

Computes RIW weights from domain probabilities. The prior ratio is always inferred from the lengths of source_prob and target_prob.

Parameters:

Name Type Description Default
source_prob NDArray

Domain probabilities for source samples — probability, output by a domain classifier, that each source observation belongs to the target group. Must be in the open interval (0, 1).

required
target_prob NDArray

Domain probabilities for target samples — probability, output by a domain classifier, that each target observation belongs to the target group. Must be in the open interval (0, 1).

required
mode ('source', 'target', 'both')

Context-aware weighting mode — controls which group's samples are reweighted:

  • 'source': reweight source samples only (default). Use when correcting the source distribution to match target.
  • 'target': reweight target samples only. Use when correcting the target distribution to match source.
  • 'both': reweight both groups simultaneously. Use when both groups contain low-overlap outliers.
'source'
lambda_ float

RIW blending coefficient in [0, 1] controlling the trade-off between correction strength and variance stability. 0.0 gives plain density-ratio weights (maximum correction, highest variance); 1.0 gives uniform weights (no correction). Default 0.5 is a balanced starting point for most applications.

0.5

Returns:

Type Description
ContextualWeights

A frozen dataclass with .source and .target weight arrays. Weights for each active group are normalized so they sum to that group's sample size. Samples not targeted by mode receive weight 1.

Raises:

Type Description
ValueError

If any value in source_prob or target_prob is outside (0, 1).

ValueError

If lambda_ is outside [0, 1].

ValueError

If mode is not one of 'source', 'target', 'both'.

ValueError

If source_prob or target_prob is empty.

Examples:

>>> import numpy as np
>>> from samesame.weights import contextual_weights
>>> source_prob = np.array([0.25, 0.4])
>>> target_prob = np.array([0.6, 0.75])
>>> w = contextual_weights(source_prob=source_prob, target_prob=target_prob)
>>> np.round(w.source, 4)
array([0.7692, 1.2308])
>>> np.round(w.target, 4)
array([1., 1.])
>>> w2 = contextual_weights(source_prob=source_prob, target_prob=target_prob, mode="both")
>>> np.round(w2.source, 4)
array([0.7692, 1.2308])
>>> np.round(w2.target, 4)
array([1.2308, 0.7692])
Source code in src/samesame/weights.py
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
def contextual_weights(
    *,
    source_prob: NDArray,
    target_prob: NDArray,
    mode: WeightingMode = "source",
    lambda_: float = 0.5,
) -> ContextualWeights:
    """Build context-aware sample weights for shift testing.

    Computes RIW weights from domain probabilities.
    The prior ratio is always inferred from the lengths of ``source_prob``
    and ``target_prob``.

    Parameters
    ----------
    source_prob : NDArray
        Domain probabilities for source samples — probability, output by a
        domain classifier, that each source observation belongs to the target
        group. Must be in the open interval (0, 1).
    target_prob : NDArray
        Domain probabilities for target samples — probability, output by a
        domain classifier, that each target observation belongs to the target
        group. Must be in the open interval (0, 1).
    mode : {'source', 'target', 'both'}, optional
        Context-aware weighting mode — controls which group's samples are
        reweighted:

        - ``'source'``: reweight source samples only (default). Use when
          correcting the source distribution to match target.
        - ``'target'``: reweight target samples only. Use when correcting
          the target distribution to match source.
        - ``'both'``: reweight both groups simultaneously. Use when both
          groups contain low-overlap outliers.

    lambda_ : float, optional
        RIW blending coefficient in [0, 1] controlling the trade-off between
        correction strength and variance stability. ``0.0`` gives plain
        density-ratio weights (maximum correction, highest variance); ``1.0``
        gives uniform weights (no correction). Default ``0.5`` is a balanced
        starting point for most applications.

    Returns
    -------
    ContextualWeights
        A frozen dataclass with ``.source`` and ``.target`` weight arrays.
        Weights for each active group are normalized so they sum to that
        group's sample size. Samples not targeted by ``mode`` receive weight 1.

    Raises
    ------
    ValueError
        If any value in ``source_prob`` or ``target_prob`` is outside (0, 1).
    ValueError
        If ``lambda_`` is outside [0, 1].
    ValueError
        If ``mode`` is not one of ``'source'``, ``'target'``, ``'both'``.
    ValueError
        If ``source_prob`` or ``target_prob`` is empty.

    Examples
    --------
    >>> import numpy as np
    >>> from samesame.weights import contextual_weights
    >>> source_prob = np.array([0.25, 0.4])
    >>> target_prob = np.array([0.6, 0.75])
    >>> w = contextual_weights(source_prob=source_prob, target_prob=target_prob)
    >>> np.round(w.source, 4)
    array([0.7692, 1.2308])
    >>> np.round(w.target, 4)
    array([1., 1.])
    >>> w2 = contextual_weights(source_prob=source_prob, target_prob=target_prob, mode="both")
    >>> np.round(w2.source, 4)
    array([0.7692, 1.2308])
    >>> np.round(w2.target, 4)
    array([1.2308, 0.7692])
    """
    source_prob = np.asarray(source_prob, dtype=np.float64)
    target_prob = np.asarray(target_prob, dtype=np.float64)
    n_source = len(source_prob)
    n_target = len(target_prob)
    if n_source == 0 or n_target == 0:
        raise ValueError("source_prob and target_prob must both be non-empty.")
    if lambda_ < 0.0 or lambda_ > 1.0:
        raise ValueError("lambda_ must be in [0, 1].")
    _validate_mode(mode)
    group_balance = n_source / n_target
    source_dr = _density_ratio(source_prob, group_balance=group_balance)
    target_dr = _density_ratio(target_prob, group_balance=group_balance)
    out_source = np.ones(n_source, dtype=np.float64)
    out_target = np.ones(n_target, dtype=np.float64)
    if mode in ("source", "both"):
        out_source = _riw(source_dr, lam=lambda_)
        out_source = out_source * (n_source / out_source.sum())
    if mode in ("target", "both"):
        out_target = _inverse_riw(target_dr, lam=lambda_)
        out_target = out_target * (n_target / out_target.sum())
    return ContextualWeights(source=out_source, target=out_target)