Skip to content
Merged

lime #47

Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 149 additions & 0 deletions lime.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"authorship_tag": "ABX9TyMnPEht5MQJF/hLpcLpa95R",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/github/karthikab5/BizRecProject/blob/dev/lime.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "_u28rTWlh3fy",
"outputId": "001c88db-9ab4-4cee-ee26-ce62b3b44b96"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Mounted at /content/drive\n"
]
}
],
"source": [
"# prompt: code to drive mount\n",
"\n",
"from google.colab import drive\n",
"drive.mount('/content/drive')"
]
},
{
"cell_type": "code",
"source": [
"# Necessary imports\n",
"import pandas as pd\n",
"from lime.lime_tabular import LimeTabularExplainer\n",
"from sklearn.ensemble import RandomForestRegressor\n",
"from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score\n",
"\n",
"# Train a Random Forest Regressor\n",
"model = RandomForestRegressor(n_estimators=100, random_state=42)\n",
"model.fit(X_train, y_train)\n",
"\n",
"# Evaluate the model\n",
"y_pred = model.predict(X_test)\n",
"\n",
"# Regression Metrics\n",
"mse = mean_squared_error(y_test, y_pred)\n",
"rmse = mean_squared_error(y_test, y_pred, squared=False) # Root Mean Squared Error\n",
"mae = mean_absolute_error(y_test, y_pred)\n",
"r2 = r2_score(y_test, y_pred)\n",
"\n",
"print(f\"Mean Squared Error (MSE): {mse}\")\n",
"print(f\"Root Mean Squared Error (RMSE): {rmse}\")\n",
"print(f\"Mean Absolute Error (MAE): {mae}\")\n",
"print(f\"R-squared (R2): {r2}\")\n",
"\n",
"# LIME interpretation setup\n",
"explainer = LimeTabularExplainer(\n",
" training_data=X_train.values,\n",
" feature_names=feature_columns,\n",
" mode='regression'\n",
")\n",
"\n",
"# Function to explain a single prediction\n",
"def explain_prediction(instance):\n",
" exp = explainer.explain_instance(\n",
" data_row=instance,\n",
" predict_fn=model.predict\n",
" )\n",
" return exp\n",
"\n",
"# Example usage with a specific instance from X_test\n",
"instance_index = 0 # Change this index for different instances\n",
"explanation = explain_prediction(X_test.values[instance_index])\n",
"\n",
"# Show the explanation\n",
"print(\"LIME Explanation for instance:\")\n",
"explanation.show_in_notebook(show_table=True) # If you're using a Jupyter Notebook\n",
"# Or you can print the explanation details\n",
"print(explanation.as_list())\n",
"\n",
"def predict_success_for_business_ids(business_ids, model, features_df, feature_columns):\n",
" # Filter the dataset for the provided business IDs\n",
" input_data = features_df[features_df[\"business_id\"].isin(business_ids)]\n",
" X_input = input_data[feature_columns]\n",
"\n",
" # Predict success probability\n",
" success_probabilities = model.predict(X_input)\n",
"\n",
" # Create a DataFrame to display results\n",
" result_df = pd.DataFrame({\n",
" \"business_id\": input_data[\"business_id\"],\n",
" \"success_probability\": success_probabilities * 10000\n",
" })\n",
"\n",
" return result_df\n",
"\n",
"# Example usage\n",
"business_ids_list = df['business_id'].tolist() # Create a list of all business ids\n",
"business_ids_to_predict = business_ids_list # Use this for prediction\n",
"success_probabilities = predict_success_for_business_ids(business_ids_to_predict, model, df, feature_columns)\n",
"\n",
"print(\"Predicted Success Probabilities:\")\n",
"print(success_probabilities)\n",
"\n",
"# Substitute a variable instead of directly giving business ID in SQL\n",
"business_df.createOrReplaceTempView(\"business_temp\")\n",
"review_df.createOrReplaceTempView(\"review_temp\")\n",
"checkin_df.createOrReplaceTempView(\"checkin_temp\")\n",
"business_id_to_search = 'SuSEmi52lP8gquHV0XIB9g' # Replace with your variable\n",
"\n",
"# Using SQL queries\n",
"spark.sql(f\"SELECT * FROM business_temp WHERE business_id = '{business_id_to_search}'\").show()\n",
"spark.sql(f\"SELECT * FROM review_temp WHERE business_id = '{business_id_to_search}'\").show()\n",
"spark.sql(f\"SELECT * FROM checkin_temp WHERE business_id = '{business_id_to_search}'\").show()\n"
],
"metadata": {
"id": "1aNFkM6Th-T-"
},
"execution_count": null,
"outputs": []
}
]
}