Skip to content

Shift testing

Use this page when you already have a numeric signal for a source group and a target group.

Choose the function

Function What it answers Use it when
shift.detect_shift(...) Did anything change? you want to detect any difference between source and target
shift.detect_harm(...) Did the target group move in a worse direction? you know what "worse" means for your signal

Examples of useful signals include predicted risk, prediction error, model confidence, and domain classifier probabilities.

Common controls

Both functions accept:

  • n_resamples to control the number of permutation resamples
  • batch to limit memory use during the permutation test
  • random_state for reproducibility
  • weights for weighted testing with ImportanceWeights

shift.detect_harm(...) also requires direction, which must be one of:

  • "higher-is-worse"
  • "higher-is-better"

What you get back

  • shift.detect_shift(...) returns ShiftResult
  • shift.detect_harm(...) returns HarmResult

In both cases, the fields most users look at first are:

  • .statistic
  • .pvalue

ShiftResult also includes .statistic_name. HarmResult also includes .direction. Both results include .null_distribution when you need the full permutation output.

Posterior evidence for harmful shift

If you want posterior draws and a Bayes factor alongside the p-value, set include_posterior=True.

import samesame as ss

result = ss.shift.detect_harm(
    source_scores,
    target_scores,
    direction="higher-is-worse",
    include_posterior=True,
)

print(f"p-value:      {result.pvalue:.4f}")
print(f"Bayes factor: {result.bayes_factor:.2f}")

threshold is only valid when include_posterior=True. Otherwise detect_harm(...) raises a ValueError.

API

Detect whether Source and Target Outlier score distributions differ.

Source code in src/samesame/shift.py
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
def detect_shift(
    source: ArrayLike,
    target: ArrayLike,
    *,
    statistic: ShiftStatistic = "roc_auc",
    alternative: Literal["less", "greater", "two-sided"] = "two-sided",
    n_resamples: int = 9999,
    batch: int | None = None,
    random_state: RandomState = None,
    weights: ImportanceWeights | None = None,
) -> ShiftResult:
    """Detect whether Source and Target Outlier score distributions differ."""
    prepared, _ = _prepare_two_sample_test(source, target, weights=weights)
    statistic_name, metric = _get_shift_statistic(statistic)
    _validate_shift_scores(statistic_name, prepared.scores)
    result = _run_permutation_test(
        prepared.labels,
        prepared.scores,
        metric,
        n_resamples=n_resamples,
        alternative=alternative,
        sample_weight=prepared.sample_weight,
        rng=_resolve_random_state(random_state),
        batch=batch,
    )
    return ShiftResult(
        statistic=float(result.statistic),
        pvalue=float(result.pvalue),
        statistic_name=statistic_name,
        null_distribution=np.asarray(result.null_distribution, dtype=np.float64),
    )

Detect whether Target is harmfully shifted relative to Source.

Source code in src/samesame/shift.py
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
def detect_harm(
    source: ArrayLike,
    target: ArrayLike,
    *,
    direction: Direction,
    n_resamples: int = 9999,
    batch: int | None = None,
    random_state: RandomState = None,
    weights: ImportanceWeights | None = None,
    include_posterior: bool = False,
    threshold: float | None = None,
) -> HarmResult:
    """Detect whether Target is harmfully shifted relative to Source."""
    prepared, validated_direction = _prepare_two_sample_test(
        source,
        target,
        weights=weights,
        direction=direction,
    )
    assert validated_direction is not None
    posterior_threshold = _resolve_posterior_threshold(
        include_posterior=include_posterior,
        threshold=threshold,
    )
    result = _run_permutation_test(
        prepared.labels,
        prepared.scores,
        _wauc,
        n_resamples=n_resamples,
        alternative="greater",
        sample_weight=prepared.sample_weight,
        rng=_resolve_random_state(random_state),
        batch=batch,
    )
    posterior = None
    bayes_factor = None
    if include_posterior:
        posterior = np.asarray(
            _bayesian_posterior(
                prepared.labels,
                prepared.scores,
                _wauc,
                n_resamples=n_resamples,
                rng=_resolve_random_state(random_state),
                base_weight=prepared.sample_weight,
            ),
            dtype=np.float64,
        )
        bayes_factor = float(_bayes_factor(posterior, posterior_threshold))
    return HarmResult(
        statistic=float(result.statistic),
        pvalue=float(result.pvalue),
        direction=validated_direction,
        null_distribution=np.asarray(result.null_distribution, dtype=np.float64),
        posterior=posterior,
        bayes_factor=bayes_factor,
    )

Result types

Bases: TestResult

Result of generic shift detection.

Source code in src/samesame/shift.py
38
39
40
41
42
43
@dataclass(frozen=True)
class ShiftResult(TestResult):
    """Result of generic shift detection."""

    statistic_name: str
    null_distribution: NDArray[np.float64]

Bases: TestResult

Result of harmful-shift detection.

Source code in src/samesame/shift.py
46
47
48
49
50
51
52
53
@dataclass(frozen=True)
class HarmResult(TestResult):
    """Result of harmful-shift detection."""

    direction: Direction
    null_distribution: NDArray[np.float64]
    posterior: NDArray[np.float64] | None = None
    bayes_factor: float | None = None

Shared fields for all statistical test results.

Source code in src/samesame/shift.py
30
31
32
33
34
35
@dataclass(frozen=True)
class TestResult:
    """Shared fields for all statistical test results."""

    statistic: float
    pvalue: float