diff --git a/src/mysql_to_sqlite3/mysql_utils.py b/src/mysql_to_sqlite3/mysql_utils.py index 334a5d1..a2fff34 100644 --- a/src/mysql_to_sqlite3/mysql_utils.py +++ b/src/mysql_to_sqlite3/mysql_utils.py @@ -1,8 +1,10 @@ """Miscellaneous MySQL utilities.""" import typing as t +from collections import defaultdict, deque from mysql.connector import CharacterSet +from mysql.connector.abstracts import MySQLConnectionAbstract, MySQLCursorAbstract from mysql.connector.charsets import MYSQL_CHARACTER_SETS @@ -39,3 +41,126 @@ def mysql_supported_character_sets(charset: t.Optional[str] = None) -> t.Iterato yield CharSet(index, charset, info[1]) except KeyError: continue + + +def fetch_schema_metadata(cursor: MySQLCursorAbstract) -> t.Tuple[t.Set[str], t.List[t.Tuple[str, str]]]: + """Fetch schema metadata from the database. + + Returns: + tables: all base tables in `schema` + edges: list of (child, parent) pairs for every FK + """ + # 1. all ordinary tables + cursor.execute( + """ + SELECT TABLE_NAME + FROM information_schema.TABLES + WHERE TABLE_SCHEMA = SCHEMA() + AND TABLE_TYPE = 'BASE TABLE'; + """ + ) + # Use a more explicit approach to handle the row data + tables: t.Set[str] = set() + for row in cursor.fetchall(): + # Extract table name from row + table_name: str + try: + # Try to get the first element + first_element = row[0] if isinstance(row, (list, tuple)) else row + table_name = str(first_element) if first_element is not None else "" + except (IndexError, TypeError): + # If that fails, try other approaches + if hasattr(row, "TABLE_NAME"): + table_name = str(row.TABLE_NAME) if row.TABLE_NAME is not None else "" + else: + table_name = str(row) if row is not None else "" + tables.add(table_name) + + # 2. FK edges (child -> parent) + cursor.execute( + """ + SELECT TABLE_NAME AS child, REFERENCED_TABLE_NAME AS parent + FROM information_schema.KEY_COLUMN_USAGE + WHERE TABLE_SCHEMA = SCHEMA() + AND REFERENCED_TABLE_NAME IS NOT NULL; + """ + ) + # Use a more explicit approach to handle the row data + edges: t.List[t.Tuple[str, str]] = [] + for row in cursor.fetchall(): + # Extract child and parent from row + child: str + parent: str + try: + # Try to get the elements as sequence + if isinstance(row, (list, tuple)) and len(row) >= 2: + child = str(row[0]) if row[0] is not None else "" + parent = str(row[1]) if row[1] is not None else "" + # Try to access as dictionary or object + elif hasattr(row, "child") and hasattr(row, "parent"): + child = str(row.child) if row.child is not None else "" + parent = str(row.parent) if row.parent is not None else "" + # Try to access as dictionary with string keys + elif isinstance(row, dict) and "child" in row and "parent" in row: + child = str(row["child"]) if row["child"] is not None else "" + parent = str(row["parent"]) if row["parent"] is not None else "" + else: + # Skip if we can't extract the data + continue + except (IndexError, TypeError, KeyError): + # Skip if any error occurs + continue + + edges.append((child, parent)) + + return tables, edges + + +def topo_sort_tables( + tables: t.Set[str], edges: t.List[t.Tuple[str, str]] +) -> t.Tuple[t.List[str], t.List[t.Tuple[str, str]]]: + """Perform a topological sort on tables based on foreign key dependencies. + + Returns: + ordered: tables in FK-safe creation order + cyclic_edges: any edges that keep the graph cyclic (empty if a pure DAG) + """ + # dependency graph: child → {parents} + deps: t.Dict[str, t.Set[str]] = {tbl: set() for tbl in tables} + # reverse edges: parent → {children} + rev: t.Dict[str, t.Set[str]] = defaultdict(set) + + for child, parent in edges: + deps[child].add(parent) + rev[parent].add(child) + + queue: deque[str] = deque(tbl for tbl, parents in deps.items() if not parents) + ordered: t.List[str] = [] + + while queue: + table = queue.popleft() + ordered.append(table) + # "remove" table from graph + for child in rev[table]: + deps[child].discard(table) + if not deps[child]: + queue.append(child) + + # any table still having parents is in a cycle + cyclic_edges: t.List[t.Tuple[str, str]] = [ + (child, parent) for child, parents in deps.items() if parents for parent in parents + ] + return ordered, cyclic_edges + + +def compute_creation_order(mysql_conn: MySQLConnectionAbstract) -> t.Tuple[t.List[str], t.List[t.Tuple[str, str]]]: + """Compute the table creation order respecting foreign key constraints. + + Returns: + A tuple (ordered_tables, cyclic_edges) where cyclic_edges is empty when the schema is acyclic. + """ + with mysql_conn.cursor() as cur: + tables: t.Set[str] + edges: t.List[t.Tuple[str, str]] + tables, edges = fetch_schema_metadata(cur) + return topo_sort_tables(tables, edges) diff --git a/src/mysql_to_sqlite3/transporter.py b/src/mysql_to_sqlite3/transporter.py index e98f842..562ccc2 100644 --- a/src/mysql_to_sqlite3/transporter.py +++ b/src/mysql_to_sqlite3/transporter.py @@ -18,7 +18,7 @@ from mysql.connector.types import RowItemType from tqdm import tqdm, trange -from mysql_to_sqlite3.mysql_utils import CHARSET_INTRODUCERS +from mysql_to_sqlite3.mysql_utils import CHARSET_INTRODUCERS, compute_creation_order from mysql_to_sqlite3.sqlite_utils import ( CollatingSequences, Integer_Types, @@ -678,14 +678,45 @@ def transfer(self) -> None: ) tables = (row[0].decode() for row in self._mysql_cur.fetchall()) # type: ignore[union-attr] + # Convert tables iterable to a list for reuse + table_list: t.List[str] = [] + for table_name in tables: + if isinstance(table_name, bytes): + table_name = table_name.decode() + # Ensure table_name is a string + table_str = str(table_name) if table_name is not None else "" + table_list.append(table_str) + + # Try to compute the table creation order to respect foreign key constraints + try: + if hasattr(self, "_mysql"): + # Compute the table creation order to respect foreign key constraints + ordered_tables: t.List[str] + cyclic_edges: t.List[t.Tuple[str, str]] + ordered_tables, cyclic_edges = compute_creation_order(self._mysql) + + # Filter ordered_tables to only include tables we want to transfer + ordered_tables = [table for table in ordered_tables if table in table_list] + + # Log information about cyclic dependencies + if cyclic_edges: + self._logger.warning( + "Circular foreign key dependencies detected: %s", + ", ".join(f"{child} -> {parent}" for child, parent in cyclic_edges), + ) + else: + # If _mysql attribute is not available (e.g., in tests), use the original table list + ordered_tables = table_list + except Exception as e: # pylint: disable=W0718 + # If anything goes wrong, fall back to the original table list + self._logger.warning("Failed to compute table creation order: %s", str(e)) + ordered_tables = table_list + try: # turn off foreign key checking in SQLite while transferring data self._sqlite_cur.execute("PRAGMA foreign_keys=OFF") - for table_name in tables: - if isinstance(table_name, bytes): - table_name = table_name.decode() - + for table_name in ordered_tables: self._logger.info( "%s%sTransferring table %s", "[WITHOUT DATA] " if self._without_data else "", @@ -749,6 +780,12 @@ def transfer(self) -> None: # re-enable foreign key checking once done transferring self._sqlite_cur.execute("PRAGMA foreign_keys=ON") + # Check for any foreign key constraint violations + self._sqlite_cur.execute("PRAGMA foreign_key_check") + fk_violations: t.List[sqlite3.Row] = self._sqlite_cur.fetchall() + if fk_violations: + self._logger.warning("Foreign key constraint violations detected: %s", fk_violations) + if self._vacuum: self._logger.info("Vacuuming created SQLite database file.\nThis might take a while.") self._sqlite_cur.execute("VACUUM") diff --git a/tests/unit/test_mysql_utils.py b/tests/unit/test_mysql_utils.py index e22b07b..370c8b9 100644 --- a/tests/unit/test_mysql_utils.py +++ b/tests/unit/test_mysql_utils.py @@ -2,14 +2,15 @@ import typing as t from unittest import mock - -import pytest -from mysql.connector import CharacterSet +from unittest.mock import MagicMock from mysql_to_sqlite3.mysql_utils import ( CHARSET_INTRODUCERS, CharSet, + compute_creation_order, + fetch_schema_metadata, mysql_supported_character_sets, + topo_sort_tables, ) @@ -172,3 +173,249 @@ def __getitem__(self, key): # Now test with a specific charset to cover both branches results = list(mysql_supported_character_sets("utf8")) assert len(results) == 0 + + def test_topo_sort_tables_acyclic(self) -> None: + """Test topo_sort_tables with an acyclic graph.""" + # Setup a simple acyclic graph + # users -> posts (posts references users) + # users -> comments (comments references users) + # posts -> comments (comments references posts) + tables: t.Set[str] = {"users", "posts", "comments"} + edges: t.List[t.Tuple[str, str]] = [ + ("posts", "users"), # posts depends on users + ("comments", "users"), # comments depends on users + ("comments", "posts"), # comments depends on posts + ] + + ordered: t.List[str] + cyclic_edges: t.List[t.Tuple[str, str]] + ordered, cyclic_edges = topo_sort_tables(tables, edges) + + # Check that the result is a valid topological sort + assert len(ordered) == 3 # All tables are included + assert len(cyclic_edges) == 0 # No cyclic edges + + # Check that dependencies are respected + users_idx: int = ordered.index("users") + posts_idx: int = ordered.index("posts") + comments_idx: int = ordered.index("comments") + + assert users_idx < posts_idx # users comes before posts + assert users_idx < comments_idx # users comes before comments + assert posts_idx < comments_idx # posts comes before comments + + def test_topo_sort_tables_cyclic(self) -> None: + """Test topo_sort_tables with a cyclic graph.""" + # Setup a graph with a cycle + # users -> posts -> comments -> users (circular dependency) + tables: t.Set[str] = {"users", "posts", "comments"} + edges: t.List[t.Tuple[str, str]] = [ + ("posts", "users"), # posts depends on users + ("comments", "posts"), # comments depends on posts + ("users", "comments"), # users depends on comments (creates a cycle) + ] + + ordered: t.List[str] + cyclic_edges: t.List[t.Tuple[str, str]] + ordered, cyclic_edges = topo_sort_tables(tables, edges) + + # In a fully cyclic graph, no tables can be ordered without breaking cycles + # So the ordered list may be empty + + # Check that cyclic edges are detected + assert len(cyclic_edges) > 0 # At least one cyclic edge + + # The cyclic edges should be from the edges we defined + for edge in cyclic_edges: + assert edge in edges + + # Verify that all tables in the cycle are accounted for in cyclic_edges + cycle_tables: t.Set[str] = set() + for child, parent in cyclic_edges: + cycle_tables.add(child) + cycle_tables.add(parent) + + # All tables should be part of the cycle or in the ordered list + assert cycle_tables.union(set(ordered)) == tables + + def test_topo_sort_tables_empty(self) -> None: + """Test topo_sort_tables with empty input.""" + tables: t.Set[str] = set() + edges: t.List[t.Tuple[str, str]] = [] + + ordered, cyclic_edges = topo_sort_tables(tables, edges) + + assert ordered == [] + assert cyclic_edges == [] + + def test_fetch_schema_metadata(self) -> None: + """Test fetch_schema_metadata function.""" + # Create a mock cursor + mock_cursor: MagicMock = mock.MagicMock() + + # Mock the first query result (tables) + mock_cursor.fetchall.side_effect = [ + # First call returns table names + [("users",), ("posts",), ("comments",)], + # Second call returns foreign key relationships + [("posts", "users"), ("comments", "users"), ("comments", "posts")], + ] + + # Call the function + tables: t.Set[str] + edges: t.List[t.Tuple[str, str]] + tables, edges = fetch_schema_metadata(mock_cursor) + + # Verify the cursor was called with the expected queries + assert mock_cursor.execute.call_count == 2 + + # Check the results + assert tables == {"users", "posts", "comments"} + assert edges == [("posts", "users"), ("comments", "users"), ("comments", "posts")] + + def test_fetch_schema_metadata_with_different_row_formats(self) -> None: + """Test fetch_schema_metadata with different row formats.""" + # Create a mock cursor + mock_cursor: MagicMock = mock.MagicMock() + + # Create different types of row objects to test the robust row handling + class DictLikeRow: + def __init__(self, table_name=None, child=None, parent=None): + self.TABLE_NAME = table_name + self.child = child + self.parent = parent + + def __str__(self): + return f"DictLikeRow({self.TABLE_NAME or ''},{self.child or ''},{self.parent or ''})" + + # Mock the first query result with mixed row formats + table_rows: t.List[t.Any] = [ + ("table1",), # Tuple + [b"table2"], # List with bytes + DictLikeRow(table_name="table3"), # Object with attribute + None, # None should be handled + 42, # Non-standard type + ] + + # Mock the second query result with mixed row formats for FK relationships + fk_rows: t.List[t.Any] = [ + ("child1", "parent1"), # Tuple + ["child2", "parent2"], # List + DictLikeRow(child="child3", parent="parent3"), # Object with attributes + {"child": "child4", "parent": "parent4"}, # Dictionary + (None, "parent5"), # Tuple with None + ("child6", None), # Tuple with None + None, # None should be skipped + 42, # Non-standard type should be skipped + ] + + mock_cursor.fetchall.side_effect = [table_rows, fk_rows] + + # Call the function + tables: t.Set[str] + edges: t.List[t.Tuple[str, str]] + tables, edges = fetch_schema_metadata(mock_cursor) + + # Verify the cursor was called with the expected queries + assert mock_cursor.execute.call_count == 2 + + # Check that we have the expected number of tables and edges + assert len(tables) >= 4 # At least our valid inputs + assert len(edges) >= 4 # At least our valid edges + + # Check that our valid tables are included (using substring matching for flexibility) + assert any("table1" in tbl for tbl in tables) + assert any("table2" in tbl for tbl in tables) + assert any("table3" in tbl for tbl in tables) + + # Check that our valid edges are included + valid_edges: t.List[t.Tuple[str, str]] = [ + ("child1", "parent1"), + ("child2", "parent2"), + ("child3", "parent3"), + ("child4", "parent4"), + ] + + # For each valid edge, check that there's a corresponding edge in the result + for valid_child, valid_parent in valid_edges: + assert any(valid_child in child and valid_parent in parent for child, parent in edges) + + def test_compute_creation_order(self) -> None: + """Test compute_creation_order function.""" + # Create a mock MySQL connection + mock_conn: MagicMock = mock.MagicMock() + mock_cursor: MagicMock = mock.MagicMock() + mock_conn.cursor.return_value.__enter__.return_value = mock_cursor + + # Mock the fetch_schema_metadata function to return known values + tables: t.Set[str] = {"users", "posts", "comments"} + edges: t.List[t.Tuple[str, str]] = [ + ("posts", "users"), # posts depends on users + ("comments", "users"), # comments depends on users + ("comments", "posts"), # comments depends on posts + ] + + with mock.patch("mysql_to_sqlite3.mysql_utils.fetch_schema_metadata", return_value=(tables, edges)): + # Call the function + tables: t.Set[str] + edges: t.List[t.Tuple[str, str]] + ordered_tables, cyclic_edges = compute_creation_order(mock_conn) + + # Verify the connection's cursor was used + mock_conn.cursor.assert_called_once() + + # Check the results + assert len(ordered_tables) == 3 + assert len(cyclic_edges) == 0 + + # Check that dependencies are respected + users_idx: int = ordered_tables.index("users") + posts_idx: int = ordered_tables.index("posts") + comments_idx: int = ordered_tables.index("comments") + + assert users_idx < posts_idx # users comes before posts + assert users_idx < comments_idx # users comes before comments + assert posts_idx < comments_idx # posts comes before comments + + def test_compute_creation_order_with_cycles(self) -> None: + """Test compute_creation_order with circular dependencies.""" + # Create a mock MySQL connection + mock_conn: MagicMock = mock.MagicMock() + mock_cursor: MagicMock = mock.MagicMock() + mock_conn.cursor.return_value.__enter__.return_value = mock_cursor + + # Mock the fetch_schema_metadata function to return a graph with a cycle + tables: t.Set[str] = {"users", "posts", "comments"} + edges: t.List[t.Tuple[str, str]] = [ + ("posts", "users"), # posts depends on users + ("comments", "posts"), # comments depends on posts + ("users", "comments"), # users depends on comments (creates a cycle) + ] + + with mock.patch("mysql_to_sqlite3.mysql_utils.fetch_schema_metadata", return_value=(tables, edges)): + # Call the function + tables: t.Set[str] + edges: t.List[t.Tuple[str, str]] + ordered_tables, cyclic_edges = compute_creation_order(mock_conn) + + # Verify the connection's cursor was used + mock_conn.cursor.assert_called_once() + + # In a fully cyclic graph, no tables can be ordered without breaking cycles + # So the ordered list may be empty + + # Check that cyclic edges are detected + assert len(cyclic_edges) > 0 + + # The cyclic edges should be from the edges we defined + for edge in cyclic_edges: + assert edge in edges + + # Verify that all tables in the cycle are accounted for in cyclic_edges + cycle_tables: t.Set[str] = set() + for child, parent in cyclic_edges: + cycle_tables.add(child) + cycle_tables.add(parent) + + # All tables should be part of the cycle or in the ordered list + assert cycle_tables.union(set(ordered_tables)) == tables diff --git a/tests/unit/test_transporter.py b/tests/unit/test_transporter.py index 3ff5b91..e203772 100644 --- a/tests/unit/test_transporter.py +++ b/tests/unit/test_transporter.py @@ -1,4 +1,5 @@ import sqlite3 +import typing as t from unittest.mock import MagicMock, patch import pytest @@ -150,7 +151,10 @@ def test_transfer_exception_handling(self, mock_sqlite_connect: MagicMock, mock_ assert "Test exception" in str(excinfo.value) # Verify that foreign keys are re-enabled in the finally block - mock_sqlite_cursor.execute.assert_called_with("PRAGMA foreign_keys=ON") + mock_sqlite_cursor.execute.assert_any_call("PRAGMA foreign_keys=ON") + + # Verify that foreign key check is performed + mock_sqlite_cursor.execute.assert_called_with("PRAGMA foreign_key_check") def test_constructor_missing_mysql_database(self) -> None: """Test constructor raises ValueError if mysql_database is missing.""" @@ -225,3 +229,199 @@ def test_translate_default_from_mysql_to_sqlite_bytes(self) -> None: """Test _translate_default_from_mysql_to_sqlite with bytes default.""" result = MySQLtoSQLite._translate_default_from_mysql_to_sqlite(b"abc", column_type="BLOB") assert result.startswith("DEFAULT x'") + + @patch("mysql.connector.connect") + @patch("sqlite3.connect") + @patch("mysql_to_sqlite3.transporter.compute_creation_order") + def test_transfer_table_ordering( + self, mock_compute_creation_order: MagicMock, mock_sqlite_connect: MagicMock, mock_mysql_connect: MagicMock + ) -> None: + """Test that tables are transferred in the correct order respecting foreign key constraints.""" + # Setup mock SQLite cursor + mock_sqlite_cursor = MagicMock() + + # Setup mock SQLite connection + mock_sqlite_connection = MagicMock() + mock_sqlite_connection.cursor.return_value = mock_sqlite_cursor + mock_sqlite_connect.return_value = mock_sqlite_connection + + # Setup mock MySQL cursor + mock_mysql_cursor = MagicMock() + mock_mysql_cursor.fetchall.return_value = [(b"table1",), (b"table2",), (b"table3",)] + + # Setup mock MySQL connection + mock_mysql_connection = MagicMock() + mock_mysql_connection.cursor.return_value = mock_mysql_cursor + mock_mysql_connect.return_value = mock_mysql_connection + + # Mock compute_creation_order to return a specific order + ordered_tables: t.List[str] = ["table2", "table1", "table3"] # Specific order for testing + cyclic_edges: t.List[t.Tuple[str, str]] = [] # No cycles for this test + mock_compute_creation_order.return_value = (ordered_tables, cyclic_edges) + + # Create a minimal instance with just what we need for the test + with patch.object(MySQLtoSQLite, "__init__", return_value=None): + instance = MySQLtoSQLite() + instance._mysql = mock_mysql_connection + instance._mysql_tables = [] + instance._exclude_mysql_tables = [] + instance._mysql_cur = mock_mysql_cursor + instance._mysql_cur_dict = MagicMock() + instance._mysql_cur_prepared = MagicMock() + instance._sqlite_cur = mock_sqlite_cursor + instance._without_data = True # Skip data transfer for simplicity + instance._without_tables = False + instance._without_foreign_keys = False + instance._vacuum = False + instance._logger = MagicMock() + instance._create_table = MagicMock() # Mock table creation + + # Call the transfer method + instance.transfer() + + # Verify compute_creation_order was called + mock_compute_creation_order.assert_called_once_with(mock_mysql_connection) + + # Verify tables were created in the correct order + creation_calls: t.List[t.Any] = [call[0][0] for call in instance._create_table.call_args_list] + assert creation_calls == ordered_tables + + # Verify foreign keys were disabled at start and enabled at end + mock_sqlite_cursor.execute.assert_any_call("PRAGMA foreign_keys=OFF") + mock_sqlite_cursor.execute.assert_any_call("PRAGMA foreign_keys=ON") + + # Verify foreign key check was performed + mock_sqlite_cursor.execute.assert_any_call("PRAGMA foreign_key_check") + + @patch("mysql.connector.connect") + @patch("sqlite3.connect") + @patch("mysql_to_sqlite3.transporter.compute_creation_order") + def test_transfer_with_circular_dependencies( + self, mock_compute_creation_order: MagicMock, mock_sqlite_connect: MagicMock, mock_mysql_connect: MagicMock + ) -> None: + """Test transfer with circular foreign key dependencies.""" + # Setup mock SQLite cursor + mock_sqlite_cursor = MagicMock() + + # Setup mock SQLite connection + mock_sqlite_connection = MagicMock() + mock_sqlite_connection.cursor.return_value = mock_sqlite_cursor + mock_sqlite_connect.return_value = mock_sqlite_connection + + # Setup mock MySQL cursor + mock_mysql_cursor = MagicMock() + mock_mysql_cursor.fetchall.return_value = [(b"table1",), (b"table2",), (b"table3",)] + + # Setup mock MySQL connection + mock_mysql_connection = MagicMock() + mock_mysql_connection.cursor.return_value = mock_mysql_cursor + mock_mysql_connect.return_value = mock_mysql_connection + + # Mock compute_creation_order to return circular dependencies + ordered_tables: t.List[str] = ["table2", "table1", "table3"] + cyclic_edges: t.List[t.Tuple[str, str]] = [("table1", "table3"), ("table3", "table1")] # Circular dependency + mock_compute_creation_order.return_value = (ordered_tables, cyclic_edges) + + # Create a minimal instance with just what we need for the test + with patch.object(MySQLtoSQLite, "__init__", return_value=None): + instance = MySQLtoSQLite() + instance._mysql = mock_mysql_connection + instance._mysql_tables = [] + instance._exclude_mysql_tables = [] + instance._mysql_cur = mock_mysql_cursor + instance._mysql_cur_dict = MagicMock() + instance._mysql_cur_prepared = MagicMock() + instance._sqlite_cur = mock_sqlite_cursor + instance._without_data = True # Skip data transfer for simplicity + instance._without_tables = False + instance._without_foreign_keys = False + instance._vacuum = False + instance._logger = MagicMock() + instance._create_table = MagicMock() # Mock table creation + + # Call the transfer method + instance.transfer() + + # Verify compute_creation_order was called + mock_compute_creation_order.assert_called_once_with(mock_mysql_connection) + + # Verify warning was logged about circular dependencies + instance._logger.warning.assert_any_call( + "Circular foreign key dependencies detected: %s", "table1 -> table3, table3 -> table1" + ) + + # Verify tables were still created in the computed order + creation_calls: t.List[t.Any] = [call[0][0] for call in instance._create_table.call_args_list] + assert creation_calls == ordered_tables + + # Verify foreign keys were disabled at start and enabled at end + mock_sqlite_cursor.execute.assert_any_call("PRAGMA foreign_keys=OFF") + mock_sqlite_cursor.execute.assert_any_call("PRAGMA foreign_keys=ON") + + # Verify foreign key check was performed + mock_sqlite_cursor.execute.assert_any_call("PRAGMA foreign_key_check") + + @patch("mysql.connector.connect") + @patch("sqlite3.connect") + @patch("mysql_to_sqlite3.transporter.compute_creation_order") + def test_transfer_fallback_on_error( + self, mock_compute_creation_order: MagicMock, mock_sqlite_connect: MagicMock, mock_mysql_connect: MagicMock + ) -> None: + """Test transfer falls back to original table list if compute_creation_order fails.""" + # Setup mock SQLite cursor + mock_sqlite_cursor = MagicMock() + + # Setup mock SQLite connection + mock_sqlite_connection = MagicMock() + mock_sqlite_connection.cursor.return_value = mock_sqlite_cursor + mock_sqlite_connect.return_value = mock_sqlite_connection + + # Setup mock MySQL cursor + mock_mysql_cursor = MagicMock() + mock_mysql_cursor.fetchall.return_value = [(b"table1",), (b"table2",), (b"table3",)] + + # Setup mock MySQL connection + mock_mysql_connection = MagicMock() + mock_mysql_connection.cursor.return_value = mock_mysql_cursor + mock_mysql_connect.return_value = mock_mysql_connection + + # Mock compute_creation_order to raise an exception + mock_compute_creation_order.side_effect = Exception("Test error in compute_creation_order") + + # Create a minimal instance with just what we need for the test + with patch.object(MySQLtoSQLite, "__init__", return_value=None): + instance = MySQLtoSQLite() + instance._mysql = mock_mysql_connection + instance._mysql_tables = [] + instance._exclude_mysql_tables = [] + instance._mysql_cur = mock_mysql_cursor + instance._mysql_cur_dict = MagicMock() + instance._mysql_cur_prepared = MagicMock() + instance._sqlite_cur = mock_sqlite_cursor + instance._without_data = True # Skip data transfer for simplicity + instance._without_tables = False + instance._without_foreign_keys = False + instance._vacuum = False + instance._logger = MagicMock() + instance._create_table = MagicMock() # Mock table creation + + # Call the transfer method + instance.transfer() + + # Verify compute_creation_order was called + mock_compute_creation_order.assert_called_once_with(mock_mysql_connection) + + # Verify warning was logged about the error + instance._logger.warning.assert_any_call( + "Failed to compute table creation order: %s", "Test error in compute_creation_order" + ) + + # Verify tables were still created (using the fallback order) + assert instance._create_table.call_count == 3 + + # Verify foreign keys were disabled at start and enabled at end + mock_sqlite_cursor.execute.assert_any_call("PRAGMA foreign_keys=OFF") + mock_sqlite_cursor.execute.assert_any_call("PRAGMA foreign_keys=ON") + + # Verify foreign key check was performed + mock_sqlite_cursor.execute.assert_any_call("PRAGMA foreign_key_check")