Skip to content

ctst

Classifier two-sample tests (CTST) from binary classification metrics.

The classifier two-sample test broadly consists of three steps: (1) training a classifier, (2) scoring the two samples and (3) turning a test statistic into a p-value from these scores. This test statistic can be the performance metric of a binary classifier such as the (weighted) area under the receiver operating characteristic curve, the Matthews correlation coefficient, and the (balanced) accuracy. This module tackles step (3).

References

.. [1] Lopez-Paz, David, and Maxime Oquab. "Revisiting Classifier Two-Sample Tests." International Conference on Learning Representations. 2017.

.. [2] Friedman, Jerome. "On multivariate goodness-of-fit and two-sample testing." No. SLAC-PUB-10325. SLAC National Accelerator Laboratory (SLAC), Menlo Park, CA (United States), 2004.

.. [3] Kübler, Jonas M., et al. "Automl two-sample test." Advances in Neural Information Processing Systems 35 (2022): 15929-15941.

.. [4] Ciémençon, Stéphan, Marine Depecker, and Nicolas Vayatis. "AUC optimization and the two-sample problem." Proceedings of the 23rd International Conference on Neural Information Processing Systems. 2009.

.. [5] Hediger, Simon, Loris Michel, and Jeffrey Näf. "On the use of random forest for two-sample testing." Computational Statistics & Data Analysis 170 (2022): 107435.

.. [6] Kim, Ilmun, et al. "Classification accuracy as a proxy for two-sample testing." Annals of Statistics 49.1 (2021): 411-434.

CTST dataclass

Classifier two-sample test (CTST) using a binary classification metric.

This test compares scores (predictions) from two independent samples. Rejecting the null implies that scoring is not random and that the classifier is able to distinguish between the two samples.

Attributes:

Name Type Description
actual NDArray

Binary indicator for sample membership.

predicted NDArray

Estimated (predicted) scores for corresponding samples in actual.

metric Callable

A callable that conforms to scikit-learn metric API. This function must take two positional arguments e.g. y_true and y_pred.

n_resamples (int, optional)

Number of resampling iterations, by default 9999.

rng (Generator, optional)

Random number generator, by default np.random.default_rng().

n_jobs (int, optional)

Number of parallel jobs, by default 1.

batch (int or None, optional)

Batch size for parallel processing, by default None.

alternative ({'less', 'greater', 'two-sided'}, optional)

Defines the alternative hypothesis. Default is 'two-sided'.

Notes

The null distribution is based on permutations. See scipy.stats.permutation_test for more details.

Examples:

>>> import numpy as np
>>> from sklearn.metrics import matthews_corrcoef, roc_auc_score
>>> from samesame.ctst import CTST
>>> actual = np.array([0, 1, 1, 0])
>>> scores = np.array([0.2, 0.8, 0.6, 0.4])
>>> ctst_mcc = CTST(actual, scores, metric=matthews_corrcoef)
>>> ctst_auc = CTST(actual, scores, metric=roc_auc_score)
>>> print(ctst_mcc.pvalue)
>>> print(ctst_auc.pvalue)
>>> ctst_ = CTST.from_samples(scores, scores, metric=roc_auc_score)
>>> isinstance(ctst_, CTST)
True
Source code in src/samesame/ctst.py
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
@dataclass
class CTST:
    """
    Classifier two-sample test (CTST) using a binary classification metric.

    This test compares scores (predictions) from two independent samples.
    Rejecting the null implies that scoring is not random and that the
    classifier is able to distinguish between the two samples.

    Attributes
    ----------
    actual : NDArray
        Binary indicator for sample membership.
    predicted : NDArray
        Estimated (predicted) scores for corresponding samples in `actual`.
    metric : Callable
        A callable that conforms to scikit-learn metric API. This function
        must take two positional arguments e.g. `y_true` and `y_pred`.
    n_resamples : int, optional
        Number of resampling iterations, by default 9999.
    rng : np.random.Generator, optional
        Random number generator, by default np.random.default_rng().
    n_jobs : int, optional
        Number of parallel jobs, by default 1.
    batch : int or None, optional
        Batch size for parallel processing, by default None.
    alternative : {'less', 'greater', 'two-sided'}, optional
        Defines the alternative hypothesis. Default is 'two-sided'.

    Notes
    -----
    The null distribution is based on permutations.
    See `scipy.stats.permutation_test` for more details.

    Examples
    --------
    >>> import numpy as np
    >>> from sklearn.metrics import matthews_corrcoef, roc_auc_score
    >>> from samesame.ctst import CTST
    >>> actual = np.array([0, 1, 1, 0])
    >>> scores = np.array([0.2, 0.8, 0.6, 0.4])
    >>> ctst_mcc = CTST(actual, scores, metric=matthews_corrcoef)
    >>> ctst_auc = CTST(actual, scores, metric=roc_auc_score)
    >>> print(ctst_mcc.pvalue) # doctest: +SKIP
    >>> print(ctst_auc.pvalue) # doctest: +SKIP
    >>> ctst_ = CTST.from_samples(scores, scores, metric=roc_auc_score)
    >>> isinstance(ctst_, CTST)
    True
    """

    actual: NDArray = field(repr=False)
    predicted: NDArray = field(repr=False)
    metric: Callable
    n_resamples: int = 9999
    rng: np.random.Generator = np.random.default_rng()
    n_jobs: int = 1
    batch: int | None = None
    alternative: Literal["less", "greater", "two-sided"] = "two-sided"

    def __post_init__(self):
        """Validate inputs."""
        self.actual = column_or_1d(self.actual)
        self.predicted = column_or_1d(self.predicted)
        check_consistent_length(self.actual, self.predicted)
        assert type_of_target(self.actual, "actual") == "binary"
        type_predicted = type_of_target(self.predicted, "predicted")
        assert type_predicted in (
            "binary",
            "continuous",
            "multiclass",
        ), f"Expected 'predicted' to be binary or continuous, got {type_predicted}."
        assert check_metric_function(self.metric), (
            f"'metric' expects a callable that conforms to scikit-learn metric. "
            f"{signature(self.metric)=} does not."
        )

    @cached_property
    def _result(self):
        def statistic(*args):
            return self.metric(args[0], args[1])

        return permutation_test(
            data=(self.actual, self.predicted),
            statistic=statistic,
            permutation_type="pairings",
            n_resamples=self.n_resamples,
            alternative=self.alternative,
            random_state=self.rng,
        )

    @cached_property
    def statistic(self) -> float:
        """
        Compute the observed test statistic.

        Returns
        -------
        float
            The test statistic.

        Notes
        -----
        The result is cached to avoid (expensive) recomputation.
        """
        return self._result.statistic

    @cached_property
    def null(self) -> NDArray:
        """
        Compute the null distribution of the test statistic.

        Notes
        -----
        The result is cached to avoid (expensive) recomputation since the
        null distribution requires permutations.
        """
        return self._result.null_distribution

    @cached_property
    def pvalue(self):
        """
        Compute the p-value using permutations.

        Notes
        -----
        The result is cached to avoid (expensive) recomputation.
        """
        return self._result.pvalue

    @classmethod
    def from_samples(
        cls,
        first_sample: NDArray,
        second_sample: NDArray,
        metric: Callable,
        n_resamples: int = 9999,
        rng: np.random.Generator = np.random.default_rng(),
        n_jobs: int = 1,
        batch: int | None = None,
        alternative: Literal["less", "greater", "two-sided"] = "two-sided",
    ):
        """
        Create a CTST instance from two samples.

        Parameters
        ----------
        first_sample : NDArray
            First sample of scores. These can be binary or continuous.
        second_sample : NDArray
            Second sample of scores. These can be binary or continuous.

        Returns
        -------
        CTST
            An instance of the CTST class.
        """
        assert type_of_target(first_sample) == type_of_target(second_sample)
        samples = (first_sample, second_sample)
        actual = assign_labels(samples)
        predicted = concat_samples(samples)
        return cls(
            actual,
            predicted,
            metric,
            n_resamples,
            rng,
            n_jobs,
            batch,
            alternative,
        )

null cached property

Compute the null distribution of the test statistic.

Notes

The result is cached to avoid (expensive) recomputation since the null distribution requires permutations.

pvalue cached property

Compute the p-value using permutations.

Notes

The result is cached to avoid (expensive) recomputation.

statistic cached property

Compute the observed test statistic.

Returns:

Type Description
float

The test statistic.

Notes

The result is cached to avoid (expensive) recomputation.

__post_init__()

Validate inputs.

Source code in src/samesame/ctst.py
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
def __post_init__(self):
    """Validate inputs."""
    self.actual = column_or_1d(self.actual)
    self.predicted = column_or_1d(self.predicted)
    check_consistent_length(self.actual, self.predicted)
    assert type_of_target(self.actual, "actual") == "binary"
    type_predicted = type_of_target(self.predicted, "predicted")
    assert type_predicted in (
        "binary",
        "continuous",
        "multiclass",
    ), f"Expected 'predicted' to be binary or continuous, got {type_predicted}."
    assert check_metric_function(self.metric), (
        f"'metric' expects a callable that conforms to scikit-learn metric. "
        f"{signature(self.metric)=} does not."
    )

from_samples(first_sample, second_sample, metric, n_resamples=9999, rng=np.random.default_rng(), n_jobs=1, batch=None, alternative='two-sided') classmethod

Create a CTST instance from two samples.

Parameters:

Name Type Description Default
first_sample NDArray

First sample of scores. These can be binary or continuous.

required
second_sample NDArray

Second sample of scores. These can be binary or continuous.

required

Returns:

Type Description
CTST

An instance of the CTST class.

Source code in src/samesame/ctst.py
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
@classmethod
def from_samples(
    cls,
    first_sample: NDArray,
    second_sample: NDArray,
    metric: Callable,
    n_resamples: int = 9999,
    rng: np.random.Generator = np.random.default_rng(),
    n_jobs: int = 1,
    batch: int | None = None,
    alternative: Literal["less", "greater", "two-sided"] = "two-sided",
):
    """
    Create a CTST instance from two samples.

    Parameters
    ----------
    first_sample : NDArray
        First sample of scores. These can be binary or continuous.
    second_sample : NDArray
        Second sample of scores. These can be binary or continuous.

    Returns
    -------
    CTST
        An instance of the CTST class.
    """
    assert type_of_target(first_sample) == type_of_target(second_sample)
    samples = (first_sample, second_sample)
    actual = assign_labels(samples)
    predicted = concat_samples(samples)
    return cls(
        actual,
        predicted,
        metric,
        n_resamples,
        rng,
        n_jobs,
        batch,
        alternative,
    )