diff --git a/optimizely/entities.py b/optimizely/entities.py index fed1a49a..7d257656 100644 --- a/optimizely/entities.py +++ b/optimizely/entities.py @@ -22,7 +22,7 @@ if TYPE_CHECKING: # prevent circular dependenacy by skipping import at runtime - from .helpers.types import ExperimentDict, TrafficAllocation, VariableDict, VariationDict + from .helpers.types import ExperimentDict, TrafficAllocation, VariableDict, VariationDict, CmabDict class BaseEntity: @@ -84,6 +84,7 @@ def __init__( audienceConditions: Optional[Sequence[str | list[str]]] = None, groupId: Optional[str] = None, groupPolicy: Optional[str] = None, + cmab: Optional[CmabDict] = None, **kwargs: Any ): self.id = id @@ -97,6 +98,7 @@ def __init__( self.layerId = layerId self.groupId = groupId self.groupPolicy = groupPolicy + self.cmab = cmab def get_audience_conditions_or_ids(self) -> Sequence[str | list[str]]: """ Returns audienceConditions if present, otherwise audienceIds. """ diff --git a/optimizely/helpers/types.py b/optimizely/helpers/types.py index a28aca67..3cca45de 100644 --- a/optimizely/helpers/types.py +++ b/optimizely/helpers/types.py @@ -109,3 +109,9 @@ class IntegrationDict(BaseEntity): key: str host: str publicKey: str + + +class CmabDict(BaseEntity): + """Cmab dict from parsed datafile json.""" + attributeIds: list[str] + trafficAllocation: int diff --git a/optimizely/project_config.py b/optimizely/project_config.py index adfeee41..f2b1467b 100644 --- a/optimizely/project_config.py +++ b/optimizely/project_config.py @@ -94,7 +94,9 @@ def __init__(self, datafile: str | bytes, logger: Logger, error_handler: Any): self.attribute_key_map: dict[str, entities.Attribute] = self._generate_key_map( self.attributes, 'key', entities.Attribute ) - + self.attribute_id_to_key_map: dict[str, str] = {} + for attribute in self.attributes: + self.attribute_id_to_key_map[attribute['id']] = attribute['key'] self.audience_id_map: dict[str, entities.Audience] = self._generate_key_map( self.audiences, 'id', entities.Audience ) @@ -510,6 +512,34 @@ def get_attribute_id(self, attribute_key: str) -> Optional[str]: self.error_handler.handle_error(exceptions.InvalidAttributeException(enums.Errors.INVALID_ATTRIBUTE)) return None + def get_attribute_by_key(self, key: str) -> Optional[entities.Attribute]: + """ Get attribute for the provided attribute key. + + Args: + key: Attribute key for which attribute is to be fetched. + + Returns: + Attribute corresponding to the provided attribute key. + """ + if key in self.attribute_key_map: + return self.attribute_key_map[key] + self.logger.error(f'Attribute with key:"{key}" is not in datafile.') + return None + + def get_attribute_key_by_id(self, id: str) -> Optional[str]: + """ Get attribute key for the provided attribute id. + + Args: + id: Attribute id for which attribute is to be fetched. + + Returns: + Attribute key corresponding to the provided attribute id. + """ + if id in self.attribute_id_to_key_map: + return self.attribute_id_to_key_map[id] + self.logger.error(f'Attribute with id:"{id}" is not in datafile.') + return None + def get_feature_from_key(self, feature_key: str) -> Optional[entities.FeatureFlag]: """ Get feature for the provided feature key. diff --git a/tests/test_config.py b/tests/test_config.py index 9a16035d..9ec5c761 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -154,6 +154,23 @@ def test_init(self): self.assertEqual(expected_variation_key_map, self.project_config.variation_key_map) self.assertEqual(expected_variation_id_map, self.project_config.variation_id_map) + def test_cmab_field_population(self): + """ Test that the cmab field is populated correctly in experiments.""" + + # Deep copy existing datafile and add cmab config to the first experiment + config_dict = copy.deepcopy(self.config_dict_with_multiple_experiments) + config_dict['experiments'][0]['cmab'] = {'attributeIds': ['808797688', '808797689'], 'trafficAllocation': 4000} + config_dict['experiments'][0]['trafficAllocation'] = [] + + opt_obj = optimizely.Optimizely(json.dumps(config_dict)) + project_config = opt_obj.config_manager.get_config() + + experiment = project_config.get_experiment_from_key('test_experiment') + self.assertEqual(experiment.cmab, {'attributeIds': ['808797688', '808797689'], 'trafficAllocation': 4000}) + + experiment_2 = project_config.get_experiment_from_key('test_experiment_2') + self.assertIsNone(experiment_2.cmab) + def test_init__with_v4_datafile(self): """ Test that on creating object, properties are initiated correctly for version 4 datafile. """