Importance weights¶
When to use importance weights¶
Use importance weights when you know that source and target have different feature
distributions — covariate shift — and you want the shift test to focus on the region
where both groups overlap rather than penalising samples that are simply foreign to the
other group. If you have no prior knowledge of covariate shift, omit weights.
Choosing a mode¶
| Mode | What it does |
|---|---|
mode="source" |
Down-weights source samples foreign to target. Target samples keep unit weight. |
mode="target" |
Down-weights target samples foreign to source. Source samples keep unit weight. |
mode="both" |
Down-weights outliers in both groups; focuses the test on common support. |
lambda_ controls numerical stability: 0.0 is the plain density ratio (IWERM); 1.0 is
uniform weights (no correction). The default 0.5 is a safe starting point.
Weights for each active group are automatically normalized to sum to that group's sample
size. In mode="both", source and target are normalized independently. Non-active groups
always receive unit weights.
For guidance on which mode fits your scenario, see Why importance weights stabilise shift detection.
Connecting weights to a shift test¶
Call contextual_weights to build per-sample weights, then pass the result as weights=
to test_shift or test_adverse_shift:
import samesame
from samesame.weights import contextual_weights
weights = contextual_weights(
source_prob=source_domain_probs,
target_prob=target_domain_probs,
mode="source",
)
result = samesame.test_shift(
source=source_scores,
target=target_scores,
weights=weights,
)
source_prob and target_prob are the domain probabilities for source and target samples
separately. The prior ratio is inferred automatically from their lengths.
See Weighting strategies for a quick-reference comparison of all three
approaches.
For a worked end-to-end example, see the tutorial
Adjust for covariate shift with importance weights.
For the conceptual background — why density ratios can become extreme and how lambda_
tames them — see
Why importance weights stabilise shift detection.
Sample weight builders for covariate shift adaptation.
ContextualWeights
dataclass
¶
Importance weights for source and target groups, used to correct for covariate shift between source and target during a shift test.
Attributes:
| Name | Type | Description |
|---|---|---|
source |
NDArray[float64]
|
Importance weights for source samples, normalized to sum to
|
target |
NDArray[float64]
|
Importance weights for target samples, normalized to sum to
|
Source code in src/samesame/weights.py
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 | |
contextual_weights(*, source_prob, target_prob, mode='source', lambda_=0.5)
¶
Build context-aware sample weights for shift testing.
Computes RIW weights from domain probabilities.
The prior ratio is always inferred from the lengths of source_prob
and target_prob.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
source_prob
|
NDArray
|
Domain probabilities for source samples — probability, output by a domain classifier, that each source observation belongs to the target group. Must be in the open interval (0, 1). |
required |
target_prob
|
NDArray
|
Domain probabilities for target samples — probability, output by a domain classifier, that each target observation belongs to the target group. Must be in the open interval (0, 1). |
required |
mode
|
('source', 'target', 'both')
|
Context-aware weighting mode — controls which group's samples are reweighted:
|
'source'
|
lambda_
|
float
|
RIW blending coefficient in [0, 1] controlling the trade-off between
correction strength and variance stability. |
0.5
|
Returns:
| Type | Description |
|---|---|
ContextualWeights
|
A frozen dataclass with |
Raises:
| Type | Description |
|---|---|
ValueError
|
If any value in |
ValueError
|
If |
ValueError
|
If |
ValueError
|
If |
Examples:
>>> import numpy as np
>>> from samesame.weights import contextual_weights
>>> source_prob = np.array([0.25, 0.4])
>>> target_prob = np.array([0.6, 0.75])
>>> w = contextual_weights(source_prob=source_prob, target_prob=target_prob)
>>> np.round(w.source, 4)
array([0.7692, 1.2308])
>>> np.round(w.target, 4)
array([1., 1.])
>>> w2 = contextual_weights(source_prob=source_prob, target_prob=target_prob, mode="both")
>>> np.round(w2.source, 4)
array([0.7692, 1.2308])
>>> np.round(w2.target, 4)
array([1.2308, 0.7692])
Source code in src/samesame/weights.py
62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 | |