nirdizati_light.confusion_matrix_feedback.confusion_matrix_feedback

  1import numpy as np
  2from pymining import itemmining
  3
  4from nirdizati_light.encoding.data_encoder import PADDING_VALUE
  5from nirdizati_light.predictive_model.common import ClassificationMethods, get_tensor
  6from nirdizati_light.predictive_model.predictive_model import drop_columns
  7
  8
  9def compute_feedback(CONF, explanations, predictive_model, test_df, encoder, top_k=None):
 10    if predictive_model.model_type not in ClassificationMethods:
 11        raise Exception('Only supported classification methods')
 12    
 13    if predictive_model.model_type in [ClassificationMethods.LSTM.value, ClassificationMethods.CUSTOM_PYTORCH.value]:
 14        probabilities = predictive_model.model.predict(get_tensor(CONF, drop_columns(test_df)))
 15        indices = np.argmax(probabilities, axis=1)
 16        onehot_enc = list(encoder._label_dict_decoder['label'].keys())
 17        predicted = []
 18        for i in indices:
 19            predicted.append(onehot_enc[i])
 20    else:
 21        predicted = predictive_model.model.predict(drop_columns(test_df))
 22
 23    actual = test_df['label']
 24
 25    trace_ids = test_df['trace_id']
 26
 27    confusion_matrix = _retrieve_confusion_matrix_ids(trace_ids, predicted, actual, encoder)
 28
 29    filtered_explanations = _filter_explanations(explanations, threshold=13)
 30
 31    frequent_patterns = _mine_frequent_patterns(confusion_matrix, filtered_explanations)
 32
 33    feedback = {
 34        classes: _subtract_patterns(
 35            sum([frequent_patterns[classes][cl] for cl in confusion_matrix.keys()], []),
 36            frequent_patterns[classes][classes]
 37        )
 38        for classes in confusion_matrix.keys()
 39    }
 40
 41    if top_k is not None:
 42        for classes in feedback:
 43            feedback[classes] = feedback[classes][:top_k]
 44
 45    return feedback
 46
 47
 48def _retrieve_confusion_matrix_ids(trace_ids, actual, predicted, encoder) -> dict:
 49    decoded_predicted = encoder.decode_column(predicted, 'label')
 50    decoded_actual = encoder.decode_column(actual, 'label')
 51    elements = np.column_stack((
 52        trace_ids,
 53        decoded_predicted,
 54        decoded_actual
 55    )).tolist()
 56
 57    # matrix format is (actual, predicted)
 58    confusion_matrix = {}
 59    classes = list(encoder.get_values('label')[0])
 60    if PADDING_VALUE in classes: classes.remove(PADDING_VALUE)
 61    for act in classes:
 62        confusion_matrix[act] = {}
 63        for pred in classes:
 64            confusion_matrix[act][pred] = {
 65                trace_id
 66                for trace_id, predicted, actual in elements
 67                if actual == act and predicted == pred
 68            }
 69
 70    return confusion_matrix
 71
 72
 73def _filter_explanations(explanations, threshold=None):
 74    if threshold is None:
 75        threshold = min(13, int(max(len(explanations[tid]) for tid in explanations) * 10 / 100) + 1)
 76    return {
 77        trace_id:
 78            sorted(explanations[trace_id], key=lambda x: x[2], reverse=True)[:threshold]
 79        for trace_id in explanations
 80    }
 81
 82
 83def _mine_frequent_patterns(confusion_matrix, filtered_explanations):
 84    mined_patterns = {}
 85    for actual in confusion_matrix:
 86        mined_patterns[actual] = {}
 87        for pred in confusion_matrix[actual]:
 88            mined_patterns[actual][pred] = itemmining.relim(itemmining.get_relim_input([
 89                [
 90                    str(feature_name) + '//' + str(value)  # + '_' + str(_tassellate_number(importance))
 91                    for feature_name, value, importance in filtered_explanations[tid]
 92                ]
 93                for tid in confusion_matrix[actual][pred]
 94                if tid in filtered_explanations
 95            ]), min_support=2)
 96            mined_patterns[actual][pred] = sorted(
 97                [
 98                    ([el.split('//') for el in list(key)], mined_patterns[actual][pred][key])
 99                    for key in mined_patterns[actual][pred]
100                ],
101                key=lambda x: x[1],
102                reverse=True
103            )
104
105    return mined_patterns
106
107
108def _tassellate_number(element):
109    element = str(element).split('.')
110    return element[0] + '.' + element[1][:3]
111
112
113def _subtract_patterns(list1, list2):
114
115    difference = [el[0] for el in list1]
116    for el, _ in list2:
117        if el in difference:
118            difference.remove(el)
119
120    return difference
def compute_feedback(CONF, explanations, predictive_model, test_df, encoder, top_k=None):
10def compute_feedback(CONF, explanations, predictive_model, test_df, encoder, top_k=None):
11    if predictive_model.model_type not in ClassificationMethods:
12        raise Exception('Only supported classification methods')
13    
14    if predictive_model.model_type in [ClassificationMethods.LSTM.value, ClassificationMethods.CUSTOM_PYTORCH.value]:
15        probabilities = predictive_model.model.predict(get_tensor(CONF, drop_columns(test_df)))
16        indices = np.argmax(probabilities, axis=1)
17        onehot_enc = list(encoder._label_dict_decoder['label'].keys())
18        predicted = []
19        for i in indices:
20            predicted.append(onehot_enc[i])
21    else:
22        predicted = predictive_model.model.predict(drop_columns(test_df))
23
24    actual = test_df['label']
25
26    trace_ids = test_df['trace_id']
27
28    confusion_matrix = _retrieve_confusion_matrix_ids(trace_ids, predicted, actual, encoder)
29
30    filtered_explanations = _filter_explanations(explanations, threshold=13)
31
32    frequent_patterns = _mine_frequent_patterns(confusion_matrix, filtered_explanations)
33
34    feedback = {
35        classes: _subtract_patterns(
36            sum([frequent_patterns[classes][cl] for cl in confusion_matrix.keys()], []),
37            frequent_patterns[classes][classes]
38        )
39        for classes in confusion_matrix.keys()
40    }
41
42    if top_k is not None:
43        for classes in feedback:
44            feedback[classes] = feedback[classes][:top_k]
45
46    return feedback