diff --git a/contextualized/easy/ContextualizedNetworks.py b/contextualized/easy/ContextualizedNetworks.py index 59e792f6..956cf0b6 100644 --- a/contextualized/easy/ContextualizedNetworks.py +++ b/contextualized/easy/ContextualizedNetworks.py @@ -80,9 +80,9 @@ def measure_mses(self, C, X, individual_preds=False): mses = np.zeros((len(betas), len(C))) # n_bootstraps x n_samples for i in range(X.shape[-1]): for j in range(X.shape[-1]): - residuals = np.tile(X[:, i], len(betas)) - ( - betas[:, :, i, j] * np.tile(X[:, j], len(betas)) + mus[:, :, i, j] - ) + tiled_xi = np.array([X[:, i] for _ in range(len(betas))]) + tiled_xj = np.array([X[:, j] for _ in range(len(betas))]) + residuals = tiled_xi - betas[:, :, i, j]*tiled_xj + mus[:, :, i ,j] mses += residuals**2 / (X.shape[-1] ** 2) if not individual_preds: mses = np.mean(mses, axis=0) @@ -164,7 +164,7 @@ def __init__(self, **kwargs): **kwargs, ) - def predict_params(self, C, individual_preds=False, **kwargs): + def predict_params(self, C, **kwargs): """ :param C: @@ -173,8 +173,12 @@ def predict_params(self, C, individual_preds=False, **kwargs): """ # Returns betas # TODO: No mus for NOTMAD at present. - return super().predict_params(C, individual_preds, - model_includes_mus=False, **kwargs) + return super().predict_params( + C, + individual_preds=kwargs.get("individual_preds", False), + model_includes_mus=False, + uses_y=False, + project_to_dag=kwargs.get("project_to_dag", True)) def predict_networks(self, C, with_offsets=False, project_to_dag=True, **kwargs): """