diff --git a/dirichletcal/calib/multinomial.py b/dirichletcal/calib/multinomial.py index f772a08..039b469 100644 --- a/dirichletcal/calib/multinomial.py +++ b/dirichletcal/calib/multinomial.py @@ -76,7 +76,7 @@ def fit(self, X, y, *args, **kwargs): self.reg_lambda = self.reg_lambda / (k * (k - 1)) self.reg_mu = self.reg_mu / k - target = label_binarize(y, self.classes) + target = label_binarize(y, classes=self.classes) if k == 2: target = np.hstack([1-target, target])