-
Notifications
You must be signed in to change notification settings - Fork 18
Add SQLAlchemy ORM plugin tests #1235
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: sqlalchemy-orm-mysql
Are you sure you want to change the base?
Changes from all commits
72457dc
6cd61c7
bc607f6
5daa970
5ee014a
4871ac7
068e2e9
f81b1db
a1ebf4c
f081a63
1b8d401
4429afd
4f6500c
9bd824f
ee53f8c
9f484bb
0ab9c66
a64d748
fa1f875
c981d2d
5c21749
d5b2b82
79f990f
f955a9c
c6958bc
1480c48
0f9859d
c75e9f7
771b4b0
7941fe0
159b6c3
7e2bb1c
edd4d02
9f80c4e
f0d7d91
cfebe5b
b45cd27
88dfb8e
c999602
9d2f642
560e67c
f186bb3
3e6e164
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,12 +12,31 @@ | |
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| # aws_advanced_python_wrapper/sqlalchemy/sqlalchemy_mysqlconnector_dialect.py | ||
| from __future__ import annotations | ||
|
|
||
| from typing import TYPE_CHECKING, Optional | ||
|
|
||
| import mysql.connector | ||
| from mysql.connector import CMySQLConnection | ||
| from mysql.connector.errors import Error | ||
| from sqlalchemy.dialects.mysql.mysqlconnector import \ | ||
| MySQLDialect_mysqlconnector | ||
| from sqlalchemy.engine import default | ||
|
|
||
| from aws_advanced_python_wrapper import AwsWrapperConnection | ||
| from aws_advanced_python_wrapper.errors import AwsWrapperError | ||
| from aws_advanced_python_wrapper.utils.properties import (Properties, | ||
| PropertiesUtils) | ||
|
|
||
| if TYPE_CHECKING: | ||
| from sqlalchemy import Connection | ||
|
|
||
| from aws_advanced_python_wrapper.hostinfo import HostInfo | ||
|
|
||
|
|
||
| class SqlAlchemyOrmMysqlDialect(MySQLDialect_mysqlconnector): | ||
| supports_statement_cache = True | ||
|
|
||
| """ | ||
| SQLAlchemy dialect for AWS Advanced Python Wrapper with mysqlconnector. Extends the SQLAlchemy MySQL mysqlconnector dialect. | ||
| This dialect is not related to the DriverDialect or DatabaseDialect classes used by our driver. Instead, it is used | ||
|
|
@@ -27,3 +46,154 @@ class SqlAlchemyOrmMysqlDialect(MySQLDialect_mysqlconnector): | |
|
|
||
| name = 'mysql' | ||
| driver = 'aws_wrapper_mysqlconnector' | ||
|
|
||
| @classmethod | ||
| def import_dbapi(cls): | ||
| """ | ||
| Return the DB-API 2.0 module. | ||
| SQLAlchemy calls this to get the driver module. | ||
| """ | ||
| import aws_advanced_python_wrapper | ||
| return aws_advanced_python_wrapper | ||
|
|
||
| def create_connect_args(self, url): | ||
| """ | ||
| Transform SQLAlchemy URL into connection arguments. | ||
| Must include the 'target' parameter for our wrapper driver. | ||
| """ | ||
| # Extract standard connection parameters | ||
| opts = url.translate_connect_args(username='user') | ||
|
|
||
| # Add query string parameters | ||
| opts.update(url.query) | ||
|
|
||
| # Add the required 'target' parameter for our wrapper | ||
| if 'target' not in opts: | ||
| opts['target'] = mysql.connector.Connect | ||
| if 'wrapper_plugins' not in opts: | ||
| opts['plugins'] = "aurora_connection_tracker,failover" | ||
| else: | ||
| opts['plugins'] = opts['wrapper_plugins'] | ||
| opts.pop('wrapper_plugins', None) | ||
| if 'connect_timeout' in opts: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we also need to check for other timeouts here
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The other timeouts set in the plugins tests (e.g. socket_timeout, monitoring-connect_timeout, monitoring-socket_timeout) don't seem to cause any problems. |
||
| opts['connect_timeout'] = int(opts['connect_timeout']) | ||
|
|
||
| # Return empty args list and kwargs dict | ||
| return [], opts | ||
|
|
||
| def _detect_charset(self, connection: Connection) -> str: | ||
| if isinstance(connection, CMySQLConnection): | ||
| return connection.charset | ||
| else: | ||
| raise Exception("Could not detect charset because connection was not a CMySQLConnection.") | ||
|
|
||
| def _extract_error_code(self, exception: BaseException) -> int: | ||
| if isinstance(exception, AwsWrapperError): | ||
| err = exception.driver_error | ||
| if err and isinstance(err, Error): | ||
| return err.errno | ||
| else: | ||
| raise Exception("Could not extract error code because driver_error was not a BaseException.") | ||
| else: | ||
| raise Exception("Could not extract error code because exception was not an AwsWrapperError.") | ||
|
|
||
| def initialize(self, connection): | ||
| """ | ||
| Override initialization to handle type introspection. | ||
| The parent class tries to use TypeInfo.fetch() which requires | ||
| a native SQLAlchemy connection, not AwsWrapperConnection. | ||
| """ | ||
|
|
||
| # Unwrap SQLAlchemy's connection object | ||
| wrapper_conn, wrapper_parent = self._get_wrapper_connection_and_parent(connection) | ||
|
|
||
| # this is driver-based, does not need server version info | ||
| # and is fairly critical for even basic SQL operations | ||
| self._connection_charset: Optional[str] = self._detect_charset( | ||
| wrapper_conn.target_connection | ||
| ) | ||
|
|
||
| # call super().initialize() because we need to have | ||
| # server_version_info set up. in 1.4 under python 2 only this does the | ||
| # "check unicode returns" thing, which is the one area that some | ||
| # SQL gets compiled within initialize() currently | ||
| default.DefaultDialect.initialize(self, connection) | ||
|
|
||
| self._detect_sql_mode(connection) | ||
| self._detect_ansiquotes(connection) # depends on sql mode | ||
| self._detect_casing(connection) | ||
| if self._server_ansiquotes: | ||
| # if ansiquotes == True, build a new IdentifierPreparer | ||
| # with the new setting | ||
| self.identifier_preparer = self.preparer( | ||
| self, server_ansiquotes=self._server_ansiquotes | ||
| ) | ||
|
|
||
| self.supports_sequences = ( | ||
| self.is_mariadb and self.server_version_info >= (10, 3) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Curious why are we checking mariadb here?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I didn't want to change too much of the existing logic this was based on from sqlalchemy. |
||
| ) | ||
|
|
||
| self.supports_for_update_of = ( | ||
| self._is_mysql and self.server_version_info >= (8,) | ||
| ) | ||
|
|
||
| self.use_mysql_for_share = ( | ||
| self._is_mysql and self.server_version_info >= (8, 0, 1) | ||
| ) | ||
|
|
||
| self._needs_correct_for_88718_96365 = ( | ||
| not self.is_mariadb and self.server_version_info >= (8,) | ||
| ) | ||
|
|
||
| self.delete_returning = ( | ||
| self.is_mariadb and self.server_version_info >= (10, 0, 5) | ||
| ) | ||
|
|
||
| self.insert_returning = ( | ||
| self.is_mariadb and self.server_version_info >= (10, 5) | ||
| ) | ||
|
|
||
| self._requires_alias_for_on_duplicate_key = ( | ||
| self._is_mysql and self.server_version_info >= (8, 0, 20) | ||
| ) | ||
|
|
||
| self._warn_for_known_db_issues() | ||
|
|
||
| def _get_wrapper_connection_and_parent(self, connection): | ||
| """ | ||
| Traverse the connection chain to find AwsWrapperConnection and its parent connection. | ||
|
|
||
| Args: | ||
| connection: SQLAlchemy Connection object | ||
|
|
||
| Returns: | ||
| AwsWrapperConnection instance or None, parent connection of AwsWrapperConnection or None | ||
| """ | ||
| # Start with the DBAPI connection | ||
| parent = connection | ||
| child = connection.connection | ||
|
|
||
| # Traverse up to 5 levels deep (reasonable limit) | ||
| for _ in range(5): | ||
| if isinstance(child, AwsWrapperConnection): | ||
| return child, parent | ||
|
|
||
| # Try to go deeper if there's a .connection attribute | ||
| if hasattr(child, 'connection'): | ||
| parent = child | ||
| child = child.connection | ||
| else: | ||
| break | ||
|
|
||
| return None | ||
|
|
||
| def prepare_connect_info(self, host_info: HostInfo, props: Properties) -> Properties: | ||
| prop_copy: Properties = Properties(props.copy()) | ||
|
|
||
| prop_copy["host"] = host_info.host | ||
|
|
||
| if host_info.is_port_specified(): | ||
| prop_copy["port"] = str(host_info.port) | ||
|
|
||
| PropertiesUtils.remove_wrapper_props(prop_copy) | ||
| return prop_copy | ||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we switch the default plugin list to using failover2 instead of failover1