diff --git a/docs/site/builtins-reference.md b/docs/site/builtins-reference.md index 9c9bb325277..11d19701ca6 100644 --- a/docs/site/builtins-reference.md +++ b/docs/site/builtins-reference.md @@ -50,6 +50,7 @@ limitations under the License. * [`img_brightness`-Function](#img_brightness-function) * [`img_crop`-Function](#img_crop-function) * [`img_mirror`-Function](#img_mirror-function) + * [`impurityMeasures`-Function](#impurityMeasures-function) * [`imputeByFD`-Function](#imputeByFD-function) * [`intersect`-Function](#intersect-function) * [`KMeans`-Function](#KMeans-function) @@ -1018,6 +1019,50 @@ B = img_mirror(img_in = A, horizontal_axis = TRUE) ``` +## `impurityMeasures`-Function + +`impurityMeasures()` computes the measure of impurity for each feature of the given dataset based on the passed method (gini or entropy). + +### Usage + +```r +IM = impurityMeasures(X = X, Y = Y, R = R, n_bins = 20, method = "gini"); +``` + +### Arguments + +| Name | Type | Default | Description | +| :--------- | :-------------- | :------ | :---------- | +| X | Matrix[Double] | --- | Feature matrix X | +| Y | Matrix[Double] | --- | Target vector Y containing only 0 or 1 values | +| R | Matrix[Double] | --- | Row vector R indicating whether a feature is categorical or continuous. 1 denotes a continuous feature, 2 denotes a categorical feature. | +| n_bins | Integer | `20` | Number of equi-width bins for binning in case of scale features. | +| method | String | --- | String indicating the method to use; either "entropy" or "gini". | + +### Returns + +| Name | Type | Description | +| :--- | :------------- | :---------- | +| IM | Matrix[Double] | (1 x ncol(X)) row vector containing information/gini gain for each feature of the dataset. In case of gini, the values denote the gini gains, i.e. how much impurity was removed with the respective split. The higher the value, the better the split. In case of entropy, the values denote the information gain, i.e. how much entropy was removed. The higher the information gain, the better the split. | + +### Example + +```r +X = matrix("4.0 3.0 2.8 3.5 + 2.4 1.0 3.4 2.9 + 1.1 1.0 4.9 3.4 + 5.0 2.0 1.4 1.8 + 1.1 3.0 1.0 1.9", rows=5, cols=4) +Y = matrix("1.0 + 0.0 + 0.0 + 1.0 + 0.0", rows=5, cols=1) +R = matrix("1.0 2.0 1.0 1.0", rows=1, cols=4) +IM = impurityMeasures(X = X, Y = Y, R = R, method = "entropy") +``` + + ## `imputeByFD`-Function The `imputeByFD`-function imputes missing values from observed values (if exist) diff --git a/scripts/builtin/impurityMeasures.dml b/scripts/builtin/impurityMeasures.dml new file mode 100644 index 00000000000..860bc629f6b --- /dev/null +++ b/scripts/builtin/impurityMeasures.dml @@ -0,0 +1,139 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# This function computes the measure of impurity for the given dataset based on the passed method (gini or entropy). +# The current version expects the target vector to contain only 0 or 1 values. +# +# INPUT PARAMETERS: +# ---------------------------------------------------------------------------------------------------------------------- +# NAME TYPE DEFAULT MEANING +# ---------------------------------------------------------------------------------------------------------------------- +# X Matrix[Double] --- Feature matrix. +# Y Matrix[Double] --- Target vector containing 0 and 1 values. +# R Matrix[Double] --- Vector indicating whether a feature is categorical or continuous. +# 1 denotes a continuous feature, 2 denotes a categorical feature. +# n_bins Integer 20 Number of bins for binning in case of scale features. +# method String --- String indicating the method to use; either "entropy" or "gini". +# ---------------------------------------------------------------------------------------------------------------------- + +# Output(s) +# ---------------------------------------------------------------------------------------------------------------------- +# NAME TYPE DEFAULT MEANING +# ---------------------------------------------------------------------------------------------------------------------- +# IM Matrix[Double] --- (1 x ncol(X)) row vector containing information/gini gain for +# each feature of the dataset. +# In case of gini, the values denote the gini gains, i.e. how much +# impurity was removed with the respective split. The higher the +# value, the better the split. +# In case of entropy, the values denote the information gain, i.e. +# how much entropy was removed. The higher the information gain, +# the better the split. +# ---------------------------------------------------------------------------------------------------------------------- + +m_impurityMeasures = function(Matrix[Double] X, Matrix[Double] Y, Matrix[Double] R, Integer n_bins = 20, String method) + return (Matrix[Double] IM) +{ + if (method != "entropy" & method != "gini") { + stop("Please specify the correct method - should be either entropy or gini.") + } + + IM = matrix(0.0, rows = 1, cols = ncol(X)) + + parfor (i in 1:ncol(X)) { + if (as.scalar(R[,i]) == 1) { + binned_feature = applyBinning(X[,i], n_bins) + IM[,i] = getImpurityMeasure(binned_feature, Y, n_bins, method) + } else { + IM[,i] = getImpurityMeasure(X[,i], Y, max(X[,i]), method) + } + } +} + +getImpurityMeasure = function(Matrix[Double] feature, Matrix[Double] Y, Double max_cat, String method) + return (Double gain) +{ + n_true_labels = sum(Y) + n_false_labels = length(Y) - n_true_labels + parent_impurity = calcImpurity(n_true_labels, n_false_labels, length(feature), method) + + # calculate the impurity after the split + children_impurity = 0 + for (i in 1:max_cat) { + count_true = 0 + count_false = 0 + for (j in 1:length(feature)) { + if (as.scalar(feature[j,]) == i) { + if (as.scalar(Y[j,]) == 0) { + count_false += 1 + } else { + count_true += 1 + } + } + } + if (!(count_true == 0 & count_false == 0)) { + children_impurity = children_impurity + calcImpurity(count_true, count_false, length(feature), method) + } + } + gain = parent_impurity - children_impurity +} + +calcImpurity = function(Double n_true, Double n_false, Double n_vars, String method) + return (Double impurity) +{ + impurity = 0 + prob_true = n_true / (n_true + n_false) + prob_false = n_false / (n_true + n_false) + weight = (n_true + n_false) / n_vars + + if (prob_true != 1 & prob_false != 1) { # if there is more than one class, calculate new impurity according to method. + if (method == "entropy") { # dividing by log(2) to obtain the information gain in bits + impurity = (-1) * weight * (prob_true * log(prob_true)/log(2) + prob_false * log(prob_false)/log(2)) + } else if (method == "gini") { + impurity = weight * (1 - (prob_true^2 + prob_false^2)) + } + } +} + +applyBinning = function(Matrix[Double] feature, Double n_bins) + return (Matrix[Double] output_f) +{ + # equi-width binning. + + if (length(feature) < n_bins) { + n_bins = length(feature) + } + max_v = max(feature) + min_v = min(feature) + width = (max_v - min_v) / n_bins + output_f = matrix(1, rows = nrow(feature), cols = 1) + + parfor (i in 1:length(feature)) { + binned = FALSE + j = 1 + while (binned == FALSE) { + if (as.scalar(feature[i,]) <= min_v + j * width) { + output_f[i,] = j + binned = TRUE + } + j += 1 + } + } +} diff --git a/src/main/java/org/apache/sysds/common/Builtins.java b/src/main/java/org/apache/sysds/common/Builtins.java index aa9c58c0b35..75188c1cbda 100644 --- a/src/main/java/org/apache/sysds/common/Builtins.java +++ b/src/main/java/org/apache/sysds/common/Builtins.java @@ -157,6 +157,7 @@ public enum Builtins { IMG_SAMPLE_PAIRING("img_sample_pairing", true), IMG_INVERT("img_invert", true), IMG_POSTERIZE("img_posterize", true), + IMPURITY_MEASURES("impurityMeasures", true), IMPUTE_BY_MEAN("imputeByMean", true), IMPUTE_BY_MEAN_APPLY("imputeByMeanApply", true), IMPUTE_BY_MEDIAN("imputeByMedian", true), diff --git a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinImpurityMeasuresTest.java b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinImpurityMeasuresTest.java new file mode 100644 index 00000000000..d67768b68d7 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinImpurityMeasuresTest.java @@ -0,0 +1,259 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.builtin.part1; + +import java.util.HashMap; + +import org.apache.sysds.common.Types; +import org.apache.sysds.common.Types.ExecType; +import org.apache.sysds.runtime.matrix.data.MatrixValue; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Test; + +public class BuiltinImpurityMeasuresTest extends AutomatedTestBase { + private final static String TEST_NAME = "impurityMeasures"; + private final static String TEST_DIR = "functions/builtin/"; + private static final String TEST_CLASS_DIR = TEST_DIR + BuiltinImpurityMeasuresTest.class.getSimpleName() + "/"; + + private final static double eps = 1e-10; + + @Override + public void setUp() { + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"C"})); + } + + @Test + public void GiniTest1() { + double[][] X = {{1, 1}, {2, 2}}; + double[][] Y = {{1}, {0}}; + double[][] R = {{2, 2}}; + HashMap expected_m = new HashMap<>(); + expected_m.put(new MatrixValue.CellIndex(1, 1), 0.5); + expected_m.put(new MatrixValue.CellIndex(1, 2), 0.5); + String method = "gini"; + + runImpurityMeasuresTest(ExecType.SPARK, X, Y, R, method, expected_m); + } + + @Test + public void GiniTest2() { + double[][] X = {{1},{1},{1},{1},{1},{1},{2},{2},{2},{2}}; + double[][] Y = {{0}, {0}, {0}, {0}, {0}, {1}, {1}, {1}, {1}, {1}}; + double[][] R = {{2}}; + HashMap expected_m = new HashMap<>(); + expected_m.put(new MatrixValue.CellIndex(1, 1), 0.3333333333); + String method = "gini"; + + runImpurityMeasuresTest(ExecType.SPARK, X, Y, R, method, expected_m); + } + + @Test + public void GiniTest3() { + double[][] X = {{1,1,2,1}, {1,3,1,2}, {2,1,1,2}, {3,2,1,1}, {1,3,2,1}}; + double[][] Y = {{0}, {0}, {1}, {1}, {1}}; + double[][] R = {{2, 2, 2, 2}}; + HashMap expected_m = new HashMap<>(); + expected_m.put(new MatrixValue.CellIndex(1, 1), 0.2133333333); + expected_m.put(new MatrixValue.CellIndex(1, 2), 0.0799999999); + expected_m.put(new MatrixValue.CellIndex(1, 3), 0.0133333333); + expected_m.put(new MatrixValue.CellIndex(1, 4), 0.0133333333); + String method = "gini"; + + runImpurityMeasuresTest(ExecType.SPARK, X, Y, R, method, expected_m); + } + + @Test + public void GiniPlayTennisTest() { + double[][] X = {{1,1,1,1}, + {1,1,1,2}, + {2,1,1,1}, + {3,2,1,1}, + {3,3,2,1}, + {3,3,2,2}, + {2,3,2,2}, + {1,2,1,1}, + {1,3,2,1}, + {3,2,2,1}, + {1,2,2,2}, + {2,2,1,2}, + {2,1,2,1}, + {3,2,1,2}}; + double[][] Y = {{0}, {0}, {1}, {1}, {1}, {0}, {1}, {0}, {1}, {1}, {1}, {1}, {1}, {0}}; + double[][] R = {{2, 2, 2, 2}}; + HashMap expected_m = new HashMap<>(); + expected_m.put(new MatrixValue.CellIndex(1, 1), 0.1163265306); + expected_m.put(new MatrixValue.CellIndex(1, 2), 0.0187074829); + expected_m.put(new MatrixValue.CellIndex(1, 3), 0.0918367346); + expected_m.put(new MatrixValue.CellIndex(1, 4), 0.0306122448); + String method = "gini"; + + runImpurityMeasuresTest(ExecType.SPARK, X, Y, R, method, expected_m); + } + + @Test + public void GiniWithContinuousValues1() { + double[][] X = {{10.3}, {31.2}, {9.5}, {34.3}}; + double[][] Y = {{0}, {1}, {0}, {1}}; + double[][] R = {{1}}; + HashMap expected_m = new HashMap<>(); + expected_m.put(new MatrixValue.CellIndex(1, 1), 0.5); + String method = "gini"; + + runImpurityMeasuresTest(ExecType.SPARK, X, Y, R, method, expected_m); + } + + @Test + public void GiniWithContinuousValues2() { + double[][] X = {{1.5, 23.7, 2929.6}, {12.6, 80.2, 2823.3}, {3.4, 238.2, 832.2}, {14.2, 282.1, 23.1}}; + double[][] Y = {{0}, {1}, {0}, {1}}; + double[][] R = {{1, 1, 1}}; + HashMap expected_m = new HashMap<>(); + expected_m.put(new MatrixValue.CellIndex(1, 1), 0.5); + expected_m.put(new MatrixValue.CellIndex(1, 2), 0.0); + expected_m.put(new MatrixValue.CellIndex(1, 3), 0.25); + String method = "gini"; + + runImpurityMeasuresTest(ExecType.SPARK, X, Y, R, method, expected_m); + } + + // comparing with values from https://planetcalc.com/8421/ + @Test + public void EntropyTest1() { + double[][] X = {{1, 1}, {2, 2}}; + double[][] Y = {{1}, {0}}; + double[][] R = {{2, 2}}; + HashMap expected_m = new HashMap<>(); + expected_m.put(new MatrixValue.CellIndex(1, 1), 1.0); + expected_m.put(new MatrixValue.CellIndex(1, 2), 1.0); + String method = "entropy"; + + runImpurityMeasuresTest(ExecType.SPARK, X, Y, R, method, expected_m); + } + + @Test + public void EntropyTest2() { + double[][] X = {{1},{1},{1},{1},{1},{1},{2},{2},{2},{2}}; + double[][] Y = {{0},{0},{0},{0},{0},{1},{1},{1},{1},{1}}; + double[][] R = {{2}}; + HashMap expected_m = new HashMap<>(); + expected_m.put(new MatrixValue.CellIndex(1, 1), 0.6099865470); + String method = "entropy"; + + runImpurityMeasuresTest(ExecType.SPARK, X, Y, R, method, expected_m); + } + + @Test + public void EntropyTest3() { + double[][] X = {{1,1,2,1}, {1,3,1,2}, {2,1,1,2}, {3,2,1,1}, {1,3,2,1}}; + double[][] Y = {{0}, {0}, {1}, {1}, {1}}; + double[][] R = {{2, 2, 2, 2}}; + HashMap expected_m = new HashMap<>(); + expected_m.put(new MatrixValue.CellIndex(1, 1), 0.4199730940); + expected_m.put(new MatrixValue.CellIndex(1, 2), 0.1709505945); + expected_m.put(new MatrixValue.CellIndex(1, 3), 0.0199730940); + expected_m.put(new MatrixValue.CellIndex(1, 4), 0.0199730940); + String method = "entropy"; + + runImpurityMeasuresTest(ExecType.SPARK, X, Y, R, method, expected_m); + } + + @Test + public void EntropyPlayTennisTest() { + double[][] X = {{1,1,1,1}, + {1,1,1,2}, + {2,1,1,1}, + {3,2,1,1}, + {3,3,2,1}, + {3,3,2,2}, + {2,3,2,2}, + {1,2,1,1}, + {1,3,2,1}, + {3,2,2,1}, + {1,2,2,2}, + {2,2,1,2}, + {2,1,2,1}, + {3,2,1,2}}; + double[][] Y = {{0}, {0}, {1}, {1}, {1}, {0}, {1}, {0}, {1}, {1}, {1}, {1}, {1}, {0}}; + double[][] R = {{2, 2, 2, 2}}; + HashMap expected_m = new HashMap<>(); + expected_m.put(new MatrixValue.CellIndex(1, 1), 0.2467498198); + expected_m.put(new MatrixValue.CellIndex(1, 2), 0.0292225657); + expected_m.put(new MatrixValue.CellIndex(1, 3), 0.1518355014); + expected_m.put(new MatrixValue.CellIndex(1, 4), 0.0481270304); + String method = "entropy"; + + runImpurityMeasuresTest(ExecType.SPARK, X, Y, R, method, expected_m); + } + + @Test + public void EntropyWithContinuousValues1() { + double[][] X = {{10.3}, {31.2}, {9.5}, {34.3}}; + double[][] Y = {{0}, {1}, {0}, {1}}; + double[][] R = {{1}}; + HashMap expected_m = new HashMap<>(); + expected_m.put(new MatrixValue.CellIndex(1, 1), 1.0); + String method = "entropy"; + + runImpurityMeasuresTest(ExecType.SPARK, X, Y, R, method, expected_m); + } + + @Test + public void EntropyWithContinuousValues2() { + double[][] X = {{1.5, 23.7, 2929.6}, {12.6, 80.2, 2823.3}, {3.4, 238.2, 832.2}, {14.2, 282.1, 23.1}}; + double[][] Y = {{0}, {1}, {0}, {1}}; + double[][] R = {{1, 1, 1}}; + HashMap expected_m = new HashMap<>(); + expected_m.put(new MatrixValue.CellIndex(1, 1), 1.0); + expected_m.put(new MatrixValue.CellIndex(1, 2), 0.0); + expected_m.put(new MatrixValue.CellIndex(1, 3), 0.5); + String method = "entropy"; + + runImpurityMeasuresTest(ExecType.SPARK, X, Y, R, method, expected_m); + } + + private void runImpurityMeasuresTest(ExecType exec_type, double[][] X, double[][] Y, double[][] R, String method, HashMap expected_m) { + Types.ExecMode platform_old = setExecMode(exec_type); + + try { + loadTestConfiguration(getTestConfiguration(TEST_NAME)); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[] {"-args", input("X"), input("Y"), input("R"), method, output("impurity_measures")}; + + writeInputMatrixWithMTD("X", X, true); + writeInputMatrixWithMTD("Y", Y, true); + writeInputMatrixWithMTD("R", R, true); + + runTest(true, false, null, -1); + + HashMap actual_measures = readDMLMatrixFromOutputDir("impurity_measures"); + + System.out.println(actual_measures); + System.out.println(expected_m); + TestUtils.compareMatrices(expected_m, actual_measures, eps, "Expected measures", "Actual measures"); + } + finally { + rtplatform = platform_old; + } + } +} diff --git a/src/test/scripts/functions/builtin/impurityMeasures.dml b/src/test/scripts/functions/builtin/impurityMeasures.dml new file mode 100644 index 00000000000..01ab4cbf214 --- /dev/null +++ b/src/test/scripts/functions/builtin/impurityMeasures.dml @@ -0,0 +1,27 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = read($1) +Y = read($2) +R = read($3) +IM = impurityMeasures(X = X, Y = Y, R = R, method = $4); + +write(IM, $5);