diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index 91cedbedc..8e78a26ce 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -18,6 +18,7 @@ deleting, publishing and unpublishing Firebase ML Kit models. """ + import datetime import numbers import re @@ -30,13 +31,27 @@ from firebase_admin import _utils from firebase_admin import exceptions +# pylint: disable=import-error,no-name-in-module +try: + from firebase_admin import storage + _GCS_ENABLED = True +except ImportError: + _GCS_ENABLED = False + +# pylint: disable=import-error,no-name-in-module +try: + import tensorflow as tf + _TF_ENABLED = True +except ImportError: + _TF_ENABLED = False _MLKIT_ATTRIBUTE = '_mlkit' _MAX_PAGE_SIZE = 100 _MODEL_ID_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,60}$') _DISPLAY_NAME_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,60}$') _TAG_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,60}$') -_GCS_TFLITE_URI_PATTERN = re.compile(r'^gs://[a-z0-9_.-]{3,63}/.+') +_GCS_TFLITE_URI_PATTERN = re.compile( + r'^gs://(?P[a-z0-9_.-]{3,63})/(?P.+)$') _RESOURCE_NAME_PATTERN = re.compile( r'^projects/(?P[^/]+)/models/(?P[A-Za-z0-9_-]{1,60})$') _OPERATION_NAME_PATTERN = re.compile( @@ -301,16 +316,16 @@ def model_format(self, model_format): self._model_format = model_format #Can be None return self - def as_dict(self): + def as_dict(self, for_upload=False): copy = dict(self._data) if self._model_format: - copy.update(self._model_format.as_dict()) + copy.update(self._model_format.as_dict(for_upload=for_upload)) return copy class ModelFormat(object): """Abstract base class representing a Model Format such as TFLite.""" - def as_dict(self): + def as_dict(self, for_upload=False): raise NotImplementedError @@ -364,22 +379,70 @@ def model_source(self, model_source): def size_bytes(self): return self._data.get('sizeBytes') - def as_dict(self): + def as_dict(self, for_upload=False): copy = dict(self._data) if self._model_source: - copy.update(self._model_source.as_dict()) + copy.update(self._model_source.as_dict(for_upload=for_upload)) return {'tfliteModel': copy} class TFLiteModelSource(object): """Abstract base class representing a model source for TFLite format models.""" - def as_dict(self): + def as_dict(self, for_upload=False): raise NotImplementedError +class _CloudStorageClient(object): + """Cloud Storage helper class""" + + GCS_URI = 'gs://{0}/{1}' + BLOB_NAME = 'Firebase/MLKit/Models/{0}' + + @staticmethod + def _assert_gcs_enabled(): + if not _GCS_ENABLED: + raise ImportError('Failed to import the Cloud Storage library for Python. Make sure ' + 'to install the "google-cloud-storage" module.') + + @staticmethod + def _parse_gcs_tflite_uri(uri): + # GCS Bucket naming rules are complex. The regex is not comprehensive. + # See https://cloud.google.com/storage/docs/naming for full details. + matcher = _GCS_TFLITE_URI_PATTERN.match(uri) + if not matcher: + raise ValueError('GCS TFLite URI format is invalid.') + return matcher.group('bucket_name'), matcher.group('blob_name') + + @staticmethod + def upload(bucket_name, model_file_name, app): + _CloudStorageClient._assert_gcs_enabled() + bucket = storage.bucket(bucket_name, app=app) + blob_name = _CloudStorageClient.BLOB_NAME.format(model_file_name) + blob = bucket.blob(blob_name) + blob.upload_from_filename(model_file_name) + return _CloudStorageClient.GCS_URI.format(bucket.name, blob_name) + + @staticmethod + def sign_uri(gcs_tflite_uri, app): + """Makes the gcs_tflite_uri readable for GET for 10 minutes via signed_uri.""" + _CloudStorageClient._assert_gcs_enabled() + bucket_name, blob_name = _CloudStorageClient._parse_gcs_tflite_uri(gcs_tflite_uri) + bucket = storage.bucket(bucket_name, app=app) + blob = bucket.blob(blob_name) + return blob.generate_signed_url( + version='v4', + expiration=datetime.timedelta(minutes=10), + method='GET' + ) + + class TFLiteGCSModelSource(TFLiteModelSource): """TFLite model source representing a tflite model file stored in GCS.""" - def __init__(self, gcs_tflite_uri): + + _STORAGE_CLIENT = _CloudStorageClient() + + def __init__(self, gcs_tflite_uri, app=None): + self._app = app self._gcs_tflite_uri = _validate_gcs_tflite_uri(gcs_tflite_uri) def __eq__(self, other): @@ -391,6 +454,81 @@ def __eq__(self, other): def __ne__(self, other): return not self.__eq__(other) + @classmethod + def from_tflite_model_file(cls, model_file_name, bucket_name=None, app=None): + """Uploads the model file to an existing Google Cloud Storage bucket. + + Args: + model_file_name: The name of the model file. + bucket_name: The name of an existing bucket. None to use the default bucket configured + in the app. + app: A Firebase app instance (or None to use the default app). + + Returns: + TFLiteGCSModelSource: The source created from the model_file + + Raises: + ImportError: If the Cloud Storage Library has not been installed. + """ + gcs_uri = TFLiteGCSModelSource._STORAGE_CLIENT.upload(bucket_name, model_file_name, app) + return TFLiteGCSModelSource(gcs_tflite_uri=gcs_uri, app=app) + + @staticmethod + def _assert_tf_version_1_enabled(): + if not _TF_ENABLED: + raise ImportError('Failed to import the tensorflow library for Python. Make sure ' + 'to install the tensorflow module.') + if not tf.VERSION.startswith('1.'): + raise ImportError('Expected tensorflow version 1.x, but found {0}'.format(tf.VERSION)) + + @classmethod + def from_saved_model(cls, saved_model_dir, bucket_name=None, app=None): + """Creates a Tensor Flow Lite model from the saved model, and uploads the model to GCS. + + Args: + saved_model_dir: The saved model directory. + bucket_name: The name of an existing bucket. None to use the default bucket configured + in the app. + app: Optional. A Firebase app instance (or None to use the default app) + + Returns: + TFLiteGCSModelSource: The source created from the saved_model_dir + + Raises: + ImportError: If the Tensor Flow or Cloud Storage Libraries have not been installed. + """ + TFLiteGCSModelSource._assert_tf_version_1_enabled() + converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) + tflite_model = converter.convert() + open('firebase_mlkit_model.tflite', 'wb').write(tflite_model) + return TFLiteGCSModelSource.from_tflite_model_file( + 'firebase_mlkit_model.tflite', bucket_name, app) + + @classmethod + def from_keras_model(cls, keras_model, bucket_name=None, app=None): + """Creates a Tensor Flow Lite model from the keras model, and uploads the model to GCS. + + Args: + keras_model: A tf.keras model. + bucket_name: The name of an existing bucket. None to use the default bucket configured + in the app. + app: Optional. A Firebase app instance (or None to use the default app) + + Returns: + TFLiteGCSModelSource: The source created from the keras_model + + Raises: + ImportError: If the Tensor Flow or Cloud Storage Libraries have not been installed. + """ + TFLiteGCSModelSource._assert_tf_version_1_enabled() + keras_file = 'keras_model.h5' + tf.keras.models.save_model(keras_model, keras_file) + converter = tf.lite.TFLiteConverter.from_keras_model_file(keras_file) + tflite_model = converter.convert() + open('firebase_mlkit_model.tflite', 'wb').write(tflite_model) + return TFLiteGCSModelSource.from_tflite_model_file( + 'firebase_mlkit_model.tflite', bucket_name, app) + @property def gcs_tflite_uri(self): return self._gcs_tflite_uri @@ -399,10 +537,15 @@ def gcs_tflite_uri(self): def gcs_tflite_uri(self, gcs_tflite_uri): self._gcs_tflite_uri = _validate_gcs_tflite_uri(gcs_tflite_uri) - def as_dict(self): - return {"gcsTfliteUri": self._gcs_tflite_uri} + def _get_signed_gcs_tflite_uri(self): + """Signs the GCS uri, so the model file can be uploaded to Firebase ML Kit and verified.""" + return TFLiteGCSModelSource._STORAGE_CLIENT.sign_uri(self._gcs_tflite_uri, self._app) + + def as_dict(self, for_upload=False): + if for_upload: + return {'gcsTfliteUri': self._get_signed_gcs_tflite_uri()} - #TODO(ifielker): implement from_saved_model etc. + return {'gcsTfliteUri': self._gcs_tflite_uri} class ListModelsPage(object): @@ -671,13 +814,13 @@ def create_model(self, model): _validate_model(model) try: return self.handle_operation( - self._client.body('post', url='models', json=model.as_dict())) + self._client.body('post', url='models', json=model.as_dict(for_upload=True))) except requests.exceptions.RequestException as error: raise _utils.handle_platform_error_from_requests(error) def update_model(self, model, update_mask=None): _validate_model(model, update_mask) - data = {'model': model.as_dict()} + data = {'model': model.as_dict(for_upload=True)} if update_mask is not None: data['updateMask'] = update_mask try: diff --git a/tests/test_mlkit.py b/tests/test_mlkit.py index 50fed4e1b..26afdfa99 100644 --- a/tests/test_mlkit.py +++ b/tests/test_mlkit.py @@ -103,7 +103,9 @@ } } -GCS_TFLITE_URI = 'gs://my_bucket/mymodel.tflite' +GCS_BUCKET_NAME = 'my_bucket' +GCS_BLOB_NAME = 'mymodel.tflite' +GCS_TFLITE_URI = 'gs://{0}/{1}'.format(GCS_BUCKET_NAME, GCS_BLOB_NAME) GCS_TFLITE_URI_JSON = {'gcsTfliteUri': GCS_TFLITE_URI} GCS_TFLITE_MODEL_SOURCE = mlkit.TFLiteGCSModelSource(GCS_TFLITE_URI) TFLITE_FORMAT_JSON = { @@ -112,6 +114,10 @@ } TFLITE_FORMAT = mlkit.TFLiteFormat.from_dict(TFLITE_FORMAT_JSON) +GCS_TFLITE_SIGNED_URI_PATTERN = ( + 'https://storage.googleapis.com/{0}/{1}?X-Goog-Algorithm=GOOG4-RSA-SHA256&foo') +GCS_TFLITE_SIGNED_URI = GCS_TFLITE_SIGNED_URI_PATTERN.format(GCS_BUCKET_NAME, GCS_BLOB_NAME) + GCS_TFLITE_URI_2 = 'gs://my_bucket/mymodel2.tflite' GCS_TFLITE_URI_JSON_2 = {'gcsTfliteUri': GCS_TFLITE_URI_2} GCS_TFLITE_MODEL_SOURCE_2 = mlkit.TFLiteGCSModelSource(GCS_TFLITE_URI_2) @@ -325,6 +331,18 @@ def instrument_mlkit_service(status=200, payload=None, operations=False, app=Non session_url, adapter(payload, status, recorder)) return recorder +class _TestStorageClient(object): + @staticmethod + def upload(bucket_name, model_file_name, app): + del app # unused variable + blob_name = mlkit._CloudStorageClient.BLOB_NAME.format(model_file_name) + return mlkit._CloudStorageClient.GCS_URI.format(bucket_name, blob_name) + + @staticmethod + def sign_uri(gcs_tflite_uri, app): + del app # unused variable + bucket_name, blob_name = mlkit._CloudStorageClient._parse_gcs_tflite_uri(gcs_tflite_uri) + return GCS_TFLITE_SIGNED_URI_PATTERN.format(bucket_name, blob_name) class TestModel(object): """Tests mlkit.Model class.""" @@ -333,6 +351,7 @@ def setup_class(cls): cred = testutils.MockCredential() firebase_admin.initialize_app(cred, {'projectId': PROJECT_ID}) mlkit._MLKitService.POLL_BASE_WAIT_TIME_SECONDS = 0.1 # shorter for test + mlkit.TFLiteGCSModelSource._STORAGE_CLIENT = _TestStorageClient() @classmethod def teardown_class(cls): @@ -404,6 +423,13 @@ def test_model_format_source_creation(self): } } + def test_source_creation_from_tflite_file(self): + model_source = mlkit.TFLiteGCSModelSource.from_tflite_model_file( + "my_model.tflite", "my_bucket") + assert model_source.as_dict() == { + 'gcsTfliteUri': 'gs://my_bucket/Firebase/MLKit/Models/my_model.tflite' + } + def test_model_source_setters(self): model_source = mlkit.TFLiteGCSModelSource(GCS_TFLITE_URI) model_source.gcs_tflite_uri = GCS_TFLITE_URI_2 @@ -420,6 +446,27 @@ def test_model_format_setters(self): } } + def test_model_as_dict_for_upload(self): + model_source = mlkit.TFLiteGCSModelSource(gcs_tflite_uri=GCS_TFLITE_URI) + model_format = mlkit.TFLiteFormat(model_source=model_source) + model = mlkit.Model(display_name=DISPLAY_NAME_1, model_format=model_format) + assert model.as_dict(for_upload=True) == { + 'displayName': DISPLAY_NAME_1, + 'tfliteModel': { + 'gcsTfliteUri': GCS_TFLITE_SIGNED_URI + } + } + + @pytest.mark.parametrize('helper_func', [ + mlkit.TFLiteGCSModelSource.from_keras_model, + mlkit.TFLiteGCSModelSource.from_saved_model + ]) + def test_tf_not_enabled(self, helper_func): + mlkit._TF_ENABLED = False # for reliability + with pytest.raises(ImportError) as excinfo: + helper_func(None) + check_error(excinfo, ImportError) + @pytest.mark.parametrize('display_name, exc_type', [ ('', ValueError), ('&_*#@:/?', ValueError), @@ -803,6 +850,7 @@ def test_rpc_error(self, publish_function): ) assert len(create_recorder) == 1 + class TestGetModel(object): """Tests mlkit.get_model.""" @classmethod