diff --git a/firebase_admin/ml.py b/firebase_admin/ml.py index 2613a3de3..bcc4b9390 100644 --- a/firebase_admin/ml.py +++ b/firebase_admin/ml.py @@ -53,6 +53,9 @@ _TAG_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,32}$') _GCS_TFLITE_URI_PATTERN = re.compile( r'^gs://(?P[a-z0-9_.-]{3,63})/(?P.+)$') +_AUTO_ML_MODEL_PATTERN = re.compile( + r'^projects/(?P[a-z0-9-]{6,30})/locations/(?P[^/]+)/' + + r'models/(?P[A-Za-z0-9]+)$') _RESOURCE_NAME_PATTERN = re.compile( r'^projects/(?P[a-z0-9-]{6,30})/models/(?P[A-Za-z0-9_-]{1,60})$') _OPERATION_NAME_PATTERN = re.compile( @@ -75,7 +78,7 @@ def _get_ml_service(app): def create_model(model, app=None): - """Creates a model in Firebase ML. + """Creates a model in the current Firebase project. Args: model: An ml.Model to create. @@ -89,7 +92,7 @@ def create_model(model, app=None): def update_model(model, app=None): - """Updates a model in Firebase ML. + """Updates a model's metadata or model file. Args: model: The ml.Model to update. @@ -103,7 +106,9 @@ def update_model(model, app=None): def publish_model(model_id, app=None): - """Publishes a model in Firebase ML. + """Publishes a Firebase ML model. + + A published model can be downloaded to client apps. Args: model_id: The id of the model to publish. @@ -117,7 +122,7 @@ def publish_model(model_id, app=None): def unpublish_model(model_id, app=None): - """Unpublishes a model in Firebase ML. + """Unpublishes a Firebase ML model. Args: model_id: The id of the model to unpublish. @@ -131,7 +136,7 @@ def unpublish_model(model_id, app=None): def get_model(model_id, app=None): - """Gets a model from Firebase ML. + """Gets the model specified by the given ID. Args: model_id: The id of the model to get. @@ -145,7 +150,7 @@ def get_model(model_id, app=None): def list_models(list_filter=None, page_size=None, page_token=None, app=None): - """Lists models from Firebase ML. + """Lists the current project's models. Args: list_filter: a list filter string such as ``tags:'tag_1'``. None will return all models. @@ -164,7 +169,7 @@ def list_models(list_filter=None, page_size=None, page_token=None, app=None): def delete_model(model_id, app=None): - """Deletes a model from Firebase ML. + """Deletes a model from the current project. Args: model_id: The id of the model you wish to delete. @@ -363,15 +368,10 @@ def __init__(self, model_source=None): def from_dict(cls, data): """Create an instance of the object from a dict.""" data_copy = dict(data) - model_source = None - gcs_tflite_uri = data_copy.pop('gcsTfliteUri', None) - if gcs_tflite_uri: - model_source = TFLiteGCSModelSource(gcs_tflite_uri=gcs_tflite_uri) - tflite_format = TFLiteFormat(model_source=model_source) + tflite_format = TFLiteFormat(model_source=cls._init_model_source(data_copy)) tflite_format._data = data_copy # pylint: disable=protected-access return tflite_format - def __eq__(self, other): if isinstance(other, self.__class__): # pylint: disable=protected-access @@ -381,6 +381,16 @@ def __eq__(self, other): def __ne__(self, other): return not self.__eq__(other) + @staticmethod + def _init_model_source(data): + gcs_tflite_uri = data.pop('gcsTfliteUri', None) + if gcs_tflite_uri: + return TFLiteGCSModelSource(gcs_tflite_uri=gcs_tflite_uri) + auto_ml_model = data.pop('automlModel', None) + if auto_ml_model: + return TFLiteAutoMlSource(auto_ml_model=auto_ml_model) + return None + @property def model_source(self): """The TF Lite model's location.""" @@ -593,8 +603,38 @@ def as_dict(self, for_upload=False): return {'gcsTfliteUri': self._gcs_tflite_uri} +class TFLiteAutoMlSource(TFLiteModelSource): + """TFLite model source representing a tflite model created with AutoML.""" + + def __init__(self, auto_ml_model, app=None): + self._app = app + self.auto_ml_model = auto_ml_model + + def __eq__(self, other): + if isinstance(other, self.__class__): + return self.auto_ml_model == other.auto_ml_model + return False + + def __ne__(self, other): + return not self.__eq__(other) + + @property + def auto_ml_model(self): + """Resource name of the model, created by the AutoML API or Cloud console.""" + return self._auto_ml_model + + @auto_ml_model.setter + def auto_ml_model(self, auto_ml_model): + self._auto_ml_model = _validate_auto_ml_model(auto_ml_model) + + def as_dict(self, for_upload=False): + """Returns a serializable representation of the object.""" + # Upload is irrelevant for auto_ml models + return {'automlModel': self._auto_ml_model} + + class ListModelsPage: - """Represents a page of models in a firebase project. + """Represents a page of models in a Firebase project. Provides methods for traversing the models included in this page, as well as retrieving subsequent pages of models. The iterator returned by @@ -740,6 +780,11 @@ def _validate_gcs_tflite_uri(uri): raise ValueError('GCS TFLite URI format is invalid.') return uri +def _validate_auto_ml_model(model): + if not _AUTO_ML_MODEL_PATTERN.match(model): + raise ValueError('Model resource name format is invalid.') + return model + def _validate_model_format(model_format): if not isinstance(model_format, ModelFormat): diff --git a/integration/test_ml.py b/integration/test_ml.py index 1d32ebed1..52cb1bb7e 100644 --- a/integration/test_ml.py +++ b/integration/test_ml.py @@ -22,6 +22,7 @@ import pytest +import firebase_admin from firebase_admin import exceptions from firebase_admin import ml from tests import testutils @@ -34,6 +35,11 @@ except ImportError: _TF_ENABLED = False +try: + from google.cloud import automl_v1 + _AUTOML_ENABLED = True +except ImportError: + _AUTOML_ENABLED = False def _random_identifier(prefix): #pylint: disable=unused-variable @@ -62,7 +68,6 @@ def _random_identifier(prefix): 'file_name': 'invalid_model.tflite' } - @pytest.fixture def firebase_model(request): args = request.param @@ -101,6 +106,7 @@ def _clean_up_model(model): try: # Try to delete the model. # Some tests delete the model as part of the test. + model.wait_for_unlocked() ml.delete_model(model.model_id) except exceptions.NotFoundError: pass @@ -132,35 +138,45 @@ def check_model(model, args): assert model.locked is False assert model.etag is not None +# Model Format Checks -def check_model_format(model, has_model_format=False, validation_error=None): - if has_model_format: - assert model.validation_error == validation_error - assert model.published is False - assert model.model_format.model_source.gcs_tflite_uri.startswith('gs://') - if validation_error: - assert model.model_format.size_bytes is None - assert model.model_hash is None - else: - assert model.model_format.size_bytes is not None - assert model.model_hash is not None - else: - assert model.model_format is None - assert model.validation_error == 'No model file has been uploaded.' - assert model.published is False +def check_no_model_format(model): + assert model.model_format is None + assert model.validation_error == 'No model file has been uploaded.' + assert model.published is False + assert model.model_hash is None + + +def check_tflite_gcs_format(model, validation_error=None): + assert model.validation_error == validation_error + assert model.published is False + assert model.model_format.model_source.gcs_tflite_uri.startswith('gs://') + if validation_error: + assert model.model_format.size_bytes is None assert model.model_hash is None + else: + assert model.model_format.size_bytes is not None + assert model.model_hash is not None + + +def check_tflite_automl_format(model): + assert model.validation_error is None + assert model.published is False + assert model.model_format.model_source.auto_ml_model.startswith('projects/') + # Automl models don't have validation errors since they are references + # to valid automl models. @pytest.mark.parametrize('firebase_model', [NAME_AND_TAGS_ARGS], indirect=True) def test_create_simple_model(firebase_model): check_model(firebase_model, NAME_AND_TAGS_ARGS) - check_model_format(firebase_model) + check_no_model_format(firebase_model) @pytest.mark.parametrize('firebase_model', [FULL_MODEL_ARGS], indirect=True) def test_create_full_model(firebase_model): check_model(firebase_model, FULL_MODEL_ARGS) - check_model_format(firebase_model, True) + check_tflite_gcs_format(firebase_model) @pytest.mark.parametrize('firebase_model', [FULL_MODEL_ARGS], indirect=True) @@ -175,14 +191,14 @@ def test_create_already_existing_fails(firebase_model): @pytest.mark.parametrize('firebase_model', [INVALID_FULL_MODEL_ARGS], indirect=True) def test_create_invalid_model(firebase_model): check_model(firebase_model, INVALID_FULL_MODEL_ARGS) - check_model_format(firebase_model, True, 'Invalid flatbuffer format') + check_tflite_gcs_format(firebase_model, 'Invalid flatbuffer format') @pytest.mark.parametrize('firebase_model', [NAME_AND_TAGS_ARGS], indirect=True) def test_get_model(firebase_model): get_model = ml.get_model(firebase_model.model_id) check_model(get_model, NAME_AND_TAGS_ARGS) - check_model_format(get_model) + check_no_model_format(get_model) @pytest.mark.parametrize('firebase_model', [NAME_ONLY_ARGS], indirect=True) @@ -201,12 +217,12 @@ def test_update_model(firebase_model): firebase_model.display_name = new_model_name updated_model = ml.update_model(firebase_model) check_model(updated_model, NAME_ONLY_ARGS_UPDATED) - check_model_format(updated_model) + check_no_model_format(updated_model) # Second call with same model does not cause error updated_model2 = ml.update_model(updated_model) check_model(updated_model2, NAME_ONLY_ARGS_UPDATED) - check_model_format(updated_model2) + check_no_model_format(updated_model2) @pytest.mark.parametrize('firebase_model', [NAME_ONLY_ARGS], indirect=True) @@ -290,7 +306,7 @@ def test_delete_model(firebase_model): # Test tensor flow conversion functions if tensor flow is enabled. #'pip install tensorflow' in the environment if you want _TF_ENABLED = True -#'pip install tensorflow==2.0.0b' for version 2 etc. +#'pip install tensorflow==2.2.0' for version 2.2.0 etc. def _clean_up_directory(save_dir): @@ -334,6 +350,7 @@ def saved_model_dir(keras_model): _clean_up_directory(parent) + @pytest.mark.skipif(not _TF_ENABLED, reason='Tensor flow is required for this test.') def test_from_keras_model(keras_model): source = ml.TFLiteGCSModelSource.from_keras_model(keras_model, 'model2.tflite') @@ -348,7 +365,7 @@ def test_from_keras_model(keras_model): try: check_model(created_model, {'display_name': model.display_name}) - check_model_format(created_model, True) + check_tflite_gcs_format(created_model) finally: _clean_up_model(created_model) @@ -371,3 +388,50 @@ def test_from_saved_model(saved_model_dir): assert created_model.validation_error is None finally: _clean_up_model(created_model) + + +# Test AutoML functionality if AutoML is enabled. +#'pip install google-cloud-automl' in the environment if you want _AUTOML_ENABLED = True +# You will also need a predefined AutoML model named 'admin_sdk_integ_test1' to run the +# successful test. (Test is skipped otherwise) + +@pytest.fixture +def automl_model(): + assert _AUTOML_ENABLED + + # It takes > 20 minutes to train a model, so we expect a predefined AutoMl + # model named 'admin_sdk_integ_test1' to exist in the project, or we skip + # the test. + automl_client = automl_v1.AutoMlClient() + project_id = firebase_admin.get_app().project_id + parent = automl_client.location_path(project_id, 'us-central1') + models = automl_client.list_models(parent, filter_="display_name=admin_sdk_integ_test1") + # Expecting exactly one. (Ok to use last one if somehow more than 1) + automl_ref = None + for model in models: + automl_ref = model.name + + # Skip if no pre-defined model. (It takes min > 20 minutes to train a model) + if automl_ref is None: + pytest.skip("No pre-existing AutoML model found. Skipping test") + + source = ml.TFLiteAutoMlSource(automl_ref) + tflite_format = ml.TFLiteFormat(model_source=source) + ml_model = ml.Model( + display_name=_random_identifier('TestModel_automl_'), + tags=['test_automl'], + model_format=tflite_format) + model = ml.create_model(model=ml_model) + yield model + _clean_up_model(model) + +@pytest.mark.skipif(not _AUTOML_ENABLED, reason='AutoML is required for this test.') +def test_automl_model(automl_model): + # This test looks for a predefined automl model with display_name = 'admin_sdk_integ_test1' + automl_model.wait_for_unlocked() + + check_model(automl_model, { + 'display_name': automl_model.display_name, + 'tags': ['test_automl'], + }) + check_tflite_automl_format(automl_model) diff --git a/requirements.txt b/requirements.txt index dbeaee3b6..1a55482da 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,5 +7,6 @@ pytest-localserver >= 0.4.1 cachecontrol >= 0.12.6 google-api-core[grpc] >= 1.14.0, < 2.0.0dev; platform.python_implementation != 'PyPy' google-api-python-client >= 1.7.8 +google-auth == 1.18.0 # temporary workaround google-cloud-firestore >= 1.4.0; platform.python_implementation != 'PyPy' google-cloud-storage >= 1.18.0 diff --git a/tests/test_ml.py b/tests/test_ml.py index 10b0441db..abd6d06f9 100644 --- a/tests/test_ml.py +++ b/tests/test_ml.py @@ -122,6 +122,18 @@ } TFLITE_FORMAT_2 = ml.TFLiteFormat.from_dict(TFLITE_FORMAT_JSON_2) +AUTOML_MODEL_NAME = 'projects/111111111111/locations/us-central1/models/ICN7683346839371803263' +AUTOML_MODEL_SOURCE = ml.TFLiteAutoMlSource(AUTOML_MODEL_NAME) +TFLITE_FORMAT_JSON_3 = { + 'automlModel': AUTOML_MODEL_NAME, + 'sizeBytes': '3456789' +} +TFLITE_FORMAT_3 = ml.TFLiteFormat.from_dict(TFLITE_FORMAT_JSON_3) + +AUTOML_MODEL_NAME_2 = 'projects/2222222222/locations/us-central1/models/ICN2222222222222222222' +AUTOML_MODEL_NAME_JSON_2 = {'automlModel': AUTOML_MODEL_NAME_2} +AUTOML_MODEL_SOURCE_2 = ml.TFLiteAutoMlSource(AUTOML_MODEL_NAME_2) + CREATED_UPDATED_MODEL_JSON_1 = { 'name': MODEL_NAME_1, 'displayName': DISPLAY_NAME_1, @@ -405,7 +417,15 @@ def test_model_keyword_based_creation_and_setters(self): 'tfliteModel': TFLITE_FORMAT_JSON_2 } - def test_model_format_source_creation(self): + model.model_format = TFLITE_FORMAT_3 + assert model.as_dict() == { + 'displayName': DISPLAY_NAME_2, + 'tags': TAGS_2, + 'tfliteModel': TFLITE_FORMAT_JSON_3 + } + + + def test_gcs_tflite_model_format_source_creation(self): model_source = ml.TFLiteGCSModelSource(gcs_tflite_uri=GCS_TFLITE_URI) model_format = ml.TFLiteFormat(model_source=model_source) model = ml.Model(display_name=DISPLAY_NAME_1, model_format=model_format) @@ -416,6 +436,17 @@ def test_model_format_source_creation(self): } } + def test_auto_ml_tflite_model_format_source_creation(self): + model_source = ml.TFLiteAutoMlSource(auto_ml_model=AUTOML_MODEL_NAME) + model_format = ml.TFLiteFormat(model_source=model_source) + model = ml.Model(display_name=DISPLAY_NAME_1, model_format=model_format) + assert model.as_dict() == { + 'displayName': DISPLAY_NAME_1, + 'tfliteModel': { + 'automlModel': AUTOML_MODEL_NAME + } + } + def test_source_creation_from_tflite_file(self): model_source = ml.TFLiteGCSModelSource.from_tflite_model_file( "my_model.tflite", "my_bucket") @@ -423,12 +454,19 @@ def test_source_creation_from_tflite_file(self): 'gcsTfliteUri': 'gs://my_bucket/Firebase/ML/Models/my_model.tflite' } - def test_model_source_setters(self): + def test_gcs_tflite_model_source_setters(self): model_source = ml.TFLiteGCSModelSource(GCS_TFLITE_URI) model_source.gcs_tflite_uri = GCS_TFLITE_URI_2 assert model_source.gcs_tflite_uri == GCS_TFLITE_URI_2 assert model_source.as_dict() == GCS_TFLITE_URI_JSON_2 + def test_auto_ml_tflite_model_source_setters(self): + model_source = ml.TFLiteAutoMlSource(AUTOML_MODEL_NAME) + model_source.auto_ml_model = AUTOML_MODEL_NAME_2 + assert model_source.auto_ml_model == AUTOML_MODEL_NAME_2 + assert model_source.as_dict() == AUTOML_MODEL_NAME_JSON_2 + + def test_model_format_setters(self): model_format = ml.TFLiteFormat(model_source=GCS_TFLITE_MODEL_SOURCE) model_format.model_source = GCS_TFLITE_MODEL_SOURCE_2 @@ -439,6 +477,14 @@ def test_model_format_setters(self): } } + model_format.model_source = AUTOML_MODEL_SOURCE + assert model_format.model_source == AUTOML_MODEL_SOURCE + assert model_format.as_dict() == { + 'tfliteModel': { + 'automlModel': AUTOML_MODEL_NAME + } + } + def test_model_as_dict_for_upload(self): model_source = ml.TFLiteGCSModelSource(gcs_tflite_uri=GCS_TFLITE_URI) model_format = ml.TFLiteFormat(model_source=model_source) @@ -524,6 +570,23 @@ def test_gcs_tflite_source_validation_errors(self, uri, exc_type): ml.TFLiteGCSModelSource(gcs_tflite_uri=uri) check_error(excinfo, exc_type) + @pytest.mark.parametrize('auto_ml_model, exc_type', [ + (123, TypeError), + ('abc', ValueError), + ('/projects/123456/locations/us-central1/models/noLeadingSlash', ValueError), + ('projects/123546/models/ICN123456', ValueError), + ('projects//locations/us-central1/models/ICN123456', ValueError), + ('projects/123456/locations//models/ICN123456', ValueError), + ('projects/123456/locations/us-central1/models/', ValueError), + ('projects/ABC/locations/us-central1/models/ICN123456', ValueError), + ('projects/123456/locations/us-central1/models/@#$%^&', ValueError), + ('projects/123456/locations/us-cent/ral1/models/ICN123456', ValueError), + ]) + def test_auto_ml_tflite_source_validation_errors(self, auto_ml_model, exc_type): + with pytest.raises(exc_type) as excinfo: + ml.TFLiteAutoMlSource(auto_ml_model=auto_ml_model) + check_error(excinfo, exc_type) + def test_wait_for_unlocked_not_locked(self): model = ml.Model(display_name="not_locked") model.wait_for_unlocked()