blob: a79d4963f05a189af4269bb1076e2837ee382fe8 [file] [log] [blame]
# coding=utf-8
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A module that allows for easily graphing feature importance data in
notebooks"""
import numpy
import numpy.typing
import shap
import json
from typing import Dict, List, Union, Optional
DataType = Dict[str, Union[numpy.typing.ArrayLike, List[str]]]
def load_shap_values(file_name: str) -> DataType:
"""Loads a set of shap values created with the feature_importance.py script
into a dictionary that can then be used for creating graphs
Args:
file_name: The name of the file in which the shap values are stored. What
the --output_path flag was set to in the feature importance script.
"""
with open(file_name, encoding='utf-8') as file_to_load:
data = json.load(file_to_load)
if data['expected_values'] is not list:
data['expected_values'] = [data['expected_values']]
return {
'expected_values': numpy.asarray(data['expected_values']),
'shap_values': numpy.asarray(data['shap_values']),
'data': numpy.asarray(data['data']),
'feature_names': data['feature_names']
}
def init_shap_for_notebook():
"""Initalizes some JS code for interactive feature importance plots."""
shap.initjs()
def graph_individual_example(data: DataType, index: Optional[int]):
"""Creates a force plot for an example
Args:
data: An object containing all the shap values and other information
necessary to create the plot. Should be created with load_shap_values.
index: The index of the example that you wish to plot.
"""
return shap.force_plot(
data['expected_values'],
data['shap_values'][index, :],
data['data'][index, :],
feature_names=data['feature_names'])
def graph_summary_plot(data: DataType):
"""Creates a summary plot of the entire dataset given
Args:
data: An object containing all the shap values necessary to create the
plot. Should come from load_shap_values
"""
return shap.summary_plot(
data['shap_values'], data['data'], feature_names=data['feature_names'])