Importance weights¶
Use samesame.weights when you want a shift test to focus on the part of feature space that source
and target actually share.
Choose an approach¶
| Situation | What to do |
|---|---|
| No weighting needed | Omit weights |
| You already have sample weights | Wrap them in ImportanceWeights(source=..., target=...) |
| You have domain-classifier probabilities | Build weights with from_domain_probabilities(...) |
import samesame as ss
from samesame.weights import ImportanceWeights, from_domain_probabilities
result = ss.shift.detect_shift(source_scores, target_scores)
result = ss.shift.detect_shift(
source_scores,
target_scores,
weights=ImportanceWeights(
source=source_weights,
target=target_weights,
),
)
weights = from_domain_probabilities(
source_prob=source_domain_probs,
target_prob=target_domain_probs,
mode="source",
)
result = ss.shift.detect_harm(
source_scores,
target_scores,
direction="higher-is-worse",
weights=weights,
)
When from_domain_probabilities(...) helps¶
Use it when source and target do not overlap well and you want the comparison to emphasize common support rather than low-overlap outliers.
It takes three main controls:
source_probandtarget_prob, passed separatelymode, which decides whether to reweight source, target, or bothlambda_, which trades off correction strength against stability
Choosing a mode¶
| Mode | What it emphasizes |
|---|---|
mode="source" |
overlap from the source side |
mode="target" |
overlap from the target side |
mode="both" |
common support from both sides |
lambda_=0.5 is a practical default. Lower values correct more aggressively. Higher values move
closer to uniform weights.
For a worked example, see Focus on shared support with importance weights. For the intuition behind the formulas, see When importance weights help.
API¶
Public importance-weight seam.
ImportanceWeights
dataclass
¶
Importance weights for Source and Target groups.
Attributes:
| Name | Type | Description |
|---|---|---|
source |
NDArray[float64]
|
Importance weights for source samples, normalized to sum to
|
target |
NDArray[float64]
|
Importance weights for target samples, normalized to sum to
|
Source code in src/samesame/weights.py
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 | |
from_domain_probabilities(*, source_prob, target_prob, mode='source', lambda_=0.5)
¶
Build Importance weights from Domain probabilities.
Computes RIW weights from domain probabilities.
The prior ratio is always inferred from the lengths of source_prob
and target_prob.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
source_prob
|
NDArray
|
Domain probabilities for source samples — probability, output by a domain classifier, that each source observation belongs to the target group. Must be in the open interval (0, 1). |
required |
target_prob
|
NDArray
|
Domain probabilities for target samples — probability, output by a domain classifier, that each target observation belongs to the target group. Must be in the open interval (0, 1). |
required |
mode
|
('source', 'target', 'both')
|
Importance-weighting mode — controls which group's samples are reweighted:
|
'source'
|
lambda_
|
float
|
RIW blending coefficient in [0, 1] controlling the trade-off between
correction strength and variance stability. |
0.5
|
Returns:
| Type | Description |
|---|---|
ImportanceWeights
|
A frozen dataclass with |
Raises:
| Type | Description |
|---|---|
ValueError
|
If any value in |
ValueError
|
If |
ValueError
|
If |
ValueError
|
If |
Examples:
>>> import numpy as np
>>> from samesame.weights import from_domain_probabilities
>>> source_prob = np.array([0.25, 0.4])
>>> target_prob = np.array([0.6, 0.75])
>>> w = from_domain_probabilities(source_prob=source_prob, target_prob=target_prob)
>>> np.round(w.source, 4)
array([0.7692, 1.2308])
>>> np.round(w.target, 4)
array([1., 1.])
>>> w2 = from_domain_probabilities(source_prob=source_prob, target_prob=target_prob, mode="both")
>>> np.round(w2.source, 4)
array([0.7692, 1.2308])
>>> np.round(w2.target, 4)
array([1.2308, 0.7692])
Source code in src/samesame/weights.py
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 | |