nirdizati_light.explanation.wrappers.shap_wrapper

  1import numpy as np
  2import shap
  3
  4from nirdizati_light.encoding.constants import TaskGenerationType
  5from nirdizati_light.predictive_model.predictive_model import drop_columns
  6
  7
  8def shap_explain(CONF, predictive_model, encoder,full_test_df, target_trace_id=None,prefix_columns=None):
  9    test_df = drop_columns(full_test_df)
 10
 11    explainer = _init_explainer(predictive_model.model, test_df)
 12    if predictive_model.model_type == 'xgboost':
 13        prefix_columns = [col for col in full_test_df.columns if 'prefix' in col]
 14    if target_trace_id is not None:
 15        full_test_df[full_test_df['trace_id'] == target_trace_id]
 16        exp = explainer(drop_columns(full_test_df))
 17        encoder.decode(full_test_df)
 18        if prefix_columns:
 19            full_test_df[prefix_columns] = full_test_df[prefix_columns].astype('category')
 20        exp.data = drop_columns(full_test_df)
 21        #shap.plots.waterfall(exp, show=False)
 22    else:
 23        exp = explainer(drop_columns(full_test_df))
 24        encoder.decode(full_test_df)
 25        if prefix_columns:
 26            full_test_df[prefix_columns] = full_test_df[prefix_columns].astype('category')
 27        exp.data = drop_columns(full_test_df)
 28        #shap.plots.bar(exp.values,show=False)
 29    return exp
 30
 31
 32def _init_explainer(model, df):
 33    try:
 34        return shap.TreeExplainer(model)
 35    except Exception as e1:
 36        try:
 37            return shap.DeepExplainer(model, df)
 38        except Exception as e2:
 39            try:
 40                return shap.KernelExplainer(model, df)
 41            except Exception as e3:
 42                raise Exception('model not supported by explainer')
 43
 44'''
 45def _get_explanation(CONF, explainer, target_df, encoder):
 46    if CONF['task_generation_type'] == TaskGenerationType.ALL_IN_ONE.value:
 47        trace_ids = list(target_df['trace_id'].values)
 48        return {
 49            str(trace_id): {
 50                prefix_size + 1:
 51                    np.column_stack((
 52                        target_df.columns[1:-1],
 53                        encoder.decode_row(row)[1:-1],
 54                        explainer.shap_values(drop_columns(row.to_frame(0).T))[row['label'] - 1].T
 55                    # list(row['label'])[0]
 56                    )).tolist()  # is the one vs all
 57                for prefix_size, row in enumerate(
 58                    [ row for _, row in target_df[target_df['trace_id'] == trace_id].iterrows() ]
 59                )
 60                if row['label'] is not '0'
 61            }
 62            for trace_id in trace_ids
 63        }
 64    else:
 65        return {
 66            str(row['trace_id']): {
 67                CONF['prefix_length_strategy']:
 68                    np.column_stack((
 69                        target_df.columns[1:-1],
 70                        encoder.decode_row(row)[1:-1],
 71                        explainer.shap_values(drop_columns(row.to_frame(0).T))[row['label'] - 1].T  # list(row['label'])[0]
 72                    )).tolist()                                                                     # is the one vs all
 73            }
 74            for _, row in target_df.iterrows()                                                  # method!
 75            if row['label'] is not '0'
 76        }
 77'''
 78def _get_explanation(CONF, explainer, target_df, encoder):
 79    def process_row(row, trace_id,prefix_columns=None):
 80        explanation = {}
 81        for prefix_size, row_entry in enumerate(target_df[target_df['trace_id'] == trace_id].iterrows()):
 82            _, row = row_entry
 83            if row['label'] != '0':
 84                decoded_row = encoder.decode_row(row)[1:-1]
 85                row_df = row.to_frame(0).T
 86                if prefix_columns:
 87                    row_df = row_df[prefix_columns].astype('category')
 88                shap_values = explainer.shap_values(row_df)[row['label'] - 1].T
 89                explanation[prefix_size + 1] = np.column_stack(
 90                    (target_df.columns[1:-1], decoded_row, shap_values)).tolist()
 91                dict = {str(trace_id): explanation}
 92
 93        return {str(trace_id): explanation}
 94    if CONF['predictive_model'] == 'xgboost':
 95        prefix_columns = [col for col in target_df.columns if 'prefix' in col]
 96    else:
 97        prefix_columns = None
 98    if CONF['task_generation_type'] == TaskGenerationType.ALL_IN_ONE.value:
 99        trace_ids = list(target_df['trace_id'].values)
100        result = {}
101        for trace_id in trace_ids:
102            result.update(process_row(row, trace_id,prefix_columns))
103        return result
104    else:
105        result = {}
106        for _, row in target_df.iterrows():
107            if row['label'] != '0':
108                result.update(process_row(row, row['trace_id'],prefix_columns))
109        return result
def shap_explain( CONF, predictive_model, encoder, full_test_df, target_trace_id=None, prefix_columns=None):
 9def shap_explain(CONF, predictive_model, encoder,full_test_df, target_trace_id=None,prefix_columns=None):
10    test_df = drop_columns(full_test_df)
11
12    explainer = _init_explainer(predictive_model.model, test_df)
13    if predictive_model.model_type == 'xgboost':
14        prefix_columns = [col for col in full_test_df.columns if 'prefix' in col]
15    if target_trace_id is not None:
16        full_test_df[full_test_df['trace_id'] == target_trace_id]
17        exp = explainer(drop_columns(full_test_df))
18        encoder.decode(full_test_df)
19        if prefix_columns:
20            full_test_df[prefix_columns] = full_test_df[prefix_columns].astype('category')
21        exp.data = drop_columns(full_test_df)
22        #shap.plots.waterfall(exp, show=False)
23    else:
24        exp = explainer(drop_columns(full_test_df))
25        encoder.decode(full_test_df)
26        if prefix_columns:
27            full_test_df[prefix_columns] = full_test_df[prefix_columns].astype('category')
28        exp.data = drop_columns(full_test_df)
29        #shap.plots.bar(exp.values,show=False)
30    return exp