Source code for subwabbit.base
from abc import ABC, abstractmethod
from typing import Dict, List, Tuple, Iterable, Any, Optional, Union
class VowpalWabbitError(Exception):
pass
[docs]class VowpalWabbitBaseFormatter(ABC):
"""
Formatter translates structured information about context and items to
Vowpal Wabbit's input format: https://github.com/VowpalWabbit/vowpal_wabbit/wiki/Input-format
It also can implement reverse translation, from Vowpal Wabbits feature names into human readable feature names.
"""
[docs] @abstractmethod
def format_common_features(self, common_features: Any,
debug_info: Any = None) -> str:
"""
Return part of VW line with features that are common for one call of predict/train.
This method will run just once per one call of
:class:`subwabbit.base.VowpalWabbitBaseModel`'s `predict()` or `train()` method.
:param common_features: Features common for all items
:param debug_info: Optional dict that can be filled by information useful for debugging
:return: Part of line that is common for each item in one call. Returned string has to start with '|' symbol.
"""
raise NotImplementedError()
[docs] @abstractmethod
def format_item_features(self, common_features: Any, item_features: Any,
debug_info: Any = None) -> str:
"""
Return part of VW line with features specific to each item.
This method will run for each item per one call of
:class:`subwabbit.base.VowpalWabbitBaseModel`'s `predict()` or `train()` method.
.. note::
It is a good idea to cache results of this method.
:param common_features: Features common for all items
:param item_features: Features for item
:param debug_info: Optional dict that can be filled by information useful for debugging
:return: Part of line that is specific for item. Depends on whether namespaces are used or not in
``format_common_features`` method:
- namespaces are used: returned string has to start with ``'|NAMESPACE_NAME'`` where `NAMESPACE_NAME`
is the name of some namespace
- namespaces are not used: returned string should not contain '|' symbol
"""
raise NotImplementedError()
# pylint: disable=too-many-arguments,no-self-use
[docs] def get_formatted_example(self, common_line_part: str, item_line_part: str,
label: Optional[float] = None, weight: Optional[float] = None,
debug_info: Optional[Dict[Any, Any]] = None): # pylint: disable=unused-argument
"""
Compose valid VW line from its common and item-dependent parts.
:param common_line_part: Part of line that is common for each item in one call.
:param item_line_part: Part of line specific for each item
:param label: Label of this row
:param weight: Optional weight of row
:param debug_info: Optional dict that can be filled by information useful for debugging
:return: One VW line in input format: https://github.com/VowpalWabbit/vowpal_wabbit/wiki/Input-format
"""
if label is not None:
return ' '.join((
str(label),
str(weight) if weight is not None else '',
common_line_part,
item_line_part
))
return ' '.join((common_line_part, item_line_part))
[docs] def get_human_readable_explanation(self, explanation_string: str,
feature_translator: Any = None) -> List[Dict]:
"""
Transform explanation string into more readable form.
Every feature used for prediction is translated into this structure:
.. code-block:: python
{
# For each feature used in higher interaction there is a 2-tuple
'names': [('Human readable namespace name 1', 'Human readable feature name 1'), ...],
'original_feature_name': 'c^c8*f^f102' # feature name how vowpal sees it,
'hashindex': 123, # Vowpal's internal hash of feature name
'value': 0.123, # value for feature in input line
'weight': -0.534, # weight learned by VW for this feature
'potential': value * weight,
'relative_potential': abs(potential) / sum_of_abs_potentials_for_all_features
}
:param explanation_string: Explanation string from :func:`~VowpalWabbitBaseModel.explain_vw_line`
:param feature_translator: Any object that can help you with translation of feature names into human readable
form, for example some database connection.
See :func:`~VowpalWabbitBaseFormatter.parse_element`
:return: List of dicts, sorted by contribution to final score
"""
parsed_features = []
potential_sum = 0.0
for feature in [f.split(':') for f in explanation_string.split('\t')]:
feature_name = feature[0]
hash_index = feature[1]
value = float(feature[2])
weight = float(feature[3].split('@')[0])
# quadratic and higher level interactions have multiple features for one weight
feature_name_parts = feature_name.split('*')
parsed_feature_name_parts = [self.parse_element(el, feature_translator) for el in feature_name_parts]
parsed_features.append({
'names': parsed_feature_name_parts,
'original_feature_name': feature_name,
'hashindex': int(hash_index),
'value': value,
'weight': weight,
'potential': value * weight
})
potential_sum += abs(value * weight)
if potential_sum == 0:
# can happen in case all features are unknown
potential_sum = 1
for parsed_feature in parsed_features:
parsed_feature['relative_potential'] = abs(parsed_feature['potential'] / potential_sum) # type: ignore
return list(sorted(parsed_features, key=lambda f: f['relative_potential'], reverse=True))
# pylint: disable=invalid-name
[docs] def get_human_readable_explanation_html(self, explanation_string: str, feature_translator: Any = None,
max_rows: Optional[int] = None):
"""
Visualize importance of features in Jupyter notebook.
:param explanation_string: Explanation string from :func:`~VowpalWabbitBaseModel.explain_vw_line`
:param feature_translator: Any object that can help you with translation, e.g. some database connection.
:param max_rows: Maximum number of most important features. None return all used features.
:return: `IPython.core.display.HTML`
"""
try:
from IPython.core.display import HTML
except ImportError:
raise ImportError('Please install IPython to use this method')
explanation = self.get_human_readable_explanation(explanation_string, feature_translator)
rows = []
for row_number, feature in enumerate(explanation):
if max_rows is not None and (row_number + 1) > max_rows:
break
feature_name = ''
for name in feature['names']:
if feature_name:
feature_name += '''
<span style="color: grey; margin-left: 10px; margin-right: 10px;">IN COMBINATION WITH</span>
'''
feature_name += name[0]
feature_name += ': <i>{}</i>'.format(name[1])
rows.append(
'''
<tr>
<td>
<div style="display: block; width: 100px; border: solid 1px;
-webkit-border-radius: 5px; -moz-border-radius: 5px; border-radius: 5px;">
<div style="display: block; width: {width}%; height: 20px; background-color: {color};
overflow: hidden;"></div>
</div>
</td>
<td>{potential:.4f}</td>
<td>
{feature_value:.4f}
</td>
<td>
{feature_weight:.4f}
</td>
<td>
{feature_name}
</td>
</tr>
'''.format(
width=feature['relative_potential'] * 100,
color='green' if feature['potential'] > 0 else 'red',
potential=feature['potential'],
feature_value=feature['value'],
feature_weight=feature['weight'],
feature_name=feature_name
)
)
return HTML('''
<table>
<thead>
<tr>
<th>Relative potential</th>
<th>Potential</th>
<th>Value</th>
<th>Weight</th>
<th>Feature name</th>
</tr>
</thead>
<tbody>
''' + ''.join(rows) + '''
</tbody>
</table>''')
# pylint: disable=unused-argument,no-self-use
[docs] def parse_element(self, element: str, feature_translator: Any = None) -> Tuple[str, str]:
"""
This method is supposed to translate namespace name and feature name to human readable form.
For example, element can be "a_item_id^i123" and result can be ('Item ID', 'News of the day: ID of item is 123')
:param element: namespace name and feature name, e.g. a_item_id^i123
:param feature_translator: Any object that can help you with translation, e.g. some database connection
:return: tuple(human understandable namespace name, human understandable feature name)
"""
splitted = element.split('^')
if len(splitted) == 1:
return '', splitted[0]
return splitted[0], splitted[1]
[docs]class VowpalWabbitDummyFormatter(VowpalWabbitBaseFormatter):
"""
Formatter that assumes that either common features and item features are already formatted VW input format strings.
"""
def format_common_features(self, common_features: str,
debug_info: Optional[Dict[Any, Any]] = None) -> str:
return common_features
def format_item_features(self, common_features: Any, item_features: str,
debug_info: Optional[Dict[Any, Any]] = None) -> str:
return item_features
[docs]class VowpalWabbitBaseModel(ABC):
"""
Declaration of Vowpal Wabbit model interface.
"""
def __init__(self, formatter: VowpalWabbitBaseFormatter):
self.formatter = formatter
super().__init__()
# pylint: disable=too-many-arguments
[docs] @abstractmethod
def predict(
self,
common_features: Any,
items_features: Iterable[Any],
timeout: Optional[float] = None,
debug_info: Any = None,
metrics: Optional[Dict] = None,
detailed_metrics: Optional[Dict] = None
) -> Iterable[Union[float,str]]:
"""
Transforms iterable with item features to iterator of predictions.
:param common_features: Features common for all items
:param items_features: Iterable with features for each item
:param timeout: Optionally specify how much time in seconds is desired for computing predictions.
In case timeout is passed, returned iterator can has less items that items features iterable.
:param debug_info: Some object that can be filled by information useful for debugging.
:param metrics: Optional dict that is populated with some metrics that are good to monitor.
:param detailed_metrics: Optional dict with more detailed (and more time consuming) metrics that are good
for debugging and profiling.
:return: Iterable with predictions for each item from ``items_features``
"""
raise NotImplementedError()
# pylint: disable=too-many-arguments
[docs] @abstractmethod
def train(
self,
common_features: Any,
items_features: Iterable[Any],
labels: Iterable[float],
weights: Iterable[Optional[float]],
debug_info: Any = None
) -> None:
"""
Transform features, label and weight into VW line format and send it to Vowpal.
:param common_features: Features common for all items
:param items_features: Iterable with features for each item
:param labels: Iterable with same length as items features with label for each item
:param weights: Iterable with same length as items features with optional weight for each item
:param debug_info: Some object that can be filled by information useful for debugging
"""
raise NotImplementedError()
[docs] @abstractmethod
def explain_vw_line(self, vw_line: str, link_function: bool = False):
"""
Uses VW audit mode to inspect weights used for prediction. Audit mode has to be turned on
by passing ``audit_mode=True`` to constructor.
:param vw_line: String in VW line format
:param link_function: If your model use link function, pass True
:return: (raw prediction without use of link function, explanation string)
"""
raise NotImplementedError()