diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index fde1ab3e1d..61261aadf8 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -38,7 +38,7 @@ jobs: if: "!contains(github.event.pull_request.labels.*.name, 'disable-integration-tests')" runs-on: ubuntu-24.04 env: - SCYLLA_VERSION: release:2025.2 + SCYLLA_VERSION: release:2026.1 strategy: fail-fast: false matrix: diff --git a/benchmarks/base.py b/benchmarks/base.py index d9cd004474..ab4663c8e4 100644 --- a/benchmarks/base.py +++ b/benchmarks/base.py @@ -25,7 +25,7 @@ dirname = os.path.dirname(os.path.abspath(__file__)) sys.path.append(dirname) -sys.path.append(os.path.join(dirname, '..')) +sys.path.append(os.path.join(dirname, "..")) import cassandra from cassandra.cluster import Cluster @@ -33,25 +33,28 @@ log = logging.getLogger() handler = logging.StreamHandler() -handler.setFormatter(logging.Formatter("%(asctime)s [%(levelname)s] %(name)s: %(message)s")) +handler.setFormatter( + logging.Formatter("%(asctime)s [%(levelname)s] %(name)s: %(message)s") +) log.addHandler(handler) -logging.getLogger('cassandra').setLevel(logging.WARN) +logging.getLogger("cassandra").setLevel(logging.WARN) _log_levels = { - 'CRITICAL': logging.CRITICAL, - 'ERROR': logging.ERROR, - 'WARN': logging.WARNING, - 'WARNING': logging.WARNING, - 'INFO': logging.INFO, - 'DEBUG': logging.DEBUG, - 'NOTSET': logging.NOTSET, + "CRITICAL": logging.CRITICAL, + "ERROR": logging.ERROR, + "WARN": logging.WARNING, + "WARNING": logging.WARNING, + "INFO": logging.INFO, + "DEBUG": logging.DEBUG, + "NOTSET": logging.NOTSET, } have_libev = False supported_reactors = [AsyncoreConnection] try: from cassandra.io.libevreactor import LibevConnection + have_libev = True supported_reactors.append(LibevConnection) except cassandra.DependencyException as exc: @@ -60,6 +63,7 @@ have_asyncio = False try: from cassandra.io.asyncioreactor import AsyncioConnection + have_asyncio = True supported_reactors.append(AsyncioConnection) except (ImportError, SyntaxError): @@ -68,6 +72,7 @@ have_twisted = False try: from cassandra.io.twistedreactor import TwistedConnection + have_twisted = True supported_reactors.append(TwistedConnection) except ImportError as exc: @@ -78,27 +83,32 @@ TABLE = "testtable" COLUMN_VALUES = { - 'int': 42, - 'text': "'42'", - 'float': 42.0, - 'uuid': uuid.uuid4(), - 'timestamp': "'2016-02-03 04:05+0000'" + "int": 42, + "text": "'42'", + "float": 42.0, + "uuid": uuid.uuid4(), + "timestamp": "'2016-02-03 04:05+0000'", } def setup(options): log.info("Using 'cassandra' package from %s", cassandra.__path__) - cluster = Cluster(options.hosts, schema_metadata_enabled=False, token_metadata_enabled=False) + cluster = Cluster( + options.hosts, schema_metadata_enabled=False, token_metadata_enabled=False + ) try: session = cluster.connect() log.debug("Creating keyspace...") try: - session.execute(""" + session.execute( + """ CREATE KEYSPACE %s - WITH replication = { 'class': 'SimpleStrategy', 'replication_factor': '2' } - """ % options.keyspace) + WITH replication = { 'class': 'NetworkTopologyStrategy', 'replication_factor': '2' } + """ + % options.keyspace + ) log.debug("Setting keyspace...") except cassandra.AlreadyExists: @@ -125,7 +135,9 @@ def setup(options): def teardown(options): - cluster = Cluster(options.hosts, schema_metadata_enabled=False, token_metadata_enabled=False) + cluster = Cluster( + options.hosts, schema_metadata_enabled=False, token_metadata_enabled=False + ) session = cluster.connect() if not options.keep_data: session.execute("DROP KEYSPACE " + options.keyspace) @@ -138,17 +150,18 @@ def benchmark(thread_class): setup(options) log.info("==== %s ====" % (conn_class.__name__,)) - kwargs = {'metrics_enabled': options.enable_metrics, - 'connection_class': conn_class} + kwargs = { + "metrics_enabled": options.enable_metrics, + "connection_class": conn_class, + } if options.protocol_version: - kwargs['protocol_version'] = options.protocol_version + kwargs["protocol_version"] = options.protocol_version cluster = Cluster(options.hosts, **kwargs) session = cluster.connect(options.keyspace) log.debug("Sleeping for two seconds...") time.sleep(2.0) - # Generate the query if options.read: query = "SELECT * FROM {0} WHERE thekey = '{{key}}'".format(TABLE) @@ -166,13 +179,19 @@ def benchmark(thread_class): per_thread = options.num_ops // options.threads threads = [] - log.debug("Beginning {0}...".format('reads' if options.read else 'inserts')) + log.debug("Beginning {0}...".format("reads" if options.read else "inserts")) start = time.time() try: for i in range(options.threads): thread = thread_class( - i, session, query, values, per_thread, - cluster.protocol_version, options.profile) + i, + session, + query, + values, + per_thread, + cluster.protocol_version, + options.profile, + ) thread.daemon = True threads.append(thread) @@ -192,73 +211,144 @@ def benchmark(thread_class): log.info("Total time: %0.2fs" % total) log.info("Average throughput: %0.2f/sec" % (options.num_ops / total)) if options.enable_metrics: - stats = getStats()['cassandra'] - log.info("Connection errors: %d", stats['connection_errors']) - log.info("Write timeouts: %d", stats['write_timeouts']) - log.info("Read timeouts: %d", stats['read_timeouts']) - log.info("Unavailables: %d", stats['unavailables']) - log.info("Other errors: %d", stats['other_errors']) - log.info("Retries: %d", stats['retries']) - - request_timer = stats['request_timer'] + stats = getStats()["cassandra"] + log.info("Connection errors: %d", stats["connection_errors"]) + log.info("Write timeouts: %d", stats["write_timeouts"]) + log.info("Read timeouts: %d", stats["read_timeouts"]) + log.info("Unavailables: %d", stats["unavailables"]) + log.info("Other errors: %d", stats["other_errors"]) + log.info("Retries: %d", stats["retries"]) + + request_timer = stats["request_timer"] log.info("Request latencies:") - log.info(" min: %0.4fs", request_timer['min']) - log.info(" max: %0.4fs", request_timer['max']) - log.info(" mean: %0.4fs", request_timer['mean']) - log.info(" stddev: %0.4fs", request_timer['stddev']) - log.info(" median: %0.4fs", request_timer['median']) - log.info(" 75th: %0.4fs", request_timer['75percentile']) - log.info(" 95th: %0.4fs", request_timer['95percentile']) - log.info(" 98th: %0.4fs", request_timer['98percentile']) - log.info(" 99th: %0.4fs", request_timer['99percentile']) - log.info(" 99.9th: %0.4fs", request_timer['999percentile']) + log.info(" min: %0.4fs", request_timer["min"]) + log.info(" max: %0.4fs", request_timer["max"]) + log.info(" mean: %0.4fs", request_timer["mean"]) + log.info(" stddev: %0.4fs", request_timer["stddev"]) + log.info(" median: %0.4fs", request_timer["median"]) + log.info(" 75th: %0.4fs", request_timer["75percentile"]) + log.info(" 95th: %0.4fs", request_timer["95percentile"]) + log.info(" 98th: %0.4fs", request_timer["98percentile"]) + log.info(" 99th: %0.4fs", request_timer["99percentile"]) + log.info(" 99.9th: %0.4fs", request_timer["999percentile"]) def parse_options(): parser = OptionParser() - parser.add_option('-H', '--hosts', default='127.0.0.1', - help='cassandra hosts to connect to (comma-separated list) [default: %default]') - parser.add_option('-t', '--threads', type='int', default=1, - help='number of threads [default: %default]') - parser.add_option('-n', '--num-ops', type='int', default=10000, - help='number of operations [default: %default]') - parser.add_option('--asyncore-only', action='store_true', dest='asyncore_only', - help='only benchmark with asyncore connections') - parser.add_option('--asyncio-only', action='store_true', dest='asyncio_only', - help='only benchmark with asyncio connections') - parser.add_option('--libev-only', action='store_true', dest='libev_only', - help='only benchmark with libev connections') - parser.add_option('--twisted-only', action='store_true', dest='twisted_only', - help='only benchmark with Twisted connections') - parser.add_option('-m', '--metrics', action='store_true', dest='enable_metrics', - help='enable and print metrics for operations') - parser.add_option('-l', '--log-level', default='info', - help='logging level: debug, info, warning, or error') - parser.add_option('-p', '--profile', action='store_true', dest='profile', - help='Profile the run') - parser.add_option('--protocol-version', type='int', dest='protocol_version', default=4, - help='Native protocol version to use') - parser.add_option('-c', '--num-columns', type='int', dest='num_columns', default=2, - help='Specify the number of columns for the schema') - parser.add_option('-k', '--keyspace', type='str', dest='keyspace', default=KEYSPACE, - help='Specify the keyspace name for the schema') - parser.add_option('--keep-data', action='store_true', dest='keep_data', default=False, - help='Keep the data after the benchmark') - parser.add_option('--column-type', type='str', dest='column_type', default='text', - help='Specify the column type for the schema (supported: int, text, float, uuid, timestamp)') - parser.add_option('--read', action='store_true', dest='read', default=False, - help='Read mode') - + parser.add_option( + "-H", + "--hosts", + default="127.0.0.1", + help="cassandra hosts to connect to (comma-separated list) [default: %default]", + ) + parser.add_option( + "-t", + "--threads", + type="int", + default=1, + help="number of threads [default: %default]", + ) + parser.add_option( + "-n", + "--num-ops", + type="int", + default=10000, + help="number of operations [default: %default]", + ) + parser.add_option( + "--asyncore-only", + action="store_true", + dest="asyncore_only", + help="only benchmark with asyncore connections", + ) + parser.add_option( + "--asyncio-only", + action="store_true", + dest="asyncio_only", + help="only benchmark with asyncio connections", + ) + parser.add_option( + "--libev-only", + action="store_true", + dest="libev_only", + help="only benchmark with libev connections", + ) + parser.add_option( + "--twisted-only", + action="store_true", + dest="twisted_only", + help="only benchmark with Twisted connections", + ) + parser.add_option( + "-m", + "--metrics", + action="store_true", + dest="enable_metrics", + help="enable and print metrics for operations", + ) + parser.add_option( + "-l", + "--log-level", + default="info", + help="logging level: debug, info, warning, or error", + ) + parser.add_option( + "-p", "--profile", action="store_true", dest="profile", help="Profile the run" + ) + parser.add_option( + "--protocol-version", + type="int", + dest="protocol_version", + default=4, + help="Native protocol version to use", + ) + parser.add_option( + "-c", + "--num-columns", + type="int", + dest="num_columns", + default=2, + help="Specify the number of columns for the schema", + ) + parser.add_option( + "-k", + "--keyspace", + type="str", + dest="keyspace", + default=KEYSPACE, + help="Specify the keyspace name for the schema", + ) + parser.add_option( + "--keep-data", + action="store_true", + dest="keep_data", + default=False, + help="Keep the data after the benchmark", + ) + parser.add_option( + "--column-type", + type="str", + dest="column_type", + default="text", + help="Specify the column type for the schema (supported: int, text, float, uuid, timestamp)", + ) + parser.add_option( + "--read", action="store_true", dest="read", default=False, help="Read mode" + ) options, args = parser.parse_args() - options.hosts = options.hosts.split(',') + options.hosts = options.hosts.split(",") level = options.log_level.upper() try: log.setLevel(_log_levels[level]) except KeyError: - log.warning("Unknown log level specified: %s; specify one of %s", options.log_level, _log_levels.keys()) + log.warning( + "Unknown log level specified: %s; specify one of %s", + options.log_level, + _log_levels.keys(), + ) if options.asyncore_only: options.supported_reactors = [AsyncoreConnection] @@ -283,8 +373,9 @@ def parse_options(): class BenchmarkThread(Thread): - - def __init__(self, thread_num, session, query, values, num_queries, protocol_version, profile): + def __init__( + self, thread_num, session, query, values, num_queries, protocol_version, profile + ): Thread.__init__(self) self.thread_num = thread_num self.session = session @@ -304,4 +395,4 @@ def run_query(self, key, **kwargs): def finish_profile(self): if self.profiler: self.profiler.disable() - self.profiler.dump_stats('profile-%d' % self.thread_num) + self.profiler.dump_stats("profile-%d" % self.thread_num) diff --git a/cassandra/cqlengine/management.py b/cassandra/cqlengine/management.py index d6dc44119a..94b85d437e 100644 --- a/cassandra/cqlengine/management.py +++ b/cassandra/cqlengine/management.py @@ -27,14 +27,14 @@ from cassandra.cqlengine.named import NamedTable from cassandra.cqlengine.usertype import UserType -CQLENG_ALLOW_SCHEMA_MANAGEMENT = 'CQLENG_ALLOW_SCHEMA_MANAGEMENT' +CQLENG_ALLOW_SCHEMA_MANAGEMENT = "CQLENG_ALLOW_SCHEMA_MANAGEMENT" -Field = namedtuple('Field', ['name', 'type']) +Field = namedtuple("Field", ["name", "type"]) log = logging.getLogger(__name__) # system keyspaces -schema_columnfamilies = NamedTable('system', 'schema_columnfamilies') +schema_columnfamilies = NamedTable("system", "schema_columnfamilies") def _get_context(keyspaces, connections): @@ -42,11 +42,11 @@ def _get_context(keyspaces, connections): if keyspaces: if not isinstance(keyspaces, (list, tuple)): - raise ValueError('keyspaces must be a list or a tuple.') + raise ValueError("keyspaces must be a list or a tuple.") if connections: if not isinstance(connections, (list, tuple)): - raise ValueError('connections must be a list or a tuple.') + raise ValueError("connections must be a list or a tuple.") keyspaces = keyspaces if keyspaces else [None] connections = connections if connections else [None] @@ -54,9 +54,11 @@ def _get_context(keyspaces, connections): return product(connections, keyspaces) -def create_keyspace_simple(name, replication_factor, durable_writes=True, connections=None): +def create_keyspace_simple( + name, replication_factor, durable_writes=True, connections=None +): """ - Creates a keyspace with SimpleStrategy for replica placement + Creates a keyspace with NetworkTopologyStrategy for replica placement If the keyspace already exists, it will not be modified. @@ -66,15 +68,22 @@ def create_keyspace_simple(name, replication_factor, durable_writes=True, connec *There are plans to guard schema-modifying functions with an environment-driven conditional.* :param str name: name of keyspace to create - :param int replication_factor: keyspace replication factor, used with :attr:`~.SimpleStrategy` + :param int replication_factor: keyspace replication factor, used with :attr:`~.NetworkTopologyStrategy` :param bool durable_writes: Write log is bypassed if set to False :param list connections: List of connection names """ - _create_keyspace(name, durable_writes, 'SimpleStrategy', - {'replication_factor': replication_factor}, connections=connections) - - -def create_keyspace_network_topology(name, dc_replication_map, durable_writes=True, connections=None): + _create_keyspace( + name, + durable_writes, + "NetworkTopologyStrategy", + {"replication_factor": replication_factor}, + connections=connections, + ) + + +def create_keyspace_network_topology( + name, dc_replication_map, durable_writes=True, connections=None +): """ Creates a keyspace with NetworkTopologyStrategy for replica placement @@ -90,30 +99,56 @@ def create_keyspace_network_topology(name, dc_replication_map, durable_writes=Tr :param bool durable_writes: Write log is bypassed if set to False :param list connections: List of connection names """ - _create_keyspace(name, durable_writes, 'NetworkTopologyStrategy', dc_replication_map, connections=connections) - - -def _create_keyspace(name, durable_writes, strategy_class, strategy_options, connections=None): + _create_keyspace( + name, + durable_writes, + "NetworkTopologyStrategy", + dc_replication_map, + connections=connections, + ) + + +def _create_keyspace( + name, durable_writes, strategy_class, strategy_options, connections=None +): if not _allow_schema_modification(): return if connections: if not isinstance(connections, (list, tuple)): - raise ValueError('Connections must be a list or a tuple.') + raise ValueError("Connections must be a list or a tuple.") - def __create_keyspace(name, durable_writes, strategy_class, strategy_options, connection=None): + def __create_keyspace( + name, durable_writes, strategy_class, strategy_options, connection=None + ): cluster = get_cluster(connection) if name not in cluster.metadata.keyspaces: - log.info(format_log_context("Creating keyspace %s", connection=connection), name) - ks_meta = metadata.KeyspaceMetadata(name, durable_writes, strategy_class, strategy_options) + log.info( + format_log_context("Creating keyspace %s", connection=connection), name + ) + ks_meta = metadata.KeyspaceMetadata( + name, durable_writes, strategy_class, strategy_options + ) execute(ks_meta.as_cql_query(), connection=connection) else: - log.info(format_log_context("Not creating keyspace %s because it already exists", connection=connection), name) + log.info( + format_log_context( + "Not creating keyspace %s because it already exists", + connection=connection, + ), + name, + ) if connections: for connection in connections: - __create_keyspace(name, durable_writes, strategy_class, strategy_options, connection=connection) + __create_keyspace( + name, + durable_writes, + strategy_class, + strategy_options, + connection=connection, + ) else: __create_keyspace(name, durable_writes, strategy_class, strategy_options) @@ -135,12 +170,15 @@ def drop_keyspace(name, connections=None): if connections: if not isinstance(connections, (list, tuple)): - raise ValueError('Connections must be a list or a tuple.') + raise ValueError("Connections must be a list or a tuple.") def _drop_keyspace(name, connection=None): cluster = get_cluster(connection) if name in cluster.metadata.keyspaces: - execute("DROP KEYSPACE {0}".format(metadata.protect_name(name)), connection=connection) + execute( + "DROP KEYSPACE {0}".format(metadata.protect_name(name)), + connection=connection, + ) if connections: for connection in connections: @@ -148,6 +186,7 @@ def _drop_keyspace(name, connection=None): else: _drop_keyspace(name) + def _get_index_name_by_column(table, column_name): """ Find the index name for a given table and column. @@ -156,7 +195,7 @@ def _get_index_name_by_column(table, column_name): possible_index_values = [protected_name, "values(%s)" % protected_name] for index_metadata in table.indexes.values(): options = dict(index_metadata.index_options) - if options.get('target') in possible_index_values: + if options.get("target") in possible_index_values: return index_metadata.name @@ -210,7 +249,9 @@ def _sync_table(model, connection=None): try: keyspace = cluster.metadata.keyspaces[ks_name] except KeyError: - msg = format_log_context("Keyspace '{0}' for model {1} does not exist.", connection=connection) + msg = format_log_context( + "Keyspace '{0}' for model {1} does not exist.", connection=connection + ) raise CQLEngineException(msg.format(ks_name, model)) tables = keyspace.tables @@ -223,7 +264,14 @@ def _sync_table(model, connection=None): _sync_type(ks_name, udt, syncd_types, connection=connection) if raw_cf_name not in tables: - log.debug(format_log_context("sync_table creating new table %s", keyspace=ks_name, connection=connection), cf_name) + log.debug( + format_log_context( + "sync_table creating new table %s", + keyspace=ks_name, + connection=connection, + ), + cf_name, + ) qs = _get_create_table(model) try: @@ -234,7 +282,14 @@ def _sync_table(model, connection=None): if "Cannot add already existing column family" not in str(ex): raise else: - log.debug(format_log_context("sync_table checking existing table %s", keyspace=ks_name, connection=connection), cf_name) + log.debug( + format_log_context( + "sync_table checking existing table %s", + keyspace=ks_name, + connection=connection, + ), + cf_name, + ) table_meta = tables[raw_cf_name] _validate_pk(model, table_meta) @@ -248,8 +303,12 @@ def _sync_table(model, connection=None): if db_name in table_columns: col_meta = table_columns[db_name] if col_meta.cql_type != col.db_type: - msg = format_log_context('Existing table {0} has column "{1}" with a type ({2}) differing from the model type ({3}).' - ' Model should be updated.', keyspace=ks_name, connection=connection) + msg = format_log_context( + 'Existing table {0} has column "{1}" with a type ({2}) differing from the model type ({3}).' + " Model should be updated.", + keyspace=ks_name, + connection=connection, + ) msg = msg.format(cf_name, db_name, col_meta.cql_type, col.db_type) warnings.warn(msg) log.warning(msg) @@ -257,7 +316,11 @@ def _sync_table(model, connection=None): continue if col.primary_key or col.partition_key: - msg = format_log_context("Cannot add primary key '{0}' (with db_field '{1}') to existing table {2}", keyspace=ks_name, connection=connection) + msg = format_log_context( + "Cannot add primary key '{0}' (with db_field '{1}') to existing table {2}", + keyspace=ks_name, + connection=connection, + ) raise CQLEngineException(msg.format(model_name, db_name, cf_name)) query = "ALTER TABLE {0} add {1}".format(cf_name, col.get_column_def()) @@ -265,7 +328,11 @@ def _sync_table(model, connection=None): db_fields_not_in_model = model_fields.symmetric_difference(table_columns) if db_fields_not_in_model: - msg = format_log_context("Table {0} has fields not referenced by model: {1}", keyspace=ks_name, connection=connection) + msg = format_log_context( + "Table {0} has fields not referenced by model: {1}", + keyspace=ks_name, + connection=connection, + ) log.info(msg.format(cf_name, db_fields_not_in_model)) _update_options(model, connection=connection) @@ -280,14 +347,14 @@ def _sync_table(model, connection=None): if index_name: continue - qs = ['CREATE INDEX'] - qs += ['ON {0}'.format(cf_name)] + qs = ["CREATE INDEX"] + qs += ["ON {0}".format(cf_name)] # Use FULL index for frozen collections, VALUES index (implicit) for non-frozen if isinstance(column, columns.BaseContainerColumn) and column.frozen: qs += ['(FULL("{0}"))'.format(column.db_field_name)] else: qs += ['("{0}")'.format(column.db_field_name)] - qs = ' '.join(qs) + qs = " ".join(qs) execute(qs, connection=connection) @@ -298,13 +365,22 @@ def _validate_pk(model, table_meta): meta_clustering = [c.name for c in table_meta.clustering_key] if model_partition != meta_partition or model_clustering != meta_clustering: + def _pk_string(partition, clustering): - return "PRIMARY KEY (({0}){1})".format(', '.join(partition), ', ' + ', '.join(clustering) if clustering else '') - raise CQLEngineException("Model {0} PRIMARY KEY composition does not match existing table {1}. " - "Model: {2}; Table: {3}. " - "Update model or drop the table.".format(model, model.column_family_name(), - _pk_string(model_partition, model_clustering), - _pk_string(meta_partition, meta_clustering))) + return "PRIMARY KEY (({0}){1})".format( + ", ".join(partition), ", " + ", ".join(clustering) if clustering else "" + ) + + raise CQLEngineException( + "Model {0} PRIMARY KEY composition does not match existing table {1}. " + "Model: {2}; Table: {3}. " + "Update model or drop the table.".format( + model, + model.column_family_name(), + _pk_string(model_partition, model_clustering), + _pk_string(meta_partition, meta_clustering), + ) + ) def sync_type(ks_name, type_model, connection=None): @@ -329,7 +405,6 @@ def sync_type(ks_name, type_model, connection=None): def _sync_type(ks_name, type_model, omit_subtypes=None, connection=None): - syncd_sub_types = omit_subtypes or set() for field in type_model._fields.values(): udts = [] @@ -347,7 +422,14 @@ def _sync_type(ks_name, type_model, omit_subtypes=None, connection=None): defined_types = keyspace.user_types if type_name not in defined_types: - log.debug(format_log_context("sync_type creating new type %s", keyspace=ks_name, connection=connection), type_name_qualified) + log.debug( + format_log_context( + "sync_type creating new type %s", + keyspace=ks_name, + connection=connection, + ), + type_name_qualified, + ) cql = get_create_type(type_model, ks_name) execute(cql, connection=connection) cluster.refresh_user_type_metadata(ks_name, type_name) @@ -359,39 +441,68 @@ def _sync_type(ks_name, type_model, omit_subtypes=None, connection=None): for field in type_model._fields.values(): model_fields.add(field.db_field_name) if field.db_field_name not in defined_fields: - execute("ALTER TYPE {0} ADD {1}".format(type_name_qualified, field.get_column_def()), connection=connection) + execute( + "ALTER TYPE {0} ADD {1}".format( + type_name_qualified, field.get_column_def() + ), + connection=connection, + ) else: - field_type = type_meta.field_types[defined_fields.index(field.db_field_name)] + field_type = type_meta.field_types[ + defined_fields.index(field.db_field_name) + ] if field_type != field.db_type: - msg = format_log_context('Existing user type {0} has field "{1}" with a type ({2}) differing from the model user type ({3}).' - ' UserType should be updated.', keyspace=ks_name, connection=connection) - msg = msg.format(type_name_qualified, field.db_field_name, field_type, field.db_type) + msg = format_log_context( + 'Existing user type {0} has field "{1}" with a type ({2}) differing from the model user type ({3}).' + " UserType should be updated.", + keyspace=ks_name, + connection=connection, + ) + msg = msg.format( + type_name_qualified, + field.db_field_name, + field_type, + field.db_type, + ) warnings.warn(msg) log.warning(msg) type_model.register_for_keyspace(ks_name, connection=connection) if len(defined_fields) == len(model_fields): - log.info(format_log_context("Type %s did not require synchronization", keyspace=ks_name, connection=connection), type_name_qualified) + log.info( + format_log_context( + "Type %s did not require synchronization", + keyspace=ks_name, + connection=connection, + ), + type_name_qualified, + ) return db_fields_not_in_model = model_fields.symmetric_difference(defined_fields) if db_fields_not_in_model: - msg = format_log_context("Type %s has fields not referenced by model: %s", keyspace=ks_name, connection=connection) + msg = format_log_context( + "Type %s has fields not referenced by model: %s", + keyspace=ks_name, + connection=connection, + ) log.info(msg, type_name_qualified, db_fields_not_in_model) def get_create_type(type_model, keyspace): - type_meta = metadata.UserType(keyspace, - type_model.type_name(), - (f.db_field_name for f in type_model._fields.values()), - (v.db_type for v in type_model._fields.values())) + type_meta = metadata.UserType( + keyspace, + type_model.type_name(), + (f.db_field_name for f in type_model._fields.values()), + (v.db_type for v in type_model._fields.values()), + ) return type_meta.as_cql_query() def _get_create_table(model): ks_table_name = model.column_family_name() - query_strings = ['CREATE TABLE {0}'.format(ks_table_name)] + query_strings = ["CREATE TABLE {0}".format(ks_table_name)] # add column types pkeys = [] # primary keys @@ -401,30 +512,39 @@ def _get_create_table(model): def add_column(col): s = col.get_column_def() if col.primary_key: - keys = (pkeys if col.partition_key else ckeys) + keys = pkeys if col.partition_key else ckeys keys.append('"{0}"'.format(col.db_field_name)) qtypes.append(s) for name, col in model._columns.items(): add_column(col) - qtypes.append('PRIMARY KEY (({0}){1})'.format(', '.join(pkeys), ckeys and ', ' + ', '.join(ckeys) or '')) + qtypes.append( + "PRIMARY KEY (({0}){1})".format( + ", ".join(pkeys), ckeys and ", " + ", ".join(ckeys) or "" + ) + ) - query_strings += ['({0})'.format(', '.join(qtypes))] + query_strings += ["({0})".format(", ".join(qtypes))] property_strings = [] - _order = ['"{0}" {1}'.format(c.db_field_name, c.clustering_order or 'ASC') for c in model._clustering_keys.values()] + _order = [ + '"{0}" {1}'.format(c.db_field_name, c.clustering_order or "ASC") + for c in model._clustering_keys.values() + ] if _order: - property_strings.append('CLUSTERING ORDER BY ({0})'.format(', '.join(_order))) + property_strings.append("CLUSTERING ORDER BY ({0})".format(", ".join(_order))) # options strings use the V3 format, which matches CQL more closely and does not require mapping - property_strings += metadata.TableMetadataV3._make_option_strings(model.__options__ or {}) + property_strings += metadata.TableMetadataV3._make_option_strings( + model.__options__ or {} + ) if property_strings: - query_strings += ['WITH {0}'.format(' AND '.join(property_strings))] + query_strings += ["WITH {0}".format(" AND ".join(property_strings))] - return ' '.join(query_strings) + return " ".join(query_strings) def _get_table_metadata(model, connection=None): @@ -440,10 +560,14 @@ def _options_map_from_strings(option_strings): # converts options strings to a mapping to strings or dict options = {} for option in option_strings: - name, value = option.split('=') - i = value.find('{') + name, value = option.split("=") + i = value.find("{") if i >= 0: - value = value[i:value.rfind('}') + 1].replace("'", '"') # from cql single quotes to json double; not aware of any values that would be escaped right now + value = value[ + i : value.rfind("}") + 1 + ].replace( + "'", '"' + ) # from cql single quotes to json double; not aware of any values that would be escaped right now value = json.loads(value) else: value = value.strip() @@ -462,7 +586,9 @@ def _update_options(model, connection=None): :rtype: bool """ ks_name = model._get_keyspace() - msg = format_log_context("Checking %s for option differences", keyspace=ks_name, connection=connection) + msg = format_log_context( + "Checking %s for option differences", keyspace=ks_name, connection=connection + ) log.debug(msg, model) model_options = model.__options__ or {} @@ -478,7 +604,11 @@ def _update_options(model, connection=None): try: existing_value = existing_options[name] except KeyError: - msg = format_log_context("Invalid table option: '%s'; known options: %s", keyspace=ks_name, connection=connection) + msg = format_log_context( + "Invalid table option: '%s'; known options: %s", + keyspace=ks_name, + connection=connection, + ) raise KeyError(msg % (name, existing_options.keys())) if isinstance(existing_value, str): if value != existing_value: @@ -489,8 +619,11 @@ def _update_options(model, connection=None): # When creating table with compaction 'class': 'org.apache.cassandra.db.compaction.LeveledCompactionStrategy' in Scylla, # it will be silently changed to 'class': 'LeveledCompactionStrategy' - same for at least SizeTieredCompactionStrategy, # probably others too. We need to handle this case here. - if k == 'class' and name == 'compaction': - if existing_value[k] != v and existing_value[k] != v.split('.')[-1]: + if k == "class" and name == "compaction": + if ( + existing_value[k] != v + and existing_value[k] != v.split(".")[-1] + ): update_options[name] = value break else: @@ -501,7 +634,9 @@ def _update_options(model, connection=None): update_options[name] = value if update_options: - options = ' AND '.join(metadata.TableMetadataV3._make_option_strings(update_options)) + options = " AND ".join( + metadata.TableMetadataV3._make_option_strings(update_options) + ) query = "ALTER TABLE {0} WITH {1}".format(model.column_family_name(), options) execute(query, connection=connection) return True @@ -545,14 +680,19 @@ def _drop_table(model, connection=None): try: meta.keyspaces[ks_name].tables[raw_cf_name] - execute('DROP TABLE {0};'.format(model.column_family_name()), connection=connection) + execute( + "DROP TABLE {0};".format(model.column_family_name()), connection=connection + ) except KeyError: pass def _allow_schema_modification(): if not os.getenv(CQLENG_ALLOW_SCHEMA_MANAGEMENT): - msg = CQLENG_ALLOW_SCHEMA_MANAGEMENT + " environment variable is not set. Future versions of this package will require this variable to enable management functions." + msg = ( + CQLENG_ALLOW_SCHEMA_MANAGEMENT + + " environment variable is not set. Future versions of this package will require this variable to enable management functions." + ) warnings.warn(msg) log.warning(msg) diff --git a/docs/scylla-specific.rst b/docs/scylla-specific.rst index e9fe695f8f..4b28781f1c 100644 --- a/docs/scylla-specific.rst +++ b/docs/scylla-specific.rst @@ -91,7 +91,7 @@ New Error Types session = cluster.connect() session.execute(""" CREATE KEYSPACE IF NOT EXISTS keyspace1 - WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'} + WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': '1'} """) session.execute("USE keyspace1") diff --git a/examples/concurrent_executions/execute_async_with_queue.py b/examples/concurrent_executions/execute_async_with_queue.py index 72d2c101cb..ea8e818677 100644 --- a/examples/concurrent_executions/execute_async_with_queue.py +++ b/examples/concurrent_executions/execute_async_with_queue.py @@ -30,10 +30,16 @@ cluster = Cluster() session = cluster.connect() -session.execute(("CREATE KEYSPACE IF NOT EXISTS examples " - "WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1' }")) +session.execute( + ( + "CREATE KEYSPACE IF NOT EXISTS examples " + "WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': '1' }" + ) +) session.execute("USE examples") -session.execute("CREATE TABLE IF NOT EXISTS tbl_sample_kv (id uuid, value text, PRIMARY KEY (id))") +session.execute( + "CREATE TABLE IF NOT EXISTS tbl_sample_kv (id uuid, value text, PRIMARY KEY (id))" +) prepared_insert = session.prepare("INSERT INTO tbl_sample_kv (id, value) VALUES (?, ?)") @@ -61,5 +67,8 @@ def clear_queue(): clear_queue() end = time.time() -print("Finished executing {} queries with a concurrency level of {} in {:.2f} seconds.". - format(TOTAL_QUERIES, CONCURRENCY_LEVEL, (end-start))) +print( + "Finished executing {} queries with a concurrency level of {} in {:.2f} seconds.".format( + TOTAL_QUERIES, CONCURRENCY_LEVEL, (end - start) + ) +) diff --git a/examples/concurrent_executions/execute_with_threads.py b/examples/concurrent_executions/execute_with_threads.py index e3c80f5d6b..30078cd8f1 100644 --- a/examples/concurrent_executions/execute_with_threads.py +++ b/examples/concurrent_executions/execute_with_threads.py @@ -33,15 +33,20 @@ cluster = Cluster() session = cluster.connect() -session.execute(("CREATE KEYSPACE IF NOT EXISTS examples " - "WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1' }")) +session.execute( + ( + "CREATE KEYSPACE IF NOT EXISTS examples " + "WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': '1' }" + ) +) session.execute("USE examples") -session.execute("CREATE TABLE IF NOT EXISTS tbl_sample_kv (id uuid, value text, PRIMARY KEY (id))") +session.execute( + "CREATE TABLE IF NOT EXISTS tbl_sample_kv (id uuid, value text, PRIMARY KEY (id))" +) prepared_insert = session.prepare("INSERT INTO tbl_sample_kv (id, value) VALUES (?, ?)") class SimpleQueryExecutor(threading.Thread): - def run(self): global COUNTER @@ -68,5 +73,8 @@ def run(self): thread.join() end = time.time() -print("Finished executing {} queries with a concurrency level of {} in {:.2f} seconds.". - format(TOTAL_QUERIES, CONCURRENCY_LEVEL, (end-start))) +print( + "Finished executing {} queries with a concurrency level of {} in {:.2f} seconds.".format( + TOTAL_QUERIES, CONCURRENCY_LEVEL, (end - start) + ) +) diff --git a/examples/example_core.py b/examples/example_core.py index 01c766e109..50808679ff 100644 --- a/examples/example_core.py +++ b/examples/example_core.py @@ -17,9 +17,11 @@ import logging log = logging.getLogger() -log.setLevel('DEBUG') +log.setLevel("DEBUG") handler = logging.StreamHandler() -handler.setFormatter(logging.Formatter("%(asctime)s [%(levelname)s] %(name)s: %(message)s")) +handler.setFormatter( + logging.Formatter("%(asctime)s [%(levelname)s] %(name)s: %(message)s") +) log.addHandler(handler) from cassandra import ConsistencyLevel @@ -30,14 +32,17 @@ def main(): - cluster = Cluster(['127.0.0.1']) + cluster = Cluster(["127.0.0.1"]) session = cluster.connect() log.info("creating keyspace...") - session.execute(""" + session.execute( + """ CREATE KEYSPACE IF NOT EXISTS %s - WITH replication = { 'class': 'SimpleStrategy', 'replication_factor': '2' } - """ % KEYSPACE) + WITH replication = { 'class': 'NetworkTopologyStrategy', 'replication_factor': '2' } + """ + % KEYSPACE + ) log.info("setting keyspace...") session.set_keyspace(KEYSPACE) @@ -52,10 +57,13 @@ def main(): ) """) - query = SimpleStatement(""" + query = SimpleStatement( + """ INSERT INTO mytable (thekey, col1, col2) VALUES (%(key)s, %(a)s, %(b)s) - """, consistency_level=ConsistencyLevel.ONE) + """, + consistency_level=ConsistencyLevel.ONE, + ) prepared = session.prepare(""" INSERT INTO mytable (thekey, col1, col2) @@ -64,8 +72,8 @@ def main(): for i in range(10): log.info("inserting row %d" % i) - session.execute(query, dict(key="key%d" % i, a='a', b='b')) - session.execute(prepared, ("key%d" % i, 'b', 'b')) + session.execute(query, dict(key="key%d" % i, a="a", b="b")) + session.execute(prepared, ("key%d" % i, "b", "b")) future = session.execute_async("SELECT * FROM mytable") log.info("key\tcol1\tcol2") @@ -78,9 +86,10 @@ def main(): return for row in rows: - log.info('\t'.join(row)) + log.info("\t".join(row)) session.execute("DROP KEYSPACE " + KEYSPACE) + if __name__ == "__main__": main() diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py index 6a809bded4..a954b949ec 100644 --- a/tests/integration/__init__.py +++ b/tests/integration/__init__.py @@ -38,8 +38,15 @@ import pytest -from cassandra import OperationTimedOut, ReadTimeout, ReadFailure, WriteTimeout, WriteFailure, AlreadyExists,\ - InvalidRequest +from cassandra import ( + OperationTimedOut, + ReadTimeout, + ReadFailure, + WriteTimeout, + WriteFailure, + AlreadyExists, + InvalidRequest, +) from cassandra.protocol import ConfigurationException from cassandra import ProtocolVersion @@ -54,9 +61,9 @@ log = logging.getLogger(__name__) -CLUSTER_NAME = 'test_cluster' -SINGLE_NODE_CLUSTER_NAME = 'single_node' -MULTIDC_CLUSTER_NAME = 'multidc_test_cluster' +CLUSTER_NAME = "test_cluster" +SINGLE_NODE_CLUSTER_NAME = "single_node" +MULTIDC_CLUSTER_NAME = "multidc_test_cluster" # When use_single_interface is specified ccm will assign distinct port numbers to each # node in the cluster. This value specifies the default port value used for the first @@ -64,11 +71,11 @@ # # TODO: In the future we may want to make this configurable, but this should only apply # if a non-standard port were specified when starting up the cluster. -DEFAULT_SINGLE_INTERFACE_PORT=9046 +DEFAULT_SINGLE_INTERFACE_PORT = 9046 CCM_CLUSTER = None -path = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'ccm') +path = os.path.join(os.path.abspath(os.path.dirname(__file__)), "ccm") if not os.path.exists(path): os.mkdir(path) @@ -88,7 +95,9 @@ def get_server_versions(): c = TestCluster() s = c.connect() - row = s.execute("SELECT cql_version, release_version FROM system.local WHERE key='local'").one() + row = s.execute( + "SELECT cql_version, release_version FROM system.local WHERE key='local'" + ).one() cass_version = _tuple_version(row.release_version) cql_version = _tuple_version(row.cql_version) @@ -99,41 +108,46 @@ def get_server_versions(): def get_scylla_version(scylla_ccm_version_string): - """ get scylla version from ccm before starting a cluster""" - ccm_repo_cache_dir, _ = ccmlib.scylla_repository.setup(version=scylla_ccm_version_string) - return ccmlib.common.get_version_from_build(ccm_repo_cache_dir) + """get scylla version from ccm before starting a cluster""" + ccm_repo_cache_dir, _ = ccmlib.scylla_repository.setup( + version=scylla_ccm_version_string + ) + return ccmlib.common.get_version_from_build(ccm_repo_cache_dir) def _tuple_version(version_string): - if '-' in version_string: - version_string = version_string[:version_string.index('-')] + if "-" in version_string: + version_string = version_string[: version_string.index("-")] - return tuple([int(p) for p in version_string.split('.')]) + return tuple([int(p) for p in version_string.split(".")]) def cmd_line_args_to_dict(env_var): cmd_args_env = os.environ.get(env_var, None) args = {} if cmd_args_env: - cmd_args = cmd_args_env.strip().split(' ') + cmd_args = cmd_args_env.strip().split(" ") while cmd_args: cmd_arg = cmd_args.pop(0) - cmd_arg_value = True if cmd_arg.startswith('--') else cmd_args.pop(0) - args[cmd_arg.lstrip('-')] = cmd_arg_value + cmd_arg_value = True if cmd_arg.startswith("--") else cmd_args.pop(0) + args[cmd_arg.lstrip("-")] = cmd_arg_value return args -USE_CASS_EXTERNAL = bool(os.getenv('USE_CASS_EXTERNAL', False)) -KEEP_TEST_CLUSTER = bool(os.getenv('KEEP_TEST_CLUSTER', False)) -SIMULACRON_JAR = os.getenv('SIMULACRON_JAR', None) + +USE_CASS_EXTERNAL = bool(os.getenv("USE_CASS_EXTERNAL", False)) +KEEP_TEST_CLUSTER = bool(os.getenv("KEEP_TEST_CLUSTER", False)) +SIMULACRON_JAR = os.getenv("SIMULACRON_JAR", None) # Supported Clusters: Cassandra, Scylla -SCYLLA_VERSION = os.getenv('SCYLLA_VERSION', None) +SCYLLA_VERSION = os.getenv("SCYLLA_VERSION", None) if SCYLLA_VERSION: cv_string = SCYLLA_VERSION - mcv_string = os.getenv('MAPPED_SCYLLA_VERSION', '3.11.4') # Assume that scylla matches cassandra `3.11.4` behavior + mcv_string = os.getenv( + "MAPPED_SCYLLA_VERSION", "3.11.4" + ) # Assume that scylla matches cassandra `3.11.4` behavior else: - cv_string = os.getenv('CASSANDRA_VERSION', None) - mcv_string = os.getenv('MAPPED_CASSANDRA_VERSION', None) + cv_string = os.getenv("CASSANDRA_VERSION", None) + mcv_string = os.getenv("MAPPED_CASSANDRA_VERSION", None) try: cassandra_version = Version(cv_string) # env var is set to test-dse for DDAC except: @@ -143,67 +157,71 @@ def cmd_line_args_to_dict(env_var): CASSANDRA_VERSION = Version(mcv_string) if mcv_string else cassandra_version CCM_VERSION = mcv_string if mcv_string else cv_string -CASSANDRA_IP = os.getenv('CLUSTER_IP', '127.0.0.1') -CASSANDRA_DIR = os.getenv('CASSANDRA_DIR', None) +CASSANDRA_IP = os.getenv("CLUSTER_IP", "127.0.0.1") +CASSANDRA_DIR = os.getenv("CASSANDRA_DIR", None) CCM_KWARGS = {} if CASSANDRA_DIR: log.info("Using Cassandra dir: %s", CASSANDRA_DIR) - CCM_KWARGS['install_dir'] = CASSANDRA_DIR -elif os.getenv('SCYLLA_VERSION'): - CCM_KWARGS['cassandra_version'] = os.path.join(os.getenv('SCYLLA_VERSION')) + CCM_KWARGS["install_dir"] = CASSANDRA_DIR +elif os.getenv("SCYLLA_VERSION"): + CCM_KWARGS["cassandra_version"] = os.path.join(os.getenv("SCYLLA_VERSION")) else: - log.info('Using Cassandra version: %s', CCM_VERSION) - CCM_KWARGS['version'] = CCM_VERSION + log.info("Using Cassandra version: %s", CCM_VERSION) + CCM_KWARGS["version"] = CCM_VERSION ALLOW_BETA_PROTOCOL = False def get_default_protocol(): - if CASSANDRA_VERSION >= Version('4.0-a'): + if CASSANDRA_VERSION >= Version("4.0-a"): return ProtocolVersion.V5 - if CASSANDRA_VERSION >= Version('3.10'): + if CASSANDRA_VERSION >= Version("3.10"): return 4 - if CASSANDRA_VERSION >= Version('2.2'): + if CASSANDRA_VERSION >= Version("2.2"): return 4 - elif CASSANDRA_VERSION >= Version('2.1'): + elif CASSANDRA_VERSION >= Version("2.1"): return 3 else: - raise Exception("Running tests with an unsupported Cassandra version: {0}".format(CASSANDRA_VERSION)) + raise Exception( + "Running tests with an unsupported Cassandra version: {0}".format( + CASSANDRA_VERSION + ) + ) def get_scylla_default_protocol(): if len(CASSANDRA_VERSION.release) == 4: # An enterprise, i.e. 2021.1.6 - if CASSANDRA_VERSION > Version('2019'): + if CASSANDRA_VERSION > Version("2019"): return 4 return 3 - if CASSANDRA_VERSION >= Version('3.0'): + if CASSANDRA_VERSION >= Version("3.0"): return 4 return 3 def get_supported_protocol_versions(): """ - 2.1 -> 3 - 2.2 -> 4, 3 - 3.X -> 4, 3 - 3.10(C*) -> 5(beta),4,3 - 4.0(C*) -> 6(beta),5,4,3 -` """ - if CASSANDRA_VERSION >= Version('4.0-beta5'): + 2.1 -> 3 + 2.2 -> 4, 3 + 3.X -> 4, 3 + 3.10(C*) -> 5(beta),4,3 + 4.0(C*) -> 6(beta),5,4,3 + `""" + if CASSANDRA_VERSION >= Version("4.0-beta5"): + return (3, 4, 5) + if CASSANDRA_VERSION >= Version("4.0-a"): return (3, 4, 5) - if CASSANDRA_VERSION >= Version('4.0-a'): - return (3, 4, 5) - elif CASSANDRA_VERSION >= Version('3.10'): + elif CASSANDRA_VERSION >= Version("3.10"): return (3, 4) - elif CASSANDRA_VERSION >= Version('3.0'): + elif CASSANDRA_VERSION >= Version("3.0"): return (3, 4) - elif CASSANDRA_VERSION >= Version('2.2'): + elif CASSANDRA_VERSION >= Version("2.2"): return (3, 4) - elif CASSANDRA_VERSION >= Version('2.1'): - return (3) + elif CASSANDRA_VERSION >= Version("2.1"): + return 3 else: return (3,) @@ -215,7 +233,7 @@ def get_unsupported_lower_protocol(): """ if SCYLLA_VERSION is not None: return 2 - if CASSANDRA_VERSION >= Version('3.0'): + if CASSANDRA_VERSION >= Version("3.0"): return 2 else: return None @@ -229,29 +247,31 @@ def get_unsupported_upper_protocol(): if SCYLLA_VERSION is not None: return 5 - if CASSANDRA_VERSION >= Version('4.0-a'): + if CASSANDRA_VERSION >= Version("4.0-a"): return ProtocolVersion.DSE_V1 - if CASSANDRA_VERSION >= Version('3.10'): + if CASSANDRA_VERSION >= Version("3.10"): return 5 - if CASSANDRA_VERSION >= Version('2.2'): + if CASSANDRA_VERSION >= Version("2.2"): return 5 - elif CASSANDRA_VERSION >= Version('2.1'): + elif CASSANDRA_VERSION >= Version("2.1"): return 4 - elif CASSANDRA_VERSION >= Version('2.0'): + elif CASSANDRA_VERSION >= Version("2.0"): return 3 else: return 2 -default_protocol_version = get_scylla_default_protocol() if SCYLLA_VERSION else get_default_protocol() +default_protocol_version = ( + get_scylla_default_protocol() if SCYLLA_VERSION else get_default_protocol() +) -PROTOCOL_VERSION = int(os.getenv('PROTOCOL_VERSION', default_protocol_version)) +PROTOCOL_VERSION = int(os.getenv("PROTOCOL_VERSION", default_protocol_version)) def local_decorator_creator(): if USE_CASS_EXTERNAL or not CASSANDRA_IP.startswith("127.0.0."): - return unittest.skip('Tests only runs against local C*') + return unittest.skip("Tests only runs against local C*") def _id_and_mark(f): f.local = True @@ -259,64 +279,130 @@ def _id_and_mark(f): return _id_and_mark -def xfail_scylla_version(filter: Callable[[Version], bool], reason: str, *args, **kwargs): + +def xfail_scylla_version( + filter: Callable[[Version], bool], reason: str, *args, **kwargs +): if SCYLLA_VERSION is None: - return pytest.mark.skipif(False, reason="It is just a NoOP Decor, should not skip anything") + return pytest.mark.skipif( + False, reason="It is just a NoOP Decor, should not skip anything" + ) current_version = Version(get_scylla_version(SCYLLA_VERSION)) return pytest.mark.xfail(filter(current_version), reason=reason, *args, **kwargs) + local = local_decorator_creator() -notprotocolv1 = unittest.skipUnless(PROTOCOL_VERSION > 1, 'Protocol v1 not supported') -greaterthanprotocolv3 = unittest.skipUnless(PROTOCOL_VERSION >= 4, 'Protocol versions less than 4 are not supported') - -greaterthancass20 = unittest.skipUnless(CASSANDRA_VERSION >= Version('2.1'), 'Cassandra version 2.1 or greater required') -greaterthancass21 = unittest.skipUnless(CASSANDRA_VERSION >= Version('2.2'), 'Cassandra version 2.2 or greater required') -greaterthanorequalcass30 = unittest.skipUnless(CASSANDRA_VERSION >= Version('3.0'), 'Cassandra version 3.0 or greater required') -greaterthanorequalcass31 = unittest.skipUnless(CASSANDRA_VERSION >= Version('3.1'), 'Cassandra version 3.1 or greater required') -greaterthanorequalcass36 = unittest.skipUnless(CASSANDRA_VERSION >= Version('3.6'), 'Cassandra version 3.6 or greater required') -greaterthanorequalcass3_10 = unittest.skipUnless(CASSANDRA_VERSION >= Version('3.10'), 'Cassandra version 3.10 or greater required') -greaterthanorequalcass3_11 = unittest.skipUnless(CASSANDRA_VERSION >= Version('3.11'), 'Cassandra version 3.11 or greater required') -greaterthanorequalcass40 = unittest.skipUnless(CASSANDRA_VERSION >= Version('4.0'), 'Cassandra version 4.0 or greater required') -greaterthanorequalcass50 = unittest.skipUnless(CASSANDRA_VERSION >= Version('5.0-beta'), 'Cassandra version 5.0 or greater required') +notprotocolv1 = unittest.skipUnless(PROTOCOL_VERSION > 1, "Protocol v1 not supported") +greaterthanprotocolv3 = unittest.skipUnless( + PROTOCOL_VERSION >= 4, "Protocol versions less than 4 are not supported" +) + +greaterthancass20 = unittest.skipUnless( + CASSANDRA_VERSION >= Version("2.1"), "Cassandra version 2.1 or greater required" +) +greaterthancass21 = unittest.skipUnless( + CASSANDRA_VERSION >= Version("2.2"), "Cassandra version 2.2 or greater required" +) +greaterthanorequalcass30 = unittest.skipUnless( + CASSANDRA_VERSION >= Version("3.0"), "Cassandra version 3.0 or greater required" +) +greaterthanorequalcass31 = unittest.skipUnless( + CASSANDRA_VERSION >= Version("3.1"), "Cassandra version 3.1 or greater required" +) +greaterthanorequalcass36 = unittest.skipUnless( + CASSANDRA_VERSION >= Version("3.6"), "Cassandra version 3.6 or greater required" +) +greaterthanorequalcass3_10 = unittest.skipUnless( + CASSANDRA_VERSION >= Version("3.10"), "Cassandra version 3.10 or greater required" +) +greaterthanorequalcass3_11 = unittest.skipUnless( + CASSANDRA_VERSION >= Version("3.11"), "Cassandra version 3.11 or greater required" +) +greaterthanorequalcass40 = unittest.skipUnless( + CASSANDRA_VERSION >= Version("4.0"), "Cassandra version 4.0 or greater required" +) +greaterthanorequalcass50 = unittest.skipUnless( + CASSANDRA_VERSION >= Version("5.0-beta"), + "Cassandra version 5.0 or greater required", +) + + def _has_vector_type(): if SCYLLA_VERSION is not None: - return Version(get_scylla_version(SCYLLA_VERSION)) >= Version('2025.4') - return CASSANDRA_VERSION >= Version('5.0-beta') + return Version(get_scylla_version(SCYLLA_VERSION)) >= Version("2025.4") + return CASSANDRA_VERSION >= Version("5.0-beta") -lessthanorequalcass40 = unittest.skipUnless(CASSANDRA_VERSION <= Version('4.0'), 'Cassandra version less or equal to 4.0 required') -lessthancass40 = unittest.skipUnless(CASSANDRA_VERSION < Version('4.0'), 'Cassandra version less than 4.0 required') -lessthancass30 = unittest.skipUnless(CASSANDRA_VERSION < Version('3.0'), 'Cassandra version less then 3.0 required') + +lessthanorequalcass40 = unittest.skipUnless( + CASSANDRA_VERSION <= Version("4.0"), + "Cassandra version less or equal to 4.0 required", +) +lessthancass40 = unittest.skipUnless( + CASSANDRA_VERSION < Version("4.0"), "Cassandra version less than 4.0 required" +) +lessthancass30 = unittest.skipUnless( + CASSANDRA_VERSION < Version("3.0"), "Cassandra version less then 3.0 required" +) # pytest.mark.xfail instead of unittest.expectedFailure because # 1. unittest doesn't skip setUpClass when used on class and we need it sometimes # 2. unittest doesn't have conditional xfail, and I prefer to use pytest than custom decorator # 3. unittest doesn't have a reason argument, so you don't see the reason in pytest report -requires_collection_indexes = pytest.mark.skipif(SCYLLA_VERSION is not None and Version(get_scylla_version(SCYLLA_VERSION)) < Version('5.2'), - reason='Scylla supports collection indexes from 5.2 onwards') -requires_custom_indexes = pytest.mark.skipif(SCYLLA_VERSION is not None, - reason='Scylla does not support SASI or any other CUSTOM INDEX class') -requires_java_udf = pytest.mark.skipif(SCYLLA_VERSION is not None, - reason='Scylla does not support UDFs written in Java') -requires_composite_type = pytest.mark.skipif(SCYLLA_VERSION is not None, - reason='Scylla does not support composite types') -requires_custom_payload = pytest.mark.skipif(SCYLLA_VERSION is not None or PROTOCOL_VERSION < 4, - reason='Scylla does not support custom payloads. Cassandra requires native protocol v4.0+') +requires_collection_indexes = pytest.mark.skipif( + SCYLLA_VERSION is not None + and Version(get_scylla_version(SCYLLA_VERSION)) < Version("5.2"), + reason="Scylla supports collection indexes from 5.2 onwards", +) +requires_custom_indexes = pytest.mark.skipif( + SCYLLA_VERSION is not None, + reason="Scylla does not support SASI or any other CUSTOM INDEX class", +) +requires_java_udf = pytest.mark.skipif( + SCYLLA_VERSION is not None, reason="Scylla does not support UDFs written in Java" +) +requires_composite_type = pytest.mark.skipif( + SCYLLA_VERSION is not None, reason="Scylla does not support composite types" +) +requires_custom_payload = pytest.mark.skipif( + SCYLLA_VERSION is not None or PROTOCOL_VERSION < 4, + reason="Scylla does not support custom payloads. Cassandra requires native protocol v4.0+", +) requires_vector_type = unittest.skipUnless( - _has_vector_type(), - 'Cassandra >= 5.0 or Scylla >= 2025.4 required') -xfail_scylla = lambda reason, *args, **kwargs: pytest.mark.xfail(SCYLLA_VERSION is not None, reason=reason, *args, **kwargs) -incorrect_test = lambda reason='This test seems to be incorrect and should be fixed', *args, **kwargs: pytest.mark.xfail(reason=reason, *args, **kwargs) - -pypy = unittest.skipUnless(platform.python_implementation() == "PyPy", "Test is skipped unless it's on PyPy") -requiresmallclockgranularity = unittest.skipIf("Windows" in platform.system() or "asyncore" in EVENT_LOOP_MANAGER, - "This test is not suitible for environments with large clock granularity") -requiressimulacron = unittest.skipIf(SIMULACRON_JAR is None or CASSANDRA_VERSION < Version("2.1"), "Simulacron jar hasn't been specified or C* version is 2.0") -requirescompactstorage = xfail_scylla_version(lambda v: v >= Version('2025.1.0'), reason="ScyllaDB deprecated compact storage", raises=InvalidRequest) -libevtest = unittest.skipUnless(EVENT_LOOP_MANAGER=="libev", "Test timing designed for libev loop") + _has_vector_type(), "Cassandra >= 5.0 or Scylla >= 2025.4 required" +) +xfail_scylla = lambda reason, *args, **kwargs: pytest.mark.xfail( + SCYLLA_VERSION is not None, reason=reason, *args, **kwargs +) +incorrect_test = ( + lambda reason="This test seems to be incorrect and should be fixed", + *args, + **kwargs: pytest.mark.xfail(reason=reason, *args, **kwargs) +) + +pypy = unittest.skipUnless( + platform.python_implementation() == "PyPy", "Test is skipped unless it's on PyPy" +) +requiresmallclockgranularity = unittest.skipIf( + "Windows" in platform.system() or "asyncore" in EVENT_LOOP_MANAGER, + "This test is not suitible for environments with large clock granularity", +) +requiressimulacron = unittest.skipIf( + SIMULACRON_JAR is None or CASSANDRA_VERSION < Version("2.1"), + "Simulacron jar hasn't been specified or C* version is 2.0", +) +requirescompactstorage = xfail_scylla_version( + lambda v: v >= Version("2025.1.0"), + reason="ScyllaDB deprecated compact storage", + raises=InvalidRequest, +) +libevtest = unittest.skipUnless( + EVENT_LOOP_MANAGER == "libev", "Test timing designed for libev loop" +) + def wait_for_node_socket(node, timeout): - binary_itf = node.network_interfaces['binary'] + binary_itf = node.network_interfaces["binary"] if not common.check_socket_listening(binary_itf, timeout=timeout): log.warning("Unable to connect to binary socket for node " + node.name) else: @@ -333,12 +419,12 @@ def check_socket_listening(itf, timeout=60): return True except socket.error: # Try again in another 200ms - time.sleep(.2) + time.sleep(0.2) continue return False -USE_SINGLE_INTERFACE = os.getenv('USE_SINGLE_INTERFACE', False) +USE_SINGLE_INTERFACE = os.getenv("USE_SINGLE_INTERFACE", False) def get_cluster(): @@ -346,7 +432,7 @@ def get_cluster(): def get_node(node_id): - return CCM_CLUSTER.nodes['node%s' % node_id] + return CCM_CLUSTER.nodes["node%s" % node_id] def use_multidc(dc_list, workloads=None): @@ -354,22 +440,36 @@ def use_multidc(dc_list, workloads=None): def use_singledc(start=True, workloads=None, use_single_interface=USE_SINGLE_INTERFACE): - use_cluster(CLUSTER_NAME, [3], start=start, workloads=workloads, use_single_interface=use_single_interface) - - -def use_single_node(start=True, workloads=None, configuration_options=None, dse_options=None): - use_cluster(SINGLE_NODE_CLUSTER_NAME, [1], start=start, workloads=workloads, - configuration_options=configuration_options, dse_options=dse_options) + use_cluster( + CLUSTER_NAME, + [3], + start=start, + workloads=workloads, + use_single_interface=use_single_interface, + ) + + +def use_single_node( + start=True, workloads=None, configuration_options=None, dse_options=None +): + use_cluster( + SINGLE_NODE_CLUSTER_NAME, + [1], + start=start, + workloads=workloads, + configuration_options=configuration_options, + dse_options=dse_options, + ) def check_log_error(): global CCM_CLUSTER log.debug("Checking log error of cluster {0}".format(CCM_CLUSTER.name)) for node in CCM_CLUSTER.nodelist(): - errors = node.grep_log_for_errors() - for error in errors: - for line in error: - print(line) + errors = node.grep_log_for_errors() + for error in errors: + for line in error: + print(line) def remove_cluster(): @@ -388,7 +488,11 @@ def remove_cluster(): return except OSError: ex_type, ex, tb = sys.exc_info() - log.warning("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) + log.warning( + "{0}: {1} Backtrace: {2}".format( + ex_type.__name__, ex, traceback.extract_tb(tb) + ) + ) del tb tries += 1 time.sleep(1) @@ -399,10 +503,12 @@ def remove_cluster(): def is_current_cluster(cluster_name, node_counts, workloads): global CCM_CLUSTER if CCM_CLUSTER and CCM_CLUSTER.name == cluster_name: - if [len(list(nodes)) for dc, nodes in - groupby(CCM_CLUSTER.nodelist(), lambda n: n.data_center)] == node_counts: + if [ + len(list(nodes)) + for dc, nodes in groupby(CCM_CLUSTER.nodelist(), lambda n: n.data_center) + ] == node_counts: for node in CCM_CLUSTER.nodelist(): - if set(getattr(node, 'workloads', [])) != set(workloads): + if set(getattr(node, "workloads", [])) != set(workloads): print("node workloads don't match creating new cluster") return False return True @@ -418,8 +524,18 @@ def start_cluster_wait_for_up(cluster): log.debug("Binary port are open") -def use_cluster(cluster_name, nodes, ipformat=None, start=True, workloads=None, set_keyspace=True, ccm_options=None, - configuration_options=None, dse_options=None, use_single_interface=USE_SINGLE_INTERFACE): +def use_cluster( + cluster_name, + nodes, + ipformat=None, + start=True, + workloads=None, + set_keyspace=True, + ccm_options=None, + configuration_options=None, + dse_options=None, + use_single_interface=USE_SINGLE_INTERFACE, +): configuration_options = configuration_options or {} dse_options = dse_options or {} workloads = workloads or [] @@ -427,7 +543,7 @@ def use_cluster(cluster_name, nodes, ipformat=None, start=True, workloads=None, if ccm_options is None: ccm_options = CCM_KWARGS.copy() - cassandra_version = ccm_options.get('version', CCM_VERSION) + cassandra_version = ccm_options.get("version", CCM_VERSION) global CCM_CLUSTER if USE_CASS_EXTERNAL: @@ -449,7 +565,11 @@ def use_cluster(cluster_name, nodes, ipformat=None, start=True, workloads=None, log.debug("Using existing cluster, matching topology: {0}".format(cluster_name)) else: if CCM_CLUSTER: - log.debug("Stopping existing cluster, topology mismatch: {0}".format(CCM_CLUSTER.name)) + log.debug( + "Stopping existing cluster, topology mismatch: {0}".format( + CCM_CLUSTER.name + ) + ) CCM_CLUSTER.stop() try: @@ -461,12 +581,20 @@ def use_cluster(cluster_name, nodes, ipformat=None, start=True, workloads=None, CCM_CLUSTER.set_dse_configuration_options(dse_options) except Exception: ex_type, ex, tb = sys.exc_info() - log.warning("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) + log.warning( + "{0}: {1} Backtrace: {2}".format( + ex_type.__name__, ex, traceback.extract_tb(tb) + ) + ) del tb - ccm_options.update(cmd_line_args_to_dict('CCM_ARGS')) + ccm_options.update(cmd_line_args_to_dict("CCM_ARGS")) - log.debug("Creating new CCM cluster, {0}, with args {1}".format(cluster_name, ccm_options)) + log.debug( + "Creating new CCM cluster, {0}, with args {1}".format( + cluster_name, ccm_options + ) + ) # Make sure we cleanup old cluster dir if it exists cluster_path = os.path.join(path, cluster_name) @@ -478,30 +606,48 @@ def use_cluster(cluster_name, nodes, ipformat=None, start=True, workloads=None, # CDC is causing an issue (can't start cluster with multiple seeds) # Selecting only features we need for tests, i.e. anything but CDC. CCM_CLUSTER = CCMScyllaCluster(path, cluster_name, **ccm_options) - CCM_CLUSTER.set_configuration_options({'experimental_features': ['lwt', 'udf'], 'start_native_transport': True}) - - CCM_CLUSTER.set_configuration_options({'skip_wait_for_gossip_to_settle': 0}) + CCM_CLUSTER.set_configuration_options( + { + "experimental_features": ["lwt", "udf"], + "start_native_transport": True, + } + ) + + CCM_CLUSTER.set_configuration_options( + {"skip_wait_for_gossip_to_settle": 0} + ) # Permit IS NOT NULL restriction on non-primary key columns of a materialized view # This allows `test_metadata_with_quoted_identifiers` to run - CCM_CLUSTER.set_configuration_options({'strict_is_not_null_in_views': False}) + CCM_CLUSTER.set_configuration_options( + {"strict_is_not_null_in_views": False} + ) else: - ccm_cluster_clz = CCMCluster if Version(cassandra_version) < Version( - '4.1') else Cassandra41CCMCluster + ccm_cluster_clz = ( + CCMCluster + if Version(cassandra_version) < Version("4.1") + else Cassandra41CCMCluster + ) CCM_CLUSTER = ccm_cluster_clz(path, cluster_name, **ccm_options) - CCM_CLUSTER.set_configuration_options({'start_native_transport': True}) - if Version(cassandra_version) >= Version('2.2'): - CCM_CLUSTER.set_configuration_options({'enable_user_defined_functions': True}) - if Version(cassandra_version) >= Version('3.0'): + CCM_CLUSTER.set_configuration_options({"start_native_transport": True}) + if Version(cassandra_version) >= Version("2.2"): + CCM_CLUSTER.set_configuration_options( + {"enable_user_defined_functions": True} + ) + if Version(cassandra_version) >= Version("3.0"): # The config.yml option below is deprecated in C* 4.0 per CASSANDRA-17280 - if Version(cassandra_version) < Version('4.0'): - CCM_CLUSTER.set_configuration_options({'enable_scripted_user_defined_functions': True}) + if Version(cassandra_version) < Version("4.0"): + CCM_CLUSTER.set_configuration_options( + {"enable_scripted_user_defined_functions": True} + ) else: # Cassandra version >= 4.0 - CCM_CLUSTER.set_configuration_options({ - 'enable_materialized_views': True, - 'enable_sasi_indexes': True, - 'enable_transient_replication': True, - }) + CCM_CLUSTER.set_configuration_options( + { + "enable_materialized_views": True, + "enable_sasi_indexes": True, + "enable_transient_replication": True, + } + ) common.switch_cluster(path, cluster_name) CCM_CLUSTER.set_configuration_options(configuration_options) @@ -513,17 +659,21 @@ def use_cluster(cluster_name, nodes, ipformat=None, start=True, workloads=None, # This will enable the Mirroring query handler which will echo our custom payload k,v pairs back - if 'graph' in workloads: - jvm_args += ['-Xms1500M', '-Xmx1500M'] + if "graph" in workloads: + jvm_args += ["-Xms1500M", "-Xmx1500M"] else: if PROTOCOL_VERSION >= 4 and not SCYLLA_VERSION: - jvm_args = [" -Dcassandra.custom_query_handler_class=org.apache.cassandra.cql3.CustomPayloadMirroringQueryHandler"] + jvm_args = [ + " -Dcassandra.custom_query_handler_class=org.apache.cassandra.cql3.CustomPayloadMirroringQueryHandler" + ] if len(workloads) > 0: for node in CCM_CLUSTER.nodes.values(): node.set_workloads(workloads) if start: log.debug("Starting CCM cluster: {0}".format(cluster_name)) - CCM_CLUSTER.start(jvm_args=jvm_args, wait_for_binary_proto=True, wait_other_notice=True) + CCM_CLUSTER.start( + jvm_args=jvm_args, wait_for_binary_proto=True, wait_other_notice=True + ) # Added to wait for slow nodes to start up log.debug("Cluster started waiting for binary ports") for node in CCM_CLUSTER.nodes.values(): @@ -559,12 +709,12 @@ def teardown_package(): cluster = CCMClusterFactory.load(path, cluster_name) try: cluster.remove() - log.info('Removed cluster: %s' % cluster_name) + log.info("Removed cluster: %s" % cluster_name) except Exception: - log.exception('Failed to remove cluster: %s' % cluster_name) + log.exception("Failed to remove cluster: %s" % cluster_name) except Exception: - log.warning('Did not find cluster: %s' % cluster_name) + log.warning("Did not find cluster: %s" % cluster_name) def execute_until_pass(session, query): @@ -573,12 +723,24 @@ def execute_until_pass(session, query): try: return session.execute(query) except (ConfigurationException, AlreadyExists, InvalidRequest): - log.warning("Received already exists from query {0} not exiting".format(query)) + log.warning( + "Received already exists from query {0} not exiting".format(query) + ) # keyspace/table was already created/dropped return - except (OperationTimedOut, ReadTimeout, ReadFailure, WriteTimeout, WriteFailure): + except ( + OperationTimedOut, + ReadTimeout, + ReadFailure, + WriteTimeout, + WriteFailure, + ): ex_type, ex, tb = sys.exc_info() - log.warning("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) + log.warning( + "{0}: {1} Backtrace: {2}".format( + ex_type.__name__, ex, traceback.extract_tb(tb) + ) + ) del tb tries += 1 @@ -591,12 +753,24 @@ def execute_with_long_wait_retry(session, query, timeout=30): try: return session.execute(query, timeout=timeout) except (ConfigurationException, AlreadyExists): - log.warning("Received already exists from query {0} not exiting".format(query)) + log.warning( + "Received already exists from query {0} not exiting".format(query) + ) # keyspace/table was already created/dropped return - except (OperationTimedOut, ReadTimeout, ReadFailure, WriteTimeout, WriteFailure): + except ( + OperationTimedOut, + ReadTimeout, + ReadFailure, + WriteTimeout, + WriteFailure, + ): ex_type, ex, tb = sys.exc_info() - log.warning("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) + log.warning( + "{0}: {1} Backtrace: {2}".format( + ex_type.__name__, ex, traceback.extract_tb(tb) + ) + ) del tb tries += 1 @@ -614,7 +788,7 @@ def execute_with_retry_tolerant(session, query, retry_exceptions, escape_excepti except escape_exception: return except retry_exceptions: - time.sleep(.1) + time.sleep(0.1) raise RuntimeError("Failed to execute query after 100 attempts: {0}".format(query)) @@ -625,7 +799,11 @@ def drop_keyspace_shutdown_cluster(keyspace_name, session, cluster): except: log.warning("Error encountered when droping keyspace {0}".format(keyspace_name)) ex_type, ex, tb = sys.exc_info() - log.warning("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) + log.warning( + "{0}: {1} Backtrace: {2}".format( + ex_type.__name__, ex, traceback.extract_tb(tb) + ) + ) del tb finally: log.warning("Shutting down cluster") @@ -641,39 +819,41 @@ def setup_keyspace(ipformat=None, protocol_version=None, port=9042): if not ipformat: cluster = TestCluster(protocol_version=_protocol_version, port=port) else: - cluster = TestCluster(contact_points=["::1"], protocol_version=_protocol_version, port=port) + cluster = TestCluster( + contact_points=["::1"], protocol_version=_protocol_version, port=port + ) session = cluster.connect() try: - for ksname in ('test1rf', 'test2rf', 'test3rf'): + for ksname in ("test1rf", "test2rf", "test3rf"): if ksname in cluster.metadata.keyspaces: execute_until_pass(session, "DROP KEYSPACE %s" % ksname) - ddl = ''' + ddl = """ CREATE KEYSPACE test3rf - WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '3'}''' + WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': '3'}""" execute_with_long_wait_retry(session, ddl) - ddl = ''' + ddl = """ CREATE KEYSPACE test2rf - WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '2'}''' + WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': '2'}""" execute_with_long_wait_retry(session, ddl) - ddl = ''' + ddl = """ CREATE KEYSPACE test1rf - WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'}''' + WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': '1'}""" execute_with_long_wait_retry(session, ddl) - ddl_3f = ''' + ddl_3f = """ CREATE TABLE test3rf.test ( k int PRIMARY KEY, - v int )''' + v int )""" execute_with_long_wait_retry(session, ddl_3f) - ddl_1f = ''' + ddl_1f = """ CREATE TABLE test1rf.test ( k int PRIMARY KEY, - v int )''' + v int )""" execute_with_long_wait_retry(session, ddl_1f) except Exception: @@ -684,7 +864,7 @@ def setup_keyspace(ipformat=None, protocol_version=None, port=9042): def is_scylla_enterprise(version: Version) -> bool: - return version > Version('2000.1.1') + return version > Version("2000.1.1") def xfail_scylla_version_lt(reason, scylla_version, *args, **kwargs): @@ -693,18 +873,27 @@ def xfail_scylla_version_lt(reason, scylla_version, *args, **kwargs): :param reason: message to fail test with :param scylla_version: str, version from which test supposed to succeed """ - if not (reason.startswith("scylladb/scylladb#") or reason.startswith("scylladb/scylla-enterprise#")): - raise ValueError('reason should start with scylladb/scylladb# or scylladb/scylla-enterprise# to reference issue in scylla repo') + if not ( + reason.startswith("scylladb/scylladb#") + or reason.startswith("scylladb/scylla-enterprise#") + ): + raise ValueError( + "reason should start with scylladb/scylladb# or scylladb/scylla-enterprise# to reference issue in scylla repo" + ) if not isinstance(scylla_version, str): - raise ValueError('scylla_version should be a str') + raise ValueError("scylla_version should be a str") if SCYLLA_VERSION is None: - return pytest.mark.skipif(False, reason="It is just a NoOP Decor, should not skip anything") + return pytest.mark.skipif( + False, reason="It is just a NoOP Decor, should not skip anything" + ) current_version = Version(get_scylla_version(SCYLLA_VERSION)) - return pytest.mark.xfail(current_version < Version(scylla_version), reason=reason, *args, **kwargs) + return pytest.mark.xfail( + current_version < Version(scylla_version), reason=reason, *args, **kwargs + ) def skip_scylla_version_lt(reason, scylla_version): @@ -713,14 +902,21 @@ def skip_scylla_version_lt(reason, scylla_version): :param reason: message explaining why the test is skipped :param scylla_version: str, version from which test supposed to work """ - if not (reason.startswith("scylladb/scylladb#") or reason.startswith("scylladb/scylla-enterprise#")): - raise ValueError('reason should start with scylladb/scylladb# or scylladb/scylla-enterprise# to reference issue in scylla repo') + if not ( + reason.startswith("scylladb/scylladb#") + or reason.startswith("scylladb/scylla-enterprise#") + ): + raise ValueError( + "reason should start with scylladb/scylladb# or scylladb/scylla-enterprise# to reference issue in scylla repo" + ) if not isinstance(scylla_version, str): - raise ValueError('scylla_version should be a str') + raise ValueError("scylla_version should be a str") if SCYLLA_VERSION is None: - return pytest.mark.skipif(False, reason="It is just a NoOP Decor, should not skip anything") + return pytest.mark.skipif( + False, reason="It is just a NoOP Decor, should not skip anything" + ) current_version = Version(get_scylla_version(SCYLLA_VERSION)) @@ -728,7 +924,6 @@ def skip_scylla_version_lt(reason, scylla_version): class UpDownWaiter(object): - def __init__(self, host): self.down_event = Event() self.up_event = Event() @@ -752,6 +947,7 @@ class BasicKeyspaceUnitTestCase(unittest.TestCase): This is basic unit test case that provides various utility methods that can be leveraged for testcase setup and tear down """ + @property def keyspace_name(self): return self.ks_name @@ -770,15 +966,21 @@ def keyspace_table_name(self): @classmethod def drop_keyspace(cls): - execute_with_long_wait_retry(cls.session, "DROP KEYSPACE {0}".format(cls.ks_name)) + execute_with_long_wait_retry( + cls.session, "DROP KEYSPACE {0}".format(cls.ks_name) + ) @classmethod def create_keyspace(cls, rf): - ddl = "CREATE KEYSPACE {0} WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': '{1}'}}".format(cls.ks_name, rf) + ddl = "CREATE KEYSPACE {0} WITH replication = {{'class': 'NetworkTopologyStrategy', 'replication_factor': '{1}'}}".format( + cls.ks_name, rf + ) execute_with_long_wait_retry(cls.session, ddl) @classmethod - def common_setup(cls, rf, keyspace_creation=True, create_class_table=False, **cluster_kwargs): + def common_setup( + cls, rf, keyspace_creation=True, create_class_table=False, **cluster_kwargs + ): cls.cluster = TestCluster(**cluster_kwargs) cls.session = cls.cluster.connect(wait_for_all_pools=True) cls.ks_name = cls.__name__.lower() @@ -787,23 +989,22 @@ def common_setup(cls, rf, keyspace_creation=True, create_class_table=False, **cl cls.cass_version, cls.cql_version = get_server_versions() if create_class_table: - - ddl = ''' + ddl = """ CREATE TABLE {0}.{1} ( k int PRIMARY KEY, - v int )'''.format(cls.ks_name, cls.ks_name) + v int )""".format(cls.ks_name, cls.ks_name) execute_until_pass(cls.session, ddl) def create_function_table(self): - ddl = ''' + ddl = """ CREATE TABLE {0}.{1} ( k int PRIMARY KEY, - v int )'''.format(self.keyspace_name, self.function_table_name) - execute_until_pass(self.session, ddl) + v int )""".format(self.keyspace_name, self.function_table_name) + execute_until_pass(self.session, ddl) def drop_function_table(self): - ddl = "DROP TABLE {0}.{1} ".format(self.keyspace_name, self.function_table_name) - execute_until_pass(self.session, ddl) + ddl = "DROP TABLE {0}.{1} ".format(self.keyspace_name, self.function_table_name) + execute_until_pass(self.session, ddl) class MockLoggingHandler(logging.Handler): @@ -818,18 +1019,18 @@ def emit(self, record): def reset(self): self.messages = { - 'debug': [], - 'info': [], - 'warning': [], - 'error': [], - 'critical': [], + "debug": [], + "info": [], + "warning": [], + "error": [], + "critical": [], } def get_message_count(self, level, sub_string): count = 0 for msg in self.messages.get(level): if sub_string in msg: - count+=1 + count += 1 return count def set_module_name(self, module_name): @@ -853,6 +1054,7 @@ class BasicExistingKeyspaceUnitTestCase(BasicKeyspaceUnitTestCase): """ This is basic unit test defines class level teardown and setup methods. It assumes that keyspace is already defined, or created as part of the test. """ + @classmethod def setUpClass(cls): cls.common_setup(1, keyspace_creation=False) @@ -867,6 +1069,7 @@ class BasicSharedKeyspaceUnitTestCase(BasicKeyspaceUnitTestCase): This is basic unit test case that can be leveraged to scope a keyspace to a specific test class. creates a keyspace named after the testclass with a rf of 1. """ + @classmethod def setUpClass(cls): cls.common_setup(1) @@ -881,6 +1084,7 @@ class BasicSharedKeyspaceUnitTestCaseRF1(BasicSharedKeyspaceUnitTestCase): This is basic unit test case that can be leveraged to scope a keyspace to a specific test class. creates a keyspace named after the testclass with a rf of 1 """ + @classmethod def setUpClass(self): self.common_setup(1, True) @@ -891,6 +1095,7 @@ class BasicSharedKeyspaceUnitTestCaseRF2(BasicSharedKeyspaceUnitTestCase): This is basic unit test case that can be leveraged to scope a keyspace to a specific test class. creates a keyspace named after the test class with a rf of 2, and a table named after the class """ + @classmethod def setUpClass(self): self.common_setup(2) @@ -901,6 +1106,7 @@ class BasicSharedKeyspaceUnitTestCaseRF3(BasicSharedKeyspaceUnitTestCase): This is basic unit test case that can be leveraged to scope a keyspace to a specific test class. creates a keyspace named after the test class with a rf of 3 """ + @classmethod def setUpClass(self): self.common_setup(3) @@ -911,6 +1117,7 @@ class BasicSharedKeyspaceUnitTestCaseRF3WM(BasicSharedKeyspaceUnitTestCase): This is basic unit test case that can be leveraged to scope a keyspace to a specific test class. creates a keyspace named after the test class with a rf of 3 with metrics enabled """ + @classmethod def setUpClass(self): self.common_setup(3, True, True, metrics_enabled=True) @@ -921,12 +1128,13 @@ def tearDownClass(cls): class BasicSharedKeyspaceUnitTestCaseWFunctionTable(BasicSharedKeyspaceUnitTestCase): - """" + """ " This is basic unit test case that can be leveraged to scope a keyspace to a specific test class. creates a keyspace named after the test class with a rf of 3 and a table named after the class the table is scoped to just the unit test and will be removed. """ + def setUp(self): self.create_function_table() @@ -940,6 +1148,7 @@ class BasicSegregatedKeyspaceUnitTestCase(BasicKeyspaceUnitTestCase): It has overhead and should only be used with complex unit test were sharing a keyspace will cause issues. """ + def setUp(self): self.common_setup(1) @@ -953,6 +1162,7 @@ class BasicExistingSegregatedKeyspaceUnitTestCase(BasicKeyspaceUnitTestCase): or created as part of a test. It has some overhead and should only be used when sharing cluster/session is not feasible. """ + def setUp(self): self.common_setup(1, keyspace_creation=False) @@ -968,22 +1178,23 @@ class TestCluster(object): DEFAULT_ALLOW_BETA = ALLOW_BETA_PROTOCOL def __new__(cls, **kwargs): - if 'protocol_version' not in kwargs: - kwargs['protocol_version'] = cls.DEFAULT_PROTOCOL_VERSION - if 'contact_points' not in kwargs: - kwargs['contact_points'] = [cls.DEFAULT_CASSANDRA_IP] - if 'allow_beta_protocol_version' not in kwargs: - kwargs['allow_beta_protocol_version'] = cls.DEFAULT_ALLOW_BETA + if "protocol_version" not in kwargs: + kwargs["protocol_version"] = cls.DEFAULT_PROTOCOL_VERSION + if "contact_points" not in kwargs: + kwargs["contact_points"] = [cls.DEFAULT_CASSANDRA_IP] + if "allow_beta_protocol_version" not in kwargs: + kwargs["allow_beta_protocol_version"] = cls.DEFAULT_ALLOW_BETA return Cluster(**kwargs) + # Subclass of CCMCluster (i.e. ccmlib.cluster.Cluster) which transparently performs # conversion of cassandra.yml directives into something matching the new syntax # introduced by CASSANDRA-15234 class Cassandra41CCMCluster(CCMCluster): __test__ = False - IN_MS_REGEX = re.compile('^(\w+)_in_ms$') - IN_KB_REGEX = re.compile('^(\w+)_in_kb$') - ENABLE_REGEX = re.compile('^enable_(\w+)$') + IN_MS_REGEX = re.compile("^(\w+)_in_ms$") + IN_KB_REGEX = re.compile("^(\w+)_in_kb$") + ENABLE_REGEX = re.compile("^enable_(\w+)$") def _get_config_key(self, k, v): if "." in k: @@ -1009,5 +1220,10 @@ def _get_config_val(self, k, v): return v def set_configuration_options(self, values=None, *args, **kwargs): - new_values = {self._get_config_key(k, str(v)):self._get_config_val(k, str(v)) for (k,v) in values.items()} - super(Cassandra41CCMCluster, self).set_configuration_options(values=new_values, *args, **kwargs) + new_values = { + self._get_config_key(k, str(v)): self._get_config_val(k, str(v)) + for (k, v) in values.items() + } + super(Cassandra41CCMCluster, self).set_configuration_options( + values=new_values, *args, **kwargs + ) diff --git a/tests/integration/cqlengine/connections/test_connection.py b/tests/integration/cqlengine/connections/test_connection.py index 78d5133e63..06858f15f5 100644 --- a/tests/integration/cqlengine/connections/test_connection.py +++ b/tests/integration/cqlengine/connections/test_connection.py @@ -19,11 +19,22 @@ from cassandra.cqlengine.models import Model from cassandra.cqlengine import columns, connection, models from cassandra.cqlengine.management import sync_table -from cassandra.cluster import ExecutionProfile, _clusters_for_shutdown, _ConfigMode, EXEC_PROFILE_DEFAULT +from cassandra.cluster import ( + ExecutionProfile, + _clusters_for_shutdown, + _ConfigMode, + EXEC_PROFILE_DEFAULT, +) from cassandra.policies import RoundRobinPolicy from cassandra.query import dict_factory -from tests.integration import CASSANDRA_IP, PROTOCOL_VERSION, execute_with_long_wait_retry, local, TestCluster +from tests.integration import ( + CASSANDRA_IP, + PROTOCOL_VERSION, + execute_with_long_wait_retry, + local, + TestCluster, +) from tests.integration.cqlengine.base import BaseCassEngTestCase from tests.integration.cqlengine import DEFAULT_KEYSPACE, setup_connection @@ -42,12 +53,18 @@ def tearDown(self): @local def test_connection_setup_with_setup(self): connection.setup(hosts=None, default_keyspace=None) - assert connection.get_connection("default").cluster.metadata.get_host("127.0.0.1") is not None + assert ( + connection.get_connection("default").cluster.metadata.get_host("127.0.0.1") + is not None + ) @local def test_connection_setup_with_default(self): connection.default() - assert connection.get_connection("default").cluster.metadata.get_host("127.0.0.1") is not None + assert ( + connection.get_connection("default").cluster.metadata.get_host("127.0.0.1") + is not None + ) def test_only_one_connection_is_created(self): """ @@ -67,24 +84,31 @@ def test_only_one_connection_is_created(self): class SeveralConnectionsTest(BaseCassEngTestCase): - @classmethod def setUpClass(cls): - connection.unregister_connection('default') - cls.keyspace1 = 'ctest1' - cls.keyspace2 = 'ctest2' + connection.unregister_connection("default") + cls.keyspace1 = "ctest1" + cls.keyspace2 = "ctest2" super(SeveralConnectionsTest, cls).setUpClass() cls.setup_cluster = TestCluster() cls.setup_session = cls.setup_cluster.connect() - ddl = "CREATE KEYSPACE {0} WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': '{1}'}}".format(cls.keyspace1, 1) + ddl = "CREATE KEYSPACE {0} WITH replication = {{'class': 'NetworkTopologyStrategy', 'replication_factor': '{1}'}}".format( + cls.keyspace1, 1 + ) execute_with_long_wait_retry(cls.setup_session, ddl) - ddl = "CREATE KEYSPACE {0} WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': '{1}'}}".format(cls.keyspace2, 1) + ddl = "CREATE KEYSPACE {0} WITH replication = {{'class': 'NetworkTopologyStrategy', 'replication_factor': '{1}'}}".format( + cls.keyspace2, 1 + ) execute_with_long_wait_retry(cls.setup_session, ddl) @classmethod def tearDownClass(cls): - execute_with_long_wait_retry(cls.setup_session, "DROP KEYSPACE {0}".format(cls.keyspace1)) - execute_with_long_wait_retry(cls.setup_session, "DROP KEYSPACE {0}".format(cls.keyspace2)) + execute_with_long_wait_retry( + cls.setup_session, "DROP KEYSPACE {0}".format(cls.keyspace1) + ) + execute_with_long_wait_retry( + cls.setup_session, "DROP KEYSPACE {0}".format(cls.keyspace2) + ) models.DEFAULT_KEYSPACE = DEFAULT_KEYSPACE cls.setup_cluster.shutdown() setup_connection(DEFAULT_KEYSPACE) @@ -141,13 +165,17 @@ def test_connection_with_legacy_settings(self): connection.setup( hosts=[CASSANDRA_IP], default_keyspace=DEFAULT_KEYSPACE, - consistency=ConsistencyLevel.LOCAL_ONE + consistency=ConsistencyLevel.LOCAL_ONE, ) conn = connection.get_connection() assert conn.cluster._config_mode == _ConfigMode.LEGACY def test_connection_from_session_with_execution_profile(self): - cluster = TestCluster(execution_profiles={EXEC_PROFILE_DEFAULT: ExecutionProfile(row_factory=dict_factory)}) + cluster = TestCluster( + execution_profiles={ + EXEC_PROFILE_DEFAULT: ExecutionProfile(row_factory=dict_factory) + } + ) session = cluster.connect() connection.default() connection.set_session(session) @@ -174,23 +202,27 @@ def test_legacy_insert_query(self): connection.setup( hosts=[CASSANDRA_IP], default_keyspace=DEFAULT_KEYSPACE, - consistency=ConsistencyLevel.LOCAL_ONE + consistency=ConsistencyLevel.LOCAL_ONE, ) assert connection.get_connection().cluster._config_mode == _ConfigMode.LEGACY sync_table(ConnectionModel) - ConnectionModel.objects.create(key=0, some_data='text0') - ConnectionModel.objects.create(key=1, some_data='text1') - assert ConnectionModel.objects(key=0)[0].some_data == 'text0' + ConnectionModel.objects.create(key=0, some_data="text0") + ConnectionModel.objects.create(key=1, some_data="text1") + assert ConnectionModel.objects(key=0)[0].some_data == "text0" def test_execution_profile_insert_query(self): - cluster = TestCluster(execution_profiles={EXEC_PROFILE_DEFAULT: ExecutionProfile(row_factory=dict_factory)}) + cluster = TestCluster( + execution_profiles={ + EXEC_PROFILE_DEFAULT: ExecutionProfile(row_factory=dict_factory) + } + ) session = cluster.connect() connection.default() connection.set_session(session) assert connection.get_connection().cluster._config_mode == _ConfigMode.PROFILES sync_table(ConnectionModel) - ConnectionModel.objects.create(key=0, some_data='text0') - ConnectionModel.objects.create(key=1, some_data='text1') - assert ConnectionModel.objects(key=0)[0].some_data == 'text0' + ConnectionModel.objects.create(key=0, some_data="text0") + ConnectionModel.objects.create(key=1, some_data="text1") + assert ConnectionModel.objects(key=0)[0].some_data == "text0" diff --git a/tests/integration/long/test_failure_types.py b/tests/integration/long/test_failure_types.py index beb10f02c0..b83e965669 100644 --- a/tests/integration/long/test_failure_types.py +++ b/tests/integration/long/test_failure_types.py @@ -21,16 +21,31 @@ from cassandra.policies import HostFilterPolicy, RoundRobinPolicy from cassandra import ( - ConsistencyLevel, OperationTimedOut, ReadTimeout, WriteTimeout, ReadFailure, WriteFailure, - FunctionFailure, ProtocolVersion, + ConsistencyLevel, + OperationTimedOut, + ReadTimeout, + WriteTimeout, + ReadFailure, + WriteFailure, + FunctionFailure, + ProtocolVersion, ) from cassandra.cluster import ExecutionProfile, EXEC_PROFILE_DEFAULT from cassandra.concurrent import execute_concurrent_with_args from cassandra.query import SimpleStatement from tests.integration import ( - use_singledc, PROTOCOL_VERSION, get_cluster, setup_keyspace, remove_cluster, - get_node, start_cluster_wait_for_up, requiresmallclockgranularity, - local, CASSANDRA_VERSION, TestCluster) + use_singledc, + PROTOCOL_VERSION, + get_cluster, + setup_keyspace, + remove_cluster, + get_node, + start_cluster_wait_for_up, + requiresmallclockgranularity, + local, + CASSANDRA_VERSION, + TestCluster, +) from tests.integration import requires_java_udf @@ -54,8 +69,8 @@ def setup_module(): ccm_cluster = get_cluster() ccm_cluster.stop() config_options = { - 'tombstone_failure_threshold': 2000, - 'tombstone_warn_threshold': 1000, + "tombstone_failure_threshold": 2000, + "tombstone_warn_threshold": 1000, } ccm_cluster.set_configuration_options(config_options) start_cluster_wait_for_up(ccm_cluster) @@ -72,7 +87,6 @@ def teardown_module(): class ClientExceptionTests(unittest.TestCase): - def setUp(self): """ Test is skipped if run with native protocol version <4 @@ -80,14 +94,14 @@ def setUp(self): if PROTOCOL_VERSION < 4: raise unittest.SkipTest( "Native protocol 4,0+ is required for custom payloads, currently using %r" - % (PROTOCOL_VERSION,)) + % (PROTOCOL_VERSION,) + ) self.cluster = TestCluster() self.session = self.cluster.connect() self.nodes_currently_failing = [] self.node1, self.node2, self.node3 = get_cluster().nodes.values() def tearDown(self): - self.cluster.shutdown() failing_nodes = [] @@ -101,24 +115,44 @@ def execute_helper(self, session, query): return session.execute(query) except OperationTimedOut: ex_type, ex, tb = sys.exc_info() - log.warning("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) + log.warning( + "{0}: {1} Backtrace: {2}".format( + ex_type.__name__, ex, traceback.extract_tb(tb) + ) + ) del tb tries += 1 - raise RuntimeError("Failed to execute query after 100 attempts: {0}".format(query)) + raise RuntimeError( + "Failed to execute query after 100 attempts: {0}".format(query) + ) def execute_concurrent_args_helper(self, session, query, params): tries = 0 while tries < 100: try: - return execute_concurrent_with_args(session, query, params, concurrency=50) - except (ReadTimeout, WriteTimeout, OperationTimedOut, ReadFailure, WriteFailure): + return execute_concurrent_with_args( + session, query, params, concurrency=50 + ) + except ( + ReadTimeout, + WriteTimeout, + OperationTimedOut, + ReadFailure, + WriteFailure, + ): ex_type, ex, tb = sys.exc_info() - log.warning("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) + log.warning( + "{0}: {1} Backtrace: {2}".format( + ex_type.__name__, ex, traceback.extract_tb(tb) + ) + ) del tb tries += 1 - raise RuntimeError("Failed to execute query after 100 attempts: {0}".format(query)) + raise RuntimeError( + "Failed to execute query after 100 attempts: {0}".format(query) + ) def setFailingNodes(self, failing_nodes, keyspace): """ @@ -133,8 +167,11 @@ def setFailingNodes(self, failing_nodes, keyspace): for node in failing_nodes: if node not in self.nodes_currently_failing: node.stop(wait_other_notice=True, gently=False) - node.start(jvm_args=[" -Dcassandra.test.fail_writes_ks=" + keyspace], wait_for_binary_proto=True, - wait_other_notice=True) + node.start( + jvm_args=[" -Dcassandra.test.fail_writes_ks=" + keyspace], + wait_for_binary_proto=True, + wait_other_notice=True, + ) self.nodes_currently_failing.append(node) # Ensure all nodes not on the list, but that are currently set to failing are enabled @@ -144,7 +181,9 @@ def setFailingNodes(self, failing_nodes, keyspace): node.start(wait_for_binary_proto=True, wait_other_notice=True) self.nodes_currently_failing.remove(node) - def _perform_cql_statement(self, text, consistency_level, expected_exception, session=None): + def _perform_cql_statement( + self, text, consistency_level, expected_exception, session=None + ): """ Simple helper method to preform cql statements and check for expected exception @param text CQl statement to execute @@ -187,8 +226,11 @@ def test_write_failures_from_coordinator(self): self._perform_cql_statement( """ CREATE KEYSPACE testksfail - WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '3'} - """, consistency_level=ConsistencyLevel.ALL, expected_exception=None) + WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': '3'} + """, + consistency_level=ConsistencyLevel.ALL, + expected_exception=None, + ) # create table self._perform_cql_statement( @@ -196,7 +238,10 @@ def test_write_failures_from_coordinator(self): CREATE TABLE testksfail.test ( k int PRIMARY KEY, v int ) - """, consistency_level=ConsistencyLevel.ALL, expected_exception=None) + """, + consistency_level=ConsistencyLevel.ALL, + expected_exception=None, + ) # Disable one node failing_nodes = [self.node1] @@ -206,13 +251,19 @@ def test_write_failures_from_coordinator(self): self._perform_cql_statement( """ INSERT INTO testksfail.test (k, v) VALUES (1, 0 ) - """, consistency_level=ConsistencyLevel.ALL, expected_exception=WriteFailure) + """, + consistency_level=ConsistencyLevel.ALL, + expected_exception=WriteFailure, + ) # We have two nodes left so a write with consistency level of QUORUM should complete as expected self._perform_cql_statement( """ INSERT INTO testksfail.test (k, v) VALUES (1, 0 ) - """, consistency_level=ConsistencyLevel.QUORUM, expected_exception=None) + """, + consistency_level=ConsistencyLevel.QUORUM, + expected_exception=None, + ) failing_nodes = [] @@ -223,7 +274,10 @@ def test_write_failures_from_coordinator(self): self._perform_cql_statement( """ DROP KEYSPACE testksfail - """, consistency_level=ConsistencyLevel.ANY, expected_exception=None) + """, + consistency_level=ConsistencyLevel.ANY, + expected_exception=None, + ) def test_tombstone_overflow_read_failure(self): """ @@ -250,26 +304,39 @@ def test_tombstone_overflow_read_failure(self): k int, v0 int, v1 int, PRIMARY KEY (k,v0)) - """, consistency_level=ConsistencyLevel.ALL, expected_exception=None) + """, + consistency_level=ConsistencyLevel.ALL, + expected_exception=None, + ) - statement = self.session.prepare("INSERT INTO test3rf.test2 (k, v0,v1) VALUES (1,?,1)") + statement = self.session.prepare( + "INSERT INTO test3rf.test2 (k, v0,v1) VALUES (1,?,1)" + ) parameters = [(x,) for x in range(3000)] self.execute_concurrent_args_helper(self.session, statement, parameters) - column = 'v1' if CASSANDRA_VERSION < Version('4.0') else '' - statement = self.session.prepare("DELETE {} FROM test3rf.test2 WHERE k = 1 AND v0 =?".format(column)) + column = "v1" if CASSANDRA_VERSION < Version("4.0") else "" + statement = self.session.prepare( + "DELETE {} FROM test3rf.test2 WHERE k = 1 AND v0 =?".format(column) + ) parameters = [(x,) for x in range(2001)] self.execute_concurrent_args_helper(self.session, statement, parameters) self._perform_cql_statement( """ SELECT * FROM test3rf.test2 WHERE k = 1 - """, consistency_level=ConsistencyLevel.ALL, expected_exception=ReadFailure) + """, + consistency_level=ConsistencyLevel.ALL, + expected_exception=ReadFailure, + ) self._perform_cql_statement( """ DROP TABLE test3rf.test2; - """, consistency_level=ConsistencyLevel.ALL, expected_exception=None) + """, + consistency_level=ConsistencyLevel.ALL, + expected_exception=None, + ) @requires_java_udf def test_user_function_failure(self): @@ -294,35 +361,53 @@ def test_user_function_failure(self): RETURNS NULL ON NULL INPUT RETURNS double LANGUAGE java AS 'throw new RuntimeException("failure");'; - """, consistency_level=ConsistencyLevel.ALL, expected_exception=None) + """, + consistency_level=ConsistencyLevel.ALL, + expected_exception=None, + ) # Create test table self._perform_cql_statement( """ CREATE TABLE test3rf.d (k int PRIMARY KEY , d double); - """, consistency_level=ConsistencyLevel.ALL, expected_exception=None) + """, + consistency_level=ConsistencyLevel.ALL, + expected_exception=None, + ) # Insert some values self._perform_cql_statement( """ INSERT INTO test3rf.d (k,d) VALUES (0, 5.12); - """, consistency_level=ConsistencyLevel.ALL, expected_exception=None) + """, + consistency_level=ConsistencyLevel.ALL, + expected_exception=None, + ) # Run the function expect a function failure exception self._perform_cql_statement( """ SELECT test_failure(d) FROM test3rf.d WHERE k = 0; - """, consistency_level=ConsistencyLevel.ALL, expected_exception=FunctionFailure) + """, + consistency_level=ConsistencyLevel.ALL, + expected_exception=FunctionFailure, + ) self._perform_cql_statement( """ DROP FUNCTION test3rf.test_failure; - """, consistency_level=ConsistencyLevel.ALL, expected_exception=None) + """, + consistency_level=ConsistencyLevel.ALL, + expected_exception=None, + ) self._perform_cql_statement( """ DROP TABLE test3rf.d; - """, consistency_level=ConsistencyLevel.ALL, expected_exception=None) + """, + consistency_level=ConsistencyLevel.ALL, + expected_exception=None, + ) @requiresmallclockgranularity @@ -346,10 +431,10 @@ def setUp(self): self.control_connection_host_number = 1 self.node_to_stop = get_node(self.control_connection_host_number) - ddl = ''' + ddl = """ CREATE TABLE test3rf.timeout ( k int PRIMARY KEY, - v int )''' + v int )""" self.session.execute(ddl) self.node_to_stop.pause() @@ -378,7 +463,9 @@ def test_async_timeouts(self): """ # Because node1 is stopped these statements will all timeout - ss = SimpleStatement('SELECT * FROM test3rf.test', consistency_level=ConsistencyLevel.ALL) + ss = SimpleStatement( + "SELECT * FROM test3rf.test", consistency_level=ConsistencyLevel.ALL + ) # Test with default timeout (should be 10) start_time = time.time() @@ -386,10 +473,10 @@ def test_async_timeouts(self): with pytest.raises(OperationTimedOut): future.result() end_time = time.time() - total_time = end_time-start_time + total_time = end_time - start_time expected_time = self.cluster.profile_manager.default.request_timeout # check timeout and ensure it's within a reasonable range - assert expected_time == pytest.approx(total_time, abs=.05) + assert expected_time == pytest.approx(total_time, abs=0.05) # Test with user defined timeout (Should be 1) expected_time = 1 @@ -403,8 +490,8 @@ def test_async_timeouts(self): with pytest.raises(OperationTimedOut): future.result() end_time = time.time() - total_time = end_time-start_time + total_time = end_time - start_time # check timeout and ensure it's within a reasonable range - assert expected_time == pytest.approx(total_time, abs=.05) + assert expected_time == pytest.approx(total_time, abs=0.05) assert mock_errorback.called assert not mock_callback.called diff --git a/tests/integration/long/test_policies.py b/tests/integration/long/test_policies.py index ab8d125ab1..26c5f867d9 100644 --- a/tests/integration/long/test_policies.py +++ b/tests/integration/long/test_policies.py @@ -22,15 +22,16 @@ def setup_module(): - use_cluster('test_cluster', [4]) + use_cluster("test_cluster", [4]) class RetryPolicyTests(unittest.TestCase): - @classmethod def tearDownClass(cls): cluster = get_cluster() - cluster.start(wait_for_binary_proto=True, wait_other_notice=True) # make sure other nodes are restarted + cluster.start( + wait_for_binary_proto=True, wait_other_notice=True + ) # make sure other nodes are restarted def test_should_rethrow_on_unvailable_with_default_policy_if_cas(self): """ @@ -42,15 +43,24 @@ def test_should_rethrow_on_unvailable_with_default_policy_if_cas(self): @test_category policy """ - ep = ExecutionProfile(consistency_level=ConsistencyLevel.ALL, - serial_consistency_level=ConsistencyLevel.SERIAL) + ep = ExecutionProfile( + consistency_level=ConsistencyLevel.ALL, + serial_consistency_level=ConsistencyLevel.SERIAL, + ) cluster = TestCluster(execution_profiles={EXEC_PROFILE_DEFAULT: ep}) session = cluster.connect() - session.execute("CREATE KEYSPACE test_retry_policy_cas WITH replication = {'class':'SimpleStrategy','replication_factor': 3};") - session.execute("CREATE TABLE test_retry_policy_cas.t (id int PRIMARY KEY, data text);") - session.execute('INSERT INTO test_retry_policy_cas.t ("id", "data") VALUES (%(0)s, %(1)s)', {'0': 42, '1': 'testing'}) + session.execute( + "CREATE KEYSPACE test_retry_policy_cas WITH replication = {'class':'NetworkTopologyStrategy','replication_factor': 3};" + ) + session.execute( + "CREATE TABLE test_retry_policy_cas.t (id int PRIMARY KEY, data text);" + ) + session.execute( + 'INSERT INTO test_retry_policy_cas.t ("id", "data") VALUES (%(0)s, %(1)s)', + {"0": 42, "1": "testing"}, + ) get_node(2).stop() get_node(4).stop() @@ -60,7 +70,9 @@ def test_should_rethrow_on_unvailable_with_default_policy_if_cas(self): # after fix: cassandra.Unavailable (expected since replicas are down) with pytest.raises(Unavailable) as cm: - session.execute("update test_retry_policy_cas.t set data = 'staging' where id = 42 if data ='testing'") + session.execute( + "update test_retry_policy_cas.t set data = 'staging' where id = 42 if data ='testing'" + ) exception = cm.value assert exception.consistency == ConsistencyLevel.SERIAL diff --git a/tests/integration/long/test_schema.py b/tests/integration/long/test_schema.py index 3b4dcd33d5..1b96af2f9d 100644 --- a/tests/integration/long/test_schema.py +++ b/tests/integration/long/test_schema.py @@ -31,7 +31,6 @@ def setup_module(): class SchemaTests(unittest.TestCase): - @classmethod def setup_class(cls): cls.cluster = TestCluster() @@ -57,11 +56,15 @@ def test_recreates(self): log.debug(drop) execute_until_pass(session, drop) - create = "CREATE KEYSPACE {0} WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': 3}}".format(keyspace) + create = "CREATE KEYSPACE {0} WITH replication = {{'class': 'NetworkTopologyStrategy', 'replication_factor': 3}}".format( + keyspace + ) log.debug(create) execute_until_pass(session, create) - create = "CREATE TABLE {0}.cf (k int PRIMARY KEY, i int)".format(keyspace) + create = "CREATE TABLE {0}.cf (k int PRIMARY KEY, i int)".format( + keyspace + ) log.debug(create) execute_until_pass(session, create) @@ -82,11 +85,24 @@ def test_for_schema_disagreements_different_keyspaces(self): session = self.session for i in range(30): - execute_until_pass(session, "CREATE KEYSPACE test_{0} WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': 1}}".format(i)) - execute_until_pass(session, "CREATE TABLE test_{0}.cf (key int PRIMARY KEY, value int)".format(i)) + execute_until_pass( + session, + "CREATE KEYSPACE test_{0} WITH replication = {{'class': 'NetworkTopologyStrategy', 'replication_factor': 1}}".format( + i + ), + ) + execute_until_pass( + session, + "CREATE TABLE test_{0}.cf (key int PRIMARY KEY, value int)".format(i), + ) for j in range(100): - execute_until_pass(session, "INSERT INTO test_{0}.cf (key, value) VALUES ({1}, {1})".format(i, j)) + execute_until_pass( + session, + "INSERT INTO test_{0}.cf (key, value) VALUES ({1}, {1})".format( + i, j + ), + ) execute_until_pass(session, "DROP KEYSPACE test_{0}".format(i)) @@ -100,15 +116,26 @@ def test_for_schema_disagreements_same_keyspace(self): for i in range(30): try: - execute_until_pass(session, "CREATE KEYSPACE test WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1}") + execute_until_pass( + session, + "CREATE KEYSPACE test WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': 1}", + ) except AlreadyExists: execute_until_pass(session, "DROP KEYSPACE test") - execute_until_pass(session, "CREATE KEYSPACE test WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1}") + execute_until_pass( + session, + "CREATE KEYSPACE test WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': 1}", + ) - execute_until_pass(session, "CREATE TABLE test.cf (key int PRIMARY KEY, value int)") + execute_until_pass( + session, "CREATE TABLE test.cf (key int PRIMARY KEY, value int)" + ) for j in range(100): - execute_until_pass(session, "INSERT INTO test.cf (key, value) VALUES ({0}, {0})".format(j)) + execute_until_pass( + session, + "INSERT INTO test.cf (key, value) VALUES ({0}, {0})".format(j), + ) execute_until_pass(session, "DROP KEYSPACE test") cluster.shutdown() @@ -132,22 +159,34 @@ def test_for_schema_disagreement_attribute(self): cluster = TestCluster(max_schema_agreement_wait=0.001) session = cluster.connect(wait_for_all_pools=True) - rs = session.execute("CREATE KEYSPACE test_schema_disagreement WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 3}") + rs = session.execute( + "CREATE KEYSPACE test_schema_disagreement WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': 3}" + ) self.check_and_wait_for_agreement(session, rs, False) - rs = session.execute(SimpleStatement("CREATE TABLE test_schema_disagreement.cf (key int PRIMARY KEY, value int)", - consistency_level=ConsistencyLevel.ALL)) + rs = session.execute( + SimpleStatement( + "CREATE TABLE test_schema_disagreement.cf (key int PRIMARY KEY, value int)", + consistency_level=ConsistencyLevel.ALL, + ) + ) self.check_and_wait_for_agreement(session, rs, False) rs = session.execute("DROP KEYSPACE test_schema_disagreement") self.check_and_wait_for_agreement(session, rs, False) cluster.shutdown() - + # These should have schema agreement cluster = TestCluster(max_schema_agreement_wait=100) session = cluster.connect() - rs = session.execute("CREATE KEYSPACE test_schema_disagreement WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 3}") + rs = session.execute( + "CREATE KEYSPACE test_schema_disagreement WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': 3}" + ) self.check_and_wait_for_agreement(session, rs, True) - rs = session.execute(SimpleStatement("CREATE TABLE test_schema_disagreement.cf (key int PRIMARY KEY, value int)", - consistency_level=ConsistencyLevel.ALL)) + rs = session.execute( + SimpleStatement( + "CREATE TABLE test_schema_disagreement.cf (key int PRIMARY KEY, value int)", + consistency_level=ConsistencyLevel.ALL, + ) + ) self.check_and_wait_for_agreement(session, rs, True) rs = session.execute("DROP KEYSPACE test_schema_disagreement") self.check_and_wait_for_agreement(session, rs, True) diff --git a/tests/integration/long/test_ssl.py b/tests/integration/long/test_ssl.py index 56dc6a5c2d..b7c2568ac3 100644 --- a/tests/integration/long/test_ssl.py +++ b/tests/integration/long/test_ssl.py @@ -23,13 +23,19 @@ from OpenSSL import SSL, crypto from tests.integration import ( - get_cluster, remove_cluster, use_single_node, start_cluster_wait_for_up, EVENT_LOOP_MANAGER, TestCluster + get_cluster, + remove_cluster, + use_single_node, + start_cluster_wait_for_up, + EVENT_LOOP_MANAGER, + TestCluster, ) import pytest -if not hasattr(ssl, 'match_hostname'): +if not hasattr(ssl, "match_hostname"): try: from ssl import match_hostname + ssl.match_hostname = match_hostname except ImportError: pass # tests will fail @@ -40,25 +46,28 @@ # Server keystore trust store locations SERVER_KEYSTORE_PATH = os.path.abspath("tests/integration/long/ssl/127.0.0.1.keystore") -SERVER_TRUSTSTORE_PATH = os.path.abspath("tests/integration/long/ssl/cassandra.truststore") +SERVER_TRUSTSTORE_PATH = os.path.abspath( + "tests/integration/long/ssl/cassandra.truststore" +) # Client specific keys/certs CLIENT_CA_CERTS = os.path.abspath("tests/integration/long/ssl/rootCa.crt") DRIVER_KEYFILE = os.path.abspath("tests/integration/long/ssl/client.key") -DRIVER_KEYFILE_ENCRYPTED = os.path.abspath("tests/integration/long/ssl/client_encrypted.key") +DRIVER_KEYFILE_ENCRYPTED = os.path.abspath( + "tests/integration/long/ssl/client_encrypted.key" +) DRIVER_CERTFILE = os.path.abspath("tests/integration/long/ssl/client.crt_signed") DRIVER_CERTFILE_BAD = os.path.abspath("tests/integration/long/ssl/client_bad.key") USES_PYOPENSSL = "twisted" in EVENT_LOOP_MANAGER or "eventlet" in EVENT_LOOP_MANAGER if "twisted" in EVENT_LOOP_MANAGER: import OpenSSL + ssl_version = OpenSSL.SSL.TLS_METHOD - verify_certs = {'cert_reqs': SSL.VERIFY_PEER, - 'check_hostname': True} + verify_certs = {"cert_reqs": SSL.VERIFY_PEER, "check_hostname": True} else: ssl_version = ssl.PROTOCOL_TLS - verify_certs = {'cert_reqs': ssl.CERT_REQUIRED, - 'check_hostname': True} + verify_certs = {"cert_reqs": ssl.CERT_REQUIRED, "check_hostname": True} def verify_callback(connection, x509, errnum, errdepth, ok): @@ -76,62 +85,69 @@ def setup_cluster_ssl(client_auth=False): ccm_cluster.stop() # Configure ccm to use ssl. - config_options = {'client_encryption_options': {'enabled': True, - 'keystore': SERVER_KEYSTORE_PATH, - 'keystore_password': DEFAULT_PASSWORD}} + config_options = { + "client_encryption_options": { + "enabled": True, + "keystore": SERVER_KEYSTORE_PATH, + "keystore_password": DEFAULT_PASSWORD, + } + } - if(client_auth): - client_encyrption_options = config_options['client_encryption_options'] - client_encyrption_options['require_client_auth'] = True - client_encyrption_options['truststore'] = SERVER_TRUSTSTORE_PATH - client_encyrption_options['truststore_password'] = DEFAULT_PASSWORD + if client_auth: + client_encyrption_options = config_options["client_encryption_options"] + client_encyrption_options["require_client_auth"] = True + client_encyrption_options["truststore"] = SERVER_TRUSTSTORE_PATH + client_encyrption_options["truststore_password"] = DEFAULT_PASSWORD ccm_cluster.set_configuration_options(config_options) start_cluster_wait_for_up(ccm_cluster) def validate_ssl_options(**kwargs): - ssl_options = kwargs.get('ssl_options', None) - ssl_context = kwargs.get('ssl_context', None) - hostname = kwargs.get('hostname', '127.0.0.1') - - # find absolute path to client CA_CERTS - tries = 0 - while True: - if tries > 5: - raise RuntimeError("Failed to connect to SSL cluster after 5 attempts") - try: - cluster = TestCluster( - contact_points=[DefaultEndPoint(hostname)], - ssl_options=ssl_options, - ssl_context=ssl_context + ssl_options = kwargs.get("ssl_options", None) + ssl_context = kwargs.get("ssl_context", None) + hostname = kwargs.get("hostname", "127.0.0.1") + + # find absolute path to client CA_CERTS + tries = 0 + while True: + if tries > 5: + raise RuntimeError("Failed to connect to SSL cluster after 5 attempts") + try: + cluster = TestCluster( + contact_points=[DefaultEndPoint(hostname)], + ssl_options=ssl_options, + ssl_context=ssl_context, + ) + session = cluster.connect(wait_for_all_pools=True) + break + except Exception: + ex_type, ex, tb = sys.exc_info() + log.warning( + "{0}: {1} Backtrace: {2}".format( + ex_type.__name__, ex, traceback.extract_tb(tb) ) - session = cluster.connect(wait_for_all_pools=True) - break - except Exception: - ex_type, ex, tb = sys.exc_info() - log.warning("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) - del tb - tries += 1 + ) + del tb + tries += 1 - # attempt a few simple commands. - insert_keyspace = """CREATE KEYSPACE ssltest - WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '3'} + # attempt a few simple commands. + insert_keyspace = """CREATE KEYSPACE ssltest + WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': '3'} """ - statement = SimpleStatement(insert_keyspace) - statement.consistency_level = 3 - session.execute(statement) + statement = SimpleStatement(insert_keyspace) + statement.consistency_level = 3 + session.execute(statement) - drop_keyspace = "DROP KEYSPACE ssltest" - statement = SimpleStatement(drop_keyspace) - statement.consistency_level = ConsistencyLevel.ANY - session.execute(statement) + drop_keyspace = "DROP KEYSPACE ssltest" + statement = SimpleStatement(drop_keyspace) + statement.consistency_level = ConsistencyLevel.ANY + session.execute(statement) - cluster.shutdown() + cluster.shutdown() class SSLConnectionTests(unittest.TestCase): - @classmethod def setUpClass(cls): setup_cluster_ssl() @@ -159,7 +175,7 @@ def test_can_connect_with_ssl_ca(self): """ # find absolute path to client CA_CERTS - ssl_options = {'ca_certs': CLIENT_CA_CERTS,'ssl_version': ssl_version} + ssl_options = {"ca_certs": CLIENT_CA_CERTS, "ssl_version": ssl_version} validate_ssl_options(ssl_options=ssl_options) def test_can_connect_with_ssl_long_running(self): @@ -175,8 +191,7 @@ def test_can_connect_with_ssl_long_running(self): # find absolute path to client CA_CERTS abs_path_ca_cert_path = os.path.abspath(CLIENT_CA_CERTS) - ssl_options = {'ca_certs': abs_path_ca_cert_path, - 'ssl_version': ssl_version} + ssl_options = {"ca_certs": abs_path_ca_cert_path, "ssl_version": ssl_version} tries = 0 while True: if tries > 5: @@ -187,7 +202,11 @@ def test_can_connect_with_ssl_long_running(self): break except Exception: ex_type, ex, tb = sys.exc_info() - log.warning("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) + log.warning( + "{0}: {1} Backtrace: {2}".format( + ex_type.__name__, ex, traceback.extract_tb(tb) + ) + ) del tb tries += 1 @@ -213,15 +232,13 @@ def test_can_connect_with_ssl_ca_host_match(self): @test_category connection:ssl """ - ssl_options = {'ca_certs': CLIENT_CA_CERTS, - 'ssl_version': ssl_version} + ssl_options = {"ca_certs": CLIENT_CA_CERTS, "ssl_version": ssl_version} ssl_options.update(verify_certs) validate_ssl_options(ssl_options=ssl_options) class SSLConnectionAuthTests(unittest.TestCase): - @classmethod def setUpClass(cls): setup_cluster_ssl(client_auth=True) @@ -246,10 +263,12 @@ def test_can_connect_with_ssl_client_auth(self): @test_category connection:ssl """ - ssl_options = {'ca_certs': CLIENT_CA_CERTS, - 'ssl_version': ssl_version, - 'keyfile': DRIVER_KEYFILE, - 'certfile': DRIVER_CERTFILE} + ssl_options = { + "ca_certs": CLIENT_CA_CERTS, + "ssl_version": ssl_version, + "keyfile": DRIVER_KEYFILE, + "certfile": DRIVER_CERTFILE, + } validate_ssl_options(ssl_options=ssl_options) def test_can_connect_with_ssl_client_auth_host_name(self): @@ -267,10 +286,12 @@ def test_can_connect_with_ssl_client_auth_host_name(self): @test_category connection:ssl """ - ssl_options = {'ca_certs': CLIENT_CA_CERTS, - 'ssl_version': ssl_version, - 'keyfile': DRIVER_KEYFILE, - 'certfile': DRIVER_CERTFILE} + ssl_options = { + "ca_certs": CLIENT_CA_CERTS, + "ssl_version": ssl_version, + "keyfile": DRIVER_KEYFILE, + "certfile": DRIVER_CERTFILE, + } ssl_options.update(verify_certs) validate_ssl_options(ssl_options=ssl_options) @@ -288,8 +309,9 @@ def test_cannot_connect_without_client_auth(self): @test_category connection:ssl """ - cluster = TestCluster(ssl_options={'ca_certs': CLIENT_CA_CERTS, - 'ssl_version': ssl_version}) + cluster = TestCluster( + ssl_options={"ca_certs": CLIENT_CA_CERTS, "ssl_version": ssl_version} + ) with pytest.raises(NoHostAvailable): cluster.connect() @@ -309,18 +331,22 @@ def test_cannot_connect_with_bad_client_auth(self): @test_category connection:ssl """ - ssl_options = {'ca_certs': CLIENT_CA_CERTS, - 'ssl_version': ssl_version, - 'keyfile': DRIVER_KEYFILE} + ssl_options = { + "ca_certs": CLIENT_CA_CERTS, + "ssl_version": ssl_version, + "keyfile": DRIVER_KEYFILE, + } if not USES_PYOPENSSL: # I don't set the bad certfile for pyopenssl because it hangs - ssl_options['certfile'] = DRIVER_CERTFILE_BAD + ssl_options["certfile"] = DRIVER_CERTFILE_BAD cluster = TestCluster( - ssl_options={'ca_certs': CLIENT_CA_CERTS, - 'ssl_version': ssl_version, - 'keyfile': DRIVER_KEYFILE} + ssl_options={ + "ca_certs": CLIENT_CA_CERTS, + "ssl_version": ssl_version, + "keyfile": DRIVER_KEYFILE, + } ) with pytest.raises(NoHostAvailable): @@ -328,18 +354,19 @@ def test_cannot_connect_with_bad_client_auth(self): cluster.shutdown() def test_cannot_connect_with_invalid_hostname(self): - ssl_options = {'ca_certs': CLIENT_CA_CERTS, - 'ssl_version': ssl_version, - 'keyfile': DRIVER_KEYFILE, - 'certfile': DRIVER_CERTFILE} + ssl_options = { + "ca_certs": CLIENT_CA_CERTS, + "ssl_version": ssl_version, + "keyfile": DRIVER_KEYFILE, + "certfile": DRIVER_CERTFILE, + } ssl_options.update(verify_certs) with pytest.raises(Exception): - validate_ssl_options(ssl_options=ssl_options, hostname='localhost') + validate_ssl_options(ssl_options=ssl_options, hostname="localhost") class SSLSocketErrorTests(unittest.TestCase): - @classmethod def setUpClass(cls): setup_cluster_ssl() @@ -360,24 +387,26 @@ def test_ssl_want_write_errors_are_retried(self): @test_category connection:ssl """ - ssl_options = {'ca_certs': CLIENT_CA_CERTS, - 'ssl_version': ssl_version} + ssl_options = {"ca_certs": CLIENT_CA_CERTS, "ssl_version": ssl_version} cluster = TestCluster(ssl_options=ssl_options) session = cluster.connect(wait_for_all_pools=True) try: - session.execute('drop keyspace ssl_error_test') + session.execute("drop keyspace ssl_error_test") except: pass session.execute( - "CREATE KEYSPACE ssl_error_test WITH replication = {'class':'SimpleStrategy','replication_factor':1};") - session.execute("CREATE TABLE ssl_error_test.big_text (id uuid PRIMARY KEY, data text);") + "CREATE KEYSPACE ssl_error_test WITH replication = {'class':'NetworkTopologyStrategy','replication_factor':1};" + ) + session.execute( + "CREATE TABLE ssl_error_test.big_text (id uuid PRIMARY KEY, data text);" + ) - params = { - '0': uuid.uuid4(), - '1': "0" * int(math.pow(10, 7)) - } + params = {"0": uuid.uuid4(), "1": "0" * int(math.pow(10, 7))} - session.execute('INSERT INTO ssl_error_test.big_text ("id", "data") VALUES (%(0)s, %(1)s)', params) + session.execute( + 'INSERT INTO ssl_error_test.big_text ("id", "data") VALUES (%(0)s, %(1)s)', + params, + ) class SSLConnectionWithSSLContextTests(unittest.TestCase): @@ -429,14 +458,18 @@ def test_can_connect_with_ssl_client_auth_password_private_key(self): ssl_context = SSL.Context(SSL.TLS_CLIENT_METHOD) ssl_context.use_certificate_file(abs_driver_certfile) with open(abs_driver_keyfile) as keyfile: - key = crypto.load_privatekey(crypto.FILETYPE_PEM, keyfile.read(), b'cassandra') + key = crypto.load_privatekey( + crypto.FILETYPE_PEM, keyfile.read(), b"cassandra" + ) ssl_context.use_privatekey(key) ssl_context.set_verify(SSL.VERIFY_NONE, verify_callback) else: ssl_context = ssl.SSLContext(ssl_version) - ssl_context.load_cert_chain(certfile=abs_driver_certfile, - keyfile=abs_driver_keyfile, - password="cassandra") + ssl_context.load_cert_chain( + certfile=abs_driver_certfile, + keyfile=abs_driver_keyfile, + password="cassandra", + ) ssl_context.verify_mode = ssl.CERT_NONE validate_ssl_options(ssl_context=ssl_context, ssl_options=ssl_options) @@ -450,7 +483,9 @@ def test_can_connect_with_ssl_context_ca_host_match(self): ssl_context = SSL.Context(SSL.TLS_CLIENT_METHOD) ssl_context.use_certificate_file(DRIVER_CERTFILE) with open(DRIVER_KEYFILE_ENCRYPTED) as keyfile: - key = crypto.load_privatekey(crypto.FILETYPE_PEM, keyfile.read(), b'cassandra') + key = crypto.load_privatekey( + crypto.FILETYPE_PEM, keyfile.read(), b"cassandra" + ) ssl_context.use_privatekey(key) ssl_context.load_verify_locations(CLIENT_CA_CERTS) ssl_options["check_hostname"] = True @@ -473,7 +508,9 @@ def test_cannot_connect_ssl_context_with_invalid_hostname(self): ssl_context = SSL.Context(SSL.TLS_CLIENT_METHOD) ssl_context.use_certificate_file(DRIVER_CERTFILE) with open(DRIVER_KEYFILE_ENCRYPTED) as keyfile: - key = crypto.load_privatekey(crypto.FILETYPE_PEM, keyfile.read(), b"cassandra") + key = crypto.load_privatekey( + crypto.FILETYPE_PEM, keyfile.read(), b"cassandra" + ) ssl_context.use_privatekey(key) ssl_context.load_verify_locations(CLIENT_CA_CERTS) ssl_options["check_hostname"] = True @@ -489,7 +526,9 @@ def test_cannot_connect_ssl_context_with_invalid_hostname(self): ssl_context.verify_mode = ssl.CERT_REQUIRED ssl_options["check_hostname"] = True with pytest.raises(Exception): - validate_ssl_options(ssl_context=ssl_context, ssl_options=ssl_options, hostname="localhost") + validate_ssl_options( + ssl_context=ssl_context, ssl_options=ssl_options, hostname="localhost" + ) @unittest.skipIf(USES_PYOPENSSL, "This test is for the built-in ssl.Context") def test_can_connect_with_sslcontext_default_context(self): diff --git a/tests/integration/long/utils.py b/tests/integration/long/utils.py index 93464df8ff..1441ba1025 100644 --- a/tests/integration/long/utils.py +++ b/tests/integration/long/utils.py @@ -19,63 +19,83 @@ from collections import defaultdict from packaging.version import Version -from tests.integration import (get_node, get_cluster, wait_for_node_socket, - CASSANDRA_VERSION) +from tests.integration import ( + get_node, + get_cluster, + wait_for_node_socket, + CASSANDRA_VERSION, +) -IP_FORMAT = '127.0.0.%s' +IP_FORMAT = "127.0.0.%s" log = logging.getLogger(__name__) -class CoordinatorStats(): - +class CoordinatorStats: def __init__(self): self.coordinator_counts = defaultdict(int) def add_coordinator(self, future): - log.debug('adding coordinator from {}'.format(future)) + log.debug("adding coordinator from {}".format(future)) future.result() coordinator = future._current_host.address self.coordinator_counts[coordinator] += 1 if future._errors: - log.error('future._errors: %s', future._errors) + log.error("future._errors: %s", future._errors) def reset_counts(self): self.coordinator_counts = defaultdict(int) def get_query_count(self, node): - ip = '127.0.0.%d' % node + ip = "127.0.0.%d" % node return self.coordinator_counts[ip] def assert_query_count_equals(self, node, expected): - ip = '127.0.0.%d' % node + ip = "127.0.0.%d" % node if self.get_query_count(node) != expected: - pytest.fail('Expected %d queries to %s, but got %d. Query counts: %s' % ( - expected, ip, self.coordinator_counts[ip], dict(self.coordinator_counts))) - - -def create_schema(cluster, session, keyspace, simple_strategy=True, - replication_factor=1, replication_strategy=None): + pytest.fail( + "Expected %d queries to %s, but got %d. Query counts: %s" + % ( + expected, + ip, + self.coordinator_counts[ip], + dict(self.coordinator_counts), + ) + ) + + +def create_schema( + cluster, + session, + keyspace, + simple_strategy=True, + replication_factor=1, + replication_strategy=None, +): if keyspace in cluster.metadata.keyspaces.keys(): - session.execute('DROP KEYSPACE %s' % keyspace, timeout=20) + session.execute("DROP KEYSPACE %s" % keyspace, timeout=20) if simple_strategy: - ddl = "CREATE KEYSPACE %s WITH replication" \ - " = {'class': 'SimpleStrategy', 'replication_factor': '%s'}" + ddl = ( + "CREATE KEYSPACE %s WITH replication" + " = {'class': 'NetworkTopologyStrategy', 'replication_factor': '%s'}" + ) session.execute(ddl % (keyspace, replication_factor), timeout=10) else: if not replication_strategy: - raise Exception('replication_strategy is not set') + raise Exception("replication_strategy is not set") - ddl = "CREATE KEYSPACE %s" \ - " WITH replication = { 'class' : 'NetworkTopologyStrategy', %s }" + ddl = ( + "CREATE KEYSPACE %s" + " WITH replication = { 'class' : 'NetworkTopologyStrategy', %s }" + ) session.execute(ddl % (keyspace, str(replication_strategy)[1:-1]), timeout=10) - ddl = 'CREATE TABLE %s.cf (k int PRIMARY KEY, i int)' + ddl = "CREATE TABLE %s.cf (k int PRIMARY KEY, i int)" session.execute(ddl % keyspace, timeout=10) - session.execute('USE %s' % keyspace) + session.execute("USE %s" % keyspace) def start(node): @@ -102,14 +122,12 @@ def decommission(node): def bootstrap(node, data_center=None, token=None): - log.debug('called bootstrap(' - 'node={node}, data_center={data_center}, ' - 'token={token})') + log.debug("called bootstrap(node={node}, data_center={data_center}, token={token})") cluster = get_cluster() # for now assumes cluster has at least one node node_type = type(next(iter(cluster.nodes.values()))) node_instance = node_type( - 'node%s' % node, + "node%s" % node, cluster, auto_bootstrap=False, thrift_interface=(IP_FORMAT % node, 9160), @@ -117,25 +135,25 @@ def bootstrap(node, data_center=None, token=None): binary_interface=(IP_FORMAT % node, 9042), jmx_port=str(7000 + 100 * node), remote_debug_port=0, - initial_token=token if token else node * 10 + initial_token=token if token else node * 10, ) cluster.add(node_instance, is_seed=False, data_center=data_center) try: node_instance.start() except Exception as e0: - log.debug('failed 1st bootstrap attempt with: \n{}'.format(e0)) + log.debug("failed 1st bootstrap attempt with: \n{}".format(e0)) # Try only twice try: node_instance.start() except Exception as e1: - log.debug('failed 2nd bootstrap attempt with: \n{}'.format(e1)) - log.error('Added node failed to start twice.') + log.debug("failed 2nd bootstrap attempt with: \n{}".format(e1)) + log.error("Added node failed to start twice.") raise e1 def ring(node): - get_node(node).nodetool('ring') + get_node(node).nodetool("ring") def wait_for_up(cluster, node): diff --git a/tests/integration/simulacron/test_empty_column.py b/tests/integration/simulacron/test_empty_column.py index 2dbf3985ad..c7d1d347a0 100644 --- a/tests/integration/simulacron/test_empty_column.py +++ b/tests/integration/simulacron/test_empty_column.py @@ -17,8 +17,12 @@ from cassandra import ProtocolVersion from cassandra.cluster import Cluster, EXEC_PROFILE_DEFAULT -from cassandra.query import (named_tuple_factory, tuple_factory, - dict_factory, ordered_dict_factory) +from cassandra.query import ( + named_tuple_factory, + tuple_factory, + dict_factory, + ordered_dict_factory, +) from cassandra.cqlengine import columns from cassandra.cqlengine.connection import set_session @@ -38,6 +42,7 @@ class EmptyColumnTests(SimulacronCluster): @jira_ticket PYTHON-1082 @expected_result the driver supports those columns """ + connect = False def tearDown(self): @@ -48,22 +53,14 @@ def tearDown(self): def _prime_testtable_query(): queries = [ 'SELECT "", " " FROM testks.testtable', - 'SELECT "", " " FROM testks.testtable LIMIT 10000' # cqlengine + 'SELECT "", " " FROM testks.testtable LIMIT 10000', # cqlengine ] then = { - 'result': 'success', - 'delay_in_ms': 0, - 'rows': [ - { - "": "testval", - " ": "testval1" - } - ], - 'column_types': { - "": "ascii", - " ": "ascii" - }, - 'ignore_on_prepare': False + "result": "success", + "delay_in_ms": 0, + "rows": [{"": "testval", " ": "testval1"}], + "column_types": {"": "ascii", " ": "ascii"}, + "ignore_on_prepare": False, } for query in queries: prime_request(PrimeQuery(query, then=then)) @@ -76,28 +73,40 @@ def test_empty_columns_with_all_row_factories(self): self.session = self.cluster.connect(wait_for_all_pools=True) # Test all row factories - self.cluster.profile_manager.profiles[EXEC_PROFILE_DEFAULT].row_factory = named_tuple_factory - assert list(self.session.execute(query)) == [namedtuple('Row', ['field_0_', 'field_1_'])('testval', 'testval1')] - - self.cluster.profile_manager.profiles[EXEC_PROFILE_DEFAULT].row_factory = tuple_factory - assert list(self.session.execute(query)) == [('testval', 'testval1')] - - self.cluster.profile_manager.profiles[EXEC_PROFILE_DEFAULT].row_factory = dict_factory - assert list(self.session.execute(query)) == [{'': 'testval', ' ': 'testval1'}] + self.cluster.profile_manager.profiles[ + EXEC_PROFILE_DEFAULT + ].row_factory = named_tuple_factory + assert list(self.session.execute(query)) == [ + namedtuple("Row", ["field_0_", "field_1_"])("testval", "testval1") + ] - self.cluster.profile_manager.profiles[EXEC_PROFILE_DEFAULT].row_factory = ordered_dict_factory - assert list(self.session.execute(query)) == [OrderedDict((('', 'testval'), (' ', 'testval1')))] + self.cluster.profile_manager.profiles[ + EXEC_PROFILE_DEFAULT + ].row_factory = tuple_factory + assert list(self.session.execute(query)) == [("testval", "testval1")] + + self.cluster.profile_manager.profiles[ + EXEC_PROFILE_DEFAULT + ].row_factory = dict_factory + assert list(self.session.execute(query)) == [{"": "testval", " ": "testval1"}] + + self.cluster.profile_manager.profiles[ + EXEC_PROFILE_DEFAULT + ].row_factory = ordered_dict_factory + assert list(self.session.execute(query)) == [ + OrderedDict((("", "testval"), (" ", "testval1"))) + ] def test_empty_columns_in_system_schema(self): queries = [ "SELECT * FROM system_schema.tables", "SELECT * FROM system.schema.tables", - "SELECT * FROM system.schema_columnfamilies" + "SELECT * FROM system.schema_columnfamilies", ] then = { - 'result': 'success', - 'delay_in_ms': 0, - 'rows': [ + "result": "success", + "delay_in_ms": 0, + "rows": [ { "compression": dict(), "compaction": dict(), @@ -109,10 +118,10 @@ def test_empty_columns_in_system_schema(self): "table_name": "testtable", "columnfamily_name": "testtable", # C* 2.2 "flags": ["compound"], - "comparator": "none" # C* 2.2 + "comparator": "none", # C* 2.2 } ], - 'column_types': { + "column_types": { "compression": "map", "compaction": "map", "bloom_filter_fp_chance": "double", @@ -123,9 +132,9 @@ def test_empty_columns_in_system_schema(self): "table_name": "ascii", "columnfamily_name": "ascii", "flags": "set", - "comparator": "ascii" + "comparator": "ascii", }, - 'ignore_on_prepare': False + "ignore_on_prepare": False, } for query in queries: query = PrimeQuery(query, then=then) @@ -133,28 +142,31 @@ def test_empty_columns_in_system_schema(self): queries = [ "SELECT * FROM system_schema.keyspaces", - "SELECT * FROM system.schema_keyspaces" + "SELECT * FROM system.schema_keyspaces", ] then = { - 'result': 'success', - 'delay_in_ms': 0, - 'rows': [ + "result": "success", + "delay_in_ms": 0, + "rows": [ { - "strategy_class": "SimpleStrategy", # C* 2.2 - "strategy_options": '{}', # C* 2.2 - "replication": {'strategy': 'SimpleStrategy', 'replication_factor': 1}, + "strategy_class": "NetworkTopologyStrategy", # C* 2.2 + "strategy_options": "{}", # C* 2.2 + "replication": { + "strategy": "NetworkTopologyStrategy", + "replication_factor": 1, + }, "durable_writes": True, - "keyspace_name": "testks" + "keyspace_name": "testks", } ], - 'column_types': { + "column_types": { "strategy_class": "ascii", "strategy_options": "ascii", "replication": "map", "keyspace_name": "ascii", - "durable_writes": "boolean" + "durable_writes": "boolean", }, - 'ignore_on_prepare': False + "ignore_on_prepare": False, } for query in queries: query = PrimeQuery(query, then=then) @@ -163,15 +175,15 @@ def test_empty_columns_in_system_schema(self): queries = [ "SELECT * FROM system_schema.columns", "SELECT * FROM system.schema.columns", - "SELECT * FROM system.schema_columns" + "SELECT * FROM system.schema_columns", ] then = { - 'result': 'success', - 'delay_in_ms': 0, - 'rows': [ + "result": "success", + "delay_in_ms": 0, + "rows": [ { - "table_name": 'testtable', - "columnfamily_name": 'testtable', # C* 2.2 + "table_name": "testtable", + "columnfamily_name": "testtable", # C* 2.2 "column_name": "", "keyspace_name": "testks", "kind": "partition_key", @@ -179,11 +191,11 @@ def test_empty_columns_in_system_schema(self): "position": 0, "type": "text", "column_name_bytes": 0x12, - "validator": "none" # C* 2.2 + "validator": "none", # C* 2.2 }, { - "table_name": 'testtable', - "columnfamily_name": 'testtable', # C* 2.2 + "table_name": "testtable", + "columnfamily_name": "testtable", # C* 2.2 "column_name": " ", "keyspace_name": "testks", "kind": "regular", @@ -191,10 +203,10 @@ def test_empty_columns_in_system_schema(self): "position": -1, "type": "text", "column_name_bytes": 0x13, - "validator": "none" # C* 2.2 - } + "validator": "none", # C* 2.2 + }, ], - 'column_types': { + "column_types": { "table_name": "ascii", "columnfamily_name": "ascii", "column_name": "ascii", @@ -204,9 +216,9 @@ def test_empty_columns_in_system_schema(self): "kind": "ascii", "position": "int", "type": "ascii", - "validator": "ascii" # C* 2.2 + "validator": "ascii", # C* 2.2 }, - 'ignore_on_prepare': False + "ignore_on_prepare": False, } for query in queries: query = PrimeQuery(query, then=then) @@ -215,10 +227,10 @@ def test_empty_columns_in_system_schema(self): self.cluster = Cluster(protocol_version=PROTOCOL_VERSION, compression=False) self.session = self.cluster.connect(wait_for_all_pools=True) - table_metadata = self.cluster.metadata.keyspaces['testks'].tables['testtable'] + table_metadata = self.cluster.metadata.keyspaces["testks"].tables["testtable"] assert len(table_metadata.columns) == 2 - assert '' in table_metadata.columns - assert ' ' in table_metadata.columns + assert "" in table_metadata.columns + assert " " in table_metadata.columns def test_empty_columns_with_cqlengine(self): self._prime_testtable_query() @@ -228,9 +240,11 @@ def test_empty_columns_with_cqlengine(self): set_session(self.session) class TestModel(Model): - __keyspace__ = 'testks' - __table_name__ = 'testtable' - empty = columns.Text(db_field='', primary_key=True) - space = columns.Text(db_field=' ') - - assert [TestModel(empty='testval', space='testval1')] == list(TestModel.objects.only(['empty', 'space']).all()) + __keyspace__ = "testks" + __table_name__ = "testtable" + empty = columns.Text(db_field="", primary_key=True) + space = columns.Text(db_field=" ") + + assert [TestModel(empty="testval", space="testval1")] == list( + TestModel.objects.only(["empty", "space"]).all() + ) diff --git a/tests/integration/standard/column_encryption/test_policies.py b/tests/integration/standard/column_encryption/test_policies.py index 9a1d186895..2b96743100 100644 --- a/tests/integration/standard/column_encryption/test_policies.py +++ b/tests/integration/standard/column_encryption/test_policies.py @@ -19,28 +19,37 @@ from cassandra.policies import ColDesc -from cassandra.column_encryption.policies import AES256ColumnEncryptionPolicy, \ - AES256_KEY_SIZE_BYTES, AES256_BLOCK_SIZE_BYTES +from cassandra.column_encryption.policies import ( + AES256ColumnEncryptionPolicy, + AES256_KEY_SIZE_BYTES, + AES256_BLOCK_SIZE_BYTES, +) + def setup_module(): use_singledc() -@unittest.skip("Skip until https://github.com/scylladb/python-driver/issues/365 is sorted out") -class ColumnEncryptionPolicyTest(unittest.TestCase): +@unittest.skip( + "Skip until https://github.com/scylladb/python-driver/issues/365 is sorted out" +) +class ColumnEncryptionPolicyTest(unittest.TestCase): def _recreate_keyspace(self, session): session.execute("drop keyspace if exists foo") - session.execute("CREATE KEYSPACE foo WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'}") - session.execute("CREATE TABLE foo.bar(encrypted blob, unencrypted int, primary key(unencrypted))") - - def _create_policy(self, key, iv = None): + session.execute( + "CREATE KEYSPACE foo WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': '1'}" + ) + session.execute( + "CREATE TABLE foo.bar(encrypted blob, unencrypted int, primary key(unencrypted))" + ) + + def _create_policy(self, key, iv=None): cl_policy = AES256ColumnEncryptionPolicy() - col_desc = ColDesc('foo','bar','encrypted') + col_desc = ColDesc("foo", "bar", "encrypted") cl_policy.add_column(col_desc, key, "int") return (col_desc, cl_policy) def test_end_to_end_prepared(self): - # We only currently perform testing on a single type/expected value pair since CLE functionality is essentially # independent of the underlying type. We intercept data after it's been encoded when it's going out and before it's # encoded when coming back; the actual types of the data involved don't impact us. @@ -52,24 +61,30 @@ def test_end_to_end_prepared(self): session = cluster.connect() self._recreate_keyspace(session) - prepared = session.prepare("insert into foo.bar (encrypted, unencrypted) values (?,?)") + prepared = session.prepare( + "insert into foo.bar (encrypted, unencrypted) values (?,?)" + ) for i in range(100): session.execute(prepared, (i, i)) # A straight select from the database will now return the decrypted bits. We select both encrypted and unencrypted # values here to confirm that we don't interfere with regular processing of unencrypted vals. - (encrypted,unencrypted) = session.execute("select encrypted, unencrypted from foo.bar where unencrypted = %s allow filtering", (expected,)).one() + (encrypted, unencrypted) = session.execute( + "select encrypted, unencrypted from foo.bar where unencrypted = %s allow filtering", + (expected,), + ).one() assert expected == encrypted assert expected == unencrypted # Confirm the same behaviour from a subsequent prepared statement as well - prepared = session.prepare("select encrypted, unencrypted from foo.bar where unencrypted = ? allow filtering") - (encrypted,unencrypted) = session.execute(prepared, [expected]).one() + prepared = session.prepare( + "select encrypted, unencrypted from foo.bar where unencrypted = ? allow filtering" + ) + (encrypted, unencrypted) = session.execute(prepared, [expected]).one() assert expected == encrypted assert expected == unencrypted def test_end_to_end_simple(self): - expected = 1 key = os.urandom(AES256_KEY_SIZE_BYTES) @@ -79,20 +94,28 @@ def test_end_to_end_simple(self): self._recreate_keyspace(session) # Use encode_and_encrypt helper function to populate date - for i in range(1,100): + for i in range(1, 100): assert i is not None encrypted = cl_policy.encode_and_encrypt(col_desc, i) - session.execute("insert into foo.bar (encrypted, unencrypted) values (%s,%s)", (encrypted, i)) + session.execute( + "insert into foo.bar (encrypted, unencrypted) values (%s,%s)", + (encrypted, i), + ) # A straight select from the database will now return the decrypted bits. We select both encrypted and unencrypted # values here to confirm that we don't interfere with regular processing of unencrypted vals. - (encrypted,unencrypted) = session.execute("select encrypted, unencrypted from foo.bar where unencrypted = %s allow filtering", (expected,)).one() + (encrypted, unencrypted) = session.execute( + "select encrypted, unencrypted from foo.bar where unencrypted = %s allow filtering", + (expected,), + ).one() assert expected == encrypted assert expected == unencrypted # Confirm the same behaviour from a subsequent prepared statement as well - prepared = session.prepare("select encrypted, unencrypted from foo.bar where unencrypted = ? allow filtering") - (encrypted,unencrypted) = session.execute(prepared, [expected]).one() + prepared = session.prepare( + "select encrypted, unencrypted from foo.bar where unencrypted = ? allow filtering" + ) + (encrypted, unencrypted) = session.execute(prepared, [expected]).one() assert expected == encrypted assert expected == unencrypted @@ -118,10 +141,13 @@ def test_end_to_end_different_cle_contexts_different_ivs(self): self._recreate_keyspace(session1) # Use encode_and_encrypt helper function to populate date - for i in range(1,100): + for i in range(1, 100): assert i is not None encrypted = cl_policy1.encode_and_encrypt(col_desc1, i) - session1.execute("insert into foo.bar (encrypted, unencrypted) values (%s,%s)", (encrypted, i)) + session1.execute( + "insert into foo.bar (encrypted, unencrypted) values (%s,%s)", + (encrypted, i), + ) session1.shutdown() cluster1.shutdown() @@ -135,7 +161,10 @@ def test_end_to_end_different_cle_contexts_different_ivs(self): (_, cl_policy2) = self._create_policy(key, iv=iv2) cluster2 = TestCluster(column_encryption_policy=cl_policy2) session2 = cluster2.connect() - (encrypted,unencrypted) = session2.execute("select encrypted, unencrypted from foo.bar where unencrypted = %s allow filtering", (expected,)).one() + (encrypted, unencrypted) = session2.execute( + "select encrypted, unencrypted from foo.bar where unencrypted = %s allow filtering", + (expected,), + ).one() assert expected == encrypted assert expected == unencrypted @@ -153,7 +182,10 @@ def test_end_to_end_different_cle_contexts_different_policies(self): self._recreate_keyspace(session) # Use encode_and_encrypt helper function to populate date - session.execute("insert into foo.bar (encrypted, unencrypted) values (%s,%s)",(cl_policy.encode_and_encrypt(col_desc, expected), expected)) + session.execute( + "insert into foo.bar (encrypted, unencrypted) values (%s,%s)", + (cl_policy.encode_and_encrypt(col_desc, expected), expected), + ) # We now open a new session _without_ the CLE policy specified. We should _not_ be able to read decrypted bits from this session. cluster2 = TestCluster() @@ -161,11 +193,16 @@ def test_end_to_end_different_cle_contexts_different_policies(self): # A straight select from the database will now return the decrypted bits. We select both encrypted and unencrypted # values here to confirm that we don't interfere with regular processing of unencrypted vals. - (encrypted,unencrypted) = session2.execute("select encrypted, unencrypted from foo.bar where unencrypted = %s allow filtering", (expected,)).one() + (encrypted, unencrypted) = session2.execute( + "select encrypted, unencrypted from foo.bar where unencrypted = %s allow filtering", + (expected,), + ).one() assert cl_policy.encode_and_encrypt(col_desc, expected) == encrypted assert expected == unencrypted # Confirm the same behaviour from a subsequent prepared statement as well - prepared = session2.prepare("select encrypted, unencrypted from foo.bar where unencrypted = ? allow filtering") - (encrypted,unencrypted) = session2.execute(prepared, [expected]).one() + prepared = session2.prepare( + "select encrypted, unencrypted from foo.bar where unencrypted = ? allow filtering" + ) + (encrypted, unencrypted) = session2.execute(prepared, [expected]).one() assert cl_policy.encode_and_encrypt(col_desc, expected) == encrypted diff --git a/tests/integration/standard/test_client_routes.py b/tests/integration/standard/test_client_routes.py index 5a20421276..3e41092d9b 100644 --- a/tests/integration/standard/test_client_routes.py +++ b/tests/integration/standard/test_client_routes.py @@ -54,6 +54,7 @@ log = logging.getLogger(__name__) + class TcpProxy: """ A simple TCP proxy that forwards connections from a local listen port @@ -84,12 +85,19 @@ def start(self): self._server_sock.listen(128) self._server_sock.setblocking(False) self._running = True - self._thread = threading.Thread(target=self._run, daemon=True, - name="proxy-%s:%d" % (self.listen_host, self.listen_port)) + self._thread = threading.Thread( + target=self._run, + daemon=True, + name="proxy-%s:%d" % (self.listen_host, self.listen_port), + ) self._thread.start() - log.info("TcpProxy started %s:%d -> %s:%d", - self.listen_host, self.listen_port, - self.target_host, self.target_port) + log.info( + "TcpProxy started %s:%d -> %s:%d", + self.listen_host, + self.listen_port, + self.target_host, + self.target_port, + ) def stop(self): self._running = False @@ -115,8 +123,13 @@ def retarget(self, new_host, new_port): """Change the backend target for new connections (existing ones keep the old target).""" self.target_host = new_host self.target_port = new_port - log.info("TcpProxy %s:%d retargeted to %s:%d", - self.listen_host, self.listen_port, new_host, new_port) + log.info( + "TcpProxy %s:%d retargeted to %s:%d", + self.listen_host, + self.listen_port, + new_host, + new_port, + ) def drop_connections(self): """Forcibly close all active connections.""" @@ -124,7 +137,9 @@ def drop_connections(self): for csock, tsock in list(self._connections): self._close_pair(csock, tsock) self._connections.clear() - log.info("TcpProxy %s:%d dropped all connections", self.listen_host, self.listen_port) + log.info( + "TcpProxy %s:%d dropped all connections", self.listen_host, self.listen_port + ) def _run(self): while self._running: @@ -147,9 +162,14 @@ def _handle_new_connection(self, client_sock, target_host=None, target_port=None target_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) target_sock.connect((target_host, target_port)) except Exception as e: - log.warning("TcpProxy %s:%d failed to connect to target %s:%d: %s", - self.listen_host, self.listen_port, - target_host, target_port, e) + log.warning( + "TcpProxy %s:%d failed to connect to target %s:%d: %s", + self.listen_host, + self.listen_port, + target_host, + target_port, + e, + ) client_sock.close() return @@ -157,9 +177,9 @@ def _handle_new_connection(self, client_sock, target_host=None, target_port=None self._connections.add((client_sock, target_sock)) self.total_connections += 1 - t = threading.Thread(target=self._forward_loop, - args=(client_sock, target_sock), - daemon=True) + t = threading.Thread( + target=self._forward_loop, args=(client_sock, target_sock), daemon=True + ) t.start() def _forward_loop(self, client_sock, target_sock): @@ -215,10 +235,9 @@ class NLBEmulator: LISTEN_HOST = "127.254.254.101" - def __init__(self, discovery_port=0, - per_node_base=0, - native_port=9042, - node_addresses=None): + def __init__( + self, discovery_port=0, per_node_base=0, native_port=9042, node_addresses=None + ): self.discovery_port = discovery_port self.per_node_base = per_node_base self.native_port = native_port @@ -244,8 +263,10 @@ def start(self, node_addresses): first_addr = list(node_addresses.values())[0] self._discovery_proxy = TcpProxy( - self.LISTEN_HOST, self.discovery_port, - first_addr, self.native_port, + self.LISTEN_HOST, + self.discovery_port, + first_addr, + self.native_port, ) self._discovery_proxy.start() self.discovery_port = self._discovery_proxy.listen_port @@ -262,12 +283,18 @@ def rr_handler(client_sock): idx = self._rr_index % len(addrs) self._rr_index += 1 addr = addrs[idx] - original_handler(client_sock, target_host=addr, target_port=self.native_port) + original_handler( + client_sock, target_host=addr, target_port=self.native_port + ) self._discovery_proxy._handle_new_connection = rr_handler - log.info("NLB started: discovery=%s:%d, %d node proxies", - self.LISTEN_HOST, self.discovery_port, len(self._node_proxies)) + log.info( + "NLB started: discovery=%s:%d, %d node proxies", + self.LISTEN_HOST, + self.discovery_port, + len(self._node_proxies), + ) return self def __enter__(self): @@ -324,13 +351,20 @@ def _add_node_proxy(self, node_id, addr): proxy.start() with self._lock: self._node_proxies[node_id] = proxy - log.info("NLB added node %d: %s:%d -> %s:%d", - node_id, self.LISTEN_HOST, port, addr, self.native_port) + log.info( + "NLB added node %d: %s:%d -> %s:%d", + node_id, + self.LISTEN_HOST, + port, + addr, + self.native_port, + ) def _live_addresses(self): """IPs of nodes with active proxies.""" return [p.target_host for p in self._node_proxies.values()] + def post_client_routes(contact_point, routes): """ Post client routes to Scylla's REST API. @@ -395,12 +429,14 @@ def build_routes_for_nlb(connection_id, host_id_map, nlb): for ip, host_id in host_id_map.items(): node_id = int(ip.split(".")[-1]) port = nlb.node_port(node_id) - routes.append({ - "connection_id": connection_id, - "host_id": host_id, - "address": NLBEmulator.LISTEN_HOST, - "port": port, - }) + routes.append( + { + "connection_id": connection_id, + "host_id": host_id, + "address": NLBEmulator.LISTEN_HOST, + "port": port, + } + ) return routes @@ -410,7 +446,10 @@ def post_routes_for_nlb(contact_point, connection_id, host_id_map, nlb): post_client_routes(contact_point, routes) return routes -def wait_for_routes_visible(session, connection_id, expected_count, timeout=10, poll_interval=0.1): + +def wait_for_routes_visible( + session, connection_id, expected_count, timeout=10, poll_interval=0.1 +): """ Poll system.client_routes on **every** node until each one sees at least *expected_count* rows for *connection_id*. @@ -431,11 +470,13 @@ def wait_for_routes_visible(session, connection_id, expected_count, timeout=10, while True: pending_hosts = [] for host in all_hosts: - rows = list(session.execute( - "SELECT * FROM system.client_routes WHERE connection_id = %s", - (connection_id,), - host=host, - )) + rows = list( + session.execute( + "SELECT * FROM system.client_routes WHERE connection_id = %s", + (connection_id,), + host=host, + ) + ) if len(rows) < expected_count: pending_hosts.append((host, len(rows))) if not pending_hosts: @@ -475,18 +516,21 @@ def assert_routes_via_nlb(test, cluster, nlb, expected_node_ids): continue resolved_addr, resolved_port = ep.resolve() test.assertEqual( - resolved_addr, nlb_listen_host, + resolved_addr, + nlb_listen_host, "Node %d endpoint should resolve to NLB address %s, got %s" % (node_id, nlb_listen_host, resolved_addr), ) test.assertEqual( - resolved_port, nlb.node_port(node_id), + resolved_port, + nlb.node_port(node_id), "Node %d endpoint should resolve to NLB port %d, got %d" % (node_id, nlb.node_port(node_id), resolved_port), ) seen_node_ids.add(node_id) test.assertEqual( - seen_node_ids, expected_node_ids, + seen_node_ids, + expected_node_ids, "Not all expected nodes found in metadata endpoints", ) @@ -508,12 +552,14 @@ def assert_routes_direct(test, cluster, expected_node_ids, direct_port=9042): resolved_addr, resolved_port = ep.resolve() expected_ip = "127.0.0.%d" % node_id test.assertEqual( - resolved_addr, expected_ip, + resolved_addr, + expected_ip, "Node %d endpoint should resolve to direct address %s, got %s" % (node_id, expected_ip, resolved_addr), ) test.assertEqual( - resolved_port, direct_port, + resolved_port, + direct_port, "Node %d endpoint should resolve to direct port %d, got %d" % (node_id, direct_port, resolved_port), ) @@ -524,19 +570,22 @@ def assert_routes_direct(test, cluster, expected_node_ids, direct_port=9042): def setup_module(): global _saved_scylla_ext_opts - _saved_scylla_ext_opts = os.environ.get('SCYLLA_EXT_OPTS') - os.environ['SCYLLA_EXT_OPTS'] = "--smp 2 --memory 2048M" - use_cluster('shared_aware', [3], start=True) + _saved_scylla_ext_opts = os.environ.get("SCYLLA_EXT_OPTS") + os.environ["SCYLLA_EXT_OPTS"] = "--smp 2 --memory 2048M" + use_cluster("shared_aware", [3], start=True) def teardown_module(): if _saved_scylla_ext_opts is None: - os.environ.pop('SCYLLA_EXT_OPTS', None) + os.environ.pop("SCYLLA_EXT_OPTS", None) else: - os.environ['SCYLLA_EXT_OPTS'] = _saved_scylla_ext_opts + os.environ["SCYLLA_EXT_OPTS"] = _saved_scylla_ext_opts + -@skip_scylla_version_lt(reason='scylladb/scylladb#26992 - system.client_routes is not yet supported', - scylla_version="2026.1.0") +@skip_scylla_version_lt( + reason="scylladb/scylladb#26992 - system.client_routes is not yet supported", + scylla_version="2026.1.0", +) class TestGetHostPortMapping(unittest.TestCase): """ Test _query_all_routes_for_connections and _query_routes_for_change_event @@ -545,8 +594,11 @@ class TestGetHostPortMapping(unittest.TestCase): @classmethod def setUpClass(cls): - cls.cluster = TestCluster(client_routes_config=ClientRoutesConfig( - proxies=[ClientRouteProxy("conn_id", "127.0.0.1")])) + cls.cluster = TestCluster( + client_routes_config=ClientRoutesConfig( + proxies=[ClientRouteProxy("conn_id", "127.0.0.1")] + ) + ) cls.session = cls.cluster.connect() cls.host_ids = [uuid.uuid4() for _ in range(3)] @@ -556,13 +608,15 @@ def setUpClass(cls): for idx, host_id in enumerate(cls.host_ids): ip = f"127.0.0.{idx + 1}" for connection_id in cls.connection_ids: - cls.expected.append({ - 'connection_id': connection_id, - 'host_id': host_id, - 'address': ip, - 'port': 9042, - 'tls_port': 9142, - }) + cls.expected.append( + { + "connection_id": connection_id, + "host_id": host_id, + "address": ip, + "port": 9042, + "tls_port": 9142, + } + ) cls._sort_routes(cls.expected) post_client_routes(cls.cluster.contact_points[0], cls.expected) @@ -573,29 +627,31 @@ def tearDownClass(cls): @staticmethod def _sort_routes(routes): - routes.sort(key=lambda r: (str(r['connection_id']), str(r['host_id']))) + routes.sort(key=lambda r: (str(r["connection_id"]), str(r["host_id"]))) def _routes_to_dicts(self, routes): """Convert _Route objects to comparable dicts, adjusting port for ssl_enabled.""" return [ { - 'connection_id': route.connection_id, - 'host_id': route.host_id, - 'address': route.address, - 'port': route.port, + "connection_id": route.connection_id, + "host_id": route.host_id, + "address": route.address, + "port": route.port, } for route in routes ] def _expected_dicts(self, expected): """Build expected dicts with tls_port or port based on ssl_enabled.""" - port_key = 'tls_port' if self.cluster._client_routes_handler.ssl_enabled else 'port' + port_key = ( + "tls_port" if self.cluster._client_routes_handler.ssl_enabled else "port" + ) return [ { - 'connection_id': e['connection_id'], - 'host_id': e['host_id'], - 'address': e['address'], - 'port': e[port_key], + "connection_id": e["connection_id"], + "host_id": e["host_id"], + "address": e["address"], + "port": e[port_key], } for e in expected ] @@ -604,7 +660,9 @@ def test_get_all_routes_for_all_connections(self): """Querying all connection IDs returns every route.""" cc = self.cluster.control_connection routes = self.cluster._client_routes_handler._query_all_routes_for_connections( - cc._connection, cc._timeout, self.connection_ids, + cc._connection, + cc._timeout, + self.connection_ids, ) got = self._routes_to_dicts(routes) self._sort_routes(got) @@ -616,12 +674,15 @@ def test_get_routes_for_single_connection(self): """Querying a single connection ID returns only its routes.""" cc = self.cluster.control_connection routes = self.cluster._client_routes_handler._query_all_routes_for_connections( - cc._connection, cc._timeout, [self.connection_ids[0]], + cc._connection, + cc._timeout, + [self.connection_ids[0]], ) got = self._routes_to_dicts(routes) self._sort_routes(got) - filtered = [r for r in self.expected - if r['connection_id'] == self.connection_ids[0]] + filtered = [ + r for r in self.expected if r["connection_id"] == self.connection_ids[0] + ] expected = self._expected_dicts(filtered) self._sort_routes(expected) self.assertEqual(got, expected) @@ -629,9 +690,11 @@ def test_get_routes_for_single_connection(self): def test_get_routes_for_change_event_all_pairs(self): """Querying all (connection_id, host_id) pairs returns every route.""" cc = self.cluster.control_connection - pairs = [(r['connection_id'], r['host_id']) for r in self.expected] + pairs = [(r["connection_id"], r["host_id"]) for r in self.expected] routes = self.cluster._client_routes_handler._query_routes_for_change_event( - cc._connection, cc._timeout, pairs, + cc._connection, + cc._timeout, + pairs, ) got = self._routes_to_dicts(routes) self._sort_routes(got) @@ -645,19 +708,26 @@ def test_get_routes_for_change_event_single_pair(self): target_conn_id = self.connection_ids[0] target_host_id = self.host_ids[0] routes = self.cluster._client_routes_handler._query_routes_for_change_event( - cc._connection, cc._timeout, [(target_conn_id, target_host_id)], + cc._connection, + cc._timeout, + [(target_conn_id, target_host_id)], ) got = self._routes_to_dicts(routes) self._sort_routes(got) - filtered = [r for r in self.expected - if r['connection_id'] == target_conn_id - and r['host_id'] == target_host_id] + filtered = [ + r + for r in self.expected + if r["connection_id"] == target_conn_id and r["host_id"] == target_host_id + ] expected = self._expected_dicts(filtered) self._sort_routes(expected) self.assertEqual(got, expected) -@skip_scylla_version_lt(reason='scylladb/scylladb#26992 - system.client_routes is not yet supported', - scylla_version="2026.1.0") + +@skip_scylla_version_lt( + reason="scylladb/scylladb#26992 - system.client_routes is not yet supported", + scylla_version="2026.1.0", +) class TestPrivateLinkConnectivity(unittest.TestCase): """ Verifies the driver connects to all cluster nodes exclusively through @@ -690,7 +760,9 @@ def setUpClass(cls): cls.connection_id = str(uuid.uuid4()) post_routes_for_nlb("127.0.0.1", cls.connection_id, cls.host_id_map, cls.nlb) - wait_for_routes_visible(cls.direct_session, cls.connection_id, len(cls.host_id_map)) + wait_for_routes_visible( + cls.direct_session, cls.connection_id, len(cls.host_id_map) + ) @classmethod def tearDownClass(cls): @@ -718,22 +790,26 @@ def test_all_connections_through_proxy(self): session.execute("SELECT key FROM system.local") pool_state = session.get_pool_state() - self.assertEqual(len(pool_state), len(self.node_addrs), - "Driver should have pools for all nodes") + self.assertEqual( + len(pool_state), + len(self.node_addrs), + "Driver should have pools for all nodes", + ) for host, state in pool_state.items(): node_id = node_id_from_ip(host.address) proxy = self.nlb.get_node_proxy(node_id) self.assertIsNotNone(proxy, f"No proxy for node {node_id}") - open_count = state['open_count'] + open_count = state["open_count"] self.assertGreaterEqual( - proxy.total_connections, open_count, + proxy.total_connections, + open_count, f"Node {node_id} proxy saw {proxy.total_connections} " f"connections but pool has {open_count} open — " - f"some connections bypassed the proxy") + f"some connections bypassed the proxy", + ) - assert_routes_via_nlb(self, cluster, self.nlb, - self.node_addrs.keys()) + assert_routes_via_nlb(self, cluster, self.nlb, self.node_addrs.keys()) def test_queries_succeed_through_proxy(self): """Queries should work normally through the proxy.""" @@ -741,7 +817,7 @@ def test_queries_succeed_through_proxy(self): session = cluster.connect() session.execute( "CREATE KEYSPACE IF NOT EXISTS test_cr_ks " - "WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 3}" + "WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': 3}" ) session.execute( "CREATE TABLE IF NOT EXISTS test_cr_ks.t (k int PRIMARY KEY, v text)" @@ -750,8 +826,7 @@ def test_queries_succeed_through_proxy(self): row = session.execute("SELECT v FROM test_cr_ks.t WHERE k = 1").one() self.assertEqual(row.v, "hello") - assert_routes_via_nlb(self, cluster, self.nlb, - self.node_addrs.keys()) + assert_routes_via_nlb(self, cluster, self.nlb, self.node_addrs.keys()) def test_connection_recovery_after_proxy_drop(self): """ @@ -762,8 +837,7 @@ def test_connection_recovery_after_proxy_drop(self): session = cluster.connect(wait_for_all_pools=True) session.execute("SELECT key FROM system.local") - assert_routes_via_nlb(self, cluster, self.nlb, - self.node_addrs.keys()) + assert_routes_via_nlb(self, cluster, self.nlb, self.node_addrs.keys()) self.nlb.drop_all_connections() @@ -772,11 +846,13 @@ def query_ok(): wait_until_not_raised(query_ok, 1, 30) - assert_routes_via_nlb(self, cluster, self.nlb, - self.node_addrs.keys()) + assert_routes_via_nlb(self, cluster, self.nlb, self.node_addrs.keys()) -@skip_scylla_version_lt(reason='scylladb/scylladb#26992 - system.client_routes is not yet supported', - scylla_version="2026.1.0") + +@skip_scylla_version_lt( + reason="scylladb/scylladb#26992 - system.client_routes is not yet supported", + scylla_version="2026.1.0", +) class TestDynamicRouteUpdates(unittest.TestCase): """ Verify that when routes are updated (e.g. port changes), the driver @@ -808,20 +884,28 @@ def test_route_update_causes_reconnect_to_new_port(self): 3. Drop v1 connections. 4. Driver should reconnect through v2 ports. """ - with NLBEmulator( - node_addresses=self.node_addrs, - ) as nlb_v1, NLBEmulator( - node_addresses=self.node_addrs, - ) as nlb_v2: - post_routes_for_nlb("127.0.0.1", self.connection_id, - self.host_id_map, nlb_v1) - wait_for_routes_visible(self.direct_session, self.connection_id, len(self.host_id_map)) + with ( + NLBEmulator( + node_addresses=self.node_addrs, + ) as nlb_v1, + NLBEmulator( + node_addresses=self.node_addrs, + ) as nlb_v2, + ): + post_routes_for_nlb( + "127.0.0.1", self.connection_id, self.host_id_map, nlb_v1 + ) + wait_for_routes_visible( + self.direct_session, self.connection_id, len(self.host_id_map) + ) with Cluster( contact_points=[NLBEmulator.LISTEN_HOST], port=nlb_v1.discovery_port, client_routes_config=ClientRoutesConfig( - proxies=[ClientRouteProxy(self.connection_id, NLBEmulator.LISTEN_HOST)], + proxies=[ + ClientRouteProxy(self.connection_id, NLBEmulator.LISTEN_HOST) + ], ), load_balancing_policy=RoundRobinPolicy(), ) as cluster: @@ -830,12 +914,13 @@ def test_route_update_causes_reconnect_to_new_port(self): for node_id in self.node_addrs: self.assertGreater( - nlb_v1.get_node_proxy(node_id).total_connections, 0) - assert_routes_via_nlb(self, cluster, nlb_v1, - self.node_addrs.keys()) + nlb_v1.get_node_proxy(node_id).total_connections, 0 + ) + assert_routes_via_nlb(self, cluster, nlb_v1, self.node_addrs.keys()) - post_routes_for_nlb("127.0.0.1", self.connection_id, - self.host_id_map, nlb_v2) + post_routes_for_nlb( + "127.0.0.1", self.connection_id, self.host_id_map, nlb_v2 + ) time.sleep(2) # let CLIENT_ROUTES_CHANGE propagate # Stop v1 per-node proxies entirely so v1 ports become @@ -849,13 +934,13 @@ def test_route_update_causes_reconnect_to_new_port(self): def all_nodes_via_v2(): session.execute("SELECT key FROM system.local") for nid in self.node_addrs: - assert nlb_v2.get_node_proxy(nid).total_connections > 0, \ + assert nlb_v2.get_node_proxy(nid).total_connections > 0, ( "NLB v2 node %d proxy has no connections yet" % nid + ) wait_until_not_raised(all_nodes_via_v2, 1, 30) - assert_routes_via_nlb(self, cluster, nlb_v2, - self.node_addrs.keys()) + assert_routes_via_nlb(self, cluster, nlb_v2, self.node_addrs.keys()) def _generate_ssl_certs(cert_dir, node_ips): @@ -873,7 +958,9 @@ def _generate_ssl_certs(cert_dir, node_ips): :param node_ips: list of IP strings to include as SANs (e.g. ["127.0.0.1", "127.0.0.2"]) """ if shutil.which("openssl") is None: - raise unittest.SkipTest("openssl not found on PATH; skipping SSL cert generation") + raise unittest.SkipTest( + "openssl not found on PATH; skipping SSL cert generation" + ) san_cnf = os.path.join(cert_dir, "san.cnf") san_value = ",".join("IP:%s" % ip for ip in node_ips) @@ -883,26 +970,73 @@ def _generate_ssl_certs(cert_dir, node_ips): def _run(cmd): result = subprocess.run(cmd, cwd=cert_dir, capture_output=True, text=True) if result.returncode != 0: - raise RuntimeError("Command failed: %s\n%s" % (" ".join(cmd), result.stderr)) + raise RuntimeError( + "Command failed: %s\n%s" % (" ".join(cmd), result.stderr) + ) - _run(["openssl", "req", "-x509", "-newkey", "rsa:2048", - "-keyout", "ca.key", "-out", "ca.crt", - "-days", "1", "-nodes", "-subj", "/CN=Test CA"]) + _run( + [ + "openssl", + "req", + "-x509", + "-newkey", + "rsa:2048", + "-keyout", + "ca.key", + "-out", + "ca.crt", + "-days", + "1", + "-nodes", + "-subj", + "/CN=Test CA", + ] + ) - _run(["openssl", "req", "-newkey", "rsa:2048", - "-keyout", "ccm_node.key", "-out", "ccm_node.csr", - "-nodes", "-subj", "/CN=Test Server"]) + _run( + [ + "openssl", + "req", + "-newkey", + "rsa:2048", + "-keyout", + "ccm_node.key", + "-out", + "ccm_node.csr", + "-nodes", + "-subj", + "/CN=Test Server", + ] + ) - _run(["openssl", "x509", "-req", - "-in", "ccm_node.csr", "-CA", "ca.crt", "-CAkey", "ca.key", - "-CAcreateserial", "-out", "ccm_node.pem", - "-days", "1", "-extfile", "san.cnf"]) + _run( + [ + "openssl", + "x509", + "-req", + "-in", + "ccm_node.csr", + "-CA", + "ca.crt", + "-CAkey", + "ca.key", + "-CAcreateserial", + "-out", + "ccm_node.pem", + "-days", + "1", + "-extfile", + "san.cnf", + ] + ) log.info("Generated SSL certs in %s with SANs: %s", cert_dir, san_value) -@skip_scylla_version_lt(reason='scylladb/scylladb#26992 - system.client_routes is not yet supported', - scylla_version="2026.1.0") +@skip_scylla_version_lt( + reason="scylladb/scylladb#26992 - system.client_routes is not yet supported", + scylla_version="2026.1.0", +) class TestMixedDirectAndNlbConnections(unittest.TestCase): """ Verify the cluster works when some nodes are accessed through the NLB @@ -940,19 +1074,23 @@ def test_mixed_direct_and_nlb_connections(self): node_addresses={proxied_node_id: proxied_ip}, ) as nlb: proxied_host_id = self.host_id_map[proxied_ip] - routes = [{ - "connection_id": self.connection_id, - "host_id": proxied_host_id, - "address": NLBEmulator.LISTEN_HOST, - "port": nlb.node_port(proxied_node_id), - }] + routes = [ + { + "connection_id": self.connection_id, + "host_id": proxied_host_id, + "address": NLBEmulator.LISTEN_HOST, + "port": nlb.node_port(proxied_node_id), + } + ] post_client_routes("127.0.0.1", routes) time.sleep(1) with Cluster( contact_points=["127.0.0.1"], client_routes_config=ClientRoutesConfig( - proxies=[ClientRouteProxy(self.connection_id, NLBEmulator.LISTEN_HOST)], + proxies=[ + ClientRouteProxy(self.connection_id, NLBEmulator.LISTEN_HOST) + ], ), load_balancing_policy=RoundRobinPolicy(), ) as cluster: @@ -961,19 +1099,23 @@ def test_mixed_direct_and_nlb_connections(self): for _ in range(50): session.execute("SELECT key FROM system.local") - assert_routes_via_nlb(self, cluster, nlb, - [proxied_node_id]) + assert_routes_via_nlb(self, cluster, nlb, [proxied_node_id]) direct_node_ids = set(self.node_addrs.keys()) - {proxied_node_id} assert_routes_direct(self, cluster, direct_node_ids) proxy = nlb.get_node_proxy(proxied_node_id) - self.assertGreater(proxy.total_connections, 0, - "Proxied node should have connections through NLB") + self.assertGreater( + proxy.total_connections, + 0, + "Proxied node should have connections through NLB", + ) -@skip_scylla_version_lt(reason='scylladb/scylladb#26992 - system.client_routes is not yet supported', - scylla_version="2026.1.0") +@skip_scylla_version_lt( + reason="scylladb/scylladb#26992 - system.client_routes is not yet supported", + scylla_version="2026.1.0", +) class TestSslThroughNlb(unittest.TestCase): """ Verify SSL with check_hostname=False works through the NLB proxy. @@ -1007,23 +1149,27 @@ def setUpClass(cls): cls.ccm_cluster = get_cluster() cls.ccm_cluster.stop() - cls.ccm_cluster.set_configuration_options({ - 'client_encryption_options': { - 'enabled': True, - 'certificate': os.path.join(cls.cert_dir, "ccm_node.pem"), - 'keyfile': os.path.join(cls.cert_dir, "ccm_node.key"), + cls.ccm_cluster.set_configuration_options( + { + "client_encryption_options": { + "enabled": True, + "certificate": os.path.join(cls.cert_dir, "ccm_node.pem"), + "keyfile": os.path.join(cls.cert_dir, "ccm_node.key"), + } } - }) + ) cls.ccm_cluster.start(wait_for_binary_proto=True) @classmethod def tearDownClass(cls): cls.ccm_cluster.stop() - cls.ccm_cluster.set_configuration_options({ - 'client_encryption_options': { - 'enabled': False, + cls.ccm_cluster.set_configuration_options( + { + "client_encryption_options": { + "enabled": False, + } } - }) + ) cls.ccm_cluster.start(wait_for_binary_proto=True) shutil.rmtree(cls.cert_dir, ignore_errors=True) @@ -1041,7 +1187,9 @@ def test_ssl_without_hostname_verification_through_nlb(self): node_addresses=self.node_addrs, ) as nlb: routes = build_routes_for_nlb( - self.connection_id, self.host_id_map, nlb, + self.connection_id, + self.host_id_map, + nlb, ) for route in routes: route["tls_port"] = route["port"] @@ -1049,29 +1197,35 @@ def test_ssl_without_hostname_verification_through_nlb(self): ssl_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ssl_ctx.check_hostname = False - ssl_ctx.load_verify_locations(os.path.join(self.cert_dir, 'ca.crt')) + ssl_ctx.load_verify_locations(os.path.join(self.cert_dir, "ca.crt")) - self.assertFalse(ssl_ctx.check_hostname, - "check_hostname must be False for this test") - self.assertEqual(ssl_ctx.verify_mode, ssl.CERT_REQUIRED, - "verify_mode must be CERT_REQUIRED") + self.assertFalse( + ssl_ctx.check_hostname, "check_hostname must be False for this test" + ) + self.assertEqual( + ssl_ctx.verify_mode, + ssl.CERT_REQUIRED, + "verify_mode must be CERT_REQUIRED", + ) def routes_visible(): with TestCluster( contact_points=["127.0.0.1"], - ssl_context=ssl_ctx, connect_timeout=30, + ssl_context=ssl_ctx, + connect_timeout=30, ) as c: session = c.connect() rs = session.execute( "SELECT * FROM system.client_routes " "WHERE connection_id = %s ALLOW FILTERING", - (self.connection_id,) + (self.connection_id,), ) return len(list(rs)) >= len(self.host_id_map) wait_until_not_raised( lambda: self.assertTrue(routes_visible()), - 1, 30, + 1, + 30, ) with Cluster( @@ -1079,7 +1233,9 @@ def routes_visible(): port=nlb.discovery_port, ssl_context=ssl_ctx, client_routes_config=ClientRoutesConfig( - proxies=[ClientRouteProxy(self.connection_id, NLBEmulator.LISTEN_HOST)], + proxies=[ + ClientRouteProxy(self.connection_id, NLBEmulator.LISTEN_HOST) + ], ), load_balancing_policy=RoundRobinPolicy(), ) as cluster: @@ -1091,8 +1247,7 @@ def routes_visible(): ).one() self.assertIsNotNone(row) - assert_routes_via_nlb(self, cluster, nlb, - self.node_addrs.keys()) + assert_routes_via_nlb(self, cluster, nlb, self.node_addrs.keys()) def test_ssl_with_hostname_verification_raises_error(self): """ @@ -1100,7 +1255,7 @@ def test_ssl_with_hostname_verification_raises_error(self): is used with SSL hostname verification enabled. """ ssl_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) - ssl_ctx.load_verify_locations(os.path.join(self.cert_dir, 'ca.crt')) + ssl_ctx.load_verify_locations(os.path.join(self.cert_dir, "ca.crt")) self.assertTrue(ssl_ctx.check_hostname) with self.assertRaises(ValueError) as cm: @@ -1113,8 +1268,11 @@ def test_ssl_with_hostname_verification_raises_error(self): ) self.assertIn("check_hostname", str(cm.exception)) -@skip_scylla_version_lt(reason='scylladb/scylladb#26992 - system.client_routes is not yet supported', - scylla_version="2026.1.0") + +@skip_scylla_version_lt( + reason="scylladb/scylladb#26992 - system.client_routes is not yet supported", + scylla_version="2026.1.0", +) class TestFullNodeReplacementThroughNlb(unittest.TestCase): """ End-to-end test: creates a session through an NLB proxy with client routes, @@ -1128,9 +1286,9 @@ class TestFullNodeReplacementThroughNlb(unittest.TestCase): @classmethod def setUpClass(cls): - cls._saved_scylla_ext_opts = os.environ.get('SCYLLA_EXT_OPTS') - os.environ['SCYLLA_EXT_OPTS'] = "--smp 2 --memory 2048M" - use_cluster('test_client_routes_replacement', [3], start=True) + cls._saved_scylla_ext_opts = os.environ.get("SCYLLA_EXT_OPTS") + os.environ["SCYLLA_EXT_OPTS"] = "--smp 2 --memory 2048M" + use_cluster("test_client_routes_replacement", [3], start=True) cls.direct_cluster = TestCluster() cls.direct_session = cls.direct_cluster.connect() @@ -1147,9 +1305,9 @@ def setUpClass(cls): def tearDownClass(cls): cls.direct_cluster.shutdown() if cls._saved_scylla_ext_opts is None: - os.environ.pop('SCYLLA_EXT_OPTS', None) + os.environ.pop("SCYLLA_EXT_OPTS", None) else: - os.environ['SCYLLA_EXT_OPTS'] = cls._saved_scylla_ext_opts + os.environ["SCYLLA_EXT_OPTS"] = cls._saved_scylla_ext_opts def test_should_survive_full_node_replacement_through_nlb(self): """ @@ -1163,10 +1321,14 @@ def test_should_survive_full_node_replacement_through_nlb(self): node_addresses=self.node_addrs, ) as nlb: # ---- Stage 1: Set up NLB for initial nodes ---- - log.info("Stage 1: Setting up NLB for %d initial nodes", len(original_node_ids)) + log.info( + "Stage 1: Setting up NLB for %d initial nodes", len(original_node_ids) + ) post_routes_for_nlb("127.0.0.1", self.connection_id, self.host_id_map, nlb) - wait_for_routes_visible(self.direct_session, self.connection_id, len(self.host_id_map)) + wait_for_routes_visible( + self.direct_session, self.connection_id, len(self.host_id_map) + ) # ---- Stage 2: Create session through NLB ---- log.info("Stage 2: Creating session through NLB") @@ -1174,7 +1336,9 @@ def test_should_survive_full_node_replacement_through_nlb(self): contact_points=[NLBEmulator.LISTEN_HOST], port=nlb.discovery_port, client_routes_config=ClientRoutesConfig( - proxies=[ClientRouteProxy(self.connection_id, NLBEmulator.LISTEN_HOST)], + proxies=[ + ClientRouteProxy(self.connection_id, NLBEmulator.LISTEN_HOST) + ], ), load_balancing_policy=RoundRobinPolicy(), ) as cluster: @@ -1184,10 +1348,11 @@ def test_should_survive_full_node_replacement_through_nlb(self): handler = cluster._client_routes_handler self.assertIsNotNone(handler) - assert_routes_via_nlb(self, cluster, nlb, - original_node_ids) - log.info("Stage 2: Session created, all %d nodes via NLB", - len(original_node_ids)) + assert_routes_via_nlb(self, cluster, nlb, original_node_ids) + log.info( + "Stage 2: Session created, all %d nodes via NLB", + len(original_node_ids), + ) # ---- Stage 3: Bootstrap new nodes ---- new_node_ids = [max(original_node_ids) + 1, max(original_node_ids) + 2] @@ -1195,7 +1360,7 @@ def test_should_survive_full_node_replacement_through_nlb(self): ccm_cluster = get_cluster() for node_id in new_node_ids: - self._bootstrap_node(ccm_cluster, node_id, data_center='dc1') + self._bootstrap_node(ccm_cluster, node_id, data_center="dc1") expected_total = len(original_node_ids) + len(new_node_ids) self._wait_for_condition( @@ -1213,10 +1378,12 @@ def test_should_survive_full_node_replacement_through_nlb(self): handler.initialize( cluster.control_connection._connection, - cluster.control_connection._timeout) + cluster.control_connection._timeout, + ) self._wait_for_condition( - lambda: sum(1 for h in cluster.metadata.all_hosts() if h.is_up) >= expected_total, + lambda: sum(1 for h in cluster.metadata.all_hosts() if h.is_up) + >= expected_total, timeout_seconds=60, description="all %d nodes up" % expected_total, ) @@ -1225,11 +1392,14 @@ def test_should_survive_full_node_replacement_through_nlb(self): all_node_ids = set(original_node_ids) | set(new_node_ids) assert_routes_via_nlb(self, cluster, nlb, all_node_ids) - log.info("Stage 3: All %d nodes via NLB after expansion", - len(all_node_ids)) + log.info( + "Stage 3: All %d nodes via NLB after expansion", len(all_node_ids) + ) # ---- Stage 4: Decommission original nodes ---- - log.info("Stage 4: Decommissioning original nodes %s", original_node_ids) + log.info( + "Stage 4: Decommissioning original nodes %s", original_node_ids + ) remaining_node_ids = set(all_node_ids) remaining_host_ids = dict(all_host_ids) @@ -1245,11 +1415,15 @@ def test_should_survive_full_node_replacement_through_nlb(self): surviving_ips = list(remaining_host_ids.keys()) if surviving_ips: post_routes_for_nlb( - surviving_ips[0], self.connection_id, - remaining_host_ids, nlb, + surviving_ips[0], + self.connection_id, + remaining_host_ids, + nlb, ) - expected_remaining = expected_total - (original_node_ids.index(node_id) + 1) + expected_remaining = expected_total - ( + original_node_ids.index(node_id) + 1 + ) self._wait_for_condition( lambda er=expected_remaining: ( len(cluster.metadata.all_hosts()) <= er @@ -1264,32 +1438,43 @@ def test_should_survive_full_node_replacement_through_nlb(self): # killed the old control connection). handler.initialize( cluster.control_connection._connection, - cluster.control_connection._timeout) + cluster.control_connection._timeout, + ) - assert_routes_via_nlb(self, cluster, nlb, - remaining_node_ids) - log.info("Node %d decommissioned, %d nodes still via NLB", - node_id, len(remaining_node_ids)) + assert_routes_via_nlb(self, cluster, nlb, remaining_node_ids) + log.info( + "Node %d decommissioned, %d nodes still via NLB", + node_id, + len(remaining_node_ids), + ) # ---- Stage 5: Verify with only new nodes ---- - log.info("Stage 5: Verifying session works with only new nodes %s", new_node_ids) + log.info( + "Stage 5: Verifying session works with only new nodes %s", + new_node_ids, + ) self._assert_query_works(session) hosts = cluster.metadata.all_hosts() self.assertEqual( - len(hosts), len(new_node_ids), - "Expected %d hosts, got %d" % (len(new_node_ids), len(hosts)) + len(hosts), + len(new_node_ids), + "Expected %d hosts, got %d" % (len(new_node_ids), len(hosts)), ) for _ in range(10): self._assert_query_works(session) assert_routes_via_nlb(self, cluster, nlb, new_node_ids) - log.info("PASS: Full node replacement, all %d new nodes via NLB", - len(new_node_ids)) + log.info( + "PASS: Full node replacement, all %d new nodes via NLB", + len(new_node_ids), + ) def _assert_query_works(self, session): - rs = session.execute("SELECT release_version FROM system.local WHERE key='local'") + rs = session.execute( + "SELECT release_version FROM system.local WHERE key='local'" + ) row = rs.one() self.assertIsNotNone(row, "Query via NLB should return a result") @@ -1304,7 +1489,7 @@ def _bootstrap_node(self, ccm_cluster, node_id, data_center=None, rack=None): node_type = type(next(iter(ccm_cluster.nodes.values()))) ip = "127.0.0.%d" % node_id node_instance = node_type( - 'node%s' % node_id, + "node%s" % node_id, ccm_cluster, auto_bootstrap=True, thrift_interface=(ip, 9160), @@ -1318,14 +1503,17 @@ def _bootstrap_node(self, ccm_cluster, node_id, data_center=None, rack=None): # cassandra-rackdc.properties is written correctly. Without this the # snitch fails to parse the empty properties file and the node crashes # on startup. - ccm_cluster.add(node_instance, is_seed=False, - data_center=data_center, rack=rack) + ccm_cluster.add( + node_instance, is_seed=False, data_center=data_center, rack=rack + ) node_instance.start(wait_for_binary_proto=True, wait_other_notice=True) wait_for_node_socket(node_instance, 120) log.info("Node %d bootstrapped successfully", node_id) @staticmethod - def _wait_for_condition(predicate, timeout_seconds, poll_interval=2, description="condition"): + def _wait_for_condition( + predicate, timeout_seconds, poll_interval=2, description="condition" + ): deadline = time.time() + timeout_seconds while time.time() < deadline: if predicate(): diff --git a/tests/integration/standard/test_cluster.py b/tests/integration/standard/test_cluster.py index 08b823d716..7f2823429f 100644 --- a/tests/integration/standard/test_cluster.py +++ b/tests/integration/standard/test_cluster.py @@ -27,11 +27,24 @@ import pytest import cassandra -from cassandra.cluster import NoHostAvailable, ExecutionProfile, EXEC_PROFILE_DEFAULT, ControlConnection, Cluster +from cassandra.cluster import ( + NoHostAvailable, + ExecutionProfile, + EXEC_PROFILE_DEFAULT, + ControlConnection, + Cluster, +) from cassandra.concurrent import execute_concurrent -from cassandra.policies import (RoundRobinPolicy, ExponentialReconnectionPolicy, - RetryPolicy, SimpleConvictionPolicy, HostDistance, - AddressTranslator, TokenAwarePolicy, HostFilterPolicy) +from cassandra.policies import ( + RoundRobinPolicy, + ExponentialReconnectionPolicy, + RetryPolicy, + SimpleConvictionPolicy, + HostDistance, + AddressTranslator, + TokenAwarePolicy, + HostFilterPolicy, +) from cassandra import ConsistencyLevel from cassandra.query import SimpleStatement, TraceUnavailable, tuple_factory @@ -40,10 +53,25 @@ from cassandra.connection import DefaultEndPoint from tests import notwindows, notasyncio -from tests.integration import use_cluster, get_server_versions, CASSANDRA_VERSION, \ - execute_until_pass, execute_with_long_wait_retry, get_node, MockLoggingHandler, get_unsupported_lower_protocol, \ - get_unsupported_upper_protocol, local, CASSANDRA_IP, greaterthanorequalcass30, \ - lessthanorequalcass40, TestCluster, PROTOCOL_VERSION, xfail_scylla, incorrect_test +from tests.integration import ( + use_cluster, + get_server_versions, + CASSANDRA_VERSION, + execute_until_pass, + execute_with_long_wait_retry, + get_node, + MockLoggingHandler, + get_unsupported_lower_protocol, + get_unsupported_upper_protocol, + local, + CASSANDRA_IP, + greaterthanorequalcass30, + lessthanorequalcass40, + TestCluster, + PROTOCOL_VERSION, + xfail_scylla, + incorrect_test, +) from tests.integration.util import assert_quiescent_pool_state from tests.util import assertListEqual import sys @@ -56,27 +84,26 @@ def setup_module(): global _saved_scylla_ext_opts - _saved_scylla_ext_opts = os.environ.get('SCYLLA_EXT_OPTS') - os.environ['SCYLLA_EXT_OPTS'] = "--smp 2" + _saved_scylla_ext_opts = os.environ.get("SCYLLA_EXT_OPTS") + os.environ["SCYLLA_EXT_OPTS"] = "--smp 2" use_cluster("cluster_tests", [3], start=True, workloads=None) warnings.simplefilter("always") def teardown_module(): if _saved_scylla_ext_opts is None: - os.environ.pop('SCYLLA_EXT_OPTS', None) + os.environ.pop("SCYLLA_EXT_OPTS", None) else: - os.environ['SCYLLA_EXT_OPTS'] = _saved_scylla_ext_opts + os.environ["SCYLLA_EXT_OPTS"] = _saved_scylla_ext_opts class IgnoredHostPolicy(RoundRobinPolicy): - def __init__(self, ignored_hosts): self.ignored_hosts = ignored_hosts RoundRobinPolicy.__init__(self) def distance(self, host): - if(host.address in self.ignored_hosts): + if host.address in self.ignored_hosts: return HostDistance.IGNORED else: return HostDistance.LOCAL @@ -96,7 +123,11 @@ def test_ignored_host_up(self): """ ignored_host_policy = IgnoredHostPolicy(["127.0.0.2", "127.0.0.3"]) cluster = TestCluster( - execution_profiles={EXEC_PROFILE_DEFAULT: ExecutionProfile(load_balancing_policy=ignored_host_policy)} + execution_profiles={ + EXEC_PROFILE_DEFAULT: ExecutionProfile( + load_balancing_policy=ignored_host_policy + ) + } ) cluster.connect() for host in cluster.metadata.all_hosts(): @@ -118,7 +149,7 @@ def test_host_resolution(self): @test_category connection """ cluster = TestCluster(contact_points=["localhost"], connect_timeout=1) - assert DefaultEndPoint('127.0.0.1') in cluster.endpoints_resolved + assert DefaultEndPoint("127.0.0.1") in cluster.endpoints_resolved @local def test_host_duplication(self): @@ -132,13 +163,21 @@ def test_host_duplication(self): @test_category connection """ cluster = TestCluster( - contact_points=["localhost", "127.0.0.1", "localhost", "localhost", "localhost"], - connect_timeout=1 + contact_points=[ + "localhost", + "127.0.0.1", + "localhost", + "localhost", + "localhost", + ], + connect_timeout=1, ) cluster.connect(wait_for_all_pools=True) assert len(cluster.metadata.all_hosts()) == 3 cluster.shutdown() - cluster = TestCluster(contact_points=["127.0.0.1", "localhost"], connect_timeout=1) + cluster = TestCluster( + contact_points=["127.0.0.1", "localhost"], connect_timeout=1 + ) cluster.connect(wait_for_all_pools=True) assert len(cluster.metadata.all_hosts()) == 3 cluster.shutdown() @@ -162,9 +201,12 @@ def test_raise_error_on_control_connection_timeout(self): """ get_node(1).pause() - cluster = TestCluster(contact_points=['127.0.0.1'], connect_timeout=1) + cluster = TestCluster(contact_points=["127.0.0.1"], connect_timeout=1) - with pytest.raises(NoHostAvailable, match=r"OperationTimedOut\('errors=Timed out creating connection \(1 seconds\)"): + with pytest.raises( + NoHostAvailable, + match=r"OperationTimedOut\('errors=Timed out creating connection \(1 seconds\)", + ): cluster.connect() cluster.shutdown() @@ -177,14 +219,17 @@ def test_basic(self): cluster = TestCluster() session = cluster.connect() - result = execute_until_pass(session, + result = execute_until_pass( + session, """ CREATE KEYSPACE clustertests - WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'} - """) + WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': '1'} + """, + ) assert not result - result = execute_with_long_wait_retry(session, + result = execute_with_long_wait_retry( + session, """ CREATE TABLE clustertests.cf0 ( a text, @@ -192,17 +237,19 @@ def test_basic(self): c text, PRIMARY KEY (a, b) ) - """) + """, + ) assert not result result = session.execute( """ INSERT INTO clustertests.cf0 (a, b, c) VALUES ('a', 'b', 'c') - """) + """ + ) assert not result result = session.execute("SELECT * FROM clustertests.cf0") - assert [('a', 'b', 'c')] == result + assert [("a", "b", "c")] == result execute_with_long_wait_retry(session, "DROP KEYSPACE clustertests") @@ -222,12 +269,14 @@ def test_session_host_parameter(self): @test_category connection """ + def cleanup(): """ When this test fails, the inline .shutdown() calls don't get called, so we register this as a cleanup. """ self.cluster_to_shutdown.shutdown() + self.addCleanup(cleanup) # Test with empty list @@ -237,13 +286,15 @@ def cleanup(): self.cluster_to_shutdown.shutdown() # Test with only invalid - self.cluster_to_shutdown = TestCluster(contact_points=('1.2.3.4',)) + self.cluster_to_shutdown = TestCluster(contact_points=("1.2.3.4",)) with pytest.raises(NoHostAvailable): self.cluster_to_shutdown.connect() self.cluster_to_shutdown.shutdown() # Test with valid and invalid hosts - self.cluster_to_shutdown = TestCluster(contact_points=("127.0.0.1", "127.0.0.2", "1.2.3.4")) + self.cluster_to_shutdown = TestCluster( + contact_points=("127.0.0.1", "127.0.0.2", "1.2.3.4") + ) self.cluster_to_shutdown.connect() self.cluster_to_shutdown.shutdown() @@ -269,25 +320,25 @@ def test_protocol_negotiation(self): updated_protocol_version = session._protocol_version updated_cluster_version = cluster.protocol_version # Make sure the correct protocol was selected by default - if CASSANDRA_VERSION >= Version('4.0-beta5'): + if CASSANDRA_VERSION >= Version("4.0-beta5"): assert updated_protocol_version == cassandra.ProtocolVersion.V5 assert updated_cluster_version == cassandra.ProtocolVersion.V5 - elif CASSANDRA_VERSION >= Version('4.0-a'): + elif CASSANDRA_VERSION >= Version("4.0-a"): assert updated_protocol_version == cassandra.ProtocolVersion.V4 assert updated_cluster_version == cassandra.ProtocolVersion.V4 - elif CASSANDRA_VERSION >= Version('3.11'): + elif CASSANDRA_VERSION >= Version("3.11"): assert updated_protocol_version == cassandra.ProtocolVersion.V4 assert updated_cluster_version == cassandra.ProtocolVersion.V4 - elif CASSANDRA_VERSION >= Version('3.0'): + elif CASSANDRA_VERSION >= Version("3.0"): assert updated_protocol_version == cassandra.ProtocolVersion.V4 assert updated_cluster_version == cassandra.ProtocolVersion.V4 - elif CASSANDRA_VERSION >= Version('2.2'): + elif CASSANDRA_VERSION >= Version("2.2"): assert updated_protocol_version == 4 assert updated_cluster_version == 4 - elif CASSANDRA_VERSION >= Version('2.1'): + elif CASSANDRA_VERSION >= Version("2.1"): assert updated_protocol_version == 3 assert updated_cluster_version == 3 - elif CASSANDRA_VERSION >= Version('2.0'): + elif CASSANDRA_VERSION >= Version("2.0"): assert updated_protocol_version == 2 assert updated_cluster_version == 2 else: @@ -319,7 +370,7 @@ def test_invalid_protocol_negotation(self): """ upper_bound = get_unsupported_upper_protocol() - log.debug('got upper_bound of {}'.format(upper_bound)) + log.debug("got upper_bound of {}".format(upper_bound)) if upper_bound is not None: cluster = TestCluster(protocol_version=upper_bound) with pytest.raises(NoHostAvailable): @@ -327,7 +378,7 @@ def test_invalid_protocol_negotation(self): cluster.shutdown() lower_bound = get_unsupported_lower_protocol() - log.debug('got lower_bound of {}'.format(lower_bound)) + log.debug("got lower_bound of {}".format(lower_bound)) if lower_bound is not None: cluster = TestCluster(protocol_version=lower_bound) with pytest.raises(NoHostAvailable): @@ -344,14 +395,17 @@ def test_connect_on_keyspace(self): result = session.execute( """ INSERT INTO test1rf.test (k, v) VALUES (8889, 8889) - """) + """ + ) assert not result result = session.execute("SELECT * FROM test1rf.test") - assert [(8889, 8889)] == result, "Rows in ResultSet are {0}".format(result.current_rows) + assert [(8889, 8889)] == result, "Rows in ResultSet are {0}".format( + result.current_rows + ) # test_connect_on_keyspace - session2 = cluster.connect('test1rf') + session2 = cluster.connect("test1rf") result2 = session2.execute("SELECT * FROM test") assert result == result2 cluster.shutdown() @@ -371,7 +425,7 @@ def test_default_connections(self): TestCluster( reconnection_policy=ExponentialReconnectionPolicy(1.0, 600.0), conviction_policy_factory=SimpleConvictionPolicy, - protocol_version=PROTOCOL_VERSION + protocol_version=PROTOCOL_VERSION, ) def test_connect_to_already_shutdown_cluster(self): @@ -391,18 +445,18 @@ def test_auth_provider_is_callable(self): Cluster(auth_provider=1, protocol_version=1) c = TestCluster(protocol_version=1) with pytest.raises(TypeError): - setattr(c, 'auth_provider', 1) + setattr(c, "auth_provider", 1) def test_v2_auth_provider(self): """ Check for v2 auth_provider compliance """ - bad_auth_provider = lambda x: {'username': 'foo', 'password': 'bar'} + bad_auth_provider = lambda x: {"username": "foo", "password": "bar"} with pytest.raises(TypeError): Cluster(auth_provider=bad_auth_provider, protocol_version=2) c = TestCluster(protocol_version=2) with pytest.raises(TypeError): - setattr(c, 'auth_provider', bad_auth_provider) + setattr(c, "auth_provider", bad_auth_provider) def test_conviction_policy_factory_is_callable(self): """ @@ -418,8 +472,10 @@ def test_connect_to_bad_hosts(self): when a cluster cannot connect to given hosts """ - cluster = TestCluster(contact_points=['127.1.2.9', '127.1.2.10'], - protocol_version=PROTOCOL_VERSION) + cluster = TestCluster( + contact_points=["127.1.2.9", "127.1.2.10"], + protocol_version=PROTOCOL_VERSION, + ) with pytest.raises(NoHostAvailable): cluster.connect() @@ -440,13 +496,13 @@ def test_refresh_schema_keyspace(self): session = cluster.connect() original_meta = cluster.metadata.keyspaces - original_system_meta = original_meta['system'] + original_system_meta = original_meta["system"] # only refresh one keyspace - cluster.refresh_keyspace_metadata('system') + cluster.refresh_keyspace_metadata("system") current_meta = cluster.metadata.keyspaces assert original_meta is current_meta - current_system_meta = current_meta['system'] + current_system_meta = current_meta["system"] assert original_system_meta is not current_system_meta assert original_system_meta.as_cql_query() == current_system_meta.as_cql_query() cluster.shutdown() @@ -456,46 +512,58 @@ def test_refresh_schema_table(self): session = cluster.connect() original_meta = cluster.metadata.keyspaces - original_system_meta = original_meta['system'] - original_system_schema_meta = original_system_meta.tables['local'] + original_system_meta = original_meta["system"] + original_system_schema_meta = original_system_meta.tables["local"] # only refresh one table - cluster.refresh_table_metadata('system', 'local') + cluster.refresh_table_metadata("system", "local") current_meta = cluster.metadata.keyspaces - current_system_meta = current_meta['system'] - current_system_schema_meta = current_system_meta.tables['local'] + current_system_meta = current_meta["system"] + current_system_schema_meta = current_system_meta.tables["local"] assert original_meta is current_meta assert original_system_meta is current_system_meta assert original_system_schema_meta is not current_system_schema_meta - assert original_system_schema_meta.as_cql_query() == current_system_schema_meta.as_cql_query() + assert ( + original_system_schema_meta.as_cql_query() + == current_system_schema_meta.as_cql_query() + ) cluster.shutdown() def test_refresh_schema_type(self): if get_server_versions()[0] < (2, 1, 0): - raise unittest.SkipTest('UDTs were introduced in Cassandra 2.1') + raise unittest.SkipTest("UDTs were introduced in Cassandra 2.1") if PROTOCOL_VERSION < 3: - raise unittest.SkipTest('UDTs are not specified in change events for protocol v2') + raise unittest.SkipTest( + "UDTs are not specified in change events for protocol v2" + ) # We may want to refresh types on keyspace change events in that case(?) cluster = TestCluster() session = cluster.connect() - keyspace_name = 'test1rf' + keyspace_name = "test1rf" type_name = self._testMethodName - execute_until_pass(session, 'CREATE TYPE IF NOT EXISTS %s.%s (one int, two text)' % (keyspace_name, type_name)) + execute_until_pass( + session, + "CREATE TYPE IF NOT EXISTS %s.%s (one int, two text)" + % (keyspace_name, type_name), + ) original_meta = cluster.metadata.keyspaces original_test1rf_meta = original_meta[keyspace_name] original_type_meta = original_test1rf_meta.user_types[type_name] # only refresh one type - cluster.refresh_user_type_metadata('test1rf', type_name) + cluster.refresh_user_type_metadata("test1rf", type_name) current_meta = cluster.metadata.keyspaces current_test1rf_meta = current_meta[keyspace_name] current_type_meta = current_test1rf_meta.user_types[type_name] assert original_meta is current_meta - assert original_test1rf_meta.export_as_string() == current_test1rf_meta.export_as_string() + assert ( + original_test1rf_meta.export_as_string() + == current_test1rf_meta.export_as_string() + ) assert original_type_meta is not current_type_meta assert original_type_meta.as_cql_query() == current_type_meta.as_cql_query() cluster.shutdown() @@ -508,12 +576,19 @@ def test_refresh_schema_no_wait(self): def patched_wait_for_responses(*args, **kwargs): # When selecting schema version, replace the real schema UUID with an unexpected UUID response = original_wait_for_responses(*args, **kwargs) - if len(args) > 2 and hasattr(args[2], "query") and "SELECT schema_version FROM system.local WHERE key='local'" in args[2].query: + if ( + len(args) > 2 + and hasattr(args[2], "query") + and "SELECT schema_version FROM system.local WHERE key='local'" + in args[2].query + ): new_uuid = uuid4() response[1].parsed_rows[0] = (new_uuid,) return response - with patch.object(connection.Connection, "wait_for_responses", patched_wait_for_responses): + with patch.object( + connection.Connection, "wait_for_responses", patched_wait_for_responses + ): agreement_timeout = 1 # cluster agreement wait exceeded @@ -577,7 +652,9 @@ def test_trace(self): cluster = TestCluster() session = cluster.connect() - result = session.execute( "SELECT * FROM system.local WHERE key='local'", trace=True) + result = session.execute( + "SELECT * FROM system.local WHERE key='local'", trace=True + ) self._check_trace(result.get_query_trace()) query = "SELECT * FROM system.local WHERE key='local'" @@ -619,7 +696,7 @@ def test_trace_unavailable(self): @expected_result TraceUnavailable is arisen in both cases @test_category query - """ + """ cluster = TestCluster() self.addCleanup(cluster.shutdown) session = cluster.connect() @@ -638,8 +715,11 @@ def test_trace_unavailable(self): except TraceUnavailable: break else: - raise Exception("get_query_trace didn't raise TraceUnavailable after {} tries".format(max_retry_count)) - + raise Exception( + "get_query_trace didn't raise TraceUnavailable after {} tries".format( + max_retry_count + ) + ) for i in range(max_retry_count): future = session.execute_async(statement, trace=True) @@ -650,7 +730,11 @@ def test_trace_unavailable(self): except TraceUnavailable: break else: - raise Exception("get_query_trace didn't raise TraceUnavailable after {} tries".format(max_retry_count)) + raise Exception( + "get_query_trace didn't raise TraceUnavailable after {} tries".format( + max_retry_count + ) + ) def test_one_returns_none(self): """ @@ -664,7 +748,12 @@ def test_one_returns_none(self): """ with TestCluster() as cluster: session = cluster.connect() - assert session.execute("SELECT * from system.local WHERE key='madeup_key'").one() is None + assert ( + session.execute( + "SELECT * from system.local WHERE key='madeup_key'" + ).one() + is None + ) def test_string_coverage(self): """ @@ -682,7 +771,7 @@ def test_string_coverage(self): future.result() assert query in str(future) - assert 'result' in str(future) + assert "result" in str(future) cluster.shutdown() def test_can_connect_with_plainauth(self): @@ -699,8 +788,7 @@ def test_can_connect_with_plainauth(self): @test_category auth """ auth_provider = PlainTextAuthProvider( - username="made_up_username", - password="made_up_password" + username="made_up_username", password="made_up_password" ) self._warning_are_issued_when_auth(auth_provider) @@ -717,11 +805,13 @@ def test_can_connect_with_sslauth(self): @test_category auth """ - sasl_kwargs = {'service': 'cassandra', - 'mechanism': 'PLAIN', - 'qops': ['auth'], - 'username': "made_up_username", - 'password': "made_up_password"} + sasl_kwargs = { + "service": "cassandra", + "mechanism": "PLAIN", + "qops": ["auth"], + "username": "made_up_username", + "password": "made_up_password", + } auth_provider = SaslAuthProvider(**sasl_kwargs) self._warning_are_issued_when_auth(auth_provider) @@ -730,37 +820,50 @@ def _warning_are_issued_when_auth(self, auth_provider): with MockLoggingHandler().set_module_name(connection.__name__) as mock_handler: with TestCluster(auth_provider=auth_provider) as cluster: session = cluster.connect() - assert session.execute("SELECT * from system.local WHERE key='local'") is not None + assert ( + session.execute("SELECT * from system.local WHERE key='local'") + is not None + ) # Verify that auth warnings are issued for connections where # auth is configured but the server does not send a challenge. # At minimum one warning per node connection (3 for a 3-node # cluster). The control connection and shard-aware connections # may add more, so we only assert a lower bound. - auth_warning = mock_handler.get_message_count('warning', "An authentication challenge was not sent") + auth_warning = mock_handler.get_message_count( + "warning", "An authentication challenge was not sent" + ) assert auth_warning >= 3 def _wait_for_all_shard_connections(self, cluster, timeout=30): """Wait until all shard-aware connections are fully established.""" from cassandra.pool import HostConnection + deadline = time.time() + timeout while time.time() < deadline: all_connected = True for holder in cluster.get_connection_holders(): if not isinstance(holder, HostConnection): continue - if holder.host.sharding_info and len(holder._connections) < holder.host.sharding_info.shards_count: + if ( + holder.host.sharding_info + and len(holder._connections) + < holder.host.sharding_info.shards_count + ): all_connected = False break if all_connected: return time.sleep(0.1) - raise RuntimeError("Timed out waiting for all shard connections to be established") + raise RuntimeError( + "Timed out waiting for all shard connections to be established" + ) def test_idle_heartbeat(self): interval = 2 - cluster = TestCluster(idle_heartbeat_interval=interval, - monitor_reporting_enabled=False) + cluster = TestCluster( + idle_heartbeat_interval=interval, monitor_reporting_enabled=False + ) session = cluster.connect(wait_for_all_pools=True) # wait_for_all_pools only waits for the first connection per host; @@ -776,12 +879,18 @@ def test_idle_heartbeat(self): # make sure none are idle (should have startup messages assert not c.is_idle with c.lock: - connection_request_ids[id(c)] = deque(c.request_ids) # copy of request ids + connection_request_ids[id(c)] = deque( + c.request_ids + ) # copy of request ids # let two heatbeat intervals pass (first one had startup messages in it) - time.sleep(2 * interval + interval/2) + time.sleep(2 * interval + interval / 2) - connections = [c for holders in cluster.get_connection_holders() for c in holders.get_connections()] + connections = [ + c + for holders in cluster.get_connection_holders() + for c in holders.get_connections() + ] # make sure requests were sent on all connections for c in connections: @@ -797,27 +906,35 @@ def test_idle_heartbeat(self): # (with shard-aware routing, each query only hits one shard per host, # so we need more queries than just len(hosts) to cover all connections) num_connections = len([c for c in connections if not c.is_control_connection]) - statements_and_params = [("SELECT release_version FROM system.local WHERE key='local'", ())] * max(num_connections * 2, len(cluster.metadata.all_hosts())) + statements_and_params = [ + ("SELECT release_version FROM system.local WHERE key='local'", ()) + ] * max(num_connections * 2, len(cluster.metadata.all_hosts())) results = execute_concurrent(session, statements_and_params) for success, result in results: assert success # assert at least some non-control connections are no longer idle # (shard-aware routing may not distribute queries to every connection) - non_idle = [c for c in connections if not c.is_control_connection and not c.is_idle] + non_idle = [ + c for c in connections if not c.is_control_connection and not c.is_idle + ] assert len(non_idle) > 0 # holders include session pools and cc holders = cluster.get_connection_holders() assert cluster.control_connection in holders - assert len(holders) == len(cluster.metadata.all_hosts()) + 1 # hosts pools, 1 for cc + assert ( + len(holders) == len(cluster.metadata.all_hosts()) + 1 + ) # hosts pools, 1 for cc # include additional sessions session2 = cluster.connect(wait_for_all_pools=True) holders = cluster.get_connection_holders() assert cluster.control_connection in holders - assert len(holders) == 2 * len(cluster.metadata.all_hosts()) + 1 # 2 sessions' hosts pools, 1 for cc + assert ( + len(holders) == 2 * len(cluster.metadata.all_hosts()) + 1 + ) # 2 sessions' hosts pools, 1 for cc cluster._idle_heartbeat.stop() cluster._idle_heartbeat.join() @@ -825,7 +942,7 @@ def test_idle_heartbeat(self): cluster.shutdown() - @patch('cassandra.cluster.Cluster.idle_heartbeat_interval', new=0.1) + @patch("cassandra.cluster.Cluster.idle_heartbeat_interval", new=0.1) def test_idle_heartbeat_disabled(self): assert Cluster.idle_heartbeat_interval @@ -837,7 +954,11 @@ def test_idle_heartbeat_disabled(self): # let two heatbeat intervals pass (first one had startup messages in it) time.sleep(2 * Cluster.idle_heartbeat_interval) - connections = [c for holders in cluster.get_connection_holders() for c in holders.get_connections()] + connections = [ + c + for holders in cluster.get_connection_holders() + for c in holders.get_connections() + ] # assert not idle status (should never get reset because there is not heartbeat) assert not any(c.is_idle for c in connections) @@ -846,24 +967,26 @@ def test_idle_heartbeat_disabled(self): def test_pool_management(self): # Ensure that in_flight and request_ids quiesce after cluster operations - cluster = TestCluster(idle_heartbeat_interval=0) # no idle heartbeat here, pool management is tested in test_idle_heartbeat + cluster = TestCluster( + idle_heartbeat_interval=0 + ) # no idle heartbeat here, pool management is tested in test_idle_heartbeat session = cluster.connect() session2 = cluster.connect() # prepare p = session.prepare("SELECT * FROM system.local WHERE key=?") - assert session.execute(p, ('local',)) + assert session.execute(p, ("local",)) # simple assert session.execute("SELECT * FROM system.local WHERE key='local'") # set keyspace - session.set_keyspace('system') - session.set_keyspace('system_traces') + session.set_keyspace("system") + session.set_keyspace("system_traces") # use keyspace - session.execute('USE system') - session.execute('USE system_traces') + session.execute("USE system") + session.execute("USE system_traces") # refresh schema cluster.refresh_schema_metadata() @@ -890,7 +1013,9 @@ def test_profile_load_balancing(self): RoundRobinPolicy(), lambda host: host.address == CASSANDRA_IP ) ) - with TestCluster(execution_profiles={'node1': node1}, monitor_reporting_enabled=False) as cluster: + with TestCluster( + execution_profiles={"node1": node1}, monitor_reporting_enabled=False + ) as cluster: session = cluster.connect(wait_for_all_pools=True) # default is DCA RR for all hosts @@ -902,10 +1027,12 @@ def test_profile_load_balancing(self): assert queried_hosts == expected_hosts # by name we should only hit the one - expected_hosts = set(h for h in cluster.metadata.all_hosts() if h.address == CASSANDRA_IP) + expected_hosts = set( + h for h in cluster.metadata.all_hosts() if h.address == CASSANDRA_IP + ) queried_hosts = set() for _ in cluster.metadata.all_hosts(): - rs = session.execute(query, execution_profile='node1') + rs = session.execute(query, execution_profile="node1") queried_hosts.add(rs.response_future._current_host) assert queried_hosts == expected_hosts @@ -928,7 +1055,9 @@ def test_profile_load_balancing(self): tuple_row.release_version # make sure original profile is not impacted - assert session.execute(query, execution_profile='node1').one().release_version + assert ( + session.execute(query, execution_profile="node1").one().release_version + ) def test_setting_lbp_legacy(self): cluster = TestCluster() @@ -955,7 +1084,7 @@ def test_profile_lb_swap(self): query = "select release_version from system.local where key='local'" rr1 = ExecutionProfile(load_balancing_policy=RoundRobinPolicy()) rr2 = ExecutionProfile(load_balancing_policy=RoundRobinPolicy()) - exec_profiles = {'rr1': rr1, 'rr2': rr2} + exec_profiles = {"rr1": rr1, "rr2": rr2} with TestCluster(execution_profiles=exec_profiles) as cluster: session = cluster.connect(wait_for_all_pools=True) @@ -967,9 +1096,9 @@ def test_profile_lb_swap(self): rr2_queried_hosts = [] for _ in range(num_hosts * 2): - rs = session.execute(query, execution_profile='rr1') + rs = session.execute(query, execution_profile="rr1") rr1_queried_hosts.append(rs.response_future._current_host) - rs = session.execute(query, execution_profile='rr2') + rs = session.execute(query, execution_profile="rr2") rr2_queried_hosts.append(rs.response_future._current_host) # Both policies should have queried all hosts @@ -997,7 +1126,7 @@ def test_ta_lbp(self): with TestCluster() as cluster: session = cluster.connect() cluster.add_execution_profile("ta1", ta1) - rs = session.execute(query, execution_profile='ta1') + rs = session.execute(query, execution_profile="ta1") def test_clone_shared_lbp(self): """ @@ -1014,18 +1143,22 @@ def test_clone_shared_lbp(self): """ query = "select release_version from system.local where key='local'" rr1 = ExecutionProfile(load_balancing_policy=RoundRobinPolicy()) - exec_profiles = {'rr1': rr1} + exec_profiles = {"rr1": rr1} with TestCluster(execution_profiles=exec_profiles) as cluster: session = cluster.connect(wait_for_all_pools=True) - assert len(cluster.metadata.all_hosts()) > 1, "We only have one host connected at this point" + assert len(cluster.metadata.all_hosts()) > 1, ( + "We only have one host connected at this point" + ) - rr1_clone = session.execution_profile_clone_update('rr1', row_factory=tuple_factory) + rr1_clone = session.execution_profile_clone_update( + "rr1", row_factory=tuple_factory + ) cluster.add_execution_profile("rr1_clone", rr1_clone) rr1_queried_hosts = set() rr1_clone_queried_hosts = set() - rs = session.execute(query, execution_profile='rr1') + rs = session.execute(query, execution_profile="rr1") rr1_queried_hosts.add(rs.response_future._current_host) - rs = session.execute(query, execution_profile='rr1_clone') + rs = session.execute(query, execution_profile="rr1_clone") rr1_clone_queried_hosts.add(rs.response_future._current_host) assert rr1_clone_queried_hosts != rr1_queried_hosts @@ -1042,11 +1175,11 @@ def test_missing_exec_prof(self): query = "select release_version from system.local where key='local'" rr1 = ExecutionProfile(load_balancing_policy=RoundRobinPolicy()) rr2 = ExecutionProfile(load_balancing_policy=RoundRobinPolicy()) - exec_profiles = {'rr1': rr1, 'rr2': rr2} + exec_profiles = {"rr1": rr1, "rr2": rr2} with TestCluster(execution_profiles=exec_profiles) as cluster: session = cluster.connect() with pytest.raises(ValueError): - session.execute(query, execution_profile='rr3') + session.execute(query, execution_profile="rr3") @local def test_profile_pool_management(self): @@ -1070,12 +1203,14 @@ def test_profile_pool_management(self): RoundRobinPolicy(), lambda host: host.address == "127.0.0.2" ) ) - with TestCluster(execution_profiles={EXEC_PROFILE_DEFAULT: node1, 'node2': node2}) as cluster: + with TestCluster( + execution_profiles={EXEC_PROFILE_DEFAULT: node1, "node2": node2} + ) as cluster: session = cluster.connect(wait_for_all_pools=True) pools = session.get_pool_state() # there are more hosts, but we connected to the ones in the lbp aggregate assert len(cluster.metadata.all_hosts()) > 2 - assert set(h.address for h in pools) == set(('127.0.0.1', '127.0.0.2')) + assert set(h.address for h in pools) == set(("127.0.0.1", "127.0.0.2")) # dynamically update pools on add node3 = ExecutionProfile( @@ -1083,9 +1218,11 @@ def test_profile_pool_management(self): RoundRobinPolicy(), lambda host: host.address == "127.0.0.3" ) ) - cluster.add_execution_profile('node3', node3) + cluster.add_execution_profile("node3", node3) pools = session.get_pool_state() - assert set(h.address for h in pools) == set(('127.0.0.1', '127.0.0.2', '127.0.0.3')) + assert set(h.address for h in pools) == set( + ("127.0.0.1", "127.0.0.2", "127.0.0.3") + ) @local def test_add_profile_timeout(self): @@ -1105,29 +1242,39 @@ def test_add_profile_timeout(self): RoundRobinPolicy(), lambda host: host.address == "127.0.0.1" ) ) - with TestCluster(execution_profiles={EXEC_PROFILE_DEFAULT: node1}) as cluster: + with TestCluster( + execution_profiles={EXEC_PROFILE_DEFAULT: node1} + ) as cluster: session = cluster.connect(wait_for_all_pools=True) pools = session.get_pool_state() assert len(cluster.metadata.all_hosts()) > 2 - assert set(h.address for h in pools) == set(('127.0.0.1',)) + assert set(h.address for h in pools) == set(("127.0.0.1",)) node2 = ExecutionProfile( load_balancing_policy=HostFilterPolicy( - RoundRobinPolicy(), lambda host: host.address in ["127.0.0.2", "127.0.0.3"] + RoundRobinPolicy(), + lambda host: host.address in ["127.0.0.2", "127.0.0.3"], ) ) start = time.time() try: with pytest.raises(cassandra.OperationTimedOut): - cluster.add_execution_profile('profile_{0}'.format(i), - node2, pool_wait_timeout=sys.float_info.min) + cluster.add_execution_profile( + "profile_{0}".format(i), + node2, + pool_wait_timeout=sys.float_info.min, + ) break except AssertionError: end = time.time() assert start == pytest.approx(end, abs=1e-1) else: - raise Exception("add_execution_profile didn't timeout after {0} retries".format(max_retry_count)) + raise Exception( + "add_execution_profile didn't timeout after {0} retries".format( + max_retry_count + ) + ) def test_stale_connections_after_shutdown(self): """ @@ -1136,7 +1283,9 @@ def test_stale_connections_after_shutdown(self): """ for _ in range(10): with TestCluster(protocol_version=3) as cluster: - cluster.connect(wait_for_all_pools=True).execute("SELECT * FROM system_schema.keyspaces") + cluster.connect(wait_for_all_pools=True).execute( + "SELECT * FROM system_schema.keyspaces" + ) with TestCluster(protocol_version=3) as cluster: session = cluster.connect() @@ -1151,10 +1300,14 @@ def test_stale_connections_after_shutdown(self): with TestCluster(protocol_version=3) as cluster: cluster.connect() - result = subprocess.run(["lsof -nP | awk '$3 ~ \":9042\" {print $0}' | grep ''"], shell=True, capture_output=True) + result = subprocess.run( + ["lsof -nP | awk '$3 ~ \":9042\" {print $0}' | grep ''"], + shell=True, + capture_output=True, + ) if result.returncode: continue - assert False, f'Found stale connections: {result.stdout}' + assert False, f"Found stale connections: {result.stdout}" @notwindows @notasyncio # asyncio can't do timeouts smaller than 1ms, as this test requires @@ -1181,11 +1334,16 @@ def test_execute_query_timeout(self): break except: import traceback + traceback.print_exc() end = time.time() assert start == pytest.approx(end, abs=1e-1) else: - raise Exception("session.execute didn't time out in {0} tries".format(max_retry_count)) + raise Exception( + "session.execute didn't time out in {0} tries".format( + max_retry_count + ) + ) def test_replicas_are_queried(self): """ @@ -1205,43 +1363,57 @@ def test_replicas_are_queried(self): tap_profile = ExecutionProfile( load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy()) ) - with TestCluster(execution_profiles={EXEC_PROFILE_DEFAULT: tap_profile}) as cluster: + with TestCluster( + execution_profiles={EXEC_PROFILE_DEFAULT: tap_profile} + ) as cluster: session = cluster.connect(wait_for_all_pools=True) - session.execute(''' + session.execute(""" CREATE TABLE test1rf.table_with_big_key ( k1 int, k2 int, k3 int, k4 int, - PRIMARY KEY((k1, k2, k3), k4))''') + PRIMARY KEY((k1, k2, k3), k4))""") prepared = session.prepare("""SELECT * from test1rf.table_with_big_key WHERE k1 = ? AND k2 = ? AND k3 = ? AND k4 = ?""") for i in range(10): result = session.execute(prepared, (i, i, i, i), trace=True) - trace = result.response_future.get_query_trace(query_cl=ConsistencyLevel.ALL) + trace = result.response_future.get_query_trace( + query_cl=ConsistencyLevel.ALL + ) queried_hosts = self._assert_replica_queried(trace, only_replicas=True) last_i = i hfp_profile = ExecutionProfile( - load_balancing_policy=HostFilterPolicy(RoundRobinPolicy(), - predicate=lambda host: host.address != only_replica) + load_balancing_policy=HostFilterPolicy( + RoundRobinPolicy(), predicate=lambda host: host.address != only_replica + ) ) only_replica = queried_hosts.pop() log = logging.getLogger(__name__) log.info("The only replica found was: {}".format(only_replica)) - available_hosts = [host for host in ["127.0.0.1", "127.0.0.2", "127.0.0.3"] if host != only_replica] - with TestCluster(contact_points=available_hosts, - execution_profiles={EXEC_PROFILE_DEFAULT: hfp_profile}) as cluster: - + available_hosts = [ + host + for host in ["127.0.0.1", "127.0.0.2", "127.0.0.3"] + if host != only_replica + ] + with TestCluster( + contact_points=available_hosts, + execution_profiles={EXEC_PROFILE_DEFAULT: hfp_profile}, + ) as cluster: session = cluster.connect(wait_for_all_pools=True) prepared = session.prepare("""SELECT * from test1rf.table_with_big_key WHERE k1 = ? AND k2 = ? AND k3 = ? AND k4 = ?""") for _ in range(10): - result = session.execute(prepared, (last_i, last_i, last_i, last_i), trace=True) - trace = result.response_future.get_query_trace(query_cl=ConsistencyLevel.ALL) + result = session.execute( + prepared, (last_i, last_i, last_i, last_i), trace=True + ) + trace = result.response_future.get_query_trace( + query_cl=ConsistencyLevel.ALL + ) self._assert_replica_queried(trace, only_replicas=False) - session.execute('''DROP TABLE test1rf.table_with_big_key''') + session.execute("""DROP TABLE test1rf.table_with_big_key""") @greaterthanorequalcass30 @lessthanorequalcass40 @@ -1270,40 +1442,51 @@ def test_compact_option(self): session.set_keyspace("test3rf") nc_session.execute( - "CREATE TABLE IF NOT EXISTS compact_table (k int PRIMARY KEY, v1 int, v2 int) WITH COMPACT STORAGE;") + "CREATE TABLE IF NOT EXISTS compact_table (k int PRIMARY KEY, v1 int, v2 int) WITH COMPACT STORAGE;" + ) for i in range(1, 5): nc_session.execute( "INSERT INTO compact_table (k, column1, v1, v2, value) VALUES " - "({i}, 'a{i}', {i}, {i}, textAsBlob('b{i}'))".format(i=i)) + "({i}, 'a{i}', {i}, {i}, textAsBlob('b{i}'))".format(i=i) + ) nc_session.execute( "INSERT INTO compact_table (k, column1, v1, v2, value) VALUES " - "({i}, 'a{i}{i}', {i}{i}, {i}{i}, textAsBlob('b{i}{i}'))".format(i=i)) + "({i}, 'a{i}{i}', {i}{i}, {i}{i}, textAsBlob('b{i}{i}'))".format(i=i) + ) nc_results = nc_session.execute("SELECT * FROM compact_table") - assert set(nc_results.current_rows) == {(1, u'a1', 11, 11, 'b1'), - (1, u'a11', 11, 11, 'b11'), - (2, u'a2', 22, 22, 'b2'), - (2, u'a22', 22, 22, 'b22'), - (3, u'a3', 33, 33, 'b3'), - (3, u'a33', 33, 33, 'b33'), - (4, u'a4', 44, 44, 'b4'), - (4, u'a44', 44, 44, 'b44')} + assert set(nc_results.current_rows) == { + (1, "a1", 11, 11, "b1"), + (1, "a11", 11, 11, "b11"), + (2, "a2", 22, 22, "b2"), + (2, "a22", 22, 22, "b22"), + (3, "a3", 33, 33, "b3"), + (3, "a33", 33, 33, "b33"), + (4, "a4", 44, 44, "b4"), + (4, "a44", 44, 44, "b44"), + } results = session.execute("SELECT * FROM compact_table") - assert set(results.current_rows) == {(1, 11, 11), - (2, 22, 22), - (3, 33, 33), - (4, 44, 44)} + assert set(results.current_rows) == { + (1, 11, 11), + (2, 22, 22), + (3, 33, 33), + (4, 44, 44), + } def _assert_replica_queried(self, trace, only_replicas=True): queried_hosts = set() for row in trace.events: queried_hosts.add(row.source) if only_replicas: - assert len(queried_hosts) == 1, "The hosts queried where {}".format(queried_hosts) + assert len(queried_hosts) == 1, "The hosts queried where {}".format( + queried_hosts + ) else: - assert len(queried_hosts) > 1, "The host queried was {}".format(queried_hosts) + assert len(queried_hosts) > 1, "The host queried was {}".format( + queried_hosts + ) return queried_hosts def _check_trace(self, trace): @@ -1315,7 +1498,6 @@ def _check_trace(self, trace): class LocalHostAdressTranslator(AddressTranslator): - def __init__(self, addr_map=None): self.addr_map = addr_map @@ -1323,9 +1505,9 @@ def translate(self, addr): new_addr = self.addr_map.get(addr) return new_addr + @local class TestAddressTranslation(unittest.TestCase): - def test_address_translator_basic(self): """ Test host address translation @@ -1340,7 +1522,13 @@ def test_address_translator_basic(self): @test_category metadata """ - lh_ad = LocalHostAdressTranslator({'127.0.0.1': '127.0.0.1', '127.0.0.2': '127.0.0.1', '127.0.0.3': '127.0.0.1'}) + lh_ad = LocalHostAdressTranslator( + { + "127.0.0.1": "127.0.0.1", + "127.0.0.2": "127.0.0.1", + "127.0.0.3": "127.0.0.1", + } + ) c = TestCluster(address_translator=lh_ad) c.connect() assert len(c.metadata.all_hosts()) == 1 @@ -1359,7 +1547,11 @@ def test_address_translator_with_mixed_nodes(self): @test_category metadata """ - adder_map = {'127.0.0.1': '127.0.0.1', '127.0.0.2': '127.0.0.3', '127.0.0.3': '127.0.0.2'} + adder_map = { + "127.0.0.1": "127.0.0.1", + "127.0.0.2": "127.0.0.3", + "127.0.0.3": "127.0.0.2", + } lh_ad = LocalHostAdressTranslator(adder_map) c = TestCluster(address_translator=lh_ad) c.connect() @@ -1367,15 +1559,21 @@ def test_address_translator_with_mixed_nodes(self): assert adder_map.get(host.address) == host.broadcast_address c.shutdown() + @local class ContextManagementTest(unittest.TestCase): load_balancing_policy = HostFilterPolicy( RoundRobinPolicy(), lambda host: host.address == CASSANDRA_IP ) - cluster_kwargs = {'execution_profiles': {EXEC_PROFILE_DEFAULT: ExecutionProfile(load_balancing_policy= - load_balancing_policy)}, - 'schema_metadata_enabled': False, - 'token_metadata_enabled': False} + cluster_kwargs = { + "execution_profiles": { + EXEC_PROFILE_DEFAULT: ExecutionProfile( + load_balancing_policy=load_balancing_policy + ) + }, + "schema_metadata_enabled": False, + "token_metadata_enabled": False, + } def test_no_connect(self): """ @@ -1405,7 +1603,7 @@ def test_simple_nested(self): with cluster.connect() as session: assert not cluster.is_shutdown assert not session.is_shutdown - assert session.execute('select release_version from system.local').one() + assert session.execute("select release_version from system.local").one() assert session.is_shutdown assert cluster.is_shutdown @@ -1423,7 +1621,7 @@ def test_cluster_no_session(self): session = cluster.connect() assert not cluster.is_shutdown assert not session.is_shutdown - assert session.execute('select release_version from system.local').one() + assert session.execute("select release_version from system.local").one() assert session.is_shutdown assert cluster.is_shutdown @@ -1443,7 +1641,7 @@ def test_session_no_cluster(self): assert not cluster.is_shutdown assert not session.is_shutdown assert not unmanaged_session.is_shutdown - assert session.execute('select release_version from system.local').one() + assert session.execute("select release_version from system.local").one() assert session.is_shutdown assert not cluster.is_shutdown assert not unmanaged_session.is_shutdown @@ -1455,7 +1653,6 @@ def test_session_no_cluster(self): class HostStateTest(unittest.TestCase): - def test_down_event_with_active_connection(self): """ Test to ensure that on down calls to clusters with connections still active don't result in @@ -1475,7 +1672,7 @@ def test_down_event_with_active_connection(self): for _ in range(10): new_host = cluster.metadata.all_hosts()[0] assert new_host.is_up, "Host was not up on iteration {0}".format(_) - time.sleep(.01) + time.sleep(0.01) pool = session._pools.get(random_host) pool.shutdown() @@ -1486,27 +1683,32 @@ def test_down_event_with_active_connection(self): if not new_host.is_up: was_marked_down = True break - time.sleep(.01) + time.sleep(0.01) assert was_marked_down @local class DontPrepareOnIgnoredHostsTest(unittest.TestCase): - ignored_addresses = ['127.0.0.3'] + ignored_addresses = ["127.0.0.3"] ignore_node_3_policy = IgnoredHostPolicy(ignored_addresses) def test_prepare_on_ignored_hosts(self): - cluster = TestCluster( - execution_profiles={EXEC_PROFILE_DEFAULT: ExecutionProfile(load_balancing_policy=self.ignore_node_3_policy)} + execution_profiles={ + EXEC_PROFILE_DEFAULT: ExecutionProfile( + load_balancing_policy=self.ignore_node_3_policy + ) + } ) session = cluster.connect() cluster.reprepare_on_up, cluster.prepare_on_all_hosts = True, False hosts = cluster.metadata.all_hosts() - session.execute("CREATE KEYSPACE clustertests " - "WITH replication = " - "{'class': 'SimpleStrategy', 'replication_factor': '1'}") + session.execute( + "CREATE KEYSPACE clustertests " + "WITH replication = " + "{'class': 'NetworkTopologyStrategy', 'replication_factor': '1'}" + ) session.execute("CREATE TABLE clustertests.tab (a text, PRIMARY KEY (a))") # assign to an unused variable so cluster._prepared_statements retains # reference @@ -1514,7 +1716,7 @@ def test_prepare_on_ignored_hosts(self): cluster.connection_factory = Mock(wraps=cluster.connection_factory) - unignored_address = '127.0.0.1' + unignored_address = "127.0.0.1" unignored_host = next(h for h in hosts if h.address == unignored_address) ignored_host = next(h for h in hosts if h.address in self.ignored_addresses) unignored_host.is_up = ignored_host.is_up = False @@ -1545,9 +1747,13 @@ def test_deprecation_warnings_legacy_parameters(self): TestCluster(load_balancing_policy=RoundRobinPolicy()) logging.info(w) assert len(w) >= 1 - assert any(["Legacy execution parameters will be removed in 4.0. " - "Consider using execution profiles." in - str(wa.message) for wa in w]) + assert any( + [ + "Legacy execution parameters will be removed in 4.0. " + "Consider using execution profiles." in str(wa.message) + for wa in w + ] + ) def test_deprecation_warnings_meta_refreshed(self): """ @@ -1565,8 +1771,13 @@ def test_deprecation_warnings_meta_refreshed(self): cluster.set_meta_refresh_enabled(True) logging.info(w) assert len(w) >= 1 - assert any(["Cluster.set_meta_refresh_enabled is deprecated and will be removed in 4.0." in - str(wa.message) for wa in w]) + assert any( + [ + "Cluster.set_meta_refresh_enabled is deprecated and will be removed in 4.0." + in str(wa.message) + for wa in w + ] + ) def test_deprecation_warning_default_consistency_level(self): """ @@ -1584,5 +1795,10 @@ def test_deprecation_warning_default_consistency_level(self): session = cluster.connect() session.default_consistency_level = ConsistencyLevel.ONE assert len(w) >= 1 - assert any(["Setting the consistency level at the session level will be removed in 4.0" in - str(wa.message) for wa in w]) + assert any( + [ + "Setting the consistency level at the session level will be removed in 4.0" + in str(wa.message) + for wa in w + ] + ) diff --git a/tests/integration/standard/test_concurrent_schema_change_and_node_kill.py b/tests/integration/standard/test_concurrent_schema_change_and_node_kill.py index 910dcaa9fe..31e2d6ea69 100644 --- a/tests/integration/standard/test_concurrent_schema_change_and_node_kill.py +++ b/tests/integration/standard/test_concurrent_schema_change_and_node_kill.py @@ -8,7 +8,8 @@ def setup_module(): - use_cluster('test_schema_kill', [3], start=True) + use_cluster("test_schema_kill", [3], start=True) + @local class TestConcurrentSchemaChangeAndNodeKill(unittest.TestCase): @@ -23,13 +24,15 @@ def teardown_class(cls): def test_schema_change_after_node_kill(self): node2 = get_node(2) - self.session.execute( - "DROP KEYSPACE IF EXISTS ks_deadlock;") + self.session.execute("DROP KEYSPACE IF EXISTS ks_deadlock;") self.session.execute( "CREATE KEYSPACE IF NOT EXISTS ks_deadlock " - "WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '2' };") - self.session.set_keyspace('ks_deadlock') - self.session.execute("CREATE TABLE IF NOT EXISTS some_table(k int, c int, v int, PRIMARY KEY (k, v));") + "WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': '2' };" + ) + self.session.set_keyspace("ks_deadlock") + self.session.execute( + "CREATE TABLE IF NOT EXISTS some_table(k int, c int, v int, PRIMARY KEY (k, v));" + ) self.session.execute("INSERT INTO some_table (k, c, v) VALUES (1, 2, 3);") node2.stop(wait=False, gently=False) self.session.execute("ALTER TABLE some_table ADD v2 int;", timeout=180) diff --git a/tests/integration/standard/test_control_connection.py b/tests/integration/standard/test_control_connection.py index c4463e17fd..5f698427b7 100644 --- a/tests/integration/standard/test_control_connection.py +++ b/tests/integration/standard/test_control_connection.py @@ -23,8 +23,13 @@ from cassandra.protocol import ConfigurationException -from tests.integration import use_singledc, PROTOCOL_VERSION, TestCluster, greaterthanorequalcass40, \ - xfail_scylla_version_lt +from tests.integration import ( + use_singledc, + PROTOCOL_VERSION, + TestCluster, + greaterthanorequalcass40, + xfail_scylla_version_lt, +) from tests.integration.datatype_utils import update_datatypes @@ -38,7 +43,8 @@ def setUp(self): if PROTOCOL_VERSION < 3: raise unittest.SkipTest( "Native protocol 3,0+ is required for UDTs using %r" - % (PROTOCOL_VERSION,)) + % (PROTOCOL_VERSION,) + ) self.cluster = TestCluster() def tearDown(self): @@ -68,7 +74,7 @@ def test_drop_keyspace(self): self.session = self.cluster.connect() self.session.execute(""" CREATE KEYSPACE keyspacetodrop - WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1' } + WITH replication = { 'class' : 'NetworkTopologyStrategy', 'replication_factor': '1' } """) self.session.set_keyspace("keyspacetodrop") self.session.execute("CREATE TYPE user (age int, name text)") @@ -134,8 +140,10 @@ def test_control_connection_port_discovery(self): assert 9042 == host.broadcast_rpc_port assert 7000 == host.broadcast_port - @xfail_scylla_version_lt(reason='scylladb/scylladb#26992 - system.client_routes is not yet supported', - scylla_version="2026.1.0") + @xfail_scylla_version_lt( + reason="scylladb/scylladb#26992 - system.client_routes is not yet supported", + scylla_version="2026.1.0", + ) def test_client_routes_change_event(self): cluster = TestCluster() @@ -145,7 +153,10 @@ def test_client_routes_change_event(self): flag = Event() connection_ids = ["anytext", "11510f50-f906-4844-8c74-49ddab9ac6a9"] - host_ids = ["1a13fa42-c45b-410f-8ba5-58b42ada9c12", "aa13fa42-c45b-410f-8ba5-58b42ada9c12"] + host_ids = [ + "1a13fa42-c45b-410f-8ba5-58b42ada9c12", + "aa13fa42-c45b-410f-8ba5-58b42ada9c12", + ] got_connection_ids = [] got_host_ids = [] @@ -159,12 +170,16 @@ def on_event(event): finally: flag.set() - self.session.cluster.control_connection._connection.register_watchers({"CLIENT_ROUTES_CHANGE": on_event}) + self.session.cluster.control_connection._connection.register_watchers( + {"CLIENT_ROUTES_CHANGE": on_event} + ) try: payload = [ { - "connection_id": connection_ids[0], # Should be a UUID if API requires that + "connection_id": connection_ids[ + 0 + ], # Should be a UUID if API requires that "host_id": host_ids[0], "address": "localhost", "port": 9042, @@ -174,7 +189,7 @@ def on_event(event): "host_id": host_ids[1], "address": "localhost", "port": 9042, - } + }, ] response = requests.post( "http://" + cluster.contact_points[0] + ":10000/v2/client-routes", @@ -182,9 +197,12 @@ def on_event(event): headers={ "Content-Type": "application/json", "Accept": "application/json", - }) + }, + ) assert response.status_code == 200 - assert flag.wait(20), "Schema change event was not received after registering watchers" + assert flag.wait(20), ( + "Schema change event was not received after registering watchers" + ) assert set(got_connection_ids) == set(connection_ids) assert set(got_host_ids) == set(host_ids) finally: diff --git a/tests/integration/standard/test_custom_protocol_handler.py b/tests/integration/standard/test_custom_protocol_handler.py index e123f2050e..2a37eb6794 100644 --- a/tests/integration/standard/test_custom_protocol_handler.py +++ b/tests/integration/standard/test_custom_protocol_handler.py @@ -14,17 +14,37 @@ import unittest -from cassandra.protocol import ProtocolHandler, ResultMessage, QueryMessage, UUIDType, read_int +from cassandra.protocol import ( + ProtocolHandler, + ResultMessage, + QueryMessage, + UUIDType, + read_int, +) from cassandra.query import tuple_factory, SimpleStatement -from cassandra.cluster import (ResponseFuture, ExecutionProfile, EXEC_PROFILE_DEFAULT, - ContinuousPagingOptions, NoHostAvailable) +from cassandra.cluster import ( + ResponseFuture, + ExecutionProfile, + EXEC_PROFILE_DEFAULT, + ContinuousPagingOptions, + NoHostAvailable, +) from cassandra import ProtocolVersion, ConsistencyLevel -from tests.integration import use_single_node, drop_keyspace_shutdown_cluster, \ - greaterthanorequalcass30, execute_with_long_wait_retry, greaterthanorequalcass3_10, \ - TestCluster, greaterthanorequalcass40 +from tests.integration import ( + use_single_node, + drop_keyspace_shutdown_cluster, + greaterthanorequalcass30, + execute_with_long_wait_retry, + greaterthanorequalcass3_10, + TestCluster, + greaterthanorequalcass40, +) from tests.integration.datatype_utils import update_datatypes, PRIMITIVE_DATATYPES -from tests.integration.standard.utils import create_table_with_all_types, get_all_primitive_params +from tests.integration.standard.utils import ( + create_table_with_all_types, + get_all_primitive_params, +) import uuid from unittest import mock @@ -37,12 +57,13 @@ def setup_module(): class CustomProtocolHandlerTest(unittest.TestCase): - @classmethod def setUpClass(cls): cls.cluster = TestCluster() cls.session = cls.cluster.connect() - cls.session.execute("CREATE KEYSPACE custserdes WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'}") + cls.session.execute( + "CREATE KEYSPACE custserdes WITH replication = { 'class' : 'NetworkTopologyStrategy', 'replication_factor': '1'}" + ) cls.session.set_keyspace("custserdes") @classmethod @@ -66,24 +87,32 @@ def test_custom_raw_uuid_row_results(self): # Ensure that we get normal uuid back first cluster = TestCluster( - execution_profiles={EXEC_PROFILE_DEFAULT: ExecutionProfile(row_factory=tuple_factory)} + execution_profiles={ + EXEC_PROFILE_DEFAULT: ExecutionProfile(row_factory=tuple_factory) + } ) session = cluster.connect(keyspace="custserdes") - result = session.execute("SELECT schema_version FROM system.local WHERE key='local'") + result = session.execute( + "SELECT schema_version FROM system.local WHERE key='local'" + ) uuid_type = result.one()[0] assert type(uuid_type) == uuid.UUID # use our custom protocol handlder session.client_protocol_handler = CustomTestRawRowType - result_set = session.execute("SELECT schema_version FROM system.local WHERE key='local'") + result_set = session.execute( + "SELECT schema_version FROM system.local WHERE key='local'" + ) raw_value = result_set.one()[0] assert isinstance(raw_value, bytes) assert len(raw_value) == 16 # Ensure that we get normal uuid back when we re-connect session.client_protocol_handler = ProtocolHandler - result_set = session.execute("SELECT schema_version FROM system.local WHERE key='local'") + result_set = session.execute( + "SELECT schema_version FROM system.local WHERE key='local'" + ) uuid_type = result_set.one()[0] assert type(uuid_type) == uuid.UUID cluster.shutdown() @@ -104,7 +133,9 @@ def test_custom_raw_row_results_all_types(self): """ # Connect using a custom protocol handler that tracks the various types the result message is used with. cluster = TestCluster( - execution_profiles={EXEC_PROFILE_DEFAULT: ExecutionProfile(row_factory=tuple_factory)} + execution_profiles={ + EXEC_PROFILE_DEFAULT: ExecutionProfile(row_factory=tuple_factory) + } ) session = cluster.connect(keyspace="custserdes") session.client_protocol_handler = CustomProtocolHandlerResultMessageTracked @@ -114,11 +145,16 @@ def test_custom_raw_row_results_all_types(self): # verify data params = get_all_primitive_params(0) - results = session.execute("SELECT {0} FROM alltypes WHERE primkey=0".format(columns_string)).one() + results = session.execute( + "SELECT {0} FROM alltypes WHERE primkey=0".format(columns_string) + ).one() for expected, actual in zip(params, results): assert actual == expected # Ensure we have covered the various primitive types - assert len(CustomResultMessageTracked.checked_rev_row_set) == len(PRIMITIVE_DATATYPES)-1 + assert ( + len(CustomResultMessageTracked.checked_rev_row_set) + == len(PRIMITIVE_DATATYPES) - 1 + ) cluster.shutdown() @greaterthanorequalcass30 @@ -133,8 +169,9 @@ def test_protocol_divergence_v4_fail_by_flag_uses_int(self): @test_category connection """ - self._protocol_divergence_fail_by_flag_uses_int(ProtocolVersion.V4, uses_int_query_flag=False, - int_flag=True) + self._protocol_divergence_fail_by_flag_uses_int( + ProtocolVersion.V4, uses_int_query_flag=False, int_flag=True + ) @unittest.expectedFailure @greaterthanorequalcass40 @@ -147,8 +184,9 @@ def test_protocol_v5_uses_flag_int(self): @test_category connection """ - self._protocol_divergence_fail_by_flag_uses_int(ProtocolVersion.V5, uses_int_query_flag=True, beta=True, - int_flag=True) + self._protocol_divergence_fail_by_flag_uses_int( + ProtocolVersion.V5, uses_int_query_flag=True, beta=True, int_flag=True + ) @unittest.expectedFailure @greaterthanorequalcass40 @@ -161,8 +199,9 @@ def test_protocol_divergence_v5_fail_by_flag_uses_int(self): @test_category connection """ - self._protocol_divergence_fail_by_flag_uses_int(ProtocolVersion.V5, uses_int_query_flag=False, beta=True, - int_flag=False) + self._protocol_divergence_fail_by_flag_uses_int( + ProtocolVersion.V5, uses_int_query_flag=False, beta=True, int_flag=False + ) def _send_query_message(self, session, timeout, **kwargs): query = "SELECT * FROM test3rf.test" @@ -171,8 +210,12 @@ def _send_query_message(self, session, timeout, **kwargs): future.send_request() return future - def _protocol_divergence_fail_by_flag_uses_int(self, version, uses_int_query_flag, int_flag = True, beta=False): - cluster = TestCluster(protocol_version=version, allow_beta_protocol_version=beta) + def _protocol_divergence_fail_by_flag_uses_int( + self, version, uses_int_query_flag, int_flag=True, beta=False + ): + cluster = TestCluster( + protocol_version=version, allow_beta_protocol_version=beta + ) session = cluster.connect() query_one = SimpleStatement("INSERT INTO test3rf.test (k, v) VALUES (1, 1)") @@ -181,9 +224,13 @@ def _protocol_divergence_fail_by_flag_uses_int(self, version, uses_int_query_fla execute_with_long_wait_retry(session, query_one) execute_with_long_wait_retry(session, query_two) - with mock.patch('cassandra.protocol.ProtocolVersion.uses_int_query_flags', new=mock.Mock(return_value=int_flag)): - future = self._send_query_message(session, 10, - consistency_level=ConsistencyLevel.ONE, fetch_size=1) + with mock.patch( + "cassandra.protocol.ProtocolVersion.uses_int_query_flags", + new=mock.Mock(return_value=int_flag), + ): + future = self._send_query_message( + session, 10, consistency_level=ConsistencyLevel.ONE, fetch_size=1 + ) response = future.result() @@ -199,17 +246,27 @@ class CustomResultMessageRaw(ResultMessage): This is a custom Result Message that is used to return raw results, rather then results which contain objects. """ + my_type_codes = ResultMessage.type_codes.copy() - my_type_codes[0xc] = UUIDType + my_type_codes[0xC] = UUIDType type_codes = my_type_codes - def recv_results_rows(self, f, protocol_version, user_type_map, result_metadata, column_encryption_policy): - self.recv_results_metadata(f, user_type_map) - column_metadata = self.column_metadata or result_metadata - rowcount = read_int(f) - self.parsed_rows = [self.recv_row(f, len(column_metadata)) for _ in range(rowcount)] - self.column_names = [c[2] for c in column_metadata] - self.column_types = [c[3] for c in column_metadata] + def recv_results_rows( + self, + f, + protocol_version, + user_type_map, + result_metadata, + column_encryption_policy, + ): + self.recv_results_metadata(f, user_type_map) + column_metadata = self.column_metadata or result_metadata + rowcount = read_int(f) + self.parsed_rows = [ + self.recv_row(f, len(column_metadata)) for _ in range(rowcount) + ] + self.column_names = [c[2] for c in column_metadata] + self.column_types = [c[3] for c in column_metadata] class CustomTestRawRowType(ProtocolHandler): @@ -217,6 +274,7 @@ class CustomTestRawRowType(ProtocolHandler): This is the a custom protocol handler that will substitute the the customResultMesageRowRaw Result message for our own implementation """ + my_opcodes = ProtocolHandler.message_types_by_opcode.copy() my_opcodes[CustomResultMessageRaw.opcode] = CustomResultMessageRaw message_types_by_opcode = my_opcodes @@ -227,12 +285,20 @@ class CustomResultMessageTracked(ResultMessage): This is a custom Result Message that is use to track what primitive types have been processed when it receives results """ + my_type_codes = ResultMessage.type_codes.copy() - my_type_codes[0xc] = UUIDType + my_type_codes[0xC] = UUIDType type_codes = my_type_codes checked_rev_row_set = set() - def recv_results_rows(self, f, protocol_version, user_type_map, result_metadata, column_encryption_policy): + def recv_results_rows( + self, + f, + protocol_version, + user_type_map, + result_metadata, + column_encryption_policy, + ): self.recv_results_metadata(f, user_type_map) column_metadata = self.column_metadata or result_metadata rowcount = read_int(f) @@ -241,9 +307,12 @@ def recv_results_rows(self, f, protocol_version, user_type_map, result_metadata, self.column_types = [c[3] for c in column_metadata] self.checked_rev_row_set.update(self.column_types) self.parsed_rows = [ - tuple(ctype.from_binary(val, protocol_version) - for ctype, val in zip(self.column_types, row)) - for row in rows] + tuple( + ctype.from_binary(val, protocol_version) + for ctype, val in zip(self.column_types, row) + ) + for row in rows + ] class CustomProtocolHandlerResultMessageTracked(ProtocolHandler): @@ -251,6 +320,7 @@ class CustomProtocolHandlerResultMessageTracked(ProtocolHandler): This is the a custom protocol handler that will substitute the the CustomTestRawRowTypeTracked Result message for our own implementation """ + my_opcodes = ProtocolHandler.message_types_by_opcode.copy() my_opcodes[CustomResultMessageTracked.opcode] = CustomResultMessageTracked message_types_by_opcode = my_opcodes diff --git a/tests/integration/standard/test_cython_protocol_handlers.py b/tests/integration/standard/test_cython_protocol_handlers.py index 9c94b2ac77..cd497f9e2a 100644 --- a/tests/integration/standard/test_cython_protocol_handlers.py +++ b/tests/integration/standard/test_cython_protocol_handlers.py @@ -9,14 +9,27 @@ from cassandra.cluster import ExecutionProfile, EXEC_PROFILE_DEFAULT from cassandra.concurrent import execute_concurrent_with_args from cassandra.cython_deps import HAVE_CYTHON, HAVE_NUMPY -from cassandra.protocol import ProtocolHandler, LazyProtocolHandler, NumpyProtocolHandler +from cassandra.protocol import ( + ProtocolHandler, + LazyProtocolHandler, + NumpyProtocolHandler, +) from cassandra.query import tuple_factory from tests import VERIFY_CYTHON -from tests.integration import use_single_node, notprotocolv1, \ - drop_keyspace_shutdown_cluster, BasicSharedKeyspaceUnitTestCase, greaterthancass21, TestCluster +from tests.integration import ( + use_single_node, + notprotocolv1, + drop_keyspace_shutdown_cluster, + BasicSharedKeyspaceUnitTestCase, + greaterthancass21, + TestCluster, +) from tests.integration.datatype_utils import update_datatypes from tests.integration.standard.utils import ( - create_table_with_all_types, get_all_primitive_params, get_primitive_datatypes) + create_table_with_all_types, + get_all_primitive_params, + get_primitive_datatypes, +) from tests.unit.cython.utils import cythontest, numpytest @@ -26,17 +39,20 @@ def setup_module(): class CythonProtocolHandlerTest(unittest.TestCase): - N_ITEMS = 10 @classmethod def setUpClass(cls): cls.cluster = TestCluster() cls.session = cls.cluster.connect() - cls.session.execute("CREATE KEYSPACE testspace WITH replication = " - "{ 'class' : 'SimpleStrategy', 'replication_factor': '1'}") + cls.session.execute( + "CREATE KEYSPACE testspace WITH replication = " + "{ 'class' : 'NetworkTopologyStrategy', 'replication_factor': '1'}" + ) cls.session.set_keyspace("testspace") - cls.colnames = create_table_with_all_types("test_table", cls.session, cls.N_ITEMS) + cls.colnames = create_table_with_all_types( + "test_table", cls.session, cls.N_ITEMS + ) @classmethod def tearDownClass(cls): @@ -63,7 +79,9 @@ def test_cython_lazy_results_paged(self): """ # arrays = { 'a': arr1, 'b': arr2, ... } cluster = TestCluster( - execution_profiles={EXEC_PROFILE_DEFAULT: ExecutionProfile(row_factory=tuple_factory)} + execution_profiles={ + EXEC_PROFILE_DEFAULT: ExecutionProfile(row_factory=tuple_factory) + } ) session = cluster.connect(keyspace="testspace") session.client_protocol_handler = LazyProtocolHandler @@ -74,7 +92,9 @@ def test_cython_lazy_results_paged(self): results = session.execute("SELECT * FROM test_table") assert results.has_more_pages - assert verify_iterator_data(results) == self.N_ITEMS # make sure we see all rows + assert ( + verify_iterator_data(results) == self.N_ITEMS + ) # make sure we see all rows cluster.shutdown() @@ -97,13 +117,17 @@ def test_numpy_results_paged(self): """ # arrays = { 'a': arr1, 'b': arr2, ... } cluster = TestCluster( - execution_profiles={EXEC_PROFILE_DEFAULT: ExecutionProfile(row_factory=tuple_factory)} + execution_profiles={ + EXEC_PROFILE_DEFAULT: ExecutionProfile(row_factory=tuple_factory) + } ) session = cluster.connect(keyspace="testspace") session.client_protocol_handler = NumpyProtocolHandler session.default_fetch_size = 2 - expected_pages = (self.N_ITEMS + session.default_fetch_size - 1) // session.default_fetch_size + expected_pages = ( + self.N_ITEMS + session.default_fetch_size - 1 + ) // session.default_fetch_size assert session.default_fetch_size < self.N_ITEMS @@ -151,18 +175,18 @@ def _verify_numpy_page(self, page): def match_dtype(self, datatype, dtype): """Match a string cqltype (e.g. 'int' or 'blob') with a numpy dtype""" - if datatype == 'smallint': - self.match_dtype_props(dtype, 'i', 2) - elif datatype == 'int': - self.match_dtype_props(dtype, 'i', 4) - elif datatype in ('bigint', 'counter'): - self.match_dtype_props(dtype, 'i', 8) - elif datatype == 'float': - self.match_dtype_props(dtype, 'f', 4) - elif datatype == 'double': - self.match_dtype_props(dtype, 'f', 8) + if datatype == "smallint": + self.match_dtype_props(dtype, "i", 2) + elif datatype == "int": + self.match_dtype_props(dtype, "i", 4) + elif datatype in ("bigint", "counter"): + self.match_dtype_props(dtype, "i", 8) + elif datatype == "float": + self.match_dtype_props(dtype, "f", 4) + elif datatype == "double": + self.match_dtype_props(dtype, "f", 8) else: - assert dtype.kind == 'O', (dtype, datatype) + assert dtype.kind == "O", (dtype, datatype) def match_dtype_props(self, dtype, kind, size, signed=None): assert dtype.kind == kind, dtype @@ -172,8 +196,10 @@ def match_dtype_props(self, dtype, kind, size, signed=None): def arrays_to_list_of_tuples(arrays, colnames): """Convert a dict of arrays (as given by the numpy protocol handler) to a list of tuples""" first_array = arrays[colnames[0]] - return [tuple(arrays[colname][i] for colname in colnames) - for i in range(len(first_array))] + return [ + tuple(arrays[colname][i] for colname in colnames) + for i in range(len(first_array)) + ] def get_data(protocol_handler): @@ -181,7 +207,9 @@ def get_data(protocol_handler): Get data from the test table. """ cluster = TestCluster( - execution_profiles={EXEC_PROFILE_DEFAULT: ExecutionProfile(row_factory=tuple_factory)} + execution_profiles={ + EXEC_PROFILE_DEFAULT: ExecutionProfile(row_factory=tuple_factory) + } ) session = cluster.connect(keyspace="testspace") @@ -224,16 +252,20 @@ class NumpyWideTableTest(unittest.TestCase): def setUpClass(cls): cls.cluster = TestCluster() cls.session = cls.cluster.connect() - cls.session.execute("CREATE KEYSPACE IF NOT EXISTS test_wide_table WITH replication = " - "{ 'class' : 'SimpleStrategy', 'replication_factor': '1'}") + cls.session.execute( + "CREATE KEYSPACE IF NOT EXISTS test_wide_table WITH replication = " + "{ 'class' : 'NetworkTopologyStrategy', 'replication_factor': '1'}" + ) cls.session.set_keyspace("test_wide_table") # Create a wide table with many int columns columns = ["pk int", "ck int"] columns += ["col{0} int".format(i) for i in range(cls.N_COLUMNS)] cls.session.execute( - "CREATE TABLE wide_table ({0}, PRIMARY KEY (pk, ck))".format(", ".join(columns)), - timeout=120 + "CREATE TABLE wide_table ({0}, PRIMARY KEY (pk, ck))".format( + ", ".join(columns) + ), + timeout=120, ) # Insert test data @@ -262,7 +294,9 @@ def test_numpy_wide_table_paging(self): that all data is still returned correctly across multiple pages. """ cluster = TestCluster( - execution_profiles={EXEC_PROFILE_DEFAULT: ExecutionProfile(row_factory=tuple_factory)} + execution_profiles={ + EXEC_PROFILE_DEFAULT: ExecutionProfile(row_factory=tuple_factory) + } ) session = cluster.connect(keyspace="test_wide_table") session.client_protocol_handler = NumpyProtocolHandler @@ -276,14 +310,18 @@ def test_numpy_wide_table_paging(self): for page in results: page_count += 1 # Get row count from first column array - arr = page.get('pk') + arr = page.get("pk") if arr is not None: total_rows += len(arr) # Verify all rows were returned - self.assertEqual(total_rows, self.N_ROWS, - "Expected {0} rows total, got {1} across {2} pages".format( - self.N_ROWS, total_rows, page_count)) + self.assertEqual( + total_rows, + self.N_ROWS, + "Expected {0} rows total, got {1} across {2} pages".format( + self.N_ROWS, total_rows, page_count + ), + ) cluster.shutdown() @@ -296,7 +334,9 @@ def test_numpy_wide_table_no_fetch_size(self): This is the recommended workaround for getting larger pages with wide tables. """ cluster = TestCluster( - execution_profiles={EXEC_PROFILE_DEFAULT: ExecutionProfile(row_factory=tuple_factory)} + execution_profiles={ + EXEC_PROFILE_DEFAULT: ExecutionProfile(row_factory=tuple_factory) + } ) session = cluster.connect(keyspace="test_wide_table") session.client_protocol_handler = NumpyProtocolHandler @@ -309,23 +349,31 @@ def test_numpy_wide_table_no_fetch_size(self): page_count = 0 for page in results: page_count += 1 - arr = page.get('pk') + arr = page.get("pk") if arr is not None: total_rows += len(arr) # Verify all rows were returned - self.assertEqual(total_rows, self.N_ROWS, - "Expected {0} rows total, got {1} across {2} pages".format( - self.N_ROWS, total_rows, page_count)) + self.assertEqual( + total_rows, + self.N_ROWS, + "Expected {0} rows total, got {1} across {2} pages".format( + self.N_ROWS, total_rows, page_count + ), + ) cluster.shutdown() class NumpyNullTest(BasicSharedKeyspaceUnitTestCase): - @classmethod def setUpClass(cls): - cls.common_setup(1, execution_profiles={EXEC_PROFILE_DEFAULT: ExecutionProfile(row_factory=tuple_factory)}) + cls.common_setup( + 1, + execution_profiles={ + EXEC_PROFILE_DEFAULT: ExecutionProfile(row_factory=tuple_factory) + }, + ) @numpytest @greaterthancass21 @@ -345,7 +393,9 @@ def test_null_types(self): table = "%s.%s" % (self.keyspace_name, self.function_table_name) create_table_with_all_types(table, s, 10) - begin_unset = max(s.execute('select primkey from %s' % (table,)).one()['primkey']) + 1 + begin_unset = ( + max(s.execute("select primkey from %s" % (table,)).one()["primkey"]) + 1 + ) keys_null = range(begin_unset, begin_unset + 10) # scatter some emptry rows in here @@ -355,7 +405,8 @@ def test_null_types(self): result = s.execute("select * from %s" % (table,)).one() from numpy.ma import masked, MaskedArray - result_keys = result.pop('primkey') + + result_keys = result.pop("primkey") mapped_index = [v[1] for v in sorted(zip(result_keys, count()))] had_masked = had_none = False diff --git a/tests/integration/standard/test_metadata.py b/tests/integration/standard/test_metadata.py index 6e64401a75..203d4a2f63 100644 --- a/tests/integration/standard/test_metadata.py +++ b/tests/integration/standard/test_metadata.py @@ -26,28 +26,67 @@ from unittest.mock import Mock, patch import pytest -from cassandra import AlreadyExists, SignatureDescriptor, UserFunctionDescriptor, UserAggregateDescriptor +from cassandra import ( + AlreadyExists, + SignatureDescriptor, + UserFunctionDescriptor, + UserAggregateDescriptor, +) from cassandra.connection import Connection from cassandra.encoder import Encoder -from cassandra.metadata import (IndexMetadata, Token, murmur3, Function, Aggregate, protect_name, protect_names, - RegisteredTableExtension, _RegisteredExtensionType, get_schema_parser, - group_keys_by_replica, NO_VALID_REPLICA) +from cassandra.metadata import ( + IndexMetadata, + Token, + murmur3, + Function, + Aggregate, + protect_name, + protect_names, + RegisteredTableExtension, + _RegisteredExtensionType, + get_schema_parser, + group_keys_by_replica, + NO_VALID_REPLICA, +) from cassandra.protocol import QueryMessage, ProtocolHandler from cassandra.util import SortedSet -from tests.integration import (get_cluster, use_singledc, PROTOCOL_VERSION, execute_until_pass, - BasicSegregatedKeyspaceUnitTestCase, BasicSharedKeyspaceUnitTestCase, - BasicExistingKeyspaceUnitTestCase, drop_keyspace_shutdown_cluster, CASSANDRA_VERSION, - greaterthanorequalcass30, lessthancass30, local, - get_supported_protocol_versions, greaterthancass20, - greaterthancass21, greaterthanorequalcass40, - lessthancass40, - TestCluster, requires_java_udf, requires_composite_type, - requires_collection_indexes, SCYLLA_VERSION, xfail_scylla, xfail_scylla_version_lt, - requirescompactstorage) - -from tests.util import wait_until, assertRegex, assertDictEqual, assertListEqual, assert_startswith_diff +from tests.integration import ( + get_cluster, + use_singledc, + PROTOCOL_VERSION, + execute_until_pass, + BasicSegregatedKeyspaceUnitTestCase, + BasicSharedKeyspaceUnitTestCase, + BasicExistingKeyspaceUnitTestCase, + drop_keyspace_shutdown_cluster, + CASSANDRA_VERSION, + greaterthanorequalcass30, + lessthancass30, + local, + get_supported_protocol_versions, + greaterthancass20, + greaterthancass21, + greaterthanorequalcass40, + lessthancass40, + TestCluster, + requires_java_udf, + requires_composite_type, + requires_collection_indexes, + SCYLLA_VERSION, + xfail_scylla, + xfail_scylla_version_lt, + requirescompactstorage, +) + +from tests.util import ( + wait_until, + assertRegex, + assertDictEqual, + assertListEqual, + assert_startswith_diff, +) log = logging.getLogger(__name__) @@ -75,7 +114,7 @@ def test_host_addresses(self): assert host.broadcast_rpc_address is not None assert host.host_id is not None - if CASSANDRA_VERSION >= Version('4-a'): + if CASSANDRA_VERSION >= Version("4-a"): assert host.broadcast_port is not None assert host.broadcast_rpc_port is not None @@ -85,16 +124,21 @@ def test_host_addresses(self): # The control connection node should have the listen address set. # Note: Scylla does not populate listen_address in system.local if SCYLLA_VERSION is None: - listen_addrs = [host.listen_address for host in self.cluster.metadata.all_hosts()] + listen_addrs = [ + host.listen_address for host in self.cluster.metadata.all_hosts() + ] assert local_host in listen_addrs # The control connection node should have the broadcast_rpc_address set. - rpc_addrs = [host.broadcast_rpc_address for host in self.cluster.metadata.all_hosts()] + rpc_addrs = [ + host.broadcast_rpc_address for host in self.cluster.metadata.all_hosts() + ] assert local_host in rpc_addrs @unittest.skipUnless( - os.getenv('MAPPED_CASSANDRA_VERSION', None) is not None, - "Don't check the host version for test-dse") + os.getenv("MAPPED_CASSANDRA_VERSION", None) is not None, + "Don't check the host version for test-dse", + ) def test_host_release_version(self): """ Checks the hosts release version and validates that it is equal to the @@ -110,12 +154,12 @@ def test_host_release_version(self): assert host.release_version.startswith(CASSANDRA_VERSION.base_version) - @local class MetaDataRemovalTest(unittest.TestCase): - def setUp(self): - self.cluster = TestCluster(contact_points=['127.0.0.1', '127.0.0.2', '127.0.0.3', '126.0.0.186']) + self.cluster = TestCluster( + contact_points=["127.0.0.1", "127.0.0.2", "127.0.0.3", "126.0.0.186"] + ) self.cluster.connect() def tearDown(self): @@ -132,15 +176,18 @@ def test_bad_contact_point(self): @test_category metadata """ # wait until we have only 3 hosts - wait_until(condition=lambda: len(self.cluster.metadata.all_hosts()) == 3, delay=0.5, max_attempts=5) + wait_until( + condition=lambda: len(self.cluster.metadata.all_hosts()) == 3, + delay=0.5, + max_attempts=5, + ) # verify the un-existing host was filtered for host in self.cluster.metadata.all_hosts(): - assert host.endpoint.address != '126.0.0.186' + assert host.endpoint.address != "126.0.0.186" class SchemaMetadataTests(BasicSegregatedKeyspaceUnitTestCase): - def test_schema_metadata_disable(self): """ Checks to ensure that schema metadata_enabled, and token_metadata_enabled @@ -157,7 +204,7 @@ def test_schema_metadata_disable(self): no_schema = TestCluster(schema_metadata_enabled=False) no_schema_session = no_schema.connect() assert len(no_schema.metadata.keyspaces) == 0 - assert no_schema.metadata.export_schema_as_string() == '' + assert no_schema.metadata.export_schema_as_string() == "" no_token = TestCluster(token_metadata_enabled=False) no_token_session = no_token.connect() assert len(no_token.metadata.token_map.token_to_host_owner) == 0 @@ -171,18 +218,27 @@ def test_schema_metadata_disable(self): no_schema.shutdown() no_token.shutdown() - def make_create_statement(self, partition_cols, clustering_cols=None, other_cols=None): + def make_create_statement( + self, partition_cols, clustering_cols=None, other_cols=None + ): clustering_cols = clustering_cols or [] other_cols = other_cols or [] - statement = "CREATE TABLE %s.%s (" % (self.keyspace_name, self.function_table_name) + statement = "CREATE TABLE %s.%s (" % ( + self.keyspace_name, + self.function_table_name, + ) if len(partition_cols) == 1 and not clustering_cols: statement += "%s text PRIMARY KEY, " % protect_name(partition_cols[0]) else: - statement += ", ".join("%s text" % protect_name(col) for col in partition_cols) + statement += ", ".join( + "%s text" % protect_name(col) for col in partition_cols + ) statement += ", " - statement += ", ".join("%s text" % protect_name(col) for col in clustering_cols + other_cols) + statement += ", ".join( + "%s text" % protect_name(col) for col in clustering_cols + other_cols + ) if len(partition_cols) != 1 or clustering_cols: statement += ", PRIMARY KEY (" @@ -204,18 +260,28 @@ def make_create_statement(self, partition_cols, clustering_cols=None, other_cols def check_create_statement(self, tablemeta, original): recreate = tablemeta.as_cql_query(formatted=False) - assert original == recreate[:len(original)] - execute_until_pass(self.session, "DROP TABLE {0}.{1}".format(self.keyspace_name, self.function_table_name)) + assert original == recreate[: len(original)] + execute_until_pass( + self.session, + "DROP TABLE {0}.{1}".format(self.keyspace_name, self.function_table_name), + ) execute_until_pass(self.session, recreate) # create the table again, but with formatting enabled - execute_until_pass(self.session, "DROP TABLE {0}.{1}".format(self.keyspace_name, self.function_table_name)) + execute_until_pass( + self.session, + "DROP TABLE {0}.{1}".format(self.keyspace_name, self.function_table_name), + ) recreate = tablemeta.as_cql_query(formatted=True) execute_until_pass(self.session, recreate) def get_table_metadata(self): - self.cluster.refresh_table_metadata(self.keyspace_name, self.function_table_name) - return self.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name] + self.cluster.refresh_table_metadata( + self.keyspace_name, self.function_table_name + ) + return self.cluster.metadata.keyspaces[self.keyspace_name].tables[ + self.function_table_name + ] def test_basic_table_meta_properties(self): create_statement = self.make_create_statement(["a"], [], ["b", "c"]) @@ -230,7 +296,7 @@ def test_basic_table_meta_properties(self): assert ksmeta.name == self.keyspace_name assert ksmeta.durable_writes - assert ksmeta.replication_strategy.name == 'SimpleStrategy' + assert ksmeta.replication_strategy.name == "NetworkTopologyStrategy" assert ksmeta.replication_strategy.replication_factor == 1 assert self.function_table_name in ksmeta.tables @@ -239,9 +305,9 @@ def test_basic_table_meta_properties(self): assert tablemeta.name == self.function_table_name assert tablemeta.name == self.function_table_name - assert [u'a'] == [c.name for c in tablemeta.partition_key] + assert ["a"] == [c.name for c in tablemeta.partition_key] assert [] == tablemeta.clustering_key - assert [u'a', u'b', u'c'] == sorted(tablemeta.columns.keys()) + assert ["a", "b", "c"] == sorted(tablemeta.columns.keys()) cc = self.cluster.control_connection._connection parser = get_schema_parser( @@ -263,9 +329,9 @@ def test_compound_primary_keys(self): self.session.execute(create_statement) tablemeta = self.get_table_metadata() - assert [u'a'] == [c.name for c in tablemeta.partition_key] - assert [u'b'] == [c.name for c in tablemeta.clustering_key] - assert [u'a', u'b', u'c'] == sorted(tablemeta.columns.keys()) + assert ["a"] == [c.name for c in tablemeta.partition_key] + assert ["b"] == [c.name for c in tablemeta.clustering_key] + assert ["a", "b", "c"] == sorted(tablemeta.columns.keys()) self.check_create_statement(tablemeta, create_statement) @@ -275,21 +341,23 @@ def test_compound_primary_keys_protected(self): self.session.execute(create_statement) tablemeta = self.get_table_metadata() - assert [u'Aa'] == [c.name for c in tablemeta.partition_key] - assert [u'Bb'] == [c.name for c in tablemeta.clustering_key] - assert [u'Aa', u'Bb', u'Cc'] == sorted(tablemeta.columns.keys()) + assert ["Aa"] == [c.name for c in tablemeta.partition_key] + assert ["Bb"] == [c.name for c in tablemeta.clustering_key] + assert ["Aa", "Bb", "Cc"] == sorted(tablemeta.columns.keys()) self.check_create_statement(tablemeta, create_statement) def test_compound_primary_keys_more_columns(self): - create_statement = self.make_create_statement(["a"], ["b", "c"], ["d", "e", "f"]) + create_statement = self.make_create_statement( + ["a"], ["b", "c"], ["d", "e", "f"] + ) create_statement += " WITH CLUSTERING ORDER BY (b ASC, c ASC)" self.session.execute(create_statement) tablemeta = self.get_table_metadata() - assert [u'a'] == [c.name for c in tablemeta.partition_key] - assert [u'b', u'c'] == [c.name for c in tablemeta.clustering_key] - assert [u'a', u'b', u'c', u'd', u'e', u'f'] == sorted(tablemeta.columns.keys()) + assert ["a"] == [c.name for c in tablemeta.partition_key] + assert ["b", "c"] == [c.name for c in tablemeta.clustering_key] + assert ["a", "b", "c", "d", "e", "f"] == sorted(tablemeta.columns.keys()) self.check_create_statement(tablemeta, create_statement) @@ -298,9 +366,9 @@ def test_composite_primary_key(self): self.session.execute(create_statement) tablemeta = self.get_table_metadata() - assert [u'a', u'b'] == [c.name for c in tablemeta.partition_key] + assert ["a", "b"] == [c.name for c in tablemeta.partition_key] assert [] == tablemeta.clustering_key - assert [u'a', u'b', u'c'] == sorted(tablemeta.columns.keys()) + assert ["a", "b", "c"] == sorted(tablemeta.columns.keys()) self.check_create_statement(tablemeta, create_statement) @@ -310,9 +378,9 @@ def test_composite_in_compound_primary_key(self): self.session.execute(create_statement) tablemeta = self.get_table_metadata() - assert [u'a', u'b'] == [c.name for c in tablemeta.partition_key] - assert [u'c'] == [c.name for c in tablemeta.clustering_key] - assert [u'a', u'b', u'c', u'd', u'e'] == sorted(tablemeta.columns.keys()) + assert ["a", "b"] == [c.name for c in tablemeta.partition_key] + assert ["c"] == [c.name for c in tablemeta.clustering_key] + assert ["a", "b", "c", "d", "e"] == sorted(tablemeta.columns.keys()) self.check_create_statement(tablemeta, create_statement) @@ -322,9 +390,9 @@ def test_compound_primary_keys_compact(self): self.session.execute(create_statement) tablemeta = self.get_table_metadata() - assert [u'a'] == [c.name for c in tablemeta.partition_key] - assert [u'b'] == [c.name for c in tablemeta.clustering_key] - assert [u'a', u'b', u'c'] == sorted(tablemeta.columns.keys()) + assert ["a"] == [c.name for c in tablemeta.partition_key] + assert ["b"] == [c.name for c in tablemeta.clustering_key] + assert ["a", "b", "c"] == sorted(tablemeta.columns.keys()) self.check_create_statement(tablemeta, create_statement) @@ -345,9 +413,9 @@ def test_cluster_column_ordering_reversed_metadata(self): create_statement += " WITH CLUSTERING ORDER BY (b ASC, c DESC)" self.session.execute(create_statement) tablemeta = self.get_table_metadata() - b_column = tablemeta.columns['b'] + b_column = tablemeta.columns["b"] assert not b_column.is_reversed - c_column = tablemeta.columns['c'] + c_column = tablemeta.columns["c"] assert c_column.is_reversed def test_compound_primary_keys_more_columns_compact(self): @@ -356,9 +424,9 @@ def test_compound_primary_keys_more_columns_compact(self): self.session.execute(create_statement) tablemeta = self.get_table_metadata() - assert [u'a'] == [c.name for c in tablemeta.partition_key] - assert [u'b', u'c'] == [c.name for c in tablemeta.clustering_key] - assert [u'a', u'b', u'c', u'd'] == sorted(tablemeta.columns.keys()) + assert ["a"] == [c.name for c in tablemeta.partition_key] + assert ["b", "c"] == [c.name for c in tablemeta.clustering_key] + assert ["a", "b", "c", "d"] == sorted(tablemeta.columns.keys()) self.check_create_statement(tablemeta, create_statement) @@ -367,9 +435,9 @@ def test_composite_primary_key_compact(self): self.session.execute(create_statement) tablemeta = self.get_table_metadata() - assert [u'a', u'b'] == [c.name for c in tablemeta.partition_key] + assert ["a", "b"] == [c.name for c in tablemeta.partition_key] assert [] == tablemeta.clustering_key - assert [u'a', u'b', u'c'] == sorted(tablemeta.columns.keys()) + assert ["a", "b", "c"] == sorted(tablemeta.columns.keys()) self.check_create_statement(tablemeta, create_statement) @@ -379,24 +447,23 @@ def test_composite_in_compound_primary_key_compact(self): self.session.execute(create_statement) tablemeta = self.get_table_metadata() - assert [u'a', u'b'] == [c.name for c in tablemeta.partition_key] - assert [u'c'] == [c.name for c in tablemeta.clustering_key] - assert [u'a', u'b', u'c', u'd'] == sorted(tablemeta.columns.keys()) + assert ["a", "b"] == [c.name for c in tablemeta.partition_key] + assert ["c"] == [c.name for c in tablemeta.clustering_key] + assert ["a", "b", "c", "d"] == sorted(tablemeta.columns.keys()) self.check_create_statement(tablemeta, create_statement) @lessthancass30 def test_cql_compatibility(self): - # having more than one non-PK column is okay if there aren't any # clustering columns create_statement = self.make_create_statement(["a"], [], ["b", "c", "d"]) self.session.execute(create_statement) tablemeta = self.get_table_metadata() - assert [u'a'] == [c.name for c in tablemeta.partition_key] + assert ["a"] == [c.name for c in tablemeta.partition_key] assert [] == tablemeta.clustering_key - assert [u'a', u'b', u'c', u'd'] == sorted(tablemeta.columns.keys()) + assert ["a", "b", "c", "d"] == sorted(tablemeta.columns.keys()) assert tablemeta.is_cql_compatible @@ -415,7 +482,9 @@ def test_compound_primary_keys_ordering(self): self.check_create_statement(tablemeta, create_statement) def test_compound_primary_keys_more_columns_ordering(self): - create_statement = self.make_create_statement(["a"], ["b", "c"], ["d", "e", "f"]) + create_statement = self.make_create_statement( + ["a"], ["b", "c"], ["d", "e", "f"] + ) create_statement += " WITH CLUSTERING ORDER BY (b DESC, c ASC)" self.session.execute(create_statement) tablemeta = self.get_table_metadata() @@ -450,8 +519,7 @@ def test_dense_compact_storage(self): def test_counter(self): create_statement = ( - "CREATE TABLE {keyspace}.{table} (" - "key text PRIMARY KEY, a1 counter)" + "CREATE TABLE {keyspace}.{table} (key text PRIMARY KEY, a1 counter)" ).format(keyspace=self.keyspace_name, table=self.function_table_name) self.session.execute(create_statement) @@ -461,7 +529,7 @@ def test_counter(self): @lessthancass40 @requirescompactstorage def test_counter_with_compact_storage(self): - """ PYTHON-1100 """ + """PYTHON-1100""" create_statement = ( "CREATE TABLE {keyspace}.{table} (" "key text PRIMARY KEY, a1 counter) WITH COMPACT STORAGE" @@ -483,20 +551,28 @@ def test_counter_with_dense_compact_storage(self): tablemeta = self.get_table_metadata() self.check_create_statement(tablemeta, create_statement) - @pytest.mark.skip(reason='https://github.com/scylladb/scylladb/issues/6058') + @pytest.mark.skip(reason="https://github.com/scylladb/scylladb/issues/6058") def test_indexes(self): - create_statement = self.make_create_statement(["a"], ["b", "c"], ["d", "e", "f"]) + create_statement = self.make_create_statement( + ["a"], ["b", "c"], ["d", "e", "f"] + ) create_statement += " WITH CLUSTERING ORDER BY (b ASC, c ASC)" execute_until_pass(self.session, create_statement) - d_index = "CREATE INDEX d_index ON %s.%s (d)" % (self.keyspace_name, self.function_table_name) - e_index = "CREATE INDEX e_index ON %s.%s (e)" % (self.keyspace_name, self.function_table_name) + d_index = "CREATE INDEX d_index ON %s.%s (d)" % ( + self.keyspace_name, + self.function_table_name, + ) + e_index = "CREATE INDEX e_index ON %s.%s (e)" % ( + self.keyspace_name, + self.function_table_name, + ) execute_until_pass(self.session, d_index) execute_until_pass(self.session, e_index) tablemeta = self.get_table_metadata() statements = tablemeta.export_as_string().strip() - statements = [s.strip() for s in statements.split(';')] + statements = [s.strip() for s in statements.split(";")] statements = list(filter(bool, statements)) assert 3 == len(statements) assert d_index in statements @@ -505,40 +581,55 @@ def test_indexes(self): # make sure indexes are included in KeyspaceMetadata.export_as_string() ksmeta = self.cluster.metadata.keyspaces[self.keyspace_name] statement = ksmeta.export_as_string() - assert 'CREATE INDEX d_index' in statement - assert 'CREATE INDEX e_index' in statement + assert "CREATE INDEX d_index" in statement + assert "CREATE INDEX e_index" in statement @greaterthancass21 @requires_collection_indexes - @xfail_scylla('scylladb/scylladb#22013 - scylla does not show full index in system_schema.indexes') + @xfail_scylla( + "scylladb/scylladb#22013 - scylla does not show full index in system_schema.indexes" + ) def test_collection_indexes(self): - - self.session.execute("CREATE TABLE %s.%s (a int PRIMARY KEY, b map)" - % (self.keyspace_name, self.function_table_name)) - self.session.execute("CREATE INDEX index1 ON %s.%s (keys(b))" - % (self.keyspace_name, self.function_table_name)) + self.session.execute( + "CREATE TABLE %s.%s (a int PRIMARY KEY, b map)" + % (self.keyspace_name, self.function_table_name) + ) + self.session.execute( + "CREATE INDEX index1 ON %s.%s (keys(b))" + % (self.keyspace_name, self.function_table_name) + ) tablemeta = self.get_table_metadata() - assert '(keys(b))' in tablemeta.export_as_string() + assert "(keys(b))" in tablemeta.export_as_string() self.session.execute("DROP INDEX %s.index1" % (self.keyspace_name,)) - self.session.execute("CREATE INDEX index2 ON %s.%s (b)" - % (self.keyspace_name, self.function_table_name)) + self.session.execute( + "CREATE INDEX index2 ON %s.%s (b)" + % (self.keyspace_name, self.function_table_name) + ) tablemeta = self.get_table_metadata() - target = ' (b)' if CASSANDRA_VERSION < Version("3.0") else 'values(b))' # explicit values in C* 3+ + target = ( + " (b)" if CASSANDRA_VERSION < Version("3.0") else "values(b))" + ) # explicit values in C* 3+ assert target in tablemeta.export_as_string() # test full indexes on frozen collections, if available if CASSANDRA_VERSION >= Version("2.1.3"): - self.session.execute("DROP TABLE %s.%s" % (self.keyspace_name, self.function_table_name)) - self.session.execute("CREATE TABLE %s.%s (a int PRIMARY KEY, b frozen>)" - % (self.keyspace_name, self.function_table_name)) - self.session.execute("CREATE INDEX index3 ON %s.%s (full(b))" - % (self.keyspace_name, self.function_table_name)) + self.session.execute( + "DROP TABLE %s.%s" % (self.keyspace_name, self.function_table_name) + ) + self.session.execute( + "CREATE TABLE %s.%s (a int PRIMARY KEY, b frozen>)" + % (self.keyspace_name, self.function_table_name) + ) + self.session.execute( + "CREATE INDEX index3 ON %s.%s (full(b))" + % (self.keyspace_name, self.function_table_name) + ) tablemeta = self.get_table_metadata() - assert '(full(b))' in tablemeta.export_as_string() + assert "(full(b))" in tablemeta.export_as_string() def test_compression_disabled(self): create_statement = self.make_create_statement(["a"], ["b"], ["c"]) @@ -572,7 +663,7 @@ def test_non_size_tiered_compaction(self): assert "'tombstone_threshold': '0.3'" in cql assert "LeveledCompactionStrategy" in cql # formerly legacy options; reintroduced in 4.0 - if CASSANDRA_VERSION < Version('4.0-a'): + if CASSANDRA_VERSION < Version("4.0-a"): assert "min_threshold" not in cql assert "max_threshold" not in cql @@ -601,56 +692,89 @@ def test_refresh_schema_metadata(self): assert "new_keyspace" not in cluster2.metadata.keyspaces # Cluster metadata modification - self.session.execute("CREATE KEYSPACE new_keyspace WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'}") + self.session.execute( + "CREATE KEYSPACE new_keyspace WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': '1'}" + ) assert "new_keyspace" not in cluster2.metadata.keyspaces cluster2.refresh_schema_metadata() assert "new_keyspace" in cluster2.metadata.keyspaces # Keyspace metadata modification - self.session.execute("ALTER KEYSPACE {0} WITH durable_writes = false".format(self.keyspace_name)) + self.session.execute( + "ALTER KEYSPACE {0} WITH durable_writes = false".format(self.keyspace_name) + ) assert cluster2.metadata.keyspaces[self.keyspace_name].durable_writes cluster2.refresh_schema_metadata() assert not cluster2.metadata.keyspaces[self.keyspace_name].durable_writes # Table metadata modification table_name = "test" - self.session.execute("CREATE TABLE {0}.{1} (a int PRIMARY KEY, b text)".format(self.keyspace_name, table_name)) + self.session.execute( + "CREATE TABLE {0}.{1} (a int PRIMARY KEY, b text)".format( + self.keyspace_name, table_name + ) + ) cluster2.refresh_schema_metadata() - self.session.execute("ALTER TABLE {0}.{1} ADD c double".format(self.keyspace_name, table_name)) - assert "c" not in cluster2.metadata.keyspaces[self.keyspace_name].tables[table_name].columns + self.session.execute( + "ALTER TABLE {0}.{1} ADD c double".format(self.keyspace_name, table_name) + ) + assert ( + "c" + not in cluster2.metadata.keyspaces[self.keyspace_name] + .tables[table_name] + .columns + ) cluster2.refresh_schema_metadata() - assert "c" in cluster2.metadata.keyspaces[self.keyspace_name].tables[table_name].columns + assert ( + "c" + in cluster2.metadata.keyspaces[self.keyspace_name] + .tables[table_name] + .columns + ) if PROTOCOL_VERSION >= 3: # UDT metadata modification - self.session.execute("CREATE TYPE {0}.user (age int, name text)".format(self.keyspace_name)) + self.session.execute( + "CREATE TYPE {0}.user (age int, name text)".format(self.keyspace_name) + ) assert cluster2.metadata.keyspaces[self.keyspace_name].user_types == {} cluster2.refresh_schema_metadata() assert "user" in cluster2.metadata.keyspaces[self.keyspace_name].user_types if PROTOCOL_VERSION >= 4: # UDF metadata modification - self.session.execute("""CREATE FUNCTION {0}.sum_int(key int, val int) + self.session.execute( + """CREATE FUNCTION {0}.sum_int(key int, val int) RETURNS NULL ON NULL INPUT RETURNS int - LANGUAGE java AS 'return key+val;';""".format(self.keyspace_name)) + LANGUAGE java AS 'return key+val;';""".format( + self.keyspace_name + ) + ) assert cluster2.metadata.keyspaces[self.keyspace_name].functions == {} cluster2.refresh_schema_metadata() - assert "sum_int(int,int)" in cluster2.metadata.keyspaces[self.keyspace_name].functions + assert ( + "sum_int(int,int)" + in cluster2.metadata.keyspaces[self.keyspace_name].functions + ) # UDA metadata modification - self.session.execute("""CREATE AGGREGATE {0}.sum_agg(int) + self.session.execute( + """CREATE AGGREGATE {0}.sum_agg(int) SFUNC sum_int STYPE int - INITCOND 0""" - .format(self.keyspace_name)) + INITCOND 0""".format(self.keyspace_name) + ) assert cluster2.metadata.keyspaces[self.keyspace_name].aggregates == {} cluster2.refresh_schema_metadata() - assert "sum_agg(int)" in cluster2.metadata.keyspaces[self.keyspace_name].aggregates + assert ( + "sum_agg(int)" + in cluster2.metadata.keyspaces[self.keyspace_name].aggregates + ) # Cluster metadata modification self.session.execute("DROP KEYSPACE new_keyspace") @@ -682,7 +806,9 @@ def test_refresh_keyspace_metadata(self): cluster2.connect() assert cluster2.metadata.keyspaces[self.keyspace_name].durable_writes - self.session.execute("ALTER KEYSPACE {0} WITH durable_writes = false".format(self.keyspace_name)) + self.session.execute( + "ALTER KEYSPACE {0} WITH durable_writes = false".format(self.keyspace_name) + ) assert cluster2.metadata.keyspaces[self.keyspace_name].durable_writes cluster2.refresh_keyspace_metadata(self.keyspace_name) assert not cluster2.metadata.keyspaces[self.keyspace_name].durable_writes @@ -707,17 +833,38 @@ def test_refresh_table_metadata(self): """ table_name = "test" - self.session.execute("CREATE TABLE {0}.{1} (a int PRIMARY KEY, b text)".format(self.keyspace_name, table_name)) + self.session.execute( + "CREATE TABLE {0}.{1} (a int PRIMARY KEY, b text)".format( + self.keyspace_name, table_name + ) + ) cluster2 = TestCluster(schema_event_refresh_window=-1) cluster2.connect() - assert "c" not in cluster2.metadata.keyspaces[self.keyspace_name].tables[table_name].columns - self.session.execute("ALTER TABLE {0}.{1} ADD c double".format(self.keyspace_name, table_name)) - assert "c" not in cluster2.metadata.keyspaces[self.keyspace_name].tables[table_name].columns + assert ( + "c" + not in cluster2.metadata.keyspaces[self.keyspace_name] + .tables[table_name] + .columns + ) + self.session.execute( + "ALTER TABLE {0}.{1} ADD c double".format(self.keyspace_name, table_name) + ) + assert ( + "c" + not in cluster2.metadata.keyspaces[self.keyspace_name] + .tables[table_name] + .columns + ) cluster2.refresh_table_metadata(self.keyspace_name, table_name) - assert "c" in cluster2.metadata.keyspaces[self.keyspace_name].tables[table_name].columns + assert ( + "c" + in cluster2.metadata.keyspaces[self.keyspace_name] + .tables[table_name] + .columns + ) cluster2.shutdown() @@ -741,44 +888,92 @@ def test_refresh_metadata_for_mv(self): @test_category metadata """ - self.session.execute("CREATE TABLE {0}.{1} (a int PRIMARY KEY, b text)".format(self.keyspace_name, self.function_table_name)) + self.session.execute( + "CREATE TABLE {0}.{1} (a int PRIMARY KEY, b text)".format( + self.keyspace_name, self.function_table_name + ) + ) cluster2 = TestCluster(schema_event_refresh_window=-1) cluster2.connect() try: - assert "mv1" not in cluster2.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views - self.session.execute("CREATE MATERIALIZED VIEW {0}.mv1 AS SELECT a, b FROM {0}.{1} " - "WHERE a IS NOT NULL AND b IS NOT NULL PRIMARY KEY (a, b)" - .format(self.keyspace_name, self.function_table_name)) - assert "mv1" not in cluster2.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views + assert ( + "mv1" + not in cluster2.metadata.keyspaces[self.keyspace_name] + .tables[self.function_table_name] + .views + ) + self.session.execute( + "CREATE MATERIALIZED VIEW {0}.mv1 AS SELECT a, b FROM {0}.{1} " + "WHERE a IS NOT NULL AND b IS NOT NULL PRIMARY KEY (a, b)".format( + self.keyspace_name, self.function_table_name + ) + ) + assert ( + "mv1" + not in cluster2.metadata.keyspaces[self.keyspace_name] + .tables[self.function_table_name] + .views + ) cluster2.refresh_table_metadata(self.keyspace_name, "mv1") - assert "mv1" in cluster2.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views + assert ( + "mv1" + in cluster2.metadata.keyspaces[self.keyspace_name] + .tables[self.function_table_name] + .views + ) finally: cluster2.shutdown() - original_meta = self.cluster.metadata.keyspaces[self.keyspace_name].views['mv1'] - assert original_meta is self.session.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views['mv1'] - self.cluster.refresh_materialized_view_metadata(self.keyspace_name, 'mv1') + original_meta = self.cluster.metadata.keyspaces[self.keyspace_name].views["mv1"] + assert ( + original_meta + is self.session.cluster.metadata.keyspaces[self.keyspace_name] + .tables[self.function_table_name] + .views["mv1"] + ) + self.cluster.refresh_materialized_view_metadata(self.keyspace_name, "mv1") - current_meta = self.cluster.metadata.keyspaces[self.keyspace_name].views['mv1'] + current_meta = self.cluster.metadata.keyspaces[self.keyspace_name].views["mv1"] assert current_meta is not original_meta - assert original_meta is not self.session.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views['mv1'] + assert ( + original_meta + is not self.session.cluster.metadata.keyspaces[self.keyspace_name] + .tables[self.function_table_name] + .views["mv1"] + ) assert original_meta.as_cql_query() == current_meta.as_cql_query() cluster3 = TestCluster(schema_event_refresh_window=-1) cluster3.connect() try: - assert "mv2" not in cluster3.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views + assert ( + "mv2" + not in cluster3.metadata.keyspaces[self.keyspace_name] + .tables[self.function_table_name] + .views + ) self.session.execute( "CREATE MATERIALIZED VIEW {0}.mv2 AS SELECT a, b FROM {0}.{1} " "WHERE a IS NOT NULL AND b IS NOT NULL PRIMARY KEY (a, b)".format( - self.keyspace_name, self.function_table_name) + self.keyspace_name, self.function_table_name + ) + ) + assert ( + "mv2" + not in cluster3.metadata.keyspaces[self.keyspace_name] + .tables[self.function_table_name] + .views + ) + cluster3.refresh_materialized_view_metadata(self.keyspace_name, "mv2") + assert ( + "mv2" + in cluster3.metadata.keyspaces[self.keyspace_name] + .tables[self.function_table_name] + .views ) - assert "mv2" not in cluster3.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views - cluster3.refresh_materialized_view_metadata(self.keyspace_name, 'mv2') - assert "mv2" in cluster3.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views finally: cluster3.shutdown() @@ -800,13 +995,19 @@ def test_refresh_user_type_metadata(self): """ if PROTOCOL_VERSION < 3: - raise unittest.SkipTest("Protocol 3+ is required for UDTs, currently testing against {0}".format(PROTOCOL_VERSION)) + raise unittest.SkipTest( + "Protocol 3+ is required for UDTs, currently testing against {0}".format( + PROTOCOL_VERSION + ) + ) cluster2 = TestCluster(schema_event_refresh_window=-1) cluster2.connect() assert cluster2.metadata.keyspaces[self.keyspace_name].user_types == {} - self.session.execute("CREATE TYPE {0}.user (age int, name text)".format(self.keyspace_name)) + self.session.execute( + "CREATE TYPE {0}.user (age int, name text)".format(self.keyspace_name) + ) assert cluster2.metadata.keyspaces[self.keyspace_name].user_types == {} cluster2.refresh_user_type_metadata(self.keyspace_name, "user") @@ -827,23 +1028,55 @@ def test_refresh_user_type_metadata_proto_2(self): """ supported_versions = get_supported_protocol_versions() if 2 not in supported_versions: # 1 and 2 were dropped in the same version - raise unittest.SkipTest("Protocol versions 1 and 2 are not supported in Cassandra version ".format(CASSANDRA_VERSION)) + raise unittest.SkipTest( + "Protocol versions 1 and 2 are not supported in Cassandra version ".format( + CASSANDRA_VERSION + ) + ) for protocol_version in (1, 2): cluster = TestCluster() session = cluster.connect() assert cluster.metadata.keyspaces[self.keyspace_name].user_types == {} - session.execute("CREATE TYPE {0}.user (age int, name text)".format(self.keyspace_name)) + session.execute( + "CREATE TYPE {0}.user (age int, name text)".format(self.keyspace_name) + ) assert "user" in cluster.metadata.keyspaces[self.keyspace_name].user_types - assert "age" in cluster.metadata.keyspaces[self.keyspace_name].user_types["user"].field_names - assert "name" in cluster.metadata.keyspaces[self.keyspace_name].user_types["user"].field_names + assert ( + "age" + in cluster.metadata.keyspaces[self.keyspace_name] + .user_types["user"] + .field_names + ) + assert ( + "name" + in cluster.metadata.keyspaces[self.keyspace_name] + .user_types["user"] + .field_names + ) - session.execute("ALTER TYPE {0}.user ADD flag boolean".format(self.keyspace_name)) - assert "flag" in cluster.metadata.keyspaces[self.keyspace_name].user_types["user"].field_names + session.execute( + "ALTER TYPE {0}.user ADD flag boolean".format(self.keyspace_name) + ) + assert ( + "flag" + in cluster.metadata.keyspaces[self.keyspace_name] + .user_types["user"] + .field_names + ) - session.execute("ALTER TYPE {0}.user RENAME flag TO something".format(self.keyspace_name)) - assert "something" in cluster.metadata.keyspaces[self.keyspace_name].user_types["user"].field_names + session.execute( + "ALTER TYPE {0}.user RENAME flag TO something".format( + self.keyspace_name + ) + ) + assert ( + "something" + in cluster.metadata.keyspaces[self.keyspace_name] + .user_types["user"] + .field_names + ) session.execute("DROP TYPE {0}.user".format(self.keyspace_name)) assert cluster.metadata.keyspaces[self.keyspace_name].user_types == {} @@ -869,20 +1102,33 @@ def test_refresh_user_function_metadata(self): """ if PROTOCOL_VERSION < 4: - raise unittest.SkipTest("Protocol 4+ is required for UDFs, currently testing against {0}".format(PROTOCOL_VERSION)) + raise unittest.SkipTest( + "Protocol 4+ is required for UDFs, currently testing against {0}".format( + PROTOCOL_VERSION + ) + ) cluster2 = TestCluster(schema_event_refresh_window=-1) cluster2.connect() assert cluster2.metadata.keyspaces[self.keyspace_name].functions == {} - self.session.execute("""CREATE FUNCTION {0}.sum_int(key int, val int) + self.session.execute( + """CREATE FUNCTION {0}.sum_int(key int, val int) RETURNS NULL ON NULL INPUT RETURNS int - LANGUAGE java AS ' return key + val;';""".format(self.keyspace_name)) + LANGUAGE java AS ' return key + val;';""".format( + self.keyspace_name + ) + ) assert cluster2.metadata.keyspaces[self.keyspace_name].functions == {} - cluster2.refresh_user_function_metadata(self.keyspace_name, UserFunctionDescriptor("sum_int", ["int", "int"])) - assert "sum_int(int,int)" in cluster2.metadata.keyspaces[self.keyspace_name].functions + cluster2.refresh_user_function_metadata( + self.keyspace_name, UserFunctionDescriptor("sum_int", ["int", "int"]) + ) + assert ( + "sum_int(int,int)" + in cluster2.metadata.keyspaces[self.keyspace_name].functions + ) cluster2.shutdown() @@ -906,26 +1152,39 @@ def test_refresh_user_aggregate_metadata(self): """ if PROTOCOL_VERSION < 4: - raise unittest.SkipTest("Protocol 4+ is required for UDAs, currently testing against {0}".format(PROTOCOL_VERSION)) + raise unittest.SkipTest( + "Protocol 4+ is required for UDAs, currently testing against {0}".format( + PROTOCOL_VERSION + ) + ) cluster2 = TestCluster(schema_event_refresh_window=-1) cluster2.connect() assert cluster2.metadata.keyspaces[self.keyspace_name].aggregates == {} - self.session.execute("""CREATE FUNCTION {0}.sum_int(key int, val int) + self.session.execute( + """CREATE FUNCTION {0}.sum_int(key int, val int) RETURNS NULL ON NULL INPUT RETURNS int - LANGUAGE java AS 'return key + val;';""".format(self.keyspace_name)) + LANGUAGE java AS 'return key + val;';""".format( + self.keyspace_name + ) + ) - self.session.execute("""CREATE AGGREGATE {0}.sum_agg(int) + self.session.execute( + """CREATE AGGREGATE {0}.sum_agg(int) SFUNC sum_int STYPE int - INITCOND 0""" - .format(self.keyspace_name)) + INITCOND 0""".format(self.keyspace_name) + ) assert cluster2.metadata.keyspaces[self.keyspace_name].aggregates == {} - cluster2.refresh_user_aggregate_metadata(self.keyspace_name, UserAggregateDescriptor("sum_agg", ["int"])) - assert "sum_agg(int)" in cluster2.metadata.keyspaces[self.keyspace_name].aggregates + cluster2.refresh_user_aggregate_metadata( + self.keyspace_name, UserAggregateDescriptor("sum_agg", ["int"]) + ) + assert ( + "sum_agg(int)" in cluster2.metadata.keyspaces[self.keyspace_name].aggregates + ) cluster2.shutdown() @@ -944,14 +1203,30 @@ def test_multiple_indices(self): @test_category metadata """ - self.session.execute("CREATE TABLE {0}.{1} (a int PRIMARY KEY, b map)".format(self.keyspace_name, self.function_table_name)) - self.session.execute("CREATE INDEX index_1 ON {0}.{1}(b)".format(self.keyspace_name, self.function_table_name)) - self.session.execute("CREATE INDEX index_2 ON {0}.{1}(keys(b))".format(self.keyspace_name, self.function_table_name)) + self.session.execute( + "CREATE TABLE {0}.{1} (a int PRIMARY KEY, b map)".format( + self.keyspace_name, self.function_table_name + ) + ) + self.session.execute( + "CREATE INDEX index_1 ON {0}.{1}(b)".format( + self.keyspace_name, self.function_table_name + ) + ) + self.session.execute( + "CREATE INDEX index_2 ON {0}.{1}(keys(b))".format( + self.keyspace_name, self.function_table_name + ) + ) - indices = self.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].indexes + indices = ( + self.cluster.metadata.keyspaces[self.keyspace_name] + .tables[self.function_table_name] + .indexes + ) assert len(indices) == 2 index_1 = indices["index_1"] - index_2 = indices['index_2'] + index_2 = indices["index_2"] assert index_1.table_name == "test_multiple_indices" assert index_1.name == "index_1" assert index_1.kind == "COMPOSITES" @@ -969,7 +1244,7 @@ def test_table_extensions(self): ks = self.keyspace_name ks_meta = s.cluster.metadata.keyspaces[ks] t = self.function_table_name - v = t + 'view' + v = t + "view" s.execute("CREATE TABLE %s.%s (k text PRIMARY KEY, v int)" % (ks, t)) s.execute( @@ -993,7 +1268,7 @@ def after_table_cql(cls, table_meta, ext_key, ext_blob): return "%s %s %s %s" % (cls.name, table_meta.name, ext_key, ext_blob) class Ext1(Ext0): - name = t + '##' + name = t + "##" assert Ext0.name in _RegisteredExtensionType._extension_registry assert Ext1.name in _RegisteredExtensionType._extension_registry @@ -1007,13 +1282,22 @@ class Ext1(Ext0): assert table_meta.export_as_string() == original_table_cql assert view_meta.export_as_string() == original_view_cql - update_t = s.prepare('UPDATE system_schema.tables SET extensions=? WHERE keyspace_name=? AND table_name=?') # for blob type coercing - update_v = s.prepare('UPDATE system_schema.views SET extensions=? WHERE keyspace_name=? AND view_name=?') + update_t = s.prepare( + "UPDATE system_schema.tables SET extensions=? WHERE keyspace_name=? AND table_name=?" + ) # for blob type coercing + update_v = s.prepare( + "UPDATE system_schema.views SET extensions=? WHERE keyspace_name=? AND view_name=?" + ) # extensions registered, one present # -------------------------------------- ext_map = {Ext0.name: b"THA VALUE"} - [(s.execute(update_t, (ext_map, ks, t)), s.execute(update_v, (ext_map, ks, v))) - for _ in self.cluster.metadata.all_hosts()] # we're manipulating metadata - do it on all hosts + [ + ( + s.execute(update_t, (ext_map, ks, t)), + s.execute(update_v, (ext_map, ks, v)), + ) + for _ in self.cluster.metadata.all_hosts() + ] # we're manipulating metadata - do it on all hosts self.cluster.refresh_table_metadata(ks, t) self.cluster.refresh_materialized_view_metadata(ks, v) table_meta = ks_meta.tables[t] @@ -1022,7 +1306,9 @@ class Ext1(Ext0): assert Ext0.name in table_meta.extensions new_cql = table_meta.export_as_string() assert new_cql != original_table_cql - assert Ext0.after_table_cql(table_meta, Ext0.name, ext_map[Ext0.name]) in new_cql + assert ( + Ext0.after_table_cql(table_meta, Ext0.name, ext_map[Ext0.name]) in new_cql + ) assert Ext1.name not in new_cql assert Ext0.name in view_meta.extensions @@ -1033,10 +1319,14 @@ class Ext1(Ext0): # extensions registered, one present # -------------------------------------- - ext_map = {Ext0.name: b"THA VALUE", - Ext1.name: b"OTHA VALUE"} - [(s.execute(update_t, (ext_map, ks, t)), s.execute(update_v, (ext_map, ks, v))) - for _ in self.cluster.metadata.all_hosts()] # we're manipulating metadata - do it on all hosts + ext_map = {Ext0.name: b"THA VALUE", Ext1.name: b"OTHA VALUE"} + [ + ( + s.execute(update_t, (ext_map, ks, t)), + s.execute(update_v, (ext_map, ks, v)), + ) + for _ in self.cluster.metadata.all_hosts() + ] # we're manipulating metadata - do it on all hosts self.cluster.refresh_table_metadata(ks, t) self.cluster.refresh_materialized_view_metadata(ks, v) table_meta = ks_meta.tables[t] @@ -1046,8 +1336,12 @@ class Ext1(Ext0): assert Ext1.name in table_meta.extensions new_cql = table_meta.export_as_string() assert new_cql != original_table_cql - assert Ext0.after_table_cql(table_meta, Ext0.name, ext_map[Ext0.name]) in new_cql - assert Ext1.after_table_cql(table_meta, Ext1.name, ext_map[Ext1.name]) in new_cql + assert ( + Ext0.after_table_cql(table_meta, Ext0.name, ext_map[Ext0.name]) in new_cql + ) + assert ( + Ext1.after_table_cql(table_meta, Ext1.name, ext_map[Ext1.name]) in new_cql + ) assert Ext0.name in view_meta.extensions assert Ext1.name in view_meta.extensions @@ -1059,8 +1353,10 @@ class Ext1(Ext0): def test_metadata_pagination(self): self.cluster.refresh_schema_metadata() for i in range(12): - self.session.execute("CREATE TABLE %s.%s_%d (a int PRIMARY KEY, b map)" - % (self.keyspace_name, self.function_table_name, i)) + self.session.execute( + "CREATE TABLE %s.%s_%d (a int PRIMARY KEY, b map)" + % (self.keyspace_name, self.function_table_name, i) + ) self.cluster.schema_metadata_page_size = 5 self.cluster.refresh_schema_metadata() @@ -1077,7 +1373,7 @@ def test_metadata_pagination_keyspaces(self): for ks in keyspaces: self.session.execute( - f"CREATE KEYSPACE IF NOT EXISTS {ks} WITH REPLICATION = {{ 'class' : 'SimpleStrategy', 'replication_factor' : 3 }}" + f"CREATE KEYSPACE IF NOT EXISTS {ks} WITH REPLICATION = {{ 'class' : 'NetworkTopologyStrategy', 'replication_factor' : 3 }}" ) self.cluster.schema_metadata_page_size = 2000 @@ -1093,7 +1389,6 @@ def test_metadata_pagination_keyspaces(self): class TestCodeCoverage(unittest.TestCase): - def test_export_schema(self): """ Test export schema functionality @@ -1128,17 +1423,20 @@ def test_export_keyspace_schema_udts(self): if PROTOCOL_VERSION < 3: raise unittest.SkipTest( "Protocol 3.0+ is required for UDT change events, currently testing against %r" - % (PROTOCOL_VERSION,)) + % (PROTOCOL_VERSION,) + ) if sys.version_info[0:2] != (2, 7): - raise unittest.SkipTest('This test compares static strings generated from dict items, which may change orders. Test with 2.7.') + raise unittest.SkipTest( + "This test compares static strings generated from dict items, which may change orders. Test with 2.7." + ) cluster = TestCluster() session = cluster.connect() session.execute(""" CREATE KEYSPACE export_udts - WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'} + WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': '1'} AND durable_writes = true; """) session.execute(""" @@ -1162,7 +1460,7 @@ def test_export_keyspace_schema_udts(self): addresses map>) """) - expected_prefix = """CREATE KEYSPACE export_udts WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'} AND durable_writes = true; + expected_prefix = """CREATE KEYSPACE export_udts WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': '1'} AND durable_writes = true; CREATE TYPE export_udts.street ( street_number int, @@ -1183,9 +1481,12 @@ def test_export_keyspace_schema_udts(self): user text PRIMARY KEY, addresses map>""" - assert_startswith_diff(cluster.metadata.keyspaces['export_udts'].export_as_string(), expected_prefix) + assert_startswith_diff( + cluster.metadata.keyspaces["export_udts"].export_as_string(), + expected_prefix, + ) - table_meta = cluster.metadata.keyspaces['export_udts'].tables['users'] + table_meta = cluster.metadata.keyspaces["export_udts"].tables["users"] expected_prefix = """CREATE TABLE export_udts.users ( user text PRIMARY KEY, @@ -1196,8 +1497,10 @@ def test_export_keyspace_schema_udts(self): cluster.shutdown() @greaterthancass21 - @xfail_scylla_version_lt(reason='scylladb/scylladb#10707 - Column name in CREATE INDEX is not quoted', - scylla_version="2023.1.1") + @xfail_scylla_version_lt( + reason="scylladb/scylladb#10707 - Column name in CREATE INDEX is not quoted", + scylla_version="2023.1.1", + ) def test_case_sensitivity(self): """ Test that names that need to be escaped in CREATE statements are @@ -1206,15 +1509,19 @@ def test_case_sensitivity(self): cluster = TestCluster() session = cluster.connect() - ksname = 'AnInterestingKeyspace' - cfname = 'AnInterestingTable' + ksname = "AnInterestingKeyspace" + cfname = "AnInterestingTable" session.execute("DROP KEYSPACE IF EXISTS {0}".format(ksname)) - session.execute(""" + session.execute( + """ CREATE KEYSPACE "%s" - WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'} - """ % (ksname,)) - session.execute(""" + WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': '1'} + """ + % (ksname,) + ) + session.execute( + """ CREATE TABLE "%s"."%s" ( k int, "A" int, @@ -1222,13 +1529,21 @@ def test_case_sensitivity(self): "MyColumn" int, PRIMARY KEY (k, "A")) WITH CLUSTERING ORDER BY ("A" DESC) - """ % (ksname, cfname)) - session.execute(""" + """ + % (ksname, cfname) + ) + session.execute( + """ CREATE INDEX myindex ON "%s"."%s" ("MyColumn") - """ % (ksname, cfname)) - session.execute(""" + """ + % (ksname, cfname) + ) + session.execute( + """ CREATE INDEX "AnotherIndex" ON "%s"."%s" ("B") - """ % (ksname, cfname)) + """ + % (ksname, cfname) + ) ksmeta = cluster.metadata.keyspaces[ksname] schema = ksmeta.export_as_string() @@ -1239,8 +1554,14 @@ def test_case_sensitivity(self): assert '"MyColumn" int' in schema assert 'PRIMARY KEY (k, "A")' in schema assert 'WITH CLUSTERING ORDER BY ("A" DESC)' in schema - assert 'CREATE INDEX myindex ON "AnInterestingKeyspace"."AnInterestingTable" ("MyColumn")' in schema - assert 'CREATE INDEX "AnotherIndex" ON "AnInterestingKeyspace"."AnInterestingTable" ("B")' in schema + assert ( + 'CREATE INDEX myindex ON "AnInterestingKeyspace"."AnInterestingTable" ("MyColumn")' + in schema + ) + assert ( + 'CREATE INDEX "AnotherIndex" ON "AnInterestingKeyspace"."AnInterestingTable" ("B")' + in schema + ) cluster.shutdown() def test_already_exists_exceptions(self): @@ -1251,41 +1572,43 @@ def test_already_exists_exceptions(self): cluster = TestCluster() session = cluster.connect() - ksname = 'test3rf' - cfname = 'test' + ksname = "test3rf" + cfname = "test" - ddl = ''' + ddl = """ CREATE KEYSPACE %s - WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '3'}''' + WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': '3'}""" with pytest.raises(AlreadyExists): session.execute(ddl % ksname) - ddl = ''' + ddl = """ CREATE TABLE %s.%s ( k int PRIMARY KEY, - v int )''' + v int )""" with pytest.raises(AlreadyExists): session.execute(ddl % (ksname, cfname)) cluster.shutdown() @local - @pytest.mark.xfail(reason='AssertionError: \'RAC1\' != \'r1\' - probably a bug in driver or in Scylla') + @pytest.mark.xfail( + reason="AssertionError: 'RAC1' != 'r1' - probably a bug in driver or in Scylla" + ) def test_replicas(self): """ Ensure cluster.metadata.get_replicas return correctly when not attached to keyspace """ if murmur3 is None: - raise unittest.SkipTest('the murmur3 extension is not available') + raise unittest.SkipTest("the murmur3 extension is not available") cluster = TestCluster() - assert cluster.metadata.get_replicas('test3rf', 'key') == [] + assert cluster.metadata.get_replicas("test3rf", "key") == [] - cluster.connect('test3rf') + cluster.connect("test3rf") - assert list(cluster.metadata.get_replicas('test3rf', b'key')) != [] - host = list(cluster.metadata.get_replicas('test3rf', b'key'))[0] - assert host.datacenter == 'dc1' - assert host.rack == 'r1' + assert list(cluster.metadata.get_replicas("test3rf", b"key")) != [] + host = list(cluster.metadata.get_replicas("test3rf", b"key"))[0] + assert host.datacenter == "dc1" + assert host.rack == "r1" cluster.shutdown() def test_token_map(self): @@ -1294,18 +1617,22 @@ def test_token_map(self): """ cluster = TestCluster() - cluster.connect('test3rf') + cluster.connect("test3rf") ring = cluster.metadata.token_map.ring - owners = list(cluster.metadata.token_map.token_to_host_owner[token] for token in ring) + owners = list( + cluster.metadata.token_map.token_to_host_owner[token] for token in ring + ) get_replicas = cluster.metadata.token_map.get_replicas - for ksname in ('test1rf', 'test2rf', 'test3rf'): + for ksname in ("test1rf", "test2rf", "test3rf"): assert list(get_replicas(ksname, ring[0])) != [] for i, token in enumerate(ring): - assert set(get_replicas('test3rf', token)) == set(owners) - assert set(get_replicas('test2rf', token)) == set([owners[i], owners[(i + 1) % 3]]) - assert set(get_replicas('test1rf', token)) == set([owners[i]]) + assert set(get_replicas("test3rf", token)) == set(owners) + assert set(get_replicas("test2rf", token)) == set( + [owners[i], owners[(i + 1) % 3]] + ) + assert set(get_replicas("test1rf", token)) == set([owners[i]]) cluster.shutdown() @@ -1313,6 +1640,7 @@ class TokenMetadataTest(unittest.TestCase): """ Test of TokenMap creation and other behavior. """ + @local def test_token(self): expected_node_count = len(get_cluster().nodes) @@ -1334,35 +1662,38 @@ class TestMetadataTimeout: "opts, expected_query_chunk", [ ( - {"metadata_request_timeout": None}, - # Should be borrowed from control_connection_timeout - "USING TIMEOUT 2000ms" + {"metadata_request_timeout": None}, + # Should be borrowed from control_connection_timeout + "USING TIMEOUT 2000ms", ), + ({"metadata_request_timeout": 0.0}, False), + ({"metadata_request_timeout": 4.0}, "USING TIMEOUT 4000ms"), ( - {"metadata_request_timeout": 0.0}, - False + {"metadata_request_timeout": None, "control_connection_timeout": None}, + False, ), - ( - {"metadata_request_timeout": 4.0}, - "USING TIMEOUT 4000ms" - ), - ( - {"metadata_request_timeout": None, "control_connection_timeout": None}, - False, - ) ], - ids=["default", "zero", "4s", "both none"] + ids=["default", "zero", "4s", "both none"], ) def test_timeout(self, opts, expected_query_chunk): cluster = TestCluster(**opts) stmts = [] class ConnectionWrapper(cluster.connection_class): - def send_msg(self, msg, request_id, cb, encoder=ProtocolHandler.encode_message, - decoder=ProtocolHandler.decode_message, result_metadata=None): + def send_msg( + self, + msg, + request_id, + cb, + encoder=ProtocolHandler.encode_message, + decoder=ProtocolHandler.decode_message, + result_metadata=None, + ): if isinstance(msg, QueryMessage): stmts.append(msg.query) - return super(ConnectionWrapper, self).send_msg(msg, request_id, cb, encoder, decoder, result_metadata) + return super(ConnectionWrapper, self).send_msg( + msg, request_id, cb, encoder, decoder, result_metadata + ) cluster.connection_class = ConnectionWrapper s = cluster.connect() @@ -1373,26 +1704,34 @@ def send_msg(self, msg, request_id, cb, encoder=ProtocolHandler.encode_message, if "SELECT now() FROM system.local WHERE key='local'" in stmt: continue if expected_query_chunk: - assert expected_query_chunk in stmt, f"query `{stmt}` does not contain `{expected_query_chunk}`" + assert expected_query_chunk in stmt, ( + f"query `{stmt}` does not contain `{expected_query_chunk}`" + ) else: - assert 'USING TIMEOUT' not in stmt, f"query `{stmt}` should not contain `USING TIMEOUT`" + assert "USING TIMEOUT" not in stmt, ( + f"query `{stmt}` should not contain `USING TIMEOUT`" + ) class KeyspaceAlterMetadata(unittest.TestCase): """ Test verifies that table metadata is preserved on keyspace alter """ + def setUp(self): self.cluster = TestCluster() self.session = self.cluster.connect() name = self._testMethodName.lower() - crt_ks = ''' - CREATE KEYSPACE %s WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1} AND durable_writes = true''' % name + crt_ks = ( + """ + CREATE KEYSPACE %s WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': 1} AND durable_writes = true""" + % name + ) self.session.execute(crt_ks) def tearDown(self): name = self._testMethodName.lower() - self.session.execute('DROP KEYSPACE %s' % name) + self.session.execute("DROP KEYSPACE %s" % name) self.cluster.shutdown() def test_keyspace_alter(self): @@ -1407,20 +1746,19 @@ def test_keyspace_alter(self): """ name = self._testMethodName.lower() - self.session.execute('CREATE TABLE %s.d (d INT PRIMARY KEY)' % name) + self.session.execute("CREATE TABLE %s.d (d INT PRIMARY KEY)" % name) original_keyspace_meta = self.cluster.metadata.keyspaces[name] assert original_keyspace_meta.durable_writes == True assert len(original_keyspace_meta.tables) == 1 - self.session.execute('ALTER KEYSPACE %s WITH durable_writes = false' % name) + self.session.execute("ALTER KEYSPACE %s WITH durable_writes = false" % name) new_keyspace_meta = self.cluster.metadata.keyspaces[name] assert original_keyspace_meta != new_keyspace_meta assert new_keyspace_meta.durable_writes == False class IndexMapTests(unittest.TestCase): - - keyspace_name = 'index_map_tests' + keyspace_name = "index_map_tests" @property def table_name(self): @@ -1437,8 +1775,10 @@ def setup_class(cls): cls.session.execute( """ CREATE KEYSPACE %s - WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'}; - """ % cls.keyspace_name) + WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': '1'}; + """ + % cls.keyspace_name + ) cls.session.set_keyspace(cls.keyspace_name) except Exception: cls.cluster.shutdown() @@ -1452,7 +1792,9 @@ def teardown_class(cls): cls.cluster.shutdown() def create_basic_table(self): - self.session.execute("CREATE TABLE %s (k int PRIMARY KEY, a int)" % self.table_name) + self.session.execute( + "CREATE TABLE %s (k int PRIMARY KEY, a int)" % self.table_name + ) def drop_basic_table(self): self.session.execute("DROP TABLE %s" % self.table_name) @@ -1462,10 +1804,10 @@ def test_index_updates(self): ks_meta = self.cluster.metadata.keyspaces[self.keyspace_name] table_meta = ks_meta.tables[self.table_name] - assert 'a_idx' not in ks_meta.indexes - assert 'b_idx' not in ks_meta.indexes - assert 'a_idx' not in table_meta.indexes - assert 'b_idx' not in table_meta.indexes + assert "a_idx" not in ks_meta.indexes + assert "b_idx" not in ks_meta.indexes + assert "a_idx" not in table_meta.indexes + assert "b_idx" not in table_meta.indexes self.session.execute("CREATE INDEX a_idx ON %s (a)" % self.table_name) self.session.execute("ALTER TABLE %s ADD b int" % self.table_name) @@ -1473,10 +1815,10 @@ def test_index_updates(self): ks_meta = self.cluster.metadata.keyspaces[self.keyspace_name] table_meta = ks_meta.tables[self.table_name] - assert isinstance(ks_meta.indexes['a_idx'], IndexMetadata) - assert isinstance(ks_meta.indexes['b_idx'], IndexMetadata) - assert isinstance(table_meta.indexes['a_idx'], IndexMetadata) - assert isinstance(table_meta.indexes['b_idx'], IndexMetadata) + assert isinstance(ks_meta.indexes["a_idx"], IndexMetadata) + assert isinstance(ks_meta.indexes["b_idx"], IndexMetadata) + assert isinstance(table_meta.indexes["a_idx"], IndexMetadata) + assert isinstance(table_meta.indexes["b_idx"], IndexMetadata) # both indexes updated when index dropped self.session.execute("DROP INDEX a_idx") @@ -1486,28 +1828,30 @@ def test_index_updates(self): ks_meta = self.cluster.metadata.keyspaces[self.keyspace_name] table_meta = ks_meta.tables[self.table_name] - assert 'a_idx' not in ks_meta.indexes - assert isinstance(ks_meta.indexes['b_idx'], IndexMetadata) - assert 'a_idx' not in table_meta.indexes - assert isinstance(table_meta.indexes['b_idx'], IndexMetadata) + assert "a_idx" not in ks_meta.indexes + assert isinstance(ks_meta.indexes["b_idx"], IndexMetadata) + assert "a_idx" not in table_meta.indexes + assert isinstance(table_meta.indexes["b_idx"], IndexMetadata) # keyspace index updated when table dropped self.drop_basic_table() ks_meta = self.cluster.metadata.keyspaces[self.keyspace_name] assert self.table_name not in ks_meta.tables - assert 'a_idx' not in ks_meta.indexes - assert 'b_idx' not in ks_meta.indexes + assert "a_idx" not in ks_meta.indexes + assert "b_idx" not in ks_meta.indexes def test_index_follows_alter(self): self.create_basic_table() - idx = self.table_name + '_idx' + idx = self.table_name + "_idx" self.session.execute("CREATE INDEX %s ON %s (a)" % (idx, self.table_name)) ks_meta = self.cluster.metadata.keyspaces[self.keyspace_name] table_meta = ks_meta.tables[self.table_name] assert isinstance(ks_meta.indexes[idx], IndexMetadata) assert isinstance(table_meta.indexes[idx], IndexMetadata) - self.session.execute('ALTER KEYSPACE %s WITH durable_writes = false' % self.keyspace_name) + self.session.execute( + "ALTER KEYSPACE %s WITH durable_writes = false" % self.keyspace_name + ) old_meta = ks_meta ks_meta = self.cluster.metadata.keyspaces[self.keyspace_name] assert ks_meta is not old_meta @@ -1516,6 +1860,7 @@ def test_index_follows_alter(self): assert isinstance(table_meta.indexes[idx], IndexMetadata) self.drop_basic_table() + @requires_java_udf class FunctionTest(unittest.TestCase): """ @@ -1528,7 +1873,9 @@ def setUp(self): """ if PROTOCOL_VERSION < 4: - raise unittest.SkipTest("Function metadata requires native protocol version 4+") + raise unittest.SkipTest( + "Function metadata requires native protocol version 4+" + ) @property def function_name(self): @@ -1540,10 +1887,17 @@ def setup_class(cls): cls.cluster = TestCluster() cls.keyspace_name = cls.__name__.lower() cls.session = cls.cluster.connect() - cls.session.execute("CREATE KEYSPACE IF NOT EXISTS %s WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1}" % cls.keyspace_name) + cls.session.execute( + "CREATE KEYSPACE IF NOT EXISTS %s WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': 1}" + % cls.keyspace_name + ) cls.session.set_keyspace(cls.keyspace_name) - cls.keyspace_function_meta = cls.cluster.metadata.keyspaces[cls.keyspace_name].functions - cls.keyspace_aggregate_meta = cls.cluster.metadata.keyspaces[cls.keyspace_name].aggregates + cls.keyspace_function_meta = cls.cluster.metadata.keyspaces[ + cls.keyspace_name + ].functions + cls.keyspace_aggregate_meta = cls.cluster.metadata.keyspaces[ + cls.keyspace_name + ].aggregates @classmethod def teardown_class(cls): @@ -1552,7 +1906,6 @@ def teardown_class(cls): cls.cluster.shutdown() class Verified(object): - def __init__(self, test_case, meta_class, element_meta, **function_kwargs): self.test_case = test_case self.function_kwargs = dict(function_kwargs) @@ -1572,38 +1925,47 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): tc = self.test_case - tc.session.execute("DROP %s %s.%s" % (self.meta_class.__name__, tc.keyspace_name, self.signature)) + tc.session.execute( + "DROP %s %s.%s" + % (self.meta_class.__name__, tc.keyspace_name, self.signature) + ) assert self.signature not in self.element_meta @property def signature(self): - return SignatureDescriptor.format_signature(self.function_kwargs['name'], - self.function_kwargs['argument_types']) + return SignatureDescriptor.format_signature( + self.function_kwargs["name"], self.function_kwargs["argument_types"] + ) class VerifiedFunction(Verified): def __init__(self, test_case, **kwargs): - super(FunctionTest.VerifiedFunction, self).__init__(test_case, Function, test_case.keyspace_function_meta, **kwargs) + super(FunctionTest.VerifiedFunction, self).__init__( + test_case, Function, test_case.keyspace_function_meta, **kwargs + ) class VerifiedAggregate(Verified): def __init__(self, test_case, **kwargs): - super(FunctionTest.VerifiedAggregate, self).__init__(test_case, Aggregate, test_case.keyspace_aggregate_meta, **kwargs) + super(FunctionTest.VerifiedAggregate, self).__init__( + test_case, Aggregate, test_case.keyspace_aggregate_meta, **kwargs + ) @requires_java_udf class FunctionMetadata(FunctionTest): - def make_function_kwargs(self, called_on_null=True): - return {'keyspace': self.keyspace_name, - 'name': self.function_name, - 'argument_types': ['double', 'int'], - 'argument_names': ['d', 'i'], - 'return_type': 'double', - 'language': 'java', - 'body': 'return new Double(0.0);', - 'called_on_null_input': called_on_null, - 'deterministic': False, - 'monotonic': False, - 'monotonic_on': []} + return { + "keyspace": self.keyspace_name, + "name": self.function_name, + "argument_types": ["double", "int"], + "argument_names": ["d", "i"], + "return_type": "double", + "language": "java", + "body": "return new Double(0.0);", + "called_on_null_input": called_on_null, + "deterministic": False, + "monotonic": False, + "monotonic_on": [], + } def test_functions_after_udt(self): """ @@ -1629,15 +1991,19 @@ def test_functions_after_udt(self): assert self.function_name not in self.keyspace_function_meta - udt_name = 'udtx' + udt_name = "udtx" self.session.execute("CREATE TYPE %s (x int)" % udt_name) with self.VerifiedFunction(self, **self.make_function_kwargs()): # udts must come before functions in keyspace dump - keyspace_cql = self.cluster.metadata.keyspaces[self.keyspace_name].export_as_string() + keyspace_cql = self.cluster.metadata.keyspaces[ + self.keyspace_name + ].export_as_string() type_idx = keyspace_cql.rfind("CREATE TYPE") func_idx = keyspace_cql.find("CREATE FUNCTION") - assert -1 not in (type_idx, func_idx), "TYPE or FUNCTION not found in keyspace_cql: " + keyspace_cql + assert -1 not in (type_idx, func_idx), ( + "TYPE or FUNCTION not found in keyspace_cql: " + keyspace_cql + ) assert func_idx > type_idx def test_function_same_name_diff_types(self): @@ -1656,16 +2022,19 @@ def test_function_same_name_diff_types(self): # Create a function kwargs = self.make_function_kwargs() with self.VerifiedFunction(self, **kwargs): - # another function: same name, different type sig. - assert len(kwargs['argument_types']) > 1 - assert len(kwargs['argument_names']) > 1 - kwargs['argument_types'] = kwargs['argument_types'][:1] - kwargs['argument_names'] = kwargs['argument_names'][:1] + assert len(kwargs["argument_types"]) > 1 + assert len(kwargs["argument_names"]) > 1 + kwargs["argument_types"] = kwargs["argument_types"][:1] + kwargs["argument_names"] = kwargs["argument_names"][:1] # Ensure they are surfaced separately with self.VerifiedFunction(self, **kwargs): - functions = [f for f in self.keyspace_function_meta.values() if f.name == self.function_name] + functions = [ + f + for f in self.keyspace_function_meta.values() + if f.name == self.function_name + ] assert len(functions) == 2 assert functions[0].argument_types != functions[1].argument_types @@ -1681,14 +2050,16 @@ def test_function_no_parameters(self): @test_category function """ kwargs = self.make_function_kwargs() - kwargs['argument_types'] = [] - kwargs['argument_names'] = [] - kwargs['return_type'] = 'bigint' - kwargs['body'] = 'return System.currentTimeMillis() / 1000L;' + kwargs["argument_types"] = [] + kwargs["argument_names"] = [] + kwargs["return_type"] = "bigint" + kwargs["body"] = "return System.currentTimeMillis() / 1000L;" with self.VerifiedFunction(self, **kwargs) as vf: fn_meta = self.keyspace_function_meta[vf.signature] - assertRegex(fn_meta.as_cql_query(), r'CREATE FUNCTION.*%s\(\) .*' % kwargs['name']) + assertRegex( + fn_meta.as_cql_query(), r"CREATE FUNCTION.*%s\(\) .*" % kwargs["name"] + ) def test_functions_follow_keyspace_alter(self): """ @@ -1707,7 +2078,9 @@ def test_functions_follow_keyspace_alter(self): # Create function with self.VerifiedFunction(self, **self.make_function_kwargs()): original_keyspace_meta = self.cluster.metadata.keyspaces[self.keyspace_name] - self.session.execute('ALTER KEYSPACE %s WITH durable_writes = false' % self.keyspace_name) + self.session.execute( + "ALTER KEYSPACE %s WITH durable_writes = false" % self.keyspace_name + ) # After keyspace alter ensure that we maintain function equality. try: @@ -1715,7 +2088,9 @@ def test_functions_follow_keyspace_alter(self): assert original_keyspace_meta != new_keyspace_meta assert original_keyspace_meta.functions is new_keyspace_meta.functions finally: - self.session.execute('ALTER KEYSPACE %s WITH durable_writes = true' % self.keyspace_name) + self.session.execute( + "ALTER KEYSPACE %s WITH durable_writes = true" % self.keyspace_name + ) def test_function_cql_called_on_null(self): """ @@ -1733,20 +2108,25 @@ def test_function_cql_called_on_null(self): """ kwargs = self.make_function_kwargs() - kwargs['called_on_null_input'] = True + kwargs["called_on_null_input"] = True with self.VerifiedFunction(self, **kwargs) as vf: fn_meta = self.keyspace_function_meta[vf.signature] - assertRegex(fn_meta.as_cql_query(), r'CREATE FUNCTION.*\) CALLED ON NULL INPUT RETURNS .*') + assertRegex( + fn_meta.as_cql_query(), + r"CREATE FUNCTION.*\) CALLED ON NULL INPUT RETURNS .*", + ) - kwargs['called_on_null_input'] = False + kwargs["called_on_null_input"] = False with self.VerifiedFunction(self, **kwargs) as vf: fn_meta = self.keyspace_function_meta[vf.signature] - assertRegex(fn_meta.as_cql_query(), r'CREATE FUNCTION.*\) RETURNS NULL ON NULL INPUT RETURNS .*') + assertRegex( + fn_meta.as_cql_query(), + r"CREATE FUNCTION.*\) RETURNS NULL ON NULL INPUT RETURNS .*", + ) @requires_java_udf class AggregateMetadata(FunctionTest): - @classmethod def setup_class(cls): if PROTOCOL_VERSION >= 4: @@ -1778,16 +2158,20 @@ def setup_class(cls): cls.session.execute("INSERT INTO t (k,v) VALUES (%s, %s)", (x, x)) cls.session.execute("INSERT INTO t (k) VALUES (%s)", (4,)) - def make_aggregate_kwargs(self, state_func, state_type, final_func=None, init_cond=None): - return {'keyspace': self.keyspace_name, - 'name': self.function_name + '_aggregate', - 'argument_types': ['int'], - 'state_func': state_func, - 'state_type': state_type, - 'final_func': final_func, - 'initial_condition': init_cond, - 'return_type': "does not matter for creation", - 'deterministic': False} + def make_aggregate_kwargs( + self, state_func, state_type, final_func=None, init_cond=None + ): + return { + "keyspace": self.keyspace_name, + "name": self.function_name + "_aggregate", + "argument_types": ["int"], + "state_func": state_func, + "state_type": state_type, + "final_func": final_func, + "initial_condition": init_cond, + "return_type": "does not matter for creation", + "deterministic": False, + } def test_return_type_meta(self): """ @@ -1803,8 +2187,10 @@ def test_return_type_meta(self): @test_category aggregate """ - with self.VerifiedAggregate(self, **self.make_aggregate_kwargs('sum_int', 'int', init_cond='1')) as va: - assert self.keyspace_aggregate_meta[va.signature].return_type == 'int' + with self.VerifiedAggregate( + self, **self.make_aggregate_kwargs("sum_int", "int", init_cond="1") + ) as va: + assert self.keyspace_aggregate_meta[va.signature].return_type == "int" def test_init_cond(self): """ @@ -1831,27 +2217,59 @@ def test_init_cond(self): # int32 for init_cond in (-1, 0, 1): cql_init = encoder.cql_encode_all_types(init_cond) - with self.VerifiedAggregate(self, **self.make_aggregate_kwargs('sum_int', 'int', init_cond=cql_init)) as va: - sum_res = s.execute("SELECT %s(v) AS sum FROM t" % va.function_kwargs['name']).one().sum + with self.VerifiedAggregate( + self, **self.make_aggregate_kwargs("sum_int", "int", init_cond=cql_init) + ) as va: + sum_res = ( + s.execute("SELECT %s(v) AS sum FROM t" % va.function_kwargs["name"]) + .one() + .sum + ) assert sum_res == int(init_cond) + sum(expected_values) # list - for init_cond in ([], ['1', '2']): + for init_cond in ([], ["1", "2"]): cql_init = encoder.cql_encode_all_types(init_cond) - with self.VerifiedAggregate(self, **self.make_aggregate_kwargs('extend_list', 'list', init_cond=cql_init)) as va: - list_res = s.execute("SELECT %s(v) AS list_res FROM t" % va.function_kwargs['name']).one().list_res - assertListEqual(list_res[:len(init_cond)], init_cond) - assert set(i for i in list_res[len(init_cond):]) == set(str(i) for i in expected_values) + with self.VerifiedAggregate( + self, + **self.make_aggregate_kwargs( + "extend_list", "list", init_cond=cql_init + ), + ) as va: + list_res = ( + s.execute( + "SELECT %s(v) AS list_res FROM t" % va.function_kwargs["name"] + ) + .one() + .list_res + ) + assertListEqual(list_res[: len(init_cond)], init_cond) + assert set(i for i in list_res[len(init_cond) :]) == set( + str(i) for i in expected_values + ) # map expected_map_values = dict((i, i) for i in expected_values) expected_key_set = set(expected_values) for init_cond in ({}, {1: 2, 3: 4}, {5: 5}): cql_init = encoder.cql_encode_all_types(init_cond) - with self.VerifiedAggregate(self, **self.make_aggregate_kwargs('update_map', 'map', init_cond=cql_init)) as va: - map_res = s.execute("SELECT %s(v) AS map_res FROM t" % va.function_kwargs['name']).one().map_res + with self.VerifiedAggregate( + self, + **self.make_aggregate_kwargs( + "update_map", "map", init_cond=cql_init + ), + ) as va: + map_res = ( + s.execute( + "SELECT %s(v) AS map_res FROM t" % va.function_kwargs["name"] + ) + .one() + .map_res + ) assert expected_map_values.items() <= map_res.items() - init_not_updated = dict((k, init_cond[k]) for k in set(init_cond) - expected_key_set) + init_not_updated = dict( + (k, init_cond[k]) for k in set(init_cond) - expected_key_set + ) assert init_not_updated.items() <= map_res.items() c.shutdown() @@ -1870,11 +2288,17 @@ def test_aggregates_after_functions(self): """ # functions must come before functions in keyspace dump - with self.VerifiedAggregate(self, **self.make_aggregate_kwargs('extend_list', 'list')): - keyspace_cql = self.cluster.metadata.keyspaces[self.keyspace_name].export_as_string() + with self.VerifiedAggregate( + self, **self.make_aggregate_kwargs("extend_list", "list") + ): + keyspace_cql = self.cluster.metadata.keyspaces[ + self.keyspace_name + ].export_as_string() func_idx = keyspace_cql.find("CREATE FUNCTION") aggregate_idx = keyspace_cql.rfind("CREATE AGGREGATE") - assert -1 not in (aggregate_idx, func_idx), "AGGREGATE or FUNCTION not found in keyspace_cql: " + keyspace_cql + assert -1 not in (aggregate_idx, func_idx), ( + "AGGREGATE or FUNCTION not found in keyspace_cql: " + keyspace_cql + ) assert aggregate_idx > func_idx def test_same_name_diff_types(self): @@ -1890,12 +2314,16 @@ def test_same_name_diff_types(self): @test_category function """ - kwargs = self.make_aggregate_kwargs('sum_int', 'int', init_cond='0') + kwargs = self.make_aggregate_kwargs("sum_int", "int", init_cond="0") with self.VerifiedAggregate(self, **kwargs): - kwargs['state_func'] = 'sum_int_two' - kwargs['argument_types'] = ['int', 'int'] + kwargs["state_func"] = "sum_int_two" + kwargs["argument_types"] = ["int", "int"] with self.VerifiedAggregate(self, **kwargs): - aggregates = [a for a in self.keyspace_aggregate_meta.values() if a.name == kwargs['name']] + aggregates = [ + a + for a in self.keyspace_aggregate_meta.values() + if a.name == kwargs["name"] + ] assert len(aggregates) == 2 assert aggregates[0].argument_types != aggregates[1].argument_types @@ -1913,15 +2341,21 @@ def test_aggregates_follow_keyspace_alter(self): @test_category function """ - with self.VerifiedAggregate(self, **self.make_aggregate_kwargs('sum_int', 'int', init_cond='0')): + with self.VerifiedAggregate( + self, **self.make_aggregate_kwargs("sum_int", "int", init_cond="0") + ): original_keyspace_meta = self.cluster.metadata.keyspaces[self.keyspace_name] - self.session.execute('ALTER KEYSPACE %s WITH durable_writes = false' % self.keyspace_name) + self.session.execute( + "ALTER KEYSPACE %s WITH durable_writes = false" % self.keyspace_name + ) try: new_keyspace_meta = self.cluster.metadata.keyspaces[self.keyspace_name] assert original_keyspace_meta != new_keyspace_meta assert original_keyspace_meta.aggregates is new_keyspace_meta.aggregates finally: - self.session.execute('ALTER KEYSPACE %s WITH durable_writes = true' % self.keyspace_name) + self.session.execute( + "ALTER KEYSPACE %s WITH durable_writes = true" % self.keyspace_name + ) def test_cql_optional_params(self): """ @@ -1937,53 +2371,57 @@ def test_cql_optional_params(self): @test_category function """ - kwargs = self.make_aggregate_kwargs('extend_list', 'list') + kwargs = self.make_aggregate_kwargs("extend_list", "list") encoder = Encoder() # no initial condition, final func - assert kwargs['initial_condition'] is None - assert kwargs['final_func'] is None + assert kwargs["initial_condition"] is None + assert kwargs["final_func"] is None with self.VerifiedAggregate(self, **kwargs) as va: meta = self.keyspace_aggregate_meta[va.signature] assert meta.initial_condition is None assert meta.final_func is None cql = meta.as_cql_query() - assert cql.find('INITCOND') == -1 - assert cql.find('FINALFUNC') == -1 + assert cql.find("INITCOND") == -1 + assert cql.find("FINALFUNC") == -1 # initial condition, no final func - kwargs['initial_condition'] = encoder.cql_encode_all_types(['init', 'cond']) + kwargs["initial_condition"] = encoder.cql_encode_all_types(["init", "cond"]) with self.VerifiedAggregate(self, **kwargs) as va: meta = self.keyspace_aggregate_meta[va.signature] - assert meta.initial_condition == kwargs['initial_condition'] + assert meta.initial_condition == kwargs["initial_condition"] assert meta.final_func is None cql = meta.as_cql_query() - search_string = "INITCOND %s" % kwargs['initial_condition'] - assert cql.find(search_string) > 0, '"%s" search string not found in cql:\n%s' % (search_string, cql) - assert cql.find('FINALFUNC') == -1 + search_string = "INITCOND %s" % kwargs["initial_condition"] + assert cql.find(search_string) > 0, ( + '"%s" search string not found in cql:\n%s' % (search_string, cql) + ) + assert cql.find("FINALFUNC") == -1 # no initial condition, final func - kwargs['initial_condition'] = None - kwargs['final_func'] = 'List_As_String' + kwargs["initial_condition"] = None + kwargs["final_func"] = "List_As_String" with self.VerifiedAggregate(self, **kwargs) as va: meta = self.keyspace_aggregate_meta[va.signature] assert meta.initial_condition is None - assert meta.final_func == kwargs['final_func'] + assert meta.final_func == kwargs["final_func"] cql = meta.as_cql_query() - assert cql.find('INITCOND') == -1 - search_string = 'FINALFUNC "%s"' % kwargs['final_func'] - assert cql.find(search_string) > 0, '"%s" search string not found in cql:\n%s' % (search_string, cql) + assert cql.find("INITCOND") == -1 + search_string = 'FINALFUNC "%s"' % kwargs["final_func"] + assert cql.find(search_string) > 0, ( + '"%s" search string not found in cql:\n%s' % (search_string, cql) + ) # both - kwargs['initial_condition'] = encoder.cql_encode_all_types(['init', 'cond']) - kwargs['final_func'] = 'List_As_String' + kwargs["initial_condition"] = encoder.cql_encode_all_types(["init", "cond"]) + kwargs["final_func"] = "List_As_String" with self.VerifiedAggregate(self, **kwargs) as va: meta = self.keyspace_aggregate_meta[va.signature] - assert meta.initial_condition == kwargs['initial_condition'] - assert meta.final_func == kwargs['final_func'] + assert meta.initial_condition == kwargs["initial_condition"] + assert meta.final_func == kwargs["final_func"] cql = meta.as_cql_query() - init_cond_idx = cql.find("INITCOND %s" % kwargs['initial_condition']) - final_func_idx = cql.find('FINALFUNC "%s"' % kwargs['final_func']) + init_cond_idx = cql.find("INITCOND %s" % kwargs["initial_condition"]) + final_func_idx = cql.find('FINALFUNC "%s"' % kwargs["final_func"]) assert -1 not in (init_cond_idx, final_func_idx) assert init_cond_idx > final_func_idx @@ -2007,7 +2445,10 @@ def setup_class(cls): cls.cluster = TestCluster() cls.keyspace_name = cls.__name__.lower() cls.session = cls.cluster.connect() - cls.session.execute("CREATE KEYSPACE %s WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1}" % cls.keyspace_name) + cls.session.execute( + "CREATE KEYSPACE %s WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': 1}" + % cls.keyspace_name + ) cls.session.set_keyspace(cls.keyspace_name) connection = cls.cluster.control_connection._connection @@ -2025,34 +2466,56 @@ def teardown_class(cls): drop_keyspace_shutdown_cluster(cls.keyspace_name, cls.session, cls.cluster) def test_bad_keyspace(self): - with patch.object(self.parser_class, '_build_keyspace_metadata_internal', side_effect=self.BadMetaException): + with patch.object( + self.parser_class, + "_build_keyspace_metadata_internal", + side_effect=self.BadMetaException, + ): self.cluster.refresh_keyspace_metadata(self.keyspace_name) m = self.cluster.metadata.keyspaces[self.keyspace_name] assert m._exc_info[0] is self.BadMetaException assert "/*\nWarning:" in m.export_as_string() def test_bad_table(self): - self.session.execute('CREATE TABLE %s (k int PRIMARY KEY, v int)' % self.function_name) - with patch.object(self.parser_class, '_build_column_metadata', side_effect=self.BadMetaException): + self.session.execute( + "CREATE TABLE %s (k int PRIMARY KEY, v int)" % self.function_name + ) + with patch.object( + self.parser_class, + "_build_column_metadata", + side_effect=self.BadMetaException, + ): self.cluster.refresh_table_metadata(self.keyspace_name, self.function_name) - m = self.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_name] + m = self.cluster.metadata.keyspaces[self.keyspace_name].tables[ + self.function_name + ] assert m._exc_info[0] is self.BadMetaException assert "/*\nWarning:" in m.export_as_string() def test_bad_index(self): - self.session.execute('CREATE TABLE %s (k int PRIMARY KEY, v int)' % self.function_name) - self.session.execute('CREATE INDEX ON %s(v)' % self.function_name) - with patch.object(self.parser_class, '_build_index_metadata', side_effect=self.BadMetaException): + self.session.execute( + "CREATE TABLE %s (k int PRIMARY KEY, v int)" % self.function_name + ) + self.session.execute("CREATE INDEX ON %s(v)" % self.function_name) + with patch.object( + self.parser_class, + "_build_index_metadata", + side_effect=self.BadMetaException, + ): self.cluster.refresh_table_metadata(self.keyspace_name, self.function_name) - m = self.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_name] + m = self.cluster.metadata.keyspaces[self.keyspace_name].tables[ + self.function_name + ] assert m._exc_info[0] is self.BadMetaException assert "/*\nWarning:" in m.export_as_string() @greaterthancass20 def test_bad_user_type(self): - self.session.execute('CREATE TYPE %s (i int, d double)' % self.function_name) - with patch.object(self.parser_class, '_build_user_type', side_effect=self.BadMetaException): - self.cluster.refresh_schema_metadata() # presently do not capture these errors on udt direct refresh -- make sure it's contained during full refresh + self.session.execute("CREATE TYPE %s (i int, d double)" % self.function_name) + with patch.object( + self.parser_class, "_build_user_type", side_effect=self.BadMetaException + ): + self.cluster.refresh_schema_metadata() # presently do not capture these errors on udt direct refresh -- make sure it's contained during full refresh m = self.cluster.metadata.keyspaces[self.keyspace_name] assert m._exc_info[0] is self.BadMetaException assert "/*\nWarning:" in m.export_as_string() @@ -2060,18 +2523,23 @@ def test_bad_user_type(self): @greaterthancass21 @requires_java_udf def test_bad_user_function(self): - self.session.execute("""CREATE FUNCTION IF NOT EXISTS %s (key int, val int) + self.session.execute( + """CREATE FUNCTION IF NOT EXISTS %s (key int, val int) RETURNS NULL ON NULL INPUT RETURNS int - LANGUAGE java AS 'return key + val;';""" % self.function_name) - - #We need to patch as well the reconnect function because after patching the _build_function - #there will an Error refreshing schema which will trigger a reconnection. If this happened - #in a timely manner in the call self.cluster.refresh_schema_metadata() it would return an exception - #due to that a connection would be closed - with patch.object(self.cluster.control_connection, 'reconnect'): - with patch.object(self.parser_class, '_build_function', side_effect=self.BadMetaException): - self.cluster.refresh_schema_metadata() # presently do not capture these errors on udt direct refresh -- make sure it's contained during full refresh + LANGUAGE java AS 'return key + val;';""" + % self.function_name + ) + + # We need to patch as well the reconnect function because after patching the _build_function + # there will an Error refreshing schema which will trigger a reconnection. If this happened + # in a timely manner in the call self.cluster.refresh_schema_metadata() it would return an exception + # due to that a connection would be closed + with patch.object(self.cluster.control_connection, "reconnect"): + with patch.object( + self.parser_class, "_build_function", side_effect=self.BadMetaException + ): + self.cluster.refresh_schema_metadata() # presently do not capture these errors on udt direct refresh -- make sure it's contained during full refresh m = self.cluster.metadata.keyspaces[self.keyspace_name] assert m._exc_info[0] is self.BadMetaException assert "/*\nWarning:" in m.export_as_string() @@ -2083,21 +2551,25 @@ def test_bad_user_aggregate(self): RETURNS NULL ON NULL INPUT RETURNS int LANGUAGE java AS 'return key + val;';""") - self.session.execute("""CREATE AGGREGATE %s(int) + self.session.execute( + """CREATE AGGREGATE %s(int) SFUNC sum_int STYPE int - INITCOND 0""" % self.function_name) - #We have the same issue here as in test_bad_user_function - with patch.object(self.cluster.control_connection, 'reconnect'): - with patch.object(self.parser_class, '_build_aggregate', side_effect=self.BadMetaException): - self.cluster.refresh_schema_metadata() # presently do not capture these errors on udt direct refresh -- make sure it's contained during full refresh + INITCOND 0""" + % self.function_name + ) + # We have the same issue here as in test_bad_user_function + with patch.object(self.cluster.control_connection, "reconnect"): + with patch.object( + self.parser_class, "_build_aggregate", side_effect=self.BadMetaException + ): + self.cluster.refresh_schema_metadata() # presently do not capture these errors on udt direct refresh -- make sure it's contained during full refresh m = self.cluster.metadata.keyspaces[self.keyspace_name] assert m._exc_info[0] is self.BadMetaException assert "/*\nWarning:" in m.export_as_string() class DynamicCompositeTypeTest(BasicSharedKeyspaceUnitTestCase): - @requires_composite_type def test_dct_alias(self): """ @@ -2111,18 +2583,24 @@ def test_dct_alias(self): @test_category metadata """ - self.session.execute("CREATE TABLE {0}.{1} (" - "k int PRIMARY KEY," - "c1 'DynamicCompositeType(s => UTF8Type, i => Int32Type)'," - "c2 Text)".format(self.ks_name, self.function_table_name)) - dct_table = self.cluster.metadata.keyspaces.get(self.ks_name).tables.get(self.function_table_name) + self.session.execute( + "CREATE TABLE {0}.{1} (" + "k int PRIMARY KEY," + "c1 'DynamicCompositeType(s => UTF8Type, i => Int32Type)'," + "c2 Text)".format(self.ks_name, self.function_table_name) + ) + dct_table = self.cluster.metadata.keyspaces.get(self.ks_name).tables.get( + self.function_table_name + ) # Format can very slightly between versions, strip out whitespace for consistency sake table_text = dct_table.as_cql_query().replace(" ", "") dynamic_type_text = "c1'org.apache.cassandra.db.marshal.DynamicCompositeType(" assert "c1'org.apache.cassandra.db.marshal.DynamicCompositeType(" in table_text # Types within in the composite can come out in random order, so grab the type definition and find each one - type_definition_start = table_text.index("(", table_text.find(dynamic_type_text)) + type_definition_start = table_text.index( + "(", table_text.find(dynamic_type_text) + ) type_definition_end = table_text.index(")") type_definition_text = table_text[type_definition_start:type_definition_end] assert "s=>org.apache.cassandra.db.marshal.UTF8Type" in type_definition_text @@ -2131,19 +2609,27 @@ def test_dct_alias(self): @greaterthanorequalcass30 class MaterializedViewMetadataTestSimple(BasicSharedKeyspaceUnitTestCase): - def setUp(self): - self.session.execute("CREATE TABLE {0}.{1} (pk int PRIMARY KEY, c int)".format(self.keyspace_name, self.function_table_name)) + self.session.execute( + "CREATE TABLE {0}.{1} (pk int PRIMARY KEY, c int)".format( + self.keyspace_name, self.function_table_name + ) + ) self.session.execute( "CREATE MATERIALIZED VIEW {0}.mv1 AS SELECT pk, c FROM {0}.{1} " "WHERE pk IS NOT NULL AND c IS NOT NULL PRIMARY KEY (pk, c) " "WITH compaction = {{ 'class' : 'SizeTieredCompactionStrategy' }}".format( - self.keyspace_name, self.function_table_name) + self.keyspace_name, self.function_table_name + ) ) def tearDown(self): - self.session.execute("DROP MATERIALIZED VIEW {0}.mv1".format(self.keyspace_name)) - self.session.execute("DROP TABLE {0}.{1}".format(self.keyspace_name, self.function_table_name)) + self.session.execute( + "DROP MATERIALIZED VIEW {0}.mv1".format(self.keyspace_name) + ) + self.session.execute( + "DROP TABLE {0}.{1}".format(self.keyspace_name, self.function_table_name) + ) def test_materialized_view_metadata_creation(self): """ @@ -2162,10 +2648,27 @@ def test_materialized_view_metadata_creation(self): """ assert "mv1" in self.cluster.metadata.keyspaces[self.keyspace_name].views - assert "mv1" in self.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views + assert ( + "mv1" + in self.cluster.metadata.keyspaces[self.keyspace_name] + .tables[self.function_table_name] + .views + ) - assert self.keyspace_name == self.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views["mv1"].keyspace_name - assert self.function_table_name == self.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views["mv1"].base_table_name + assert ( + self.keyspace_name + == self.cluster.metadata.keyspaces[self.keyspace_name] + .tables[self.function_table_name] + .views["mv1"] + .keyspace_name + ) + assert ( + self.function_table_name + == self.cluster.metadata.keyspaces[self.keyspace_name] + .tables[self.function_table_name] + .views["mv1"] + .base_table_name + ) def test_materialized_view_metadata_alter(self): """ @@ -2182,10 +2685,26 @@ def test_materialized_view_metadata_alter(self): @test_category metadata """ - assert "SizeTieredCompactionStrategy" in self.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views["mv1"].options["compaction"]["class"] + assert ( + "SizeTieredCompactionStrategy" + in self.cluster.metadata.keyspaces[self.keyspace_name] + .tables[self.function_table_name] + .views["mv1"] + .options["compaction"]["class"] + ) - self.session.execute("ALTER MATERIALIZED VIEW {0}.mv1 WITH compaction = {{ 'class' : 'LeveledCompactionStrategy' }}".format(self.keyspace_name)) - assert "LeveledCompactionStrategy" in self.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views["mv1"].options["compaction"]["class"] + self.session.execute( + "ALTER MATERIALIZED VIEW {0}.mv1 WITH compaction = {{ 'class' : 'LeveledCompactionStrategy' }}".format( + self.keyspace_name + ) + ) + assert ( + "LeveledCompactionStrategy" + in self.cluster.metadata.keyspaces[self.keyspace_name] + .tables[self.function_table_name] + .views["mv1"] + .options["compaction"]["class"] + ) def test_materialized_view_metadata_drop(self): """ @@ -2203,17 +2722,30 @@ def test_materialized_view_metadata_drop(self): @test_category metadata """ - self.session.execute("DROP MATERIALIZED VIEW {0}.mv1".format(self.keyspace_name)) + self.session.execute( + "DROP MATERIALIZED VIEW {0}.mv1".format(self.keyspace_name) + ) - assert "mv1" not in self.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views + assert ( + "mv1" + not in self.cluster.metadata.keyspaces[self.keyspace_name] + .tables[self.function_table_name] + .views + ) assert "mv1" not in self.cluster.metadata.keyspaces[self.keyspace_name].views - assertDictEqual({}, self.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views) + assertDictEqual( + {}, + self.cluster.metadata.keyspaces[self.keyspace_name] + .tables[self.function_table_name] + .views, + ) assertDictEqual({}, self.cluster.metadata.keyspaces[self.keyspace_name].views) self.session.execute( "CREATE MATERIALIZED VIEW {0}.mv1 AS SELECT pk, c FROM {0}.{1} " "WHERE pk IS NOT NULL AND c IS NOT NULL PRIMARY KEY (pk, c)".format( - self.keyspace_name, self.function_table_name) + self.keyspace_name, self.function_table_name + ) ) @@ -2249,36 +2781,40 @@ def test_create_view_metadata(self): SELECT game, year, month, score, user, day FROM {0}.scores WHERE game IS NOT NULL AND year IS NOT NULL AND month IS NOT NULL AND score IS NOT NULL AND user IS NOT NULL AND day IS NOT NULL PRIMARY KEY ((game, year, month), score, user, day) - WITH CLUSTERING ORDER BY (score DESC, user ASC, day ASC)""".format(self.keyspace_name) + WITH CLUSTERING ORDER BY (score DESC, user ASC, day ASC)""".format( + self.keyspace_name + ) self.session.execute(create_mv) - score_table = self.cluster.metadata.keyspaces[self.keyspace_name].tables['scores'] - mv = self.cluster.metadata.keyspaces[self.keyspace_name].views['monthlyhigh'] + score_table = self.cluster.metadata.keyspaces[self.keyspace_name].tables[ + "scores" + ] + mv = self.cluster.metadata.keyspaces[self.keyspace_name].views["monthlyhigh"] assert score_table.views["monthlyhigh"] is not None assert len(score_table.views) is not None, 1 # Make sure user is a partition key, and not null assert len(score_table.partition_key) == 1 - assert score_table.columns['user'] is not None - assert score_table.columns['user'], score_table.partition_key[0] + assert score_table.columns["user"] is not None + assert score_table.columns["user"], score_table.partition_key[0] # Validate clustering keys assert len(score_table.clustering_key) == 4 - assert score_table.columns['game'] is not None - assert score_table.columns['game'], score_table.clustering_key[0] + assert score_table.columns["game"] is not None + assert score_table.columns["game"], score_table.clustering_key[0] - assert score_table.columns['year'] is not None - assert score_table.columns['year'], score_table.clustering_key[1] + assert score_table.columns["year"] is not None + assert score_table.columns["year"], score_table.clustering_key[1] - assert score_table.columns['month'] is not None - assert score_table.columns['month'], score_table.clustering_key[2] + assert score_table.columns["month"] is not None + assert score_table.columns["month"], score_table.clustering_key[2] - assert score_table.columns['day'] is not None - assert score_table.columns['day'], score_table.clustering_key[3] + assert score_table.columns["day"] is not None + assert score_table.columns["day"], score_table.clustering_key[3] - assert score_table.columns['score'] is not None + assert score_table.columns["score"] is not None # Validate basic mv information assert mv.keyspace_name == self.keyspace_name @@ -2292,17 +2828,17 @@ def test_create_view_metadata(self): game_column = mv_columns[0] assert game_column is not None - assert game_column.name == 'game' + assert game_column.name == "game" assert game_column == mv.partition_key[0] year_column = mv_columns[1] assert year_column is not None - assert year_column.name == 'year' + assert year_column.name == "year" assert year_column == mv.partition_key[1] month_column = mv_columns[2] assert month_column is not None - assert month_column.name == 'month' + assert month_column.name == "month" assert month_column == mv.partition_key[2] def compare_columns(a, b, name): @@ -2314,13 +2850,13 @@ def compare_columns(a, b, name): assert a.is_reversed == b.is_reversed score_column = mv_columns[3] - compare_columns(score_column, mv.clustering_key[0], 'score') + compare_columns(score_column, mv.clustering_key[0], "score") user_column = mv_columns[4] - compare_columns(user_column, mv.clustering_key[1], 'user') + compare_columns(user_column, mv.clustering_key[1], "user") day_column = mv_columns[5] - compare_columns(day_column, mv.clustering_key[2], 'day') + compare_columns(day_column, mv.clustering_key[2], "day") def test_base_table_column_addition_mv(self): """ @@ -2351,44 +2887,60 @@ def test_base_table_column_addition_mv(self): SELECT game, year, month, score, user, day FROM {0}.scores WHERE game IS NOT NULL AND year IS NOT NULL AND month IS NOT NULL AND score IS NOT NULL AND user IS NOT NULL AND day IS NOT NULL PRIMARY KEY ((game, year, month), score, user, day) - WITH CLUSTERING ORDER BY (score DESC, user ASC, day ASC)""".format(self.keyspace_name) + WITH CLUSTERING ORDER BY (score DESC, user ASC, day ASC)""".format( + self.keyspace_name + ) create_mv_alltime = """CREATE MATERIALIZED VIEW {0}.alltimehigh AS SELECT * FROM {0}.scores WHERE game IS NOT NULL AND score IS NOT NULL AND user IS NOT NULL AND year IS NOT NULL AND month IS NOT NULL AND day IS NOT NULL PRIMARY KEY (game, score, user, year, month, day) - WITH CLUSTERING ORDER BY (score DESC, user ASC, year ASC, month ASC, day ASC)""".format(self.keyspace_name) + WITH CLUSTERING ORDER BY (score DESC, user ASC, year ASC, month ASC, day ASC)""".format( + self.keyspace_name + ) self.session.execute(create_mv) self.session.execute(create_mv_alltime) - score_table = self.cluster.metadata.keyspaces[self.keyspace_name].tables['scores'] + score_table = self.cluster.metadata.keyspaces[self.keyspace_name].tables[ + "scores" + ] assert score_table.views["monthlyhigh"] is not None assert score_table.views["alltimehigh"] is not None assert len(self.cluster.metadata.keyspaces[self.keyspace_name].views) == 2 - insert_fouls = """ALTER TABLE {0}.scores ADD fouls INT""".format((self.keyspace_name)) + insert_fouls = """ALTER TABLE {0}.scores ADD fouls INT""".format( + (self.keyspace_name) + ) self.session.execute(insert_fouls) assert len(self.cluster.metadata.keyspaces[self.keyspace_name].views) == 2 - score_table = self.cluster.metadata.keyspaces[self.keyspace_name].tables['scores'] + score_table = self.cluster.metadata.keyspaces[self.keyspace_name].tables[ + "scores" + ] assert "fouls" in score_table.columns # This is a workaround for mv notifications being separate from base table schema responses. # This maybe fixed with future protocol changes for i in range(10): - mv_alltime = self.cluster.metadata.keyspaces[self.keyspace_name].views["alltimehigh"] - if("fouls" in mv_alltime.columns): + mv_alltime = self.cluster.metadata.keyspaces[self.keyspace_name].views[ + "alltimehigh" + ] + if "fouls" in mv_alltime.columns: break - time.sleep(.2) + time.sleep(0.2) assert "fouls" in mv_alltime.columns - mv_alltime_fouls_comumn = self.cluster.metadata.keyspaces[self.keyspace_name].views["alltimehigh"].columns['fouls'] - assert mv_alltime_fouls_comumn.cql_type == 'int' + mv_alltime_fouls_comumn = ( + self.cluster.metadata.keyspaces[self.keyspace_name] + .views["alltimehigh"] + .columns["fouls"] + ) + assert mv_alltime_fouls_comumn.cql_type == "int" @lessthancass30 def test_base_table_type_alter_mv(self): @@ -2423,25 +2975,37 @@ def test_base_table_type_alter_mv(self): SELECT game, year, month, score, user, day FROM {0}.scores WHERE game IS NOT NULL AND year IS NOT NULL AND month IS NOT NULL AND score IS NOT NULL AND user IS NOT NULL AND day IS NOT NULL PRIMARY KEY ((game, year, month), score, user, day) - WITH CLUSTERING ORDER BY (score DESC, user ASC, day ASC)""".format(self.keyspace_name) + WITH CLUSTERING ORDER BY (score DESC, user ASC, day ASC)""".format( + self.keyspace_name + ) self.session.execute(create_mv) assert len(self.cluster.metadata.keyspaces[self.keyspace_name].views) == 1 - alter_scores = """ALTER TABLE {0}.scores ALTER score TYPE blob""".format((self.keyspace_name)) + alter_scores = """ALTER TABLE {0}.scores ALTER score TYPE blob""".format( + (self.keyspace_name) + ) self.session.execute(alter_scores) assert len(self.cluster.metadata.keyspaces[self.keyspace_name].views) == 1 - score_column = self.cluster.metadata.keyspaces[self.keyspace_name].tables['scores'].columns['score'] - assert score_column.cql_type == 'blob' + score_column = ( + self.cluster.metadata.keyspaces[self.keyspace_name] + .tables["scores"] + .columns["score"] + ) + assert score_column.cql_type == "blob" # until CASSANDRA-9920+CASSANDRA-10500 MV updates are only available later with an async event for i in range(10): - score_mv_column = self.cluster.metadata.keyspaces[self.keyspace_name].views["monthlyhigh"].columns['score'] + score_mv_column = ( + self.cluster.metadata.keyspaces[self.keyspace_name] + .views["monthlyhigh"] + .columns["score"] + ) if "blob" == score_mv_column.cql_type: break - time.sleep(.2) + time.sleep(0.2) - assert score_mv_column.cql_type == 'blob' + assert score_mv_column.cql_type == "blob" def test_metadata_with_quoted_identifiers(self): """ @@ -2462,7 +3026,9 @@ def test_metadata_with_quoted_identifiers(self): "theKey" int, "the;Clustering" int, "the Value" int, - PRIMARY KEY ("theKey", "the;Clustering"))""".format(self.keyspace_name) + PRIMARY KEY ("theKey", "the;Clustering"))""".format( + self.keyspace_name + ) self.session.execute(create_table) @@ -2470,28 +3036,30 @@ def test_metadata_with_quoted_identifiers(self): SELECT "theKey", "the;Clustering", "the Value" FROM {0}.t1 WHERE "theKey" IS NOT NULL AND "the;Clustering" IS NOT NULL AND "the Value" IS NOT NULL - PRIMARY KEY ("theKey", "the;Clustering")""".format(self.keyspace_name) + PRIMARY KEY ("theKey", "the;Clustering")""".format( + self.keyspace_name + ) self.session.execute(create_mv) - t1_table = self.cluster.metadata.keyspaces[self.keyspace_name].tables['t1'] - mv = self.cluster.metadata.keyspaces[self.keyspace_name].views['mv1'] + t1_table = self.cluster.metadata.keyspaces[self.keyspace_name].tables["t1"] + mv = self.cluster.metadata.keyspaces[self.keyspace_name].views["mv1"] assert t1_table.views["mv1"] is not None assert len(t1_table.views) is not None, 1 # Validate partition key, and not null assert len(t1_table.partition_key) == 1 - assert t1_table.columns['theKey'] is not None - assert t1_table.columns['theKey'], t1_table.partition_key[0] + assert t1_table.columns["theKey"] is not None + assert t1_table.columns["theKey"], t1_table.partition_key[0] # Validate clustering key column assert len(t1_table.clustering_key) == 1 - assert t1_table.columns['the;Clustering'] is not None - assert t1_table.columns['the;Clustering'], t1_table.clustering_key[0] + assert t1_table.columns["the;Clustering"] is not None + assert t1_table.columns["the;Clustering"], t1_table.clustering_key[0] # Validate regular column - assert t1_table.columns['the Value'] is not None + assert t1_table.columns["the Value"] is not None # Validate basic mv information assert mv.keyspace_name == self.keyspace_name @@ -2505,12 +3073,12 @@ def test_metadata_with_quoted_identifiers(self): theKey_column = mv_columns[0] assert theKey_column is not None - assert theKey_column.name == 'theKey' + assert theKey_column.name == "theKey" assert theKey_column == mv.partition_key[0] cluster_column = mv_columns[1] assert cluster_column is not None - assert cluster_column.name == 'the;Clustering' + assert cluster_column.name == "the;Clustering" assert cluster_column.name == mv.clustering_key[0].name assert cluster_column.table == mv.clustering_key[0].table assert cluster_column.is_static == mv.clustering_key[0].is_static @@ -2518,7 +3086,7 @@ def test_metadata_with_quoted_identifiers(self): value_column = mv_columns[2] assert value_column is not None - assert value_column.name == 'the Value' + assert value_column.name == "the Value" class GroupPerHost(BasicSharedKeyspaceUnitTestCase): @@ -2527,13 +3095,13 @@ def setUpClass(cls): cls.common_setup(rf=1, create_class_table=True) cls.table_two_pk = "table_with_two_pk" cls.session.execute( - ''' + """ CREATE TABLE {0}.{1} ( k_one int, k_two int, v int, PRIMARY KEY ((k_one, k_two)) - )'''.format(cls.ks_name, cls.table_two_pk) + )""".format(cls.ks_name, cls.table_two_pk) ) def test_group_keys_by_host(self): @@ -2548,7 +3116,9 @@ def test_group_keys_by_host(self): @test_category metadata """ stmt = """SELECT * FROM {}.{} - WHERE k_one = ? AND k_two = ? """.format(self.ks_name, self.table_two_pk) + WHERE k_one = ? AND k_two = ? """.format( + self.ks_name, self.table_two_pk + ) keys = ((1, 2), (2, 2), (2, 3), (3, 4)) self._assert_group_keys_by_host(keys, self.table_two_pk, stmt) @@ -2558,7 +3128,9 @@ def test_group_keys_by_host(self): self._assert_group_keys_by_host(keys, self.ks_name, stmt) def _assert_group_keys_by_host(self, keys, table_name, stmt): - keys_per_host = group_keys_by_replica(self.session, self.ks_name, table_name, keys) + keys_per_host = group_keys_by_replica( + self.session, self.ks_name, table_name, keys + ) assert NO_VALID_REPLICA not in keys_per_host prepared_stmt = self.session.prepare(stmt) diff --git a/tests/integration/standard/test_policies.py b/tests/integration/standard/test_policies.py index 2de12f7b7f..7a638b89d4 100644 --- a/tests/integration/standard/test_policies.py +++ b/tests/integration/standard/test_policies.py @@ -15,8 +15,14 @@ import unittest from cassandra.cluster import ExecutionProfile, EXEC_PROFILE_DEFAULT -from cassandra.policies import HostFilterPolicy, RoundRobinPolicy, SimpleConvictionPolicy, \ - WhiteListRoundRobinPolicy, ExponentialBackoffRetryPolicy, ColDesc +from cassandra.policies import ( + HostFilterPolicy, + RoundRobinPolicy, + SimpleConvictionPolicy, + WhiteListRoundRobinPolicy, + ExponentialBackoffRetryPolicy, + ColDesc, +) from cassandra.pool import Host from cassandra.connection import DefaultEndPoint @@ -30,7 +36,6 @@ def setup_module(): class HostFilterPolicyTests(unittest.TestCase): - def test_predicate_changes(self): """ Test to validate host filter reacts correctly when the predicate return @@ -45,13 +50,20 @@ def test_predicate_changes(self): external_event = True contact_point = DefaultEndPoint("127.0.0.1") - predicate = lambda host: host.endpoint == contact_point if external_event else True + predicate = ( + lambda host: host.endpoint == contact_point if external_event else True + ) hfp = ExecutionProfile( - load_balancing_policy=HostFilterPolicy(RoundRobinPolicy(), predicate=predicate) + load_balancing_policy=HostFilterPolicy( + RoundRobinPolicy(), predicate=predicate + ) + ) + cluster = TestCluster( + contact_points=(contact_point,), + execution_profiles={EXEC_PROFILE_DEFAULT: hfp}, + topology_event_refresh_window=0, + status_event_refresh_window=0, ) - cluster = TestCluster(contact_points=(contact_point,), execution_profiles={EXEC_PROFILE_DEFAULT: hfp}, - topology_event_refresh_window=0, - status_event_refresh_window=0) session = cluster.connect(wait_for_all_pools=True) queried_hosts = set() @@ -71,30 +83,37 @@ def test_predicate_changes(self): response = session.execute("SELECT * from system.local WHERE key='local'") queried_hosts.update(response.response_future.attempted_hosts) assert len(queried_hosts) == 3 - assert {host.endpoint for host in queried_hosts} == {DefaultEndPoint(f"127.0.0.{i}") for i in range(1, 4)} + assert {host.endpoint for host in queried_hosts} == { + DefaultEndPoint(f"127.0.0.{i}") for i in range(1, 4) + } class WhiteListRoundRobinPolicyTests(unittest.TestCase): - @local def test_only_connects_to_subset(self): only_connect_hosts = {"127.0.0.1", "127.0.0.2"} - white_list = ExecutionProfile(load_balancing_policy=WhiteListRoundRobinPolicy(only_connect_hosts)) + white_list = ExecutionProfile( + load_balancing_policy=WhiteListRoundRobinPolicy(only_connect_hosts) + ) cluster = TestCluster(execution_profiles={"white_list": white_list}) - #cluster = Cluster(load_balancing_policy=WhiteListRoundRobinPolicy(only_connect_hosts)) + # cluster = Cluster(load_balancing_policy=WhiteListRoundRobinPolicy(only_connect_hosts)) session = cluster.connect(wait_for_all_pools=True) queried_hosts = set() for _ in range(10): - response = session.execute("SELECT * from system.local WHERE key='local'", execution_profile="white_list") + response = session.execute( + "SELECT * from system.local WHERE key='local'", + execution_profile="white_list", + ) queried_hosts.update(response.response_future.attempted_hosts) queried_hosts = set(host.address for host in queried_hosts) assert queried_hosts == only_connect_hosts class ExponentialRetryPolicyTests(unittest.TestCase): - def setUp(self): - self.cluster = TestCluster(default_retry_policy=ExponentialBackoffRetryPolicy(max_num_retries=3)) + self.cluster = TestCluster( + default_retry_policy=ExponentialBackoffRetryPolicy(max_num_retries=3) + ) self.session = self.cluster.connect() def tearDown(self): @@ -104,5 +123,6 @@ def test_exponential_retries(self): self.session.execute( """ CREATE KEYSPACE preparedtests - WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'} - """) + WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': '1'} + """ + ) diff --git a/tests/integration/standard/test_prepared_statements.py b/tests/integration/standard/test_prepared_statements.py index 3f63b881ef..e5d513f9d2 100644 --- a/tests/integration/standard/test_prepared_statements.py +++ b/tests/integration/standard/test_prepared_statements.py @@ -13,7 +13,12 @@ # limitations under the License. -from tests.integration import use_singledc, PROTOCOL_VERSION, TestCluster, CASSANDRA_VERSION +from tests.integration import ( + use_singledc, + PROTOCOL_VERSION, + TestCluster, + CASSANDRA_VERSION, +) import unittest @@ -23,8 +28,11 @@ from cassandra import ConsistencyLevel, ProtocolVersion from cassandra.query import PreparedStatement, UNSET_VALUE -from tests.integration import (get_server_versions, greaterthanorequalcass40, - BasicSharedKeyspaceUnitTestCase) +from tests.integration import ( + get_server_versions, + greaterthanorequalcass40, + BasicSharedKeyspaceUnitTestCase, +) import logging import pytest @@ -38,13 +46,16 @@ def setup_module(): class PreparedStatementTests(unittest.TestCase): - @classmethod def setUpClass(cls): cls.cass_version = get_server_versions() def setUp(self): - self.cluster = TestCluster(metrics_enabled=True, allow_beta_protocol_version=True, protocol_version=PROTOCOL_VERSION) + self.cluster = TestCluster( + metrics_enabled=True, + allow_beta_protocol_version=True, + protocol_version=PROTOCOL_VERSION, + ) self.session = self.cluster.connect() def tearDown(self): @@ -62,8 +73,9 @@ def test_basic(self): self.session.execute( """ CREATE KEYSPACE preparedtests - WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'} - """) + WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': '1'} + """ + ) self.session.set_keyspace("preparedtests") self.session.execute( @@ -74,53 +86,54 @@ def test_basic(self): c text, PRIMARY KEY (a, b) ) - """) + """ + ) prepared = self.session.prepare( """ INSERT INTO cf0 (a, b, c) VALUES (?, ?, ?) - """) + """ + ) assert isinstance(prepared, PreparedStatement) - bound = prepared.bind(('a', 'b', 'c')) + bound = prepared.bind(("a", "b", "c")) self.session.execute(bound) prepared = self.session.prepare( """ SELECT * FROM cf0 WHERE a=? - """) + """ + ) assert isinstance(prepared, PreparedStatement) - bound = prepared.bind(('a')) + bound = prepared.bind(("a")) results = self.session.execute(bound) - assert results == [('a', 'b', 'c')] + assert results == [("a", "b", "c")] # test with new dict binding prepared = self.session.prepare( """ INSERT INTO cf0 (a, b, c) VALUES (?, ?, ?) - """) + """ + ) assert isinstance(prepared, PreparedStatement) - bound = prepared.bind({ - 'a': 'x', - 'b': 'y', - 'c': 'z' - }) + bound = prepared.bind({"a": "x", "b": "y", "c": "z"}) self.session.execute(bound) prepared = self.session.prepare( """ SELECT * FROM cf0 WHERE a=? - """) + """ + ) assert isinstance(prepared, PreparedStatement) - bound = prepared.bind({'a': 'x'}) + bound = prepared.bind({"a": "x"}) results = self.session.execute(bound) - assert results == [('x', 'y', 'z')] + assert results == [("x", "y", "z")] def test_missing_primary_key(self): """ @@ -160,7 +173,7 @@ def _run_missing_primary_key_dicts(self, session): else: prepared = session.prepare(statement_to_prepare) assert isinstance(prepared, PreparedStatement) - bound = prepared.bind({'v': 1}) + bound = prepared.bind({"v": 1}) with pytest.raises(InvalidRequest): session.execute(bound) @@ -172,7 +185,7 @@ def test_too_many_bind_values(self): def _run_too_many_bind_values(self, session): statement_to_prepare = """ INSERT INTO test3rf.test (v) VALUES (?)""" - # logic needed work with changes in CASSANDRA-6237 + # logic needed work with changes in CASSANDRA-6237 if self.cass_version[0] >= (2, 2, 8): with pytest.raises(InvalidRequest): session.prepare(statement_to_prepare) @@ -191,21 +204,22 @@ def test_imprecise_bind_values_dicts(self): prepared = self.session.prepare( """ INSERT INTO test3rf.test (k, v) VALUES (?, ?) - """) + """ + ) assert isinstance(prepared, PreparedStatement) # too many values is ok - others are ignored - prepared.bind({'k': 1, 'v': 2, 'v2': 3}) + prepared.bind({"k": 1, "v": 2, "v2": 3}) # right number, but one does not belong if PROTOCOL_VERSION < 4: # pre v4, the driver bails with key error when 'v' is found missing with pytest.raises(KeyError): - prepared.bind({'k': 1, 'v2': 3}) + prepared.bind({"k": 1, "v2": 3}) else: # post v4, the driver uses UNSET_VALUE for 'v' and 'v2' is ignored - prepared.bind({'k': 1, 'v2': 3}) + prepared.bind({"k": 1, "v2": 3}) # also catch too few variables with dicts assert isinstance(prepared, PreparedStatement) @@ -225,7 +239,8 @@ def test_none_values(self): prepared = self.session.prepare( """ INSERT INTO test3rf.test (k, v) VALUES (?, ?) - """) + """ + ) assert isinstance(prepared, PreparedStatement) bound = prepared.bind((1, None)) @@ -234,7 +249,8 @@ def test_none_values(self): prepared = self.session.prepare( """ SELECT * FROM test3rf.test WHERE k=? - """) + """ + ) assert isinstance(prepared, PreparedStatement) bound = prepared.bind((1,)) @@ -256,27 +272,41 @@ def test_unset_values(self): @test_category prepared_statements:binding """ if PROTOCOL_VERSION < 4: - raise unittest.SkipTest("Binding UNSET values is not supported in protocol version < 4") + raise unittest.SkipTest( + "Binding UNSET values is not supported in protocol version < 4" + ) # table with at least two values so one can be used as a marker - self.session.execute("CREATE TABLE IF NOT EXISTS test1rf.test_unset_values (k int PRIMARY KEY, v0 int, v1 int)") - insert = self.session.prepare("INSERT INTO test1rf.test_unset_values (k, v0, v1) VALUES (?, ?, ?)") - select = self.session.prepare("SELECT * FROM test1rf.test_unset_values WHERE k=?") + self.session.execute( + "CREATE TABLE IF NOT EXISTS test1rf.test_unset_values (k int PRIMARY KEY, v0 int, v1 int)" + ) + insert = self.session.prepare( + "INSERT INTO test1rf.test_unset_values (k, v0, v1) VALUES (?, ?, ?)" + ) + select = self.session.prepare( + "SELECT * FROM test1rf.test_unset_values WHERE k=?" + ) bind_expected = [ # initial condition - ((0, 0, 0), (0, 0, 0)), + ((0, 0, 0), (0, 0, 0)), # unset implicit - ((0, 1,), (0, 1, 0)), - ({'k': 0, 'v0': 2}, (0, 2, 0)), - ({'k': 0, 'v1': 1}, (0, 2, 1)), + ( + ( + 0, + 1, + ), + (0, 1, 0), + ), + ({"k": 0, "v0": 2}, (0, 2, 0)), + ({"k": 0, "v1": 1}, (0, 2, 1)), # unset explicit - ((0, 3, UNSET_VALUE), (0, 3, 1)), - ((0, UNSET_VALUE, 2), (0, 3, 2)), - ({'k': 0, 'v0': 4, 'v1': UNSET_VALUE}, (0, 4, 2)), - ({'k': 0, 'v0': UNSET_VALUE, 'v1': 3}, (0, 4, 3)), + ((0, 3, UNSET_VALUE), (0, 3, 1)), + ((0, UNSET_VALUE, 2), (0, 3, 2)), + ({"k": 0, "v0": 4, "v1": UNSET_VALUE}, (0, 4, 2)), + ({"k": 0, "v0": UNSET_VALUE, "v1": 3}, (0, 4, 3)), # nulls still work - ((0, None, None), (0, None, None)), + ((0, None, None), (0, None, None)), ] for params, expected in bind_expected: @@ -288,11 +318,11 @@ def test_unset_values(self): self.session.execute(select, (UNSET_VALUE, 0, 0)) def test_no_meta(self): - prepared = self.session.prepare( """ INSERT INTO test3rf.test (k, v) VALUES (0, 0) - """) + """ + ) assert isinstance(prepared, PreparedStatement) bound = prepared.bind(None) @@ -302,7 +332,8 @@ def test_no_meta(self): prepared = self.session.prepare( """ SELECT * FROM test3rf.test WHERE k=0 - """) + """ + ) assert isinstance(prepared, PreparedStatement) bound = prepared.bind(None) @@ -319,19 +350,21 @@ def test_none_values_dicts(self): prepared = self.session.prepare( """ INSERT INTO test3rf.test (k, v) VALUES (?, ?) - """) + """ + ) assert isinstance(prepared, PreparedStatement) - bound = prepared.bind({'k': 1, 'v': None}) + bound = prepared.bind({"k": 1, "v": None}) self.session.execute(bound) prepared = self.session.prepare( """ SELECT * FROM test3rf.test WHERE k=? - """) + """ + ) assert isinstance(prepared, PreparedStatement) - bound = prepared.bind({'k': 1}) + bound = prepared.bind({"k": 1}) results = self.session.execute(bound) assert results.one().v == None @@ -343,7 +376,8 @@ def test_async_binding(self): prepared = self.session.prepare( """ INSERT INTO test3rf.test (k, v) VALUES (?, ?) - """) + """ + ) assert isinstance(prepared, PreparedStatement) future = self.session.execute_async(prepared, (873, None)) @@ -352,7 +386,8 @@ def test_async_binding(self): prepared = self.session.prepare( """ SELECT * FROM test3rf.test WHERE k=? - """) + """ + ) assert isinstance(prepared, PreparedStatement) future = self.session.execute_async(prepared, (873,)) @@ -366,19 +401,21 @@ def test_async_binding_dicts(self): prepared = self.session.prepare( """ INSERT INTO test3rf.test (k, v) VALUES (?, ?) - """) + """ + ) assert isinstance(prepared, PreparedStatement) - future = self.session.execute_async(prepared, {'k': 873, 'v': None}) + future = self.session.execute_async(prepared, {"k": 873, "v": None}) future.result() prepared = self.session.prepare( """ SELECT * FROM test3rf.test WHERE k=? - """) + """ + ) assert isinstance(prepared, PreparedStatement) - future = self.session.execute_async(prepared, {'k': 873}) + future = self.session.execute_async(prepared, {"k": 873}) results = future.result() assert results.one().v == None @@ -399,7 +436,9 @@ def test_raise_error_on_prepared_statement_execution_dropped_table(self): @test_category prepared_statements """ - self.session.execute("CREATE TABLE test3rf.error_test (k int PRIMARY KEY, v int)") + self.session.execute( + "CREATE TABLE test3rf.error_test (k int PRIMARY KEY, v int)" + ) prepared = self.session.prepare("SELECT * FROM test3rf.error_test WHERE k=?") self.session.execute("DROP TABLE test3rf.error_test") @@ -407,11 +446,17 @@ def test_raise_error_on_prepared_statement_execution_dropped_table(self): self.session.execute(prepared, [0]) def test_recognize_lwt_query(self): - self.session.execute("CREATE TABLE IF NOT EXISTS preparedtests.bound_statement_test (a int PRIMARY KEY, b int)") + self.session.execute( + "CREATE TABLE IF NOT EXISTS preparedtests.bound_statement_test (a int PRIMARY KEY, b int)" + ) # Prepare a non-LWT statement - statementNonLWT = self.session.prepare("UPDATE preparedtests.bound_statement_test SET b = ? WHERE a = ?") + statementNonLWT = self.session.prepare( + "UPDATE preparedtests.bound_statement_test SET b = ? WHERE a = ?" + ) # Prepare an LWT statement - statementLWT = self.session.prepare("UPDATE preparedtests.bound_statement_test SET b = ? WHERE a = ? IF b = ?") + statementLWT = self.session.prepare( + "UPDATE preparedtests.bound_statement_test SET b = ? WHERE a = ? IF b = ?" + ) boundNonLWT = statementNonLWT.bind((3, 1)) boundLWT = statementLWT.bind((3, 1, 5)) @@ -420,27 +465,41 @@ def test_recognize_lwt_query(self): assert not boundNonLWT.is_lwt() assert boundLWT.is_lwt() - self.session.execute("CREATE TABLE IF NOT EXISTS preparedtests.prepared_statement_test (a int PRIMARY KEY, b int)") + self.session.execute( + "CREATE TABLE IF NOT EXISTS preparedtests.prepared_statement_test (a int PRIMARY KEY, b int)" + ) # Prepare a non-LWT statement - statementNonLWT = self.session.prepare("UPDATE preparedtests.prepared_statement_test SET b = ? WHERE a = ?") + statementNonLWT = self.session.prepare( + "UPDATE preparedtests.prepared_statement_test SET b = ? WHERE a = ?" + ) # Prepare an LWT statement - statementLWT = self.session.prepare("UPDATE preparedtests.prepared_statement_test SET b = ? WHERE a = ? IF b = ?") + statementLWT = self.session.prepare( + "UPDATE preparedtests.prepared_statement_test SET b = ? WHERE a = ? IF b = ?" + ) # Check LWT detection assert not statementNonLWT.is_lwt() assert statementLWT.is_lwt() - @unittest.skipIf((CASSANDRA_VERSION >= Version('3.11.12') and CASSANDRA_VERSION < Version('4.0')) or \ - CASSANDRA_VERSION >= Version('4.0.2'), - "Fixed server-side in Cassandra 3.11.12, 4.0.2") + @unittest.skipIf( + (CASSANDRA_VERSION >= Version("3.11.12") and CASSANDRA_VERSION < Version("4.0")) + or CASSANDRA_VERSION >= Version("4.0.2"), + "Fixed server-side in Cassandra 3.11.12, 4.0.2", + ) def test_fail_if_different_query_id_on_reprepare(self): - """ PYTHON-1124 and CASSANDRA-15252 """ + """PYTHON-1124 and CASSANDRA-15252""" keyspace = "test_fail_if_different_query_id_on_reprepare" self.session.execute( "CREATE KEYSPACE IF NOT EXISTS {} WITH replication = " - "{{'class': 'SimpleStrategy', 'replication_factor': 1}}".format(keyspace) + "{{'class': 'NetworkTopologyStrategy', 'replication_factor': 1}}".format( + keyspace + ) + ) + self.session.execute( + "CREATE TABLE IF NOT EXISTS {}.foo(k int PRIMARY KEY)".format(keyspace) + ) + prepared = self.session.prepare( + "SELECT * FROM {}.foo WHERE k=?".format(keyspace) ) - self.session.execute("CREATE TABLE IF NOT EXISTS {}.foo(k int PRIMARY KEY)".format(keyspace)) - prepared = self.session.prepare("SELECT * FROM {}.foo WHERE k=?".format(keyspace)) self.session.execute("DROP TABLE {}.foo".format(keyspace)) self.session.execute("CREATE TABLE {}.foo(k int PRIMARY KEY)".format(keyspace)) self.session.execute("USE {}".format(keyspace)) @@ -451,14 +510,25 @@ def test_fail_if_different_query_id_on_reprepare(self): @greaterthanorequalcass40 class PreparedStatementInvalidationTest(BasicSharedKeyspaceUnitTestCase): - def setUp(self): - self.table_name = "{}.prepared_statement_invalidation_test".format(self.keyspace_name) - self.session.execute("CREATE TABLE {} (a int PRIMARY KEY, b int, d int);".format(self.table_name)) - self.session.execute("INSERT INTO {} (a, b, d) VALUES (1, 1, 1);".format(self.table_name)) - self.session.execute("INSERT INTO {} (a, b, d) VALUES (2, 2, 2);".format(self.table_name)) - self.session.execute("INSERT INTO {} (a, b, d) VALUES (3, 3, 3);".format(self.table_name)) - self.session.execute("INSERT INTO {} (a, b, d) VALUES (4, 4, 4);".format(self.table_name)) + self.table_name = "{}.prepared_statement_invalidation_test".format( + self.keyspace_name + ) + self.session.execute( + "CREATE TABLE {} (a int PRIMARY KEY, b int, d int);".format(self.table_name) + ) + self.session.execute( + "INSERT INTO {} (a, b, d) VALUES (1, 1, 1);".format(self.table_name) + ) + self.session.execute( + "INSERT INTO {} (a, b, d) VALUES (2, 2, 2);".format(self.table_name) + ) + self.session.execute( + "INSERT INTO {} (a, b, d) VALUES (3, 3, 3);".format(self.table_name) + ) + self.session.execute( + "INSERT INTO {} (a, b, d) VALUES (4, 4, 4);".format(self.table_name) + ) def tearDown(self): self.session.execute("DROP TABLE {}".format(self.table_name)) @@ -473,7 +543,9 @@ def test_invalidated_result_metadata(self): Prior to this fix, the request would blow up with a protocol error when the result was decoded expecting a different number of columns. """ - wildcard_prepared = self.session.prepare("SELECT * FROM {}".format(self.table_name)) + wildcard_prepared = self.session.prepare( + "SELECT * FROM {}".format(self.table_name) + ) original_result_metadata = wildcard_prepared.result_metadata assert len(original_result_metadata) == 3 @@ -483,7 +555,9 @@ def test_invalidated_result_metadata(self): self.session.execute("ALTER TABLE {} DROP d".format(self.table_name)) # Get a bunch of requests in the pipeline with varying states of result_meta, reprepare, resolved - futures = set(self.session.execute_async(wildcard_prepared.bind(None)) for _ in range(200)) + futures = set( + self.session.execute_async(wildcard_prepared.bind(None)) for _ in range(200) + ) for f in futures: assert f.result()[0] == (1, 1) @@ -500,12 +574,14 @@ def test_prepared_id_is_update(self): The query id from the prepared statment must have changed """ - prepared_statement = self.session.prepare("SELECT * from {} WHERE a = ?".format(self.table_name)) + prepared_statement = self.session.prepare( + "SELECT * from {} WHERE a = ?".format(self.table_name) + ) id_before = prepared_statement.result_metadata_id assert len(prepared_statement.result_metadata) == 3 self.session.execute("ALTER TABLE {} ADD c int".format(self.table_name)) - bound_statement = prepared_statement.bind((1, )) + bound_statement = prepared_statement.bind((1,)) self.session.execute(bound_statement, timeout=1) id_after = prepared_statement.result_metadata_id @@ -523,7 +599,9 @@ def test_prepared_id_is_updated_across_pages(self): @since 3.12 @jira_ticket PYTHON-808 """ - prepared_statement = self.session.prepare("SELECT * from {}".format(self.table_name)) + prepared_statement = self.session.prepare( + "SELECT * from {}".format(self.table_name) + ) id_before = prepared_statement.result_metadata_id assert len(prepared_statement.result_metadata) == 3 @@ -534,7 +612,9 @@ def test_prepared_id_is_updated_across_pages(self): self.session.execute("ALTER TABLE {} ADD c int".format(self.table_name)) - result_set = set(x for x in ((1, 1, 1), (2, 2, 2), (3, 3, None, 3), (4, 4, None, 4))) + result_set = set( + x for x in ((1, 1, 1), (2, 2, 2), (3, 3, None, 3), (4, 4, None, 4)) + ) expected_result_set = set(row for row in result) id_after = prepared_statement.result_metadata_id @@ -563,7 +643,7 @@ def test_prepare_id_is_updated_across_session(self): one_id_before = one_prepared_stm.result_metadata_id self.session.execute("ALTER TABLE {} ADD c int".format(self.table_name)) - one_session.execute(one_prepared_stm, (1, )) + one_session.execute(one_prepared_stm, (1,)) one_id_after = one_prepared_stm.result_metadata_id assert one_id_before != one_id_after @@ -578,10 +658,11 @@ def test_not_reprepare_invalid_statements(self): @jira_ticket PYTHON-808 """ prepared_statement = self.session.prepare( - "SELECT a, b, d FROM {} WHERE a = ?".format(self.table_name)) + "SELECT a, b, d FROM {} WHERE a = ?".format(self.table_name) + ) self.session.execute("ALTER TABLE {} DROP d".format(self.table_name)) with pytest.raises(InvalidRequest): - self.session.execute(prepared_statement.bind((1, ))) + self.session.execute(prepared_statement.bind((1,))) def test_id_is_not_updated_conditional_v4(self): """ @@ -612,13 +693,18 @@ def test_id_is_not_updated_conditional_v5(self): def _test_updated_conditional(self, session, value): prepared_statement = session.prepare( - "INSERT INTO {}(a, b, d) VALUES " - "(?, ? , ?) IF NOT EXISTS".format(self.table_name)) + "INSERT INTO {}(a, b, d) VALUES (?, ? , ?) IF NOT EXISTS".format( + self.table_name + ) + ) first_id = prepared_statement.result_metadata_id - LOG.debug('initial result_metadata_id: {}'.format(first_id)) + LOG.debug("initial result_metadata_id: {}".format(first_id)) def check_result_and_metadata(expected): - assert session.execute(prepared_statement, (value, value, value)).one() == expected + assert ( + session.execute(prepared_statement, (value, value, value)).one() + == expected + ) assert prepared_statement.result_metadata_id == first_id assert prepared_statement.result_metadata is None diff --git a/tests/integration/standard/test_query.py b/tests/integration/standard/test_query.py index f9d3dc26bc..bdec4f943b 100644 --- a/tests/integration/standard/test_query.py +++ b/tests/integration/standard/test_query.py @@ -20,13 +20,37 @@ import pytest from cassandra import ProtocolVersion from cassandra import ConsistencyLevel, Unavailable, InvalidRequest, cluster -from cassandra.query import (PreparedStatement, BoundStatement, SimpleStatement, - BatchStatement, BatchType, dict_factory, TraceUnavailable) -from cassandra.cluster import NoHostAvailable, ExecutionProfile, EXEC_PROFILE_DEFAULT, Cluster +from cassandra.query import ( + PreparedStatement, + BoundStatement, + SimpleStatement, + BatchStatement, + BatchType, + dict_factory, + TraceUnavailable, +) +from cassandra.cluster import ( + NoHostAvailable, + ExecutionProfile, + EXEC_PROFILE_DEFAULT, + Cluster, +) from cassandra.policies import HostDistance, RoundRobinPolicy, WhiteListRoundRobinPolicy -from tests.integration import use_singledc, PROTOCOL_VERSION, BasicSharedKeyspaceUnitTestCase, \ - greaterthanprotocolv3, MockLoggingHandler, get_supported_protocol_versions, local, get_cluster, setup_keyspace, \ - USE_CASS_EXTERNAL, greaterthanorequalcass40, TestCluster, xfail_scylla +from tests.integration import ( + use_singledc, + PROTOCOL_VERSION, + BasicSharedKeyspaceUnitTestCase, + greaterthanprotocolv3, + MockLoggingHandler, + get_supported_protocol_versions, + local, + get_cluster, + setup_keyspace, + USE_CASS_EXTERNAL, + greaterthanorequalcass40, + TestCluster, + xfail_scylla, +) from tests import notwindows from tests.integration import greaterthanorequalcass30, get_node from tests.util import assertListEqual, wait_until @@ -48,7 +72,7 @@ def setup_module(): ccm_cluster.stop() # This is necessary because test_too_many_statements may # timeout otherwise - config_options = {'write_request_timeout_in_ms': '20000'} + config_options = {"write_request_timeout_in_ms": "20000"} ccm_cluster.set_configuration_options(config_options) ccm_cluster.start(wait_for_binary_proto=True, wait_other_notice=True) @@ -56,20 +80,19 @@ def setup_module(): class QueryTests(BasicSharedKeyspaceUnitTestCase): - def test_query(self): - prepared = self.session.prepare( """ INSERT INTO test3rf.test (k, v) VALUES (?, ?) - """.format(self.keyspace_name)) + """.format(self.keyspace_name) + ) assert isinstance(prepared, PreparedStatement) bound = prepared.bind((1, None)) assert isinstance(bound, BoundStatement) assert 2 == len(bound.values) self.session.execute(bound) - assert bound.routing_key == b'\x00\x00\x00\x01' + assert bound.routing_key == b"\x00\x00\x00\x01" def test_trace_prints_okay(self): """ @@ -96,16 +119,29 @@ def test_row_error_message(self): @test_category tracing """ - self.session.execute("CREATE TABLE {0}.{1} (k int PRIMARY KEY, v timestamp)".format(self.keyspace_name,self.function_table_name)) - ss = SimpleStatement("INSERT INTO {0}.{1} (k, v) VALUES (1, 1000000000000000)".format(self.keyspace_name, self.function_table_name)) + self.session.execute( + "CREATE TABLE {0}.{1} (k int PRIMARY KEY, v timestamp)".format( + self.keyspace_name, self.function_table_name + ) + ) + ss = SimpleStatement( + "INSERT INTO {0}.{1} (k, v) VALUES (1, 1000000000000000)".format( + self.keyspace_name, self.function_table_name + ) + ) self.session.execute(ss) with pytest.raises(DriverException) as context: - self.session.execute("SELECT * FROM {0}.{1}".format(self.keyspace_name, self.function_table_name)) + self.session.execute( + "SELECT * FROM {0}.{1}".format( + self.keyspace_name, self.function_table_name + ) + ) assert "Failed decoding result column" in str(context.value) def test_trace_id_to_resultset(self): - - future = self.session.execute_async("SELECT * FROM system.local WHERE key='local'", trace=True) + future = self.session.execute_async( + "SELECT * FROM system.local WHERE key='local'", trace=True + ) # future should have the current trace rs = future.result() @@ -121,7 +157,9 @@ def test_trace_id_to_resultset(self): def test_trace_ignores_row_factory(self): with TestCluster( - execution_profiles={EXEC_PROFILE_DEFAULT: ExecutionProfile(row_factory=dict_factory)} + execution_profiles={ + EXEC_PROFILE_DEFAULT: ExecutionProfile(row_factory=dict_factory) + } ) as cluster: s = cluster.connect() query = "SELECT * FROM system.local WHERE key='local'" @@ -168,11 +206,13 @@ def test_client_ip_in_trace(self): client_ip = trace.client # Ip address should be in the local_host range - pat = re.compile(r'127.0.0.\d{1,3}') + pat = re.compile(r"127.0.0.\d{1,3}") # Ensure that ip is set assert client_ip is not None, "Client IP was not set in trace with C* >= 2.2" - assert pat.match(client_ip), "Client IP from trace did not match the expected value" + assert pat.match(client_ip), ( + "Client IP from trace did not match the expected value" + ) def test_trace_cl(self): """ @@ -190,15 +230,35 @@ def test_trace_cl(self): with pytest.raises(Unavailable): response_future.get_query_trace(query_cl=ConsistencyLevel.THREE) # Try again with a smattering of other CL's - assert response_future.get_query_trace(max_wait=2.0, query_cl=ConsistencyLevel.TWO).trace_id is not None + assert ( + response_future.get_query_trace( + max_wait=2.0, query_cl=ConsistencyLevel.TWO + ).trace_id + is not None + ) response_future = self.session.execute_async(statement, trace=True) response_future.result() - assert response_future.get_query_trace(max_wait=2.0, query_cl=ConsistencyLevel.ONE).trace_id is not None + assert ( + response_future.get_query_trace( + max_wait=2.0, query_cl=ConsistencyLevel.ONE + ).trace_id + is not None + ) response_future = self.session.execute_async(statement, trace=True) response_future.result() with pytest.raises(InvalidRequest): - assert response_future.get_query_trace(max_wait=2.0, query_cl=ConsistencyLevel.ANY).trace_id is not None - assert response_future.get_query_trace(max_wait=2.0, query_cl=ConsistencyLevel.QUORUM).trace_id is not None + assert ( + response_future.get_query_trace( + max_wait=2.0, query_cl=ConsistencyLevel.ANY + ).trace_id + is not None + ) + assert ( + response_future.get_query_trace( + max_wait=2.0, query_cl=ConsistencyLevel.QUORUM + ).trace_id + is not None + ) @notwindows def test_incomplete_query_trace(self): @@ -217,10 +277,18 @@ def test_incomplete_query_trace(self): """ # Create table and run insert, then select - self.session.execute("CREATE TABLE {0} (k INT, i INT, PRIMARY KEY(k, i))".format(self.keyspace_table_name)) - self.session.execute("INSERT INTO {0} (k, i) VALUES (0, 1)".format(self.keyspace_table_name)) + self.session.execute( + "CREATE TABLE {0} (k INT, i INT, PRIMARY KEY(k, i))".format( + self.keyspace_table_name + ) + ) + self.session.execute( + "INSERT INTO {0} (k, i) VALUES (0, 1)".format(self.keyspace_table_name) + ) - response_future = self.session.execute_async("SELECT i FROM {0} WHERE k=0".format(self.keyspace_table_name), trace=True) + response_future = self.session.execute_async( + "SELECT i FROM {0} WHERE k=0".format(self.keyspace_table_name), trace=True + ) response_future.result() assert len(response_future._query_traces) == 1 @@ -228,7 +296,12 @@ def test_incomplete_query_trace(self): assert self._wait_for_trace_to_populate(trace.trace_id) # Delete trace duration from the session (this is what the driver polls for "complete") - delete_statement = SimpleStatement("DELETE duration FROM system_traces.sessions WHERE session_id = {0}".format(trace.trace_id), consistency_level=ConsistencyLevel.ALL) + delete_statement = SimpleStatement( + "DELETE duration FROM system_traces.sessions WHERE session_id = {0}".format( + trace.trace_id + ), + consistency_level=ConsistencyLevel.ALL, + ) self.session.execute(delete_statement) assert self._wait_for_trace_to_delete(trace.trace_id) @@ -249,21 +322,26 @@ def test_incomplete_query_trace(self): def _wait_for_trace_to_populate(self, trace_id): count = 0 retry_max = 10 - while(not self._is_trace_present(trace_id) and count < retry_max): - time.sleep(.2) + while not self._is_trace_present(trace_id) and count < retry_max: + time.sleep(0.2) count += 1 return count != retry_max def _wait_for_trace_to_delete(self, trace_id): count = 0 retry_max = 10 - while(self._is_trace_present(trace_id) and count < retry_max): - time.sleep(.2) + while self._is_trace_present(trace_id) and count < retry_max: + time.sleep(0.2) count += 1 return count != retry_max def _is_trace_present(self, trace_id): - select_statement = SimpleStatement("SElECT duration FROM system_traces.sessions WHERE session_id = {0}".format(trace_id), consistency_level=ConsistencyLevel.ALL) + select_statement = SimpleStatement( + "SElECT duration FROM system_traces.sessions WHERE session_id = {0}".format( + trace_id + ), + consistency_level=ConsistencyLevel.ALL, + ) ssrs = self.session.execute(select_statement) if not len(ssrs.current_rows) or ssrs.one().duration is None: return False @@ -279,17 +357,31 @@ def test_query_by_id(self): @test_category queries basic """ - create_table = "CREATE TABLE {0}.{1} (id int primary key, m map)".format(self.keyspace_name, self.function_table_name) + create_table = ( + "CREATE TABLE {0}.{1} (id int primary key, m map)".format( + self.keyspace_name, self.function_table_name + ) + ) self.session.execute(create_table) - self.session.execute("insert into "+self.keyspace_name+"."+self.function_table_name+" (id, m) VALUES ( 1, {1: 'one', 2: 'two', 3:'three'})") - results1 = self.session.execute("select id, m from {0}.{1}".format(self.keyspace_name, self.function_table_name)) + self.session.execute( + "insert into " + + self.keyspace_name + + "." + + self.function_table_name + + " (id, m) VALUES ( 1, {1: 'one', 2: 'two', 3:'three'})" + ) + results1 = self.session.execute( + "select id, m from {0}.{1}".format( + self.keyspace_name, self.function_table_name + ) + ) assert results1.column_types is not None - assert results1.column_types[0].typename == 'int' - assert results1.column_types[1].typename == 'map' - assert results1.column_types[0].cassname == 'Int32Type' - assert results1.column_types[1].cassname == 'MapType' + assert results1.column_types[0].typename == "int" + assert results1.column_types[1].typename == "map" + assert results1.column_types[0].cassname == "Int32Type" + assert results1.column_types[1].cassname == "MapType" assert len(results1.column_types[0].subtypes) == 0 assert len(results1.column_types[1].subtypes) == 2 assert results1.column_types[1].subtypes[0].typename == "int" @@ -318,17 +410,31 @@ def test_column_names(self): PRIMARY KEY (user, game, year, month, day) )""".format(self.keyspace_name, self.function_table_name) - self.session.execute(create_table) - result_set = self.session.execute("SELECT * FROM {0}.{1}".format(self.keyspace_name, self.function_table_name)) + result_set = self.session.execute( + "SELECT * FROM {0}.{1}".format(self.keyspace_name, self.function_table_name) + ) assert result_set.column_types is not None - assert result_set.column_names == [u'user', u'game', u'year', u'month', u'day', u'score'] + assert result_set.column_names == [ + "user", + "game", + "year", + "month", + "day", + "score", + ] @greaterthanorequalcass30 def test_basic_json_query(self): - insert_query = SimpleStatement("INSERT INTO test3rf.test(k, v) values (1, 1)", consistency_level = ConsistencyLevel.QUORUM) - json_query = SimpleStatement("SELECT JSON * FROM test3rf.test where k=1", consistency_level = ConsistencyLevel.QUORUM) + insert_query = SimpleStatement( + "INSERT INTO test3rf.test(k, v) values (1, 1)", + consistency_level=ConsistencyLevel.QUORUM, + ) + json_query = SimpleStatement( + "SELECT JSON * FROM test3rf.test where k=1", + consistency_level=ConsistencyLevel.QUORUM, + ) self.session.execute(insert_query) results = self.session.execute(json_query) @@ -348,14 +454,16 @@ def test_host_targeting_query(self): # copy of default EP with checkable LBP checkable_ep = self.session.execution_profile_clone_update( ep=default_ep, - load_balancing_policy=mock.Mock(wraps=default_ep.load_balancing_policy) + load_balancing_policy=mock.Mock(wraps=default_ep.load_balancing_policy), ) query = SimpleStatement("INSERT INTO test3rf.test(k, v) values (1, 1)") for i in range(10): host = random.choice(self.cluster.metadata.all_hosts()) - log.debug('targeting {}'.format(host)) - future = self.session.execute_async(query, host=host, execution_profile=checkable_ep) + log.debug("targeting {}".format(host)) + future = self.session.execute_async( + query, host=host, execution_profile=checkable_ep + ) future.result() # check we're using the selected host assert host == future.coordinator_host @@ -364,7 +472,6 @@ def test_host_targeting_query(self): class PreparedStatementTests(unittest.TestCase): - def setUp(self): self.cluster = TestCluster() self.session = self.cluster.connect() @@ -379,11 +486,12 @@ def test_routing_key(self): prepared = self.session.prepare( """ INSERT INTO test3rf.test (k, v) VALUES (?, ?) - """) + """ + ) assert isinstance(prepared, PreparedStatement) bound = prepared.bind((1, None)) - assert bound.routing_key == b'\x00\x00\x00\x01' + assert bound.routing_key == b"\x00\x00\x00\x01" def test_empty_routing_key_indexes(self): """ @@ -393,7 +501,8 @@ def test_empty_routing_key_indexes(self): prepared = self.session.prepare( """ INSERT INTO test3rf.test (k, v) VALUES (?, ?) - """) + """ + ) prepared.routing_key_indexes = None assert isinstance(prepared, PreparedStatement) @@ -408,12 +517,13 @@ def test_predefined_routing_key(self): prepared = self.session.prepare( """ INSERT INTO test3rf.test (k, v) VALUES (?, ?) - """) + """ + ) assert isinstance(prepared, PreparedStatement) bound = prepared.bind((1, None)) - bound._set_routing_key('fake_key') - assert bound.routing_key == 'fake_key' + bound._set_routing_key("fake_key") + assert bound.routing_key == "fake_key" def test_multiple_routing_key_indexes(self): """ @@ -422,16 +532,23 @@ def test_multiple_routing_key_indexes(self): prepared = self.session.prepare( """ INSERT INTO test3rf.test (k, v) VALUES (?, ?) - """) + """ + ) assert isinstance(prepared, PreparedStatement) prepared.routing_key_indexes = [0, 1] bound = prepared.bind((1, 2)) - assert bound.routing_key == b'\x00\x04\x00\x00\x00\x01\x00\x00\x04\x00\x00\x00\x02\x00' + assert ( + bound.routing_key + == b"\x00\x04\x00\x00\x00\x01\x00\x00\x04\x00\x00\x00\x02\x00" + ) prepared.routing_key_indexes = [1, 0] bound = prepared.bind((1, 2)) - assert bound.routing_key == b'\x00\x04\x00\x00\x00\x02\x00\x00\x04\x00\x00\x00\x01\x00' + assert ( + bound.routing_key + == b"\x00\x04\x00\x00\x00\x02\x00\x00\x04\x00\x00\x00\x01\x00" + ) def test_bound_keyspace(self): """ @@ -440,11 +557,12 @@ def test_bound_keyspace(self): prepared = self.session.prepare( """ INSERT INTO test3rf.test (k, v) VALUES (?, ?) - """) + """ + ) assert isinstance(prepared, PreparedStatement) bound = prepared.bind((1, 2)) - assert bound.keyspace == 'test3rf' + assert bound.keyspace == "test3rf" class ForcedHostIndexPolicy(RoundRobinPolicy): @@ -453,7 +571,7 @@ def __init__(self, host_index_to_use=0): self.host_index_to_use = host_index_to_use def set_host(self, host_index): - """ 0-based index of which host to use """ + """0-based index of which host to use""" self.host_index_to_use = host_index def make_query_plan(self, working_keyspace=None, query=None): @@ -464,14 +582,14 @@ def make_query_plan(self, working_keyspace=None, query=None): host = [live_hosts[self.host_index_to_use]] except IndexError as e: raise IndexError( - 'You specified an index larger than the number of hosts. Total hosts: {}. Index specified: {}'.format( + "You specified an index larger than the number of hosts. Total hosts: {}. Index specified: {}".format( len(live_hosts), self.host_index_to_use - )) from e + ) + ) from e return host class PreparedStatementMetdataTest(unittest.TestCase): - def test_prepared_metadata_generation(self): """ Test to validate that result metadata is appropriately populated across protocol version @@ -487,11 +605,17 @@ def test_prepared_metadata_generation(self): base_line = None for proto_version in get_supported_protocol_versions(): - beta_flag = True if proto_version in ProtocolVersion.BETA_VERSIONS else False - cluster = Cluster(protocol_version=proto_version, allow_beta_protocol_version=beta_flag) + beta_flag = ( + True if proto_version in ProtocolVersion.BETA_VERSIONS else False + ) + cluster = Cluster( + protocol_version=proto_version, allow_beta_protocol_version=beta_flag + ) session = cluster.connect() - select_statement = session.prepare("SELECT * FROM system.local WHERE key='local'") + select_statement = session.prepare( + "SELECT * FROM system.local WHERE key='local'" + ) if proto_version == 1: assert select_statement.result_metadata == None else: @@ -524,8 +648,8 @@ def test_prepare_on_all_hosts(self): session = clus.connect(wait_for_all_pools=True) select_statement = session.prepare("SELECT k FROM test3rf.test WHERE k = ?") for host in clus.metadata.all_hosts(): - session.execute(select_statement, (1, ), host=host) - assert 2 == mock_handler.get_message_count('debug', "Re-preparing") + session.execute(select_statement, (1,), host=host) + assert 2 == mock_handler.get_message_count("debug", "Re-preparing") def test_prepare_batch_statement(self): """ @@ -541,7 +665,9 @@ def test_prepare_batch_statement(self): policy = ForcedHostIndexPolicy() clus = TestCluster( execution_profiles={ - EXEC_PROFILE_DEFAULT: ExecutionProfile(load_balancing_policy=policy), + EXEC_PROFILE_DEFAULT: ExecutionProfile( + load_balancing_policy=policy + ), }, prepare_on_all_hosts=False, reprepare_on_up=False, @@ -555,7 +681,9 @@ def test_prepare_batch_statement(self): session.execute("DROP TABLE IF EXISTS %s" % table) session.execute("CREATE TABLE %s (k int PRIMARY KEY, v int )" % table) - insert_statement = session.prepare("INSERT INTO %s (k, v) VALUES (?, ?)" % table) + insert_statement = session.prepare( + "INSERT INTO %s (k, v) VALUES (?, ?)" % table + ) # This is going to query a host where the query # is not prepared @@ -565,10 +693,14 @@ def test_prepare_batch_statement(self): session.execute(batch_statement) # To verify our test assumption that queries are getting re-prepared properly - assert 1 == mock_handler.get_message_count('debug', "Re-preparing") + assert 1 == mock_handler.get_message_count("debug", "Re-preparing") - select_results = session.execute(SimpleStatement("SELECT * FROM %s WHERE k = 1" % table, - consistency_level=ConsistencyLevel.ALL)) + select_results = session.execute( + SimpleStatement( + "SELECT * FROM %s WHERE k = 1" % table, + consistency_level=ConsistencyLevel.ALL, + ) + ) first_row = select_results.one()[:2] assert (1, 2) == first_row @@ -592,8 +724,12 @@ def test_prepare_batch_statement_after_alter(self): session = clus.connect(wait_for_all_pools=True) session.execute("DROP TABLE IF EXISTS %s" % table) - session.execute("CREATE TABLE %s (k int PRIMARY KEY, a int, b int, d int)" % table) - insert_statement = session.prepare("INSERT INTO %s (k, b, d) VALUES (?, ?, ?)" % table) + session.execute( + "CREATE TABLE %s (k int PRIMARY KEY, a int, b int, d int)" % table + ) + insert_statement = session.prepare( + "INSERT INTO %s (k, b, d) VALUES (?, ?, ?)" % table + ) # Altering the table might trigger an update in the insert metadata session.execute("ALTER TABLE %s ADD c int" % table) @@ -615,13 +751,13 @@ def test_prepare_batch_statement_after_alter(self): (1, None, 2, None, 3), (2, None, 3, None, 4), (3, None, 4, None, 5), - (4, None, 5, None, 6) + (4, None, 5, None, 6), ] assert set(expected_results) == set(select_results._current_rows) # To verify our test assumption that queries are getting re-prepared properly - assert 3 == mock_handler.get_message_count('debug', "Re-preparing") + assert 3 == mock_handler.get_message_count("debug", "Re-preparing") class PrintStatementTests(unittest.TestCase): @@ -634,8 +770,13 @@ def test_simple_statement(self): Highlight the format of printing SimpleStatements """ - ss = SimpleStatement('SELECT * FROM test3rf.test', consistency_level=ConsistencyLevel.ONE) - assert str(ss) == '' + ss = SimpleStatement( + "SELECT * FROM test3rf.test", consistency_level=ConsistencyLevel.ONE + ) + assert ( + str(ss) + == '' + ) def test_prepared_statement(self): """ @@ -645,24 +786,30 @@ def test_prepared_statement(self): cluster = TestCluster() session = cluster.connect() - prepared = session.prepare('INSERT INTO test3rf.test (k, v) VALUES (?, ?)') + prepared = session.prepare("INSERT INTO test3rf.test (k, v) VALUES (?, ?)") prepared.consistency_level = ConsistencyLevel.ONE - assert str(prepared) == '' + assert ( + str(prepared) + == '' + ) bound = prepared.bind((1, 2)) - assert str(bound) == '' + assert ( + str(bound) + == '' + ) cluster.shutdown() class BatchStatementTests(BasicSharedKeyspaceUnitTestCase): - def setUp(self): if PROTOCOL_VERSION < 2: raise unittest.SkipTest( "Protocol 2.0+ is required for BATCH operations, currently testing against %r" - % (PROTOCOL_VERSION,)) + % (PROTOCOL_VERSION,) + ) self.cluster = TestCluster() self.session = self.cluster.connect(wait_for_all_pools=True) @@ -675,8 +822,11 @@ def confirm_results(self): values = set() # Assuming the test data is inserted at default CL.ONE, we need ALL here to guarantee we see # everything inserted - results = self.session.execute(SimpleStatement("SELECT * FROM test3rf.test", - consistency_level=ConsistencyLevel.ALL)) + results = self.session.execute( + SimpleStatement( + "SELECT * FROM test3rf.test", consistency_level=ConsistencyLevel.ALL + ) + ) for result in results: keys.add(result.k) values.add(result.v) @@ -696,7 +846,10 @@ def test_string_statements(self): def test_simple_statements(self): batch = BatchStatement(BatchType.LOGGED) for i in range(10): - batch.add(SimpleStatement("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)"), (i, i)) + batch.add( + SimpleStatement("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)"), + (i, i), + ) self.session.execute(batch) self.session.execute_async(batch).result() @@ -754,16 +907,18 @@ def test_no_parameters(self): self.confirm_results() def test_unicode(self): - ddl = ''' + ddl = """ CREATE TABLE test3rf.testtext ( k int PRIMARY KEY, - v text )''' + v text )""" self.session.execute(ddl) - unicode_text = u'Fran\u00E7ois' - query = u'INSERT INTO test3rf.testtext (k, v) VALUES (%s, %s)' + unicode_text = "Fran\u00e7ois" + query = "INSERT INTO test3rf.testtext (k, v) VALUES (%s, %s)" try: batch = BatchStatement(BatchType.LOGGED) - batch.add(u"INSERT INTO test3rf.testtext (k, v) VALUES (%s, %s)", (0, unicode_text)) + batch.add( + "INSERT INTO test3rf.testtext (k, v) VALUES (%s, %s)", (0, unicode_text) + ) self.session.execute(batch) finally: self.session.execute("DROP TABLE test3rf.testtext") @@ -773,13 +928,17 @@ def test_too_many_statements(self): large_batch = 0xFFF max_statements = 0xFFFF ss = SimpleStatement("INSERT INTO test3rf.test (k, v) VALUES (0, 0)") - b = BatchStatement(batch_type=BatchType.UNLOGGED, consistency_level=ConsistencyLevel.ONE) + b = BatchStatement( + batch_type=BatchType.UNLOGGED, consistency_level=ConsistencyLevel.ONE + ) # large number works works b.add_all([ss] * large_batch, [None] * large_batch) self.session.execute(b, timeout=30.0) - b = BatchStatement(batch_type=BatchType.UNLOGGED, consistency_level=ConsistencyLevel.ONE) + b = BatchStatement( + batch_type=BatchType.UNLOGGED, consistency_level=ConsistencyLevel.ONE + ) # max + 1 raises b.add_all([ss] * max_statements, [None] * max_statements) with pytest.raises(ValueError): @@ -796,7 +955,8 @@ def setUp(self): if PROTOCOL_VERSION < 2: raise unittest.SkipTest( "Protocol 2.0+ is required for Serial Consistency, currently testing against %r" - % (PROTOCOL_VERSION,)) + % (PROTOCOL_VERSION,) + ) self.cluster = TestCluster() self.session = self.cluster.connect() @@ -808,7 +968,8 @@ def test_conditional_update(self): self.session.execute("INSERT INTO test3rf.test (k, v) VALUES (0, 0)") statement = SimpleStatement( "UPDATE test3rf.test SET v=1 WHERE k=0 IF v=1", - serial_consistency_level=ConsistencyLevel.SERIAL) + serial_consistency_level=ConsistencyLevel.SERIAL, + ) # crazy test, but PYTHON-299 # TODO: expand to check more parameters get passed to statement, and on to messages assert statement.serial_consistency_level == ConsistencyLevel.SERIAL @@ -820,7 +981,8 @@ def test_conditional_update(self): statement = SimpleStatement( "UPDATE test3rf.test SET v=1 WHERE k=0 IF v=0", - serial_consistency_level=ConsistencyLevel.LOCAL_SERIAL) + serial_consistency_level=ConsistencyLevel.LOCAL_SERIAL, + ) assert statement.serial_consistency_level == ConsistencyLevel.LOCAL_SERIAL future = self.session.execute_async(statement) result = future.result() @@ -830,8 +992,7 @@ def test_conditional_update(self): def test_conditional_update_with_prepared_statements(self): self.session.execute("INSERT INTO test3rf.test (k, v) VALUES (0, 0)") - statement = self.session.prepare( - "UPDATE test3rf.test SET v=1 WHERE k=0 IF v=2") + statement = self.session.prepare("UPDATE test3rf.test SET v=1 WHERE k=0 IF v=2") statement.serial_consistency_level = ConsistencyLevel.SERIAL future = self.session.execute_async(statement) @@ -840,8 +1001,7 @@ def test_conditional_update_with_prepared_statements(self): assert result assert not result.one().applied - statement = self.session.prepare( - "UPDATE test3rf.test SET v=1 WHERE k=0 IF v=0") + statement = self.session.prepare("UPDATE test3rf.test SET v=1 WHERE k=0 IF v=0") bound = statement.bind(()) bound.serial_consistency_level = ConsistencyLevel.LOCAL_SERIAL future = self.session.execute_async(bound) @@ -861,7 +1021,9 @@ def test_conditional_update_with_batch_statements(self): assert result assert not result.one().applied - statement = BatchStatement(serial_consistency_level=ConsistencyLevel.LOCAL_SERIAL) + statement = BatchStatement( + serial_consistency_level=ConsistencyLevel.LOCAL_SERIAL + ) statement.add("UPDATE test3rf.test SET v=1 WHERE k=0 IF v=0") assert statement.serial_consistency_level == ConsistencyLevel.LOCAL_SERIAL future = self.session.execute_async(statement) @@ -873,9 +1035,9 @@ def test_conditional_update_with_batch_statements(self): def test_bad_consistency_level(self): statement = SimpleStatement("foo") with pytest.raises(ValueError): - setattr(statement, 'serial_consistency_level', ConsistencyLevel.ONE) + setattr(statement, "serial_consistency_level", ConsistencyLevel.ONE) with pytest.raises(ValueError): - SimpleStatement('foo', serial_consistency_level=ConsistencyLevel.ONE) + SimpleStatement("foo", serial_consistency_level=ConsistencyLevel.ONE) class LightweightTransactionTests(unittest.TestCase): @@ -887,24 +1049,25 @@ def setUp(self): if PROTOCOL_VERSION < 2: raise unittest.SkipTest( "Protocol 2.0+ is required for Lightweight transactions, currently testing against %r" - % (PROTOCOL_VERSION,)) + % (PROTOCOL_VERSION,) + ) serial_profile = ExecutionProfile(consistency_level=ConsistencyLevel.SERIAL) - self.cluster = TestCluster(execution_profiles={'serial': serial_profile}) + self.cluster = TestCluster(execution_profiles={"serial": serial_profile}) self.session = self.cluster.connect() - ddl = ''' + ddl = """ CREATE TABLE test3rf.lwt ( k int PRIMARY KEY, - v int )''' + v int )""" self.session.execute(ddl) - ddl = ''' + ddl = """ CREATE TABLE test3rf.lwt_clustering ( k int, c int, v int, - PRIMARY KEY (k, c))''' + PRIMARY KEY (k, c))""" self.session.execute(ddl) def tearDown(self): @@ -922,8 +1085,12 @@ def test_no_connection_refused_on_timeout(self): Number of iterations can be specified with LWT_ITERATIONS environment variable. Default value is 1000 """ - insert_statement = self.session.prepare("INSERT INTO test3rf.lwt (k, v) VALUES (0, 0) IF NOT EXISTS") - delete_statement = self.session.prepare("DELETE FROM test3rf.lwt WHERE k = 0 IF EXISTS") + insert_statement = self.session.prepare( + "INSERT INTO test3rf.lwt (k, v) VALUES (0, 0) IF NOT EXISTS" + ) + delete_statement = self.session.prepare( + "DELETE FROM test3rf.lwt WHERE k = 0 IF EXISTS" + ) iterations = int(os.getenv("LWT_ITERATIONS", 1000)) @@ -934,26 +1101,40 @@ def test_no_connection_refused_on_timeout(self): statements_and_params.append((delete_statement, ())) received_timeout = False - results = execute_concurrent(self.session, statements_and_params, raise_on_first_error=False) - for (success, result) in results: + results = execute_concurrent( + self.session, statements_and_params, raise_on_first_error=False + ) + for success, result in results: if success: continue else: # In this case result is an exception exception_type = type(result).__name__ if exception_type == "NoHostAvailable": - pytest.fail("PYTHON-91: Disconnected from Cassandra: %s" % result.message) - if exception_type in ["WriteTimeout", "WriteFailure", "ReadTimeout", "ReadFailure", "ErrorMessageSub"]: + pytest.fail( + "PYTHON-91: Disconnected from Cassandra: %s" % result.message + ) + if exception_type in [ + "WriteTimeout", + "WriteFailure", + "ReadTimeout", + "ReadFailure", + "ErrorMessageSub", + ]: if type(result).__name__ in ["WriteTimeout", "WriteFailure"]: received_timeout = True continue - pytest.fail("Unexpected exception %s: %s" % (exception_type, result.message)) + pytest.fail( + "Unexpected exception %s: %s" % (exception_type, result.message) + ) # Make sure test passed assert received_timeout - @xfail_scylla('Fails on Scylla with error `SERIAL/LOCAL_SERIAL consistency may only be requested for one partition at a time`') + @xfail_scylla( + "Fails on Scylla with error `SERIAL/LOCAL_SERIAL consistency may only be requested for one partition at a time`" + ) def test_was_applied_batch_stmt(self): """ Test to ensure `:attr:cassandra.cluster.ResultSet.was_applied` works as expected @@ -974,49 +1155,75 @@ def test_was_applied_batch_stmt(self): """ for batch_type in (BatchType.UNLOGGED, BatchType.LOGGED): batch_statement = BatchStatement(batch_type) - batch_statement.add_all(["INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 0, 10) IF NOT EXISTS;", - "INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 1, 10);", - "INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 2, 10);"], [None] * 3) + batch_statement.add_all( + [ + "INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 0, 10) IF NOT EXISTS;", + "INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 1, 10);", + "INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 2, 10);", + ], + [None] * 3, + ) result = self.session.execute(batch_statement) - #assert result.was_applied + # assert result.was_applied # Should fail since (0, 0, 10) have already been written # The non conditional insert shouldn't be written as well batch_statement = BatchStatement(batch_type) - batch_statement.add_all(["INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 0, 10) IF NOT EXISTS;", - "INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 3, 10) IF NOT EXISTS;", - "INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 4, 10);", - "INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 5, 10) IF NOT EXISTS;"], [None] * 4) + batch_statement.add_all( + [ + "INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 0, 10) IF NOT EXISTS;", + "INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 3, 10) IF NOT EXISTS;", + "INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 4, 10);", + "INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 5, 10) IF NOT EXISTS;", + ], + [None] * 4, + ) result = self.session.execute(batch_statement) assert not result.was_applied - all_rows = self.session.execute("SELECT * from test3rf.lwt_clustering", execution_profile='serial') + all_rows = self.session.execute( + "SELECT * from test3rf.lwt_clustering", execution_profile="serial" + ) # Verify the non conditional insert hasn't been inserted assert len(all_rows.current_rows) == 3 # Should fail since (0, 0, 10) have already been written batch_statement = BatchStatement(batch_type) - batch_statement.add_all(["INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 0, 10) IF NOT EXISTS;", - "INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 3, 10) IF NOT EXISTS;", - "INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 5, 10) IF NOT EXISTS;"], [None] * 3) + batch_statement.add_all( + [ + "INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 0, 10) IF NOT EXISTS;", + "INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 3, 10) IF NOT EXISTS;", + "INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 5, 10) IF NOT EXISTS;", + ], + [None] * 3, + ) result = self.session.execute(batch_statement) assert not result.was_applied # Should fail since (0, 0, 10) have already been written - batch_statement.add("INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 0, 10) IF NOT EXISTS;") + batch_statement.add( + "INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 0, 10) IF NOT EXISTS;" + ) result = self.session.execute(batch_statement) assert not result.was_applied # Should succeed batch_statement = BatchStatement(batch_type) - batch_statement.add_all(["INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 3, 10) IF NOT EXISTS;", - "INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 4, 10) IF NOT EXISTS;", - "INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 5, 10) IF NOT EXISTS;"], [None] * 3) + batch_statement.add_all( + [ + "INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 3, 10) IF NOT EXISTS;", + "INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 4, 10) IF NOT EXISTS;", + "INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 5, 10) IF NOT EXISTS;", + ], + [None] * 3, + ) result = self.session.execute(batch_statement) assert result.was_applied - all_rows = self.session.execute("SELECT * from test3rf.lwt_clustering", execution_profile='serial') + all_rows = self.session.execute( + "SELECT * from test3rf.lwt_clustering", execution_profile="serial" + ) for i, row in enumerate(all_rows): assert (0, i, 10) == (row[0], row[1], row[2]) @@ -1039,12 +1246,17 @@ def test_empty_batch_statement(self): with pytest.raises(RuntimeError): results.was_applied - @pytest.mark.xfail(reason='Skipping until PYTHON-943 is resolved') + @pytest.mark.xfail(reason="Skipping until PYTHON-943 is resolved") def test_was_applied_batch_string(self): batch_statement = BatchStatement(BatchType.LOGGED) - batch_statement.add_all(["INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 0, 10);", - "INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 1, 10);", - "INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 2, 10);"], [None] * 3) + batch_statement.add_all( + [ + "INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 0, 10);", + "INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 1, 10);", + "INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 2, 10);", + ], + [None] * 3, + ) self.session.execute(batch_statement) batch_str = """ @@ -1075,13 +1287,16 @@ def setUp(self): if PROTOCOL_VERSION < 2: raise unittest.SkipTest( "Protocol 2.0+ is required for BATCH operations, currently testing against %r" - % (PROTOCOL_VERSION,)) + % (PROTOCOL_VERSION,) + ) self.cluster = TestCluster() self.session = self.cluster.connect() query = """ INSERT INTO test3rf.test (k, v) VALUES (?, ?) """ - self.simple_statement = SimpleStatement(query, routing_key='ss_rk', keyspace='keyspace_name') + self.simple_statement = SimpleStatement( + query, routing_key="ss_rk", keyspace="keyspace_name" + ) self.prepared = self.session.prepare(query) def tearDown(self): @@ -1155,7 +1370,6 @@ def test_inherit_first_rk_prepared_param(self): @greaterthanorequalcass30 class MaterializedViewQueryTest(BasicSharedKeyspaceUnitTestCase): - def test_mv_filtering(self): """ Test to ensure that cql filtering where clauses are properly supported in the python driver. @@ -1185,68 +1399,96 @@ def test_mv_filtering(self): SELECT * FROM {0}.scores WHERE game IS NOT NULL AND score IS NOT NULL AND user IS NOT NULL AND year IS NOT NULL AND month IS NOT NULL AND day IS NOT NULL PRIMARY KEY (game, score, user, year, month, day) - WITH CLUSTERING ORDER BY (score DESC, user ASC, year ASC, month ASC, day ASC)""".format(self.keyspace_name) + WITH CLUSTERING ORDER BY (score DESC, user ASC, year ASC, month ASC, day ASC)""".format( + self.keyspace_name + ) create_mv_dailyhigh = """CREATE MATERIALIZED VIEW {0}.dailyhigh AS SELECT * FROM {0}.scores WHERE game IS NOT NULL AND year IS NOT NULL AND month IS NOT NULL AND day IS NOT NULL AND score IS NOT NULL AND user IS NOT NULL PRIMARY KEY ((game, year, month, day), score, user) - WITH CLUSTERING ORDER BY (score DESC, user ASC)""".format(self.keyspace_name) + WITH CLUSTERING ORDER BY (score DESC, user ASC)""".format( + self.keyspace_name + ) create_mv_monthlyhigh = """CREATE MATERIALIZED VIEW {0}.monthlyhigh AS SELECT * FROM {0}.scores WHERE game IS NOT NULL AND year IS NOT NULL AND month IS NOT NULL AND score IS NOT NULL AND user IS NOT NULL AND day IS NOT NULL PRIMARY KEY ((game, year, month), score, user, day) - WITH CLUSTERING ORDER BY (score DESC, user ASC, day ASC)""".format(self.keyspace_name) + WITH CLUSTERING ORDER BY (score DESC, user ASC, day ASC)""".format( + self.keyspace_name + ) create_mv_filtereduserhigh = """CREATE MATERIALIZED VIEW {0}.filtereduserhigh AS SELECT * FROM {0}.scores WHERE user in ('jbellis', 'pcmanus') AND game IS NOT NULL AND score IS NOT NULL AND year is NOT NULL AND day is not NULL and month IS NOT NULL PRIMARY KEY (game, score, user, year, month, day) - WITH CLUSTERING ORDER BY (score DESC, user ASC, year ASC, month ASC, day ASC)""".format(self.keyspace_name) + WITH CLUSTERING ORDER BY (score DESC, user ASC, year ASC, month ASC, day ASC)""".format( + self.keyspace_name + ) self.session.execute(create_mv_alltime) self.session.execute(create_mv_dailyhigh) self.session.execute(create_mv_monthlyhigh) self.session.execute(create_mv_filtereduserhigh) - self.addCleanup(self.session.execute, "DROP MATERIALIZED VIEW {0}.alltimehigh".format(self.keyspace_name)) - self.addCleanup(self.session.execute, "DROP MATERIALIZED VIEW {0}.dailyhigh".format(self.keyspace_name)) - self.addCleanup(self.session.execute, "DROP MATERIALIZED VIEW {0}.monthlyhigh".format(self.keyspace_name)) - self.addCleanup(self.session.execute, "DROP MATERIALIZED VIEW {0}.filtereduserhigh".format(self.keyspace_name)) + self.addCleanup( + self.session.execute, + "DROP MATERIALIZED VIEW {0}.alltimehigh".format(self.keyspace_name), + ) + self.addCleanup( + self.session.execute, + "DROP MATERIALIZED VIEW {0}.dailyhigh".format(self.keyspace_name), + ) + self.addCleanup( + self.session.execute, + "DROP MATERIALIZED VIEW {0}.monthlyhigh".format(self.keyspace_name), + ) + self.addCleanup( + self.session.execute, + "DROP MATERIALIZED VIEW {0}.filtereduserhigh".format(self.keyspace_name), + ) - prepared_insert = self.session.prepare("""INSERT INTO {0}.scores (user, game, year, month, day, score) VALUES (?, ?, ? ,? ,?, ?)""".format(self.keyspace_name)) + prepared_insert = self.session.prepare( + """INSERT INTO {0}.scores (user, game, year, month, day, score) VALUES (?, ?, ? ,? ,?, ?)""".format( + self.keyspace_name + ) + ) - bound = prepared_insert.bind(('pcmanus', 'Coup', 2015, 5, 1, 4000)) + bound = prepared_insert.bind(("pcmanus", "Coup", 2015, 5, 1, 4000)) self.session.execute(bound) - bound = prepared_insert.bind(('jbellis', 'Coup', 2015, 5, 3, 1750)) + bound = prepared_insert.bind(("jbellis", "Coup", 2015, 5, 3, 1750)) self.session.execute(bound) - bound = prepared_insert.bind(('yukim', 'Coup', 2015, 5, 3, 2250)) + bound = prepared_insert.bind(("yukim", "Coup", 2015, 5, 3, 2250)) self.session.execute(bound) - bound = prepared_insert.bind(('tjake', 'Coup', 2015, 5, 3, 500)) + bound = prepared_insert.bind(("tjake", "Coup", 2015, 5, 3, 500)) self.session.execute(bound) - bound = prepared_insert.bind(('iamaleksey', 'Coup', 2015, 6, 1, 2500)) + bound = prepared_insert.bind(("iamaleksey", "Coup", 2015, 6, 1, 2500)) self.session.execute(bound) - bound = prepared_insert.bind(('tjake', 'Coup', 2015, 6, 2, 1000)) + bound = prepared_insert.bind(("tjake", "Coup", 2015, 6, 2, 1000)) self.session.execute(bound) - bound = prepared_insert.bind(('pcmanus', 'Coup', 2015, 6, 2, 2000)) + bound = prepared_insert.bind(("pcmanus", "Coup", 2015, 6, 2, 2000)) self.session.execute(bound) - bound = prepared_insert.bind(('jmckenzie', 'Coup', 2015, 6, 9, 2700)) + bound = prepared_insert.bind(("jmckenzie", "Coup", 2015, 6, 9, 2700)) self.session.execute(bound) - bound = prepared_insert.bind(('jbellis', 'Coup', 2015, 6, 20, 3500)) + bound = prepared_insert.bind(("jbellis", "Coup", 2015, 6, 20, 3500)) self.session.execute(bound) - bound = prepared_insert.bind(('jbellis', 'Checkers', 2015, 6, 20, 1200)) + bound = prepared_insert.bind(("jbellis", "Checkers", 2015, 6, 20, 1200)) self.session.execute(bound) - bound = prepared_insert.bind(('jbellis', 'Chess', 2015, 6, 21, 3500)) + bound = prepared_insert.bind(("jbellis", "Chess", 2015, 6, 21, 3500)) self.session.execute(bound) - bound = prepared_insert.bind(('pcmanus', 'Chess', 2015, 1, 25, 3200)) + bound = prepared_insert.bind(("pcmanus", "Chess", 2015, 1, 25, 3200)) self.session.execute(bound) # Test simple statement and alltime high filtering - query_statement = SimpleStatement("SELECT * FROM {0}.alltimehigh WHERE game='Coup'".format(self.keyspace_name), - consistency_level=ConsistencyLevel.QUORUM) + query_statement = SimpleStatement( + "SELECT * FROM {0}.alltimehigh WHERE game='Coup'".format( + self.keyspace_name + ), + consistency_level=ConsistencyLevel.QUORUM, + ) results = self.session.execute(query_statement) - assert results.one().game == 'Coup' + assert results.one().game == "Coup" assert results.one().year == 2015 assert results.one().month == 5 assert results.one().day == 1 @@ -1254,17 +1496,21 @@ def test_mv_filtering(self): assert results.one().user == "pcmanus" # Test prepared statement and daily high filtering - prepared_query = self.session.prepare("SELECT * FROM {0}.dailyhigh WHERE game=? AND year=? AND month=? and day=?".format(self.keyspace_name)) + prepared_query = self.session.prepare( + "SELECT * FROM {0}.dailyhigh WHERE game=? AND year=? AND month=? and day=?".format( + self.keyspace_name + ) + ) bound_query = prepared_query.bind(("Coup", 2015, 6, 2)) results = self.session.execute(bound_query) - assert results.one().game == 'Coup' + assert results.one().game == "Coup" assert results.one().year == 2015 assert results.one().month == 6 assert results.one().day == 2 assert results.one().score == 2000 assert results.one().user == "pcmanus" - assert results[1].game == 'Coup' + assert results[1].game == "Coup" assert results[1].year == 2015 assert results[1].month == 6 assert results[1].day == 2 @@ -1272,24 +1518,28 @@ def test_mv_filtering(self): assert results[1].user == "tjake" # Test montly high range queries - prepared_query = self.session.prepare("SELECT * FROM {0}.monthlyhigh WHERE game=? AND year=? AND month=? and score >= ? and score <= ?".format(self.keyspace_name)) + prepared_query = self.session.prepare( + "SELECT * FROM {0}.monthlyhigh WHERE game=? AND year=? AND month=? and score >= ? and score <= ?".format( + self.keyspace_name + ) + ) bound_query = prepared_query.bind(("Coup", 2015, 6, 2500, 3500)) results = self.session.execute(bound_query) - assert results.one().game == 'Coup' + assert results.one().game == "Coup" assert results.one().year == 2015 assert results.one().month == 6 assert results.one().day == 20 assert results.one().score == 3500 assert results.one().user == "jbellis" - assert results[1].game == 'Coup' + assert results[1].game == "Coup" assert results[1].year == 2015 assert results[1].month == 6 assert results[1].day == 9 assert results[1].score == 2700 assert results[1].user == "jmckenzie" - assert results[2].game == 'Coup' + assert results[2].game == "Coup" assert results[2].year == 2015 assert results[2].month == 6 assert results[2].day == 1 @@ -1297,17 +1547,21 @@ def test_mv_filtering(self): assert results[2].user == "iamaleksey" # Test filtered user high scores - query_statement = SimpleStatement("SELECT * FROM {0}.filtereduserhigh WHERE game='Chess'".format(self.keyspace_name), - consistency_level=ConsistencyLevel.QUORUM) + query_statement = SimpleStatement( + "SELECT * FROM {0}.filtereduserhigh WHERE game='Chess'".format( + self.keyspace_name + ), + consistency_level=ConsistencyLevel.QUORUM, + ) results = self.session.execute(query_statement) - assert results.one().game == 'Chess' + assert results.one().game == "Chess" assert results.one().year == 2015 assert results.one().month == 6 assert results.one().day == 21 assert results.one().score == 3500 assert results.one().user == "jbellis" - assert results[1].game == 'Chess' + assert results[1].game == "Chess" assert results[1].year == 2015 assert results[1].month == 1 assert results[1].day == 25 @@ -1316,16 +1570,17 @@ def test_mv_filtering(self): class UnicodeQueryTest(BasicSharedKeyspaceUnitTestCase): - def setUp(self): - ddl = ''' + ddl = """ CREATE TABLE {0}.{1} ( k int PRIMARY KEY, - v text )'''.format(self.keyspace_name, self.function_table_name) + v text )""".format(self.keyspace_name, self.function_table_name) self.session.execute(ddl) def tearDown(self): - self.session.execute("DROP TABLE {0}.{1}".format(self.keyspace_name,self.function_table_name)) + self.session.execute( + "DROP TABLE {0}.{1}".format(self.keyspace_name, self.function_table_name) + ) def test_unicode(self): """ @@ -1338,17 +1593,31 @@ def test_unicode(self): @test_category query """ - unicode_text = u'Fran\u00E7ois' + unicode_text = "Fran\u00e7ois" batch = BatchStatement(BatchType.LOGGED) - batch.add(u"INSERT INTO {0}.{1} (k, v) VALUES (%s, %s)".format(self.keyspace_name, self.function_table_name), (0, unicode_text)) + batch.add( + "INSERT INTO {0}.{1} (k, v) VALUES (%s, %s)".format( + self.keyspace_name, self.function_table_name + ), + (0, unicode_text), + ) self.session.execute(batch) - self.session.execute(u"INSERT INTO {0}.{1} (k, v) VALUES (%s, %s)".format(self.keyspace_name, self.function_table_name), (0, unicode_text)) - prepared = self.session.prepare(u"INSERT INTO {0}.{1} (k, v) VALUES (?, ?)".format(self.keyspace_name, self.function_table_name)) + self.session.execute( + "INSERT INTO {0}.{1} (k, v) VALUES (%s, %s)".format( + self.keyspace_name, self.function_table_name + ), + (0, unicode_text), + ) + prepared = self.session.prepare( + "INSERT INTO {0}.{1} (k, v) VALUES (?, ?)".format( + self.keyspace_name, self.function_table_name + ) + ) bound = prepared.bind((1, unicode_text)) self.session.execute(bound) -class BaseKeyspaceTests(): +class BaseKeyspaceTests: @classmethod def setUpClass(cls): cls.cluster = TestCluster() @@ -1359,28 +1628,36 @@ def setUpClass(cls): cls.table_name = "table_query_keyspace_tests" ddl = """CREATE KEYSPACE {0} WITH replication = - {{'class': 'SimpleStrategy', + {{'class': 'NetworkTopologyStrategy', 'replication_factor': '{1}'}}""".format(cls.ks_name, 1) cls.session.execute(ddl) ddl = """CREATE KEYSPACE {0} WITH replication = - {{'class': 'SimpleStrategy', - 'replication_factor': '{1}'}}""".format(cls.alternative_ks, 1) + {{'class': 'NetworkTopologyStrategy', + 'replication_factor': '{1}'}}""".format( + cls.alternative_ks, 1 + ) cls.session.execute(ddl) - ddl = ''' + ddl = """ CREATE TABLE {0}.{1} ( k int PRIMARY KEY, - v int )'''.format(cls.ks_name, cls.table_name) + v int )""".format(cls.ks_name, cls.table_name) cls.session.execute(ddl) - ddl = ''' + ddl = """ CREATE TABLE {0}.{1} ( k int PRIMARY KEY, - v int )'''.format(cls.alternative_ks, cls.table_name) + v int )""".format(cls.alternative_ks, cls.table_name) cls.session.execute(ddl) - cls.session.execute("INSERT INTO {}.{} (k, v) VALUES (1, 1)".format(cls.ks_name, cls.table_name)) - cls.session.execute("INSERT INTO {}.{} (k, v) VALUES (2, 2)".format(cls.alternative_ks, cls.table_name)) + cls.session.execute( + "INSERT INTO {}.{} (k, v) VALUES (1, 1)".format(cls.ks_name, cls.table_name) + ) + cls.session.execute( + "INSERT INTO {}.{} (k, v) VALUES (2, 2)".format( + cls.alternative_ks, cls.table_name + ) + ) @classmethod def tearDownClass(cls): @@ -1392,7 +1669,6 @@ def tearDownClass(cls): class QueryKeyspaceTests(BaseKeyspaceTests): - def test_setting_keyspace(self): """ Test the basic functionality of PYTHON-678, the keyspace can be set @@ -1418,7 +1694,9 @@ def test_setting_keyspace_and_session(self): @test_category query """ - cluster = TestCluster(protocol_version=ProtocolVersion.V5, allow_beta_protocol_version=True) + cluster = TestCluster( + protocol_version=ProtocolVersion.V5, allow_beta_protocol_version=True + ) session = cluster.connect(self.alternative_ks) self.addCleanup(cluster.shutdown) @@ -1468,7 +1746,9 @@ def test_lower_protocol(self): session = cluster.connect(self.ks_name) self.addCleanup(cluster.shutdown) - simple_stmt = SimpleStatement("SELECT * from {}".format(self.table_name), keyspace=self.ks_name) + simple_stmt = SimpleStatement( + "SELECT * from {}".format(self.table_name), keyspace=self.ks_name + ) # This raises cassandra.cluster.NoHostAvailable: ('Unable to complete the operation against # any hosts', {: UnsupportedOperation('Keyspaces may only be # set on queries with protocol version 5 or higher. Consider setting Cluster.protocol_version to 5.',), @@ -1478,7 +1758,9 @@ def test_lower_protocol(self): session.execute(simple_stmt) def _check_set_keyspace_in_statement(self, session): - simple_stmt = SimpleStatement("SELECT * from {}".format(self.table_name), keyspace=self.ks_name) + simple_stmt = SimpleStatement( + "SELECT * from {}".format(self.table_name), keyspace=self.ks_name + ) results = session.execute(simple_stmt) assert results.one() == (1, 1) @@ -1493,7 +1775,9 @@ class BatchWithKeyspaceTests(QueryKeyspaceTests, unittest.TestCase): def _check_set_keyspace_in_statement(self, session): batch_stmt = BatchStatement() for i in range(10): - batch_stmt.add("INSERT INTO {} (k, v) VALUES (%s, %s)".format(self.table_name), (i, i)) + batch_stmt.add( + "INSERT INTO {} (k, v) VALUES (%s, %s)".format(self.table_name), (i, i) + ) batch_stmt.keyspace = self.ks_name session.execute(batch_stmt) @@ -1504,8 +1788,12 @@ def confirm_results(self): values = set() # Assuming the test data is inserted at default CL.ONE, we need ALL here to guarantee we see # everything inserted - results = self.session.execute(SimpleStatement("SELECT * FROM {}.{}".format(self.ks_name, self.table_name), - consistency_level=ConsistencyLevel.ALL)) + results = self.session.execute( + SimpleStatement( + "SELECT * FROM {}.{}".format(self.ks_name, self.table_name), + consistency_level=ConsistencyLevel.ALL, + ) + ) for result in results: keys.add(result.k) values.add(result.v) @@ -1516,7 +1804,6 @@ def confirm_results(self): @greaterthanorequalcass40 class PreparedWithKeyspaceTests(BaseKeyspaceTests, unittest.TestCase): - def setUp(self): self.cluster = TestCluster() self.session = self.cluster.connect() @@ -1538,10 +1825,12 @@ def test_prepared_with_keyspace_explicit(self): query = "SELECT * from {} WHERE k = ?".format(self.table_name) prepared_statement = self.session.prepare(query, keyspace=self.ks_name) - results = self.session.execute(prepared_statement, (1, )) + results = self.session.execute(prepared_statement, (1,)) assert results.one() == (1, 1) - prepared_statement_alternative = self.session.prepare(query, keyspace=self.alternative_ks) + prepared_statement_alternative = self.session.prepare( + query, keyspace=self.alternative_ks + ) assert prepared_statement.query_id != prepared_statement_alternative.query_id @@ -1562,25 +1851,38 @@ def test_reprepare_after_host_is_down(self): with MockLoggingHandler().set_module_name(cluster.__name__) as mock_handler: get_node(1).stop(wait=True, gently=True, wait_other_notice=True) - only_first = ExecutionProfile(load_balancing_policy=WhiteListRoundRobinPolicy(["127.0.0.1"])) + only_first = ExecutionProfile( + load_balancing_policy=WhiteListRoundRobinPolicy(["127.0.0.1"]) + ) self.cluster.add_execution_profile("only_first", only_first) query = "SELECT v from {} WHERE k = ?".format(self.table_name) prepared_statement = self.session.prepare(query, keyspace=self.ks_name) - prepared_statement_alternative = self.session.prepare(query, keyspace=self.alternative_ks) + prepared_statement_alternative = self.session.prepare( + query, keyspace=self.alternative_ks + ) get_node(1).start(wait_for_binary_proto=True, wait_other_notice=True) # Wait for cluster._prepare_all_queries to be called wait_until( - lambda: mock_handler.get_message_count('debug', 'Preparing all known prepared statements') >= 1, - delay=0.5, max_attempts=20) + lambda: mock_handler.get_message_count( + "debug", "Preparing all known prepared statements" + ) + >= 1, + delay=0.5, + max_attempts=20, + ) - results = self.session.execute(prepared_statement, (1,), execution_profile="only_first") - assert results.one() == (1, ) + results = self.session.execute( + prepared_statement, (1,), execution_profile="only_first" + ) + assert results.one() == (1,) - results = self.session.execute(prepared_statement_alternative, (2,), execution_profile="only_first") - assert results.one() == (2, ) + results = self.session.execute( + prepared_statement_alternative, (2,), execution_profile="only_first" + ) + assert results.one() == (2,) def test_prepared_not_found(self): """ @@ -1603,7 +1905,7 @@ def test_prepared_not_found(self): prepared_statement = session.prepare(query, keyspace=self.ks_name) for _ in range(10): - results = session.execute(prepared_statement, (1, )) + results = session.execute(prepared_statement, (1,)) assert results.one() == (1,) def test_prepared_in_query_keyspace(self): @@ -1625,7 +1927,9 @@ def test_prepared_in_query_keyspace(self): results = session.execute(prepared_statement, (1,)) assert results.one() == (1,) - query = "SELECT k from {}.{} WHERE k = ?".format(self.alternative_ks, self.table_name) + query = "SELECT k from {}.{} WHERE k = ?".format( + self.alternative_ks, self.table_name + ) prepared_statement = session.prepare(query) results = session.execute(prepared_statement, (2,)) assert results.one() == (2,) diff --git a/tests/integration/standard/test_rate_limit_exceeded.py b/tests/integration/standard/test_rate_limit_exceeded.py index ea7dfc7d61..8c89be2967 100644 --- a/tests/integration/standard/test_rate_limit_exceeded.py +++ b/tests/integration/standard/test_rate_limit_exceeded.py @@ -2,22 +2,31 @@ import unittest from cassandra import OperationType, RateLimitReached from cassandra.cluster import Cluster -from cassandra.policies import ConstantReconnectionPolicy, RoundRobinPolicy, TokenAwarePolicy +from cassandra.policies import ( + ConstantReconnectionPolicy, + RoundRobinPolicy, + TokenAwarePolicy, +) from tests.integration import PROTOCOL_VERSION, use_singledc import pytest LOGGER = logging.getLogger(__name__) + def setup_module(): use_singledc() + class TestRateLimitExceededException(unittest.TestCase): @classmethod def setup_class(cls): - cls.cluster = Cluster(contact_points=["127.0.0.1"], protocol_version=PROTOCOL_VERSION, - load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy()), - reconnection_policy=ConstantReconnectionPolicy(1)) + cls.cluster = Cluster( + contact_points=["127.0.0.1"], + protocol_version=PROTOCOL_VERSION, + load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy()), + reconnection_policy=ConstantReconnectionPolicy(1), + ) cls.session = cls.cluster.connect() @classmethod @@ -33,20 +42,23 @@ def test_rate_limit_exceeded(self): self.session.execute( """ CREATE KEYSPACE IF NOT EXISTS ratetests - WITH REPLICATION = {'class' : 'SimpleStrategy', 'replication_factor' : 1} - """) + WITH REPLICATION = {'class' : 'NetworkTopologyStrategy', 'replication_factor' : 1} + """ + ) self.session.execute("USE ratetests") self.session.execute( """ CREATE TABLE tbl (pk int PRIMARY KEY, v int) WITH per_partition_rate_limit = {'max_writes_per_second': 1} - """) + """ + ) prepared = self.session.prepare( """ INSERT INTO tbl (pk, v) VALUES (?, ?) - """) + """ + ) # The rate limit is 1 write/s, so repeat the same query # until an error occurs, it should happen quickly diff --git a/tests/integration/standard/test_shard_aware.py b/tests/integration/standard/test_shard_aware.py index d1f3e27abd..397d985ae7 100644 --- a/tests/integration/standard/test_shard_aware.py +++ b/tests/integration/standard/test_shard_aware.py @@ -22,7 +22,11 @@ import pytest from cassandra.cluster import Cluster -from cassandra.policies import TokenAwarePolicy, RoundRobinPolicy, ConstantReconnectionPolicy +from cassandra.policies import ( + TokenAwarePolicy, + RoundRobinPolicy, + ConstantReconnectionPolicy, +) from cassandra import OperationTimedOut, ConsistencyLevel from tests.integration import use_cluster, get_node, PROTOCOL_VERSION @@ -36,24 +40,27 @@ def setup_module(): global _saved_scylla_ext_opts - _saved_scylla_ext_opts = os.environ.get('SCYLLA_EXT_OPTS') - os.environ['SCYLLA_EXT_OPTS'] = "--smp 2" - use_cluster('cluster_tests', [3], start=True) + _saved_scylla_ext_opts = os.environ.get("SCYLLA_EXT_OPTS") + os.environ["SCYLLA_EXT_OPTS"] = "--smp 2" + use_cluster("cluster_tests", [3], start=True) def teardown_module(): if _saved_scylla_ext_opts is None: - os.environ.pop('SCYLLA_EXT_OPTS', None) + os.environ.pop("SCYLLA_EXT_OPTS", None) else: - os.environ['SCYLLA_EXT_OPTS'] = _saved_scylla_ext_opts + os.environ["SCYLLA_EXT_OPTS"] = _saved_scylla_ext_opts class TestShardAwareIntegration(unittest.TestCase): @classmethod def setup_class(cls): - cls.cluster = Cluster(contact_points=["127.0.0.1"], protocol_version=PROTOCOL_VERSION, - load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy()), - reconnection_policy=ConstantReconnectionPolicy(1)) + cls.cluster = Cluster( + contact_points=["127.0.0.1"], + protocol_version=PROTOCOL_VERSION, + load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy()), + reconnection_policy=ConstantReconnectionPolicy(1), + ) cls.session = cls.cluster.connect() LOGGER.info(cls.cluster.is_shard_aware()) LOGGER.info(cls.cluster.shard_aware_stats()) @@ -69,16 +76,18 @@ def verify_same_shard_in_tracing(self, results, shard_name): LOGGER.info("%s %s %s", event.source, event.thread_name, event.description) for event in events: assert shard_name in event.thread_name - assert 'querying locally' in "\n".join([event.description for event in events]) + assert "querying locally" in "\n".join([event.description for event in events]) trace_id = results.response_future.get_query_trace_ids()[0] - traces = self.session.execute("SELECT * FROM system_traces.events WHERE session_id = %s", (trace_id,)) + traces = self.session.execute( + "SELECT * FROM system_traces.events WHERE session_id = %s", (trace_id,) + ) events = [event for event in traces] for event in events: LOGGER.info("%s %s", event.thread, event.activity) for event in events: assert shard_name in event.thread - assert 'querying locally' in "\n".join([event.activity for event in events]) + assert "querying locally" in "\n".join([event.activity for event in events]) def create_ks_and_cf(self): self.session.execute( @@ -89,8 +98,9 @@ def create_ks_and_cf(self): self.session.execute( """ CREATE KEYSPACE preparedtests - WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '3'} - """) + WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': '3'} + """ + ) self.session.execute("USE preparedtests") self.session.execute( @@ -101,7 +111,8 @@ def create_ks_and_cf(self): c text, PRIMARY KEY (a, b) ) - """) + """ + ) @staticmethod def create_data(session): @@ -109,35 +120,37 @@ def create_data(session): prepared = session.prepare( """ INSERT INTO cf0 (a, b, c) VALUES (?, ?, ?) - """) + """ + ) - bound = prepared.bind(('a', 'b', 'c')) + bound = prepared.bind(("a", "b", "c")) session.execute(bound) - bound = prepared.bind(('e', 'f', 'g')) + bound = prepared.bind(("e", "f", "g")) session.execute(bound) - bound = prepared.bind(('100002', 'f', 'g')) + bound = prepared.bind(("100002", "f", "g")) session.execute(bound) def query_data(self, session, verify_in_tracing=True): prepared = session.prepare( """ SELECT * FROM cf0 WHERE a=? AND b=? - """) + """ + ) - bound = prepared.bind(('a', 'b')) + bound = prepared.bind(("a", "b")) results = session.execute(bound, trace=True) - assert results == [('a', 'b', 'c')] + assert results == [("a", "b", "c")] if verify_in_tracing: self.verify_same_shard_in_tracing(results, "shard 0") - bound = prepared.bind(('100002', 'f')) + bound = prepared.bind(("100002", "f")) results = session.execute(bound, trace=True) - assert results == [('100002', 'f', 'g')] + assert results == [("100002", "f", "g")] if verify_in_tracing: self.verify_same_shard_in_tracing(results, "shard 1") - bound = prepared.bind(('e', 'f')) + bound = prepared.bind(("e", "f")) results = session.execute(bound, trace=True) if verify_in_tracing: @@ -145,25 +158,37 @@ def query_data(self, session, verify_in_tracing=True): def _assert_blocked_node_disconnected(self, node_ip_address, node_port): control_connection = self.cluster.control_connection - active_control_connection = control_connection._connection if control_connection else None - if active_control_connection and \ - active_control_connection.endpoint.address == node_ip_address and \ - active_control_connection.endpoint.port == node_port: - assert active_control_connection.is_closed or active_control_connection.is_defunct + active_control_connection = ( + control_connection._connection if control_connection else None + ) + if ( + active_control_connection + and active_control_connection.endpoint.address == node_ip_address + and active_control_connection.endpoint.port == node_port + ): + assert ( + active_control_connection.is_closed + or active_control_connection.is_defunct + ) - pools = getattr(self.session, '_pools', None) or {} + pools = getattr(self.session, "_pools", None) or {} for host, pool in pools.items(): - if host.endpoint.address != node_ip_address or host.endpoint.port != node_port: + if ( + host.endpoint.address != node_ip_address + or host.endpoint.port != node_port + ): continue open_connections = [ - connection for connection in pool._connections.values() + connection + for connection in pool._connections.values() if not (connection.is_closed or connection.is_defunct) ] assert not open_connections pending_connections = [ - connection for connection in pool._pending_connections + connection + for connection in pool._pending_connections if not (connection.is_closed or connection.is_defunct) ] assert not pending_connections @@ -188,18 +213,24 @@ def test_connect_from_multiple_clients(self): self.create_ks_and_cf() number_of_clients = 15 - session_list = [self.session] + [self.cluster.connect() for _ in range(number_of_clients)] + session_list = [self.session] + [ + self.cluster.connect() for _ in range(number_of_clients) + ] with ThreadPoolExecutor(number_of_clients) as pool: - futures = [pool.submit(self.create_data, session) for session in session_list] + futures = [ + pool.submit(self.create_data, session) for session in session_list + ] for result in as_completed(futures): print(result) - futures = [pool.submit(self.query_data, session) for session in session_list] + futures = [ + pool.submit(self.query_data, session) for session in session_list + ] for result in as_completed(futures): print(result) - @pytest.mark.skip(reason='https://github.com/scylladb/python-driver/issues/221') + @pytest.mark.skip(reason="https://github.com/scylladb/python-driver/issues/221") def test_closing_connections(self): """ Verify that reconnection is working as expected, when connection are being closed. @@ -217,18 +248,20 @@ def test_closing_connections(self): pool._connections.get(shard_id).close() wait_until_not_raised( lambda: self.query_data(self.session, verify_in_tracing=False), - delay=0.5, max_attempts=30) + delay=0.5, + max_attempts=30, + ) wait_until_not_raised( - lambda: self.query_data(self.session), - delay=0.5, max_attempts=60) + lambda: self.query_data(self.session), delay=0.5, max_attempts=60 + ) @pytest.mark.skip def test_blocking_connections(self): """ Verify that reconnection is working as expected, when connection are being blocked. """ - res = run('which iptables'.split(' ')) + res = run("which iptables".split(" ")) if not res.returncode == 0: self.skipTest("iptables isn't installed") @@ -236,26 +269,37 @@ def test_blocking_connections(self): self.create_data(self.session) self.query_data(self.session) - node1_ip_address, node1_port = get_node(1).network_interfaces['binary'] + node1_ip_address, node1_port = get_node(1).network_interfaces["binary"] def remove_iptables(): - run(('sudo iptables -t filter -D INPUT -p tcp --dport {node1_port} ' - '--destination {node1_ip_address}/32 -j REJECT --reject-with icmp-port-unreachable' - ).format(node1_ip_address=node1_ip_address, node1_port=node1_port).split(' ') + run( + ( + "sudo iptables -t filter -D INPUT -p tcp --dport {node1_port} " + "--destination {node1_ip_address}/32 -j REJECT --reject-with icmp-port-unreachable" ) + .format(node1_ip_address=node1_ip_address, node1_port=node1_port) + .split(" ") + ) self.addCleanup(remove_iptables) for i in range(3): - run(('sudo iptables -t filter -A INPUT -p tcp --dport {node1_port} ' - '--destination {node1_ip_address}/32 -j REJECT --reject-with icmp-port-unreachable' - ).format(node1_ip_address=node1_ip_address, node1_port=node1_port).split(' ') + run( + ( + "sudo iptables -t filter -A INPUT -p tcp --dport {node1_port} " + "--destination {node1_ip_address}/32 -j REJECT --reject-with icmp-port-unreachable" ) + .format(node1_ip_address=node1_ip_address, node1_port=node1_port) + .split(" ") + ) wait_until_not_raised( - lambda: self._assert_blocked_node_disconnected(node1_ip_address, node1_port), + lambda: self._assert_blocked_node_disconnected( + node1_ip_address, node1_port + ), delay=0.1, - max_attempts=50) + max_attempts=50, + ) try: self.query_data(self.session, verify_in_tracing=False) except OperationTimedOut: @@ -263,6 +307,8 @@ def remove_iptables(): remove_iptables() wait_until_not_raised( lambda: self.query_data(self.session, verify_in_tracing=False), - delay=0.5, max_attempts=30) + delay=0.5, + max_attempts=30, + ) self.query_data(self.session) diff --git a/tests/integration/standard/test_udts.py b/tests/integration/standard/test_udts.py index 18f3dfb298..19448f261d 100644 --- a/tests/integration/standard/test_udts.py +++ b/tests/integration/standard/test_udts.py @@ -17,18 +17,37 @@ from functools import partial from cassandra import InvalidRequest -from cassandra.cluster import UserTypeDoesNotExist, ExecutionProfile, EXEC_PROFILE_DEFAULT +from cassandra.cluster import ( + UserTypeDoesNotExist, + ExecutionProfile, + EXEC_PROFILE_DEFAULT, +) from cassandra.query import dict_factory from cassandra.util import OrderedMap -from tests.integration import use_single_node, execute_until_pass, \ - BasicSegregatedKeyspaceUnitTestCase, greaterthancass20, lessthancass30, greaterthanorequalcass36, TestCluster -from tests.integration.datatype_utils import update_datatypes, PRIMITIVE_DATATYPES, PRIMITIVE_DATATYPES_KEYS, \ - COLLECTION_TYPES, get_sample, get_collection_sample +from tests.integration import ( + use_single_node, + execute_until_pass, + BasicSegregatedKeyspaceUnitTestCase, + greaterthancass20, + lessthancass30, + greaterthanorequalcass36, + TestCluster, +) +from tests.integration.datatype_utils import ( + update_datatypes, + PRIMITIVE_DATATYPES, + PRIMITIVE_DATATYPES_KEYS, + COLLECTION_TYPES, + get_sample, + get_collection_sample, +) import pytest -nested_collection_udt = namedtuple('nested_collection_udt', ['m', 't', 'l', 's']) -nested_collection_udt_nested = namedtuple('nested_collection_udt_nested', ['m', 't', 'l', 's', 'u']) +nested_collection_udt = namedtuple("nested_collection_udt", ["m", "t", "l", "s"]) +nested_collection_udt_nested = namedtuple( + "nested_collection_udt_nested", ["m", "t", "l", "s", "u"] +) def setup_module(): @@ -38,7 +57,6 @@ def setup_module(): @greaterthancass20 class UDTTests(BasicSegregatedKeyspaceUnitTestCase): - @property def table_name(self): return self._testMethodName.lower() @@ -60,14 +78,31 @@ def test_non_frozen_udts(self): """ self.session.execute("USE {0}".format(self.keyspace_name)) self.session.execute("CREATE TYPE user (state text, has_corn boolean)") - self.session.execute("CREATE TABLE {0} (a int PRIMARY KEY, b user)".format(self.function_table_name)) - User = namedtuple('user', ('state', 'has_corn')) + self.session.execute( + "CREATE TABLE {0} (a int PRIMARY KEY, b user)".format( + self.function_table_name + ) + ) + User = namedtuple("user", ("state", "has_corn")) self.cluster.register_user_type(self.keyspace_name, "user", User) - self.session.execute("INSERT INTO {0} (a, b) VALUES (%s, %s)".format(self.function_table_name), (0, User("Nebraska", True))) - self.session.execute("UPDATE {0} SET b.has_corn = False where a = 0".format(self.function_table_name)) - result = self.session.execute("SELECT * FROM {0}".format(self.function_table_name)) + self.session.execute( + "INSERT INTO {0} (a, b) VALUES (%s, %s)".format(self.function_table_name), + (0, User("Nebraska", True)), + ) + self.session.execute( + "UPDATE {0} SET b.has_corn = False where a = 0".format( + self.function_table_name + ) + ) + result = self.session.execute( + "SELECT * FROM {0}".format(self.function_table_name) + ) assert not result.one().b.has_corn - table_sql = self.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].as_cql_query() + table_sql = ( + self.cluster.metadata.keyspaces[self.keyspace_name] + .tables[self.function_table_name] + .as_cql_query() + ) assert "" not in table_sql def test_can_insert_unprepared_registered_udts(self): @@ -81,32 +116,34 @@ def test_can_insert_unprepared_registered_udts(self): s.execute("CREATE TYPE user (age int, name text)") s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b frozen)") - User = namedtuple('user', ('age', 'name')) + User = namedtuple("user", ("age", "name")) c.register_user_type(self.keyspace_name, "user", User) - s.execute("INSERT INTO mytable (a, b) VALUES (%s, %s)", (0, User(42, 'bob'))) + s.execute("INSERT INTO mytable (a, b) VALUES (%s, %s)", (0, User(42, "bob"))) result = s.execute("SELECT b FROM mytable WHERE a=0") row = result.one() assert 42 == row.b.age - assert 'bob' == row.b.name + assert "bob" == row.b.name assert type(row.b) is User # use the same UDT name in a different keyspace s.execute(""" CREATE KEYSPACE udt_test_unprepared_registered2 - WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1' } + WITH replication = { 'class' : 'NetworkTopologyStrategy', 'replication_factor': '1' } """) s.set_keyspace("udt_test_unprepared_registered2") s.execute("CREATE TYPE user (state text, is_cool boolean)") s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b frozen)") - User = namedtuple('user', ('state', 'is_cool')) + User = namedtuple("user", ("state", "is_cool")) c.register_user_type("udt_test_unprepared_registered2", "user", User) - s.execute("INSERT INTO mytable (a, b) VALUES (%s, %s)", (0, User('Texas', True))) + s.execute( + "INSERT INTO mytable (a, b) VALUES (%s, %s)", (0, User("Texas", True)) + ) result = s.execute("SELECT b FROM mytable WHERE a=0") row = result.one() - assert 'Texas' == row.b.state + assert "Texas" == row.b.state assert True == row.b.is_cool assert type(row.b) is User @@ -124,24 +161,32 @@ def test_can_register_udt_before_connecting(self): s.execute(""" CREATE KEYSPACE udt_test_register_before_connecting - WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1' } + WITH replication = { 'class' : 'NetworkTopologyStrategy', 'replication_factor': '1' } """) - s.execute("CREATE TYPE udt_test_register_before_connecting.user (age int, name text)") - s.execute("CREATE TABLE udt_test_register_before_connecting.mytable (a int PRIMARY KEY, b frozen)") + s.execute( + "CREATE TYPE udt_test_register_before_connecting.user (age int, name text)" + ) + s.execute( + "CREATE TABLE udt_test_register_before_connecting.mytable (a int PRIMARY KEY, b frozen)" + ) s.execute(""" CREATE KEYSPACE udt_test_register_before_connecting2 - WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1' } + WITH replication = { 'class' : 'NetworkTopologyStrategy', 'replication_factor': '1' } """) - s.execute("CREATE TYPE udt_test_register_before_connecting2.user (state text, is_cool boolean)") - s.execute("CREATE TABLE udt_test_register_before_connecting2.mytable (a int PRIMARY KEY, b frozen)") + s.execute( + "CREATE TYPE udt_test_register_before_connecting2.user (state text, is_cool boolean)" + ) + s.execute( + "CREATE TABLE udt_test_register_before_connecting2.mytable (a int PRIMARY KEY, b frozen)" + ) # now that types are defined, shutdown and re-create Cluster c.shutdown() c = TestCluster() - User1 = namedtuple('user', ('age', 'name')) - User2 = namedtuple('user', ('state', 'is_cool')) + User1 = namedtuple("user", ("age", "name")) + User2 = namedtuple("user", ("state", "is_cool")) c.register_user_type("udt_test_register_before_connecting", "user", User1) c.register_user_type("udt_test_register_before_connecting2", "user", User2) @@ -149,18 +194,28 @@ def test_can_register_udt_before_connecting(self): s = c.connect(wait_for_all_pools=True) s.wait_for_schema_agreement() - s.execute("INSERT INTO udt_test_register_before_connecting.mytable (a, b) VALUES (%s, %s)", (0, User1(42, 'bob'))) - result = s.execute("SELECT b FROM udt_test_register_before_connecting.mytable WHERE a=0") + s.execute( + "INSERT INTO udt_test_register_before_connecting.mytable (a, b) VALUES (%s, %s)", + (0, User1(42, "bob")), + ) + result = s.execute( + "SELECT b FROM udt_test_register_before_connecting.mytable WHERE a=0" + ) row = result.one() assert 42 == row.b.age - assert 'bob' == row.b.name + assert "bob" == row.b.name assert type(row.b) is User1 # use the same UDT name in a different keyspace - s.execute("INSERT INTO udt_test_register_before_connecting2.mytable (a, b) VALUES (%s, %s)", (0, User2('Texas', True))) - result = s.execute("SELECT b FROM udt_test_register_before_connecting2.mytable WHERE a=0") + s.execute( + "INSERT INTO udt_test_register_before_connecting2.mytable (a, b) VALUES (%s, %s)", + (0, User2("Texas", True)), + ) + result = s.execute( + "SELECT b FROM udt_test_register_before_connecting2.mytable WHERE a=0" + ) row = result.one() - assert 'Texas' == row.b.state + assert "Texas" == row.b.state assert True == row.b.is_cool assert type(row.b) is User2 @@ -180,33 +235,33 @@ def test_can_insert_prepared_unregistered_udts(self): s.execute("CREATE TYPE user (age int, name text)") s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b frozen)") - User = namedtuple('user', ('age', 'name')) + User = namedtuple("user", ("age", "name")) insert = s.prepare("INSERT INTO mytable (a, b) VALUES (?, ?)") - s.execute(insert, (0, User(42, 'bob'))) + s.execute(insert, (0, User(42, "bob"))) select = s.prepare("SELECT b FROM mytable WHERE a=?") result = s.execute(select, (0,)) row = result.one() assert 42 == row.b.age - assert 'bob' == row.b.name + assert "bob" == row.b.name # use the same UDT name in a different keyspace s.execute(""" CREATE KEYSPACE udt_test_prepared_unregistered2 - WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1' } + WITH replication = { 'class' : 'NetworkTopologyStrategy', 'replication_factor': '1' } """) s.set_keyspace("udt_test_prepared_unregistered2") s.execute("CREATE TYPE user (state text, is_cool boolean)") s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b frozen)") - User = namedtuple('user', ('state', 'is_cool')) + User = namedtuple("user", ("state", "is_cool")) insert = s.prepare("INSERT INTO mytable (a, b) VALUES (?, ?)") - s.execute(insert, (0, User('Texas', True))) + s.execute(insert, (0, User("Texas", True))) select = s.prepare("SELECT b FROM mytable WHERE a=?") result = s.execute(select, (0,)) row = result.one() - assert 'Texas' == row.b.state + assert "Texas" == row.b.state assert True == row.b.is_cool s.execute("DROP KEYSPACE udt_test_prepared_unregistered2") @@ -222,40 +277,40 @@ def test_can_insert_prepared_registered_udts(self): s = c.connect(self.keyspace_name, wait_for_all_pools=True) s.execute("CREATE TYPE user (age int, name text)") - User = namedtuple('user', ('age', 'name')) + User = namedtuple("user", ("age", "name")) c.register_user_type(self.keyspace_name, "user", User) s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b frozen)") insert = s.prepare("INSERT INTO mytable (a, b) VALUES (?, ?)") - s.execute(insert, (0, User(42, 'bob'))) + s.execute(insert, (0, User(42, "bob"))) select = s.prepare("SELECT b FROM mytable WHERE a=?") result = s.execute(select, (0,)) row = result.one() assert 42 == row.b.age - assert 'bob' == row.b.name + assert "bob" == row.b.name assert type(row.b) is User # use the same UDT name in a different keyspace s.execute(""" CREATE KEYSPACE udt_test_prepared_registered2 - WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1' } + WITH replication = { 'class' : 'NetworkTopologyStrategy', 'replication_factor': '1' } """) s.set_keyspace("udt_test_prepared_registered2") s.execute("CREATE TYPE user (state text, is_cool boolean)") - User = namedtuple('user', ('state', 'is_cool')) + User = namedtuple("user", ("state", "is_cool")) c.register_user_type("udt_test_prepared_registered2", "user", User) s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b frozen)") insert = s.prepare("INSERT INTO mytable (a, b) VALUES (?, ?)") - s.execute(insert, (0, User('Texas', True))) + s.execute(insert, (0, User("Texas", True))) select = s.prepare("SELECT b FROM mytable WHERE a=?") result = s.execute(select, (0,)) row = result.one() - assert 'Texas' == row.b.state + assert "Texas" == row.b.state assert True == row.b.is_cool assert type(row.b) is User @@ -272,7 +327,7 @@ def test_can_insert_udts_with_nulls(self): s = c.connect(self.keyspace_name, wait_for_all_pools=True) s.execute("CREATE TYPE user (a text, b int, c uuid, d blob)") - User = namedtuple('user', ('a', 'b', 'c', 'd')) + User = namedtuple("user", ("a", "b", "c", "d")) c.register_user_type(self.keyspace_name, "user", User) s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b frozen)") @@ -287,9 +342,9 @@ def test_can_insert_udts_with_nulls(self): assert (None, None, None, None) == s.execute(select).one().b # also test empty strings - s.execute(insert, [User('', None, None, bytes())]) + s.execute(insert, [User("", None, None, bytes())]) results = s.execute("SELECT b FROM mytable WHERE a=0") - assert ('', None, None, bytes()) == results.one().b + assert ("", None, None, bytes()) == results.one().b c.shutdown() @@ -304,18 +359,20 @@ def test_can_insert_udts_with_varying_lengths(self): max_test_length = 254 # create the seed udt, increase timeout to avoid the query failure on slow systems - s.execute("CREATE TYPE lengthy_udt ({0})" - .format(', '.join(['v_{0} int'.format(i) - for i in range(max_test_length)]))) + s.execute( + "CREATE TYPE lengthy_udt ({0})".format( + ", ".join(["v_{0} int".format(i) for i in range(max_test_length)]) + ) + ) # create a table with multiple sizes of nested udts # no need for all nested types, only a spot checked few and the largest one - s.execute("CREATE TABLE mytable (" - "k int PRIMARY KEY, " - "v frozen)") + s.execute("CREATE TABLE mytable (k int PRIMARY KEY, v frozen)") # create and register the seed udt type - udt = namedtuple('lengthy_udt', tuple(['v_{0}'.format(i) for i in range(max_test_length)])) + udt = namedtuple( + "lengthy_udt", tuple(["v_{0}".format(i) for i in range(max_test_length)]) + ) c.register_user_type(self.keyspace_name, "lengthy_udt", udt) # verify inserts and reads @@ -339,21 +396,27 @@ def nested_udt_schema_helper(self, session, max_nesting_depth): # create the nested udts for i in range(max_nesting_depth): - execute_until_pass(session, "CREATE TYPE depth_{0} (value frozen)".format(i + 1, i)) + execute_until_pass( + session, + "CREATE TYPE depth_{0} (value frozen)".format(i + 1, i), + ) # create a table with multiple sizes of nested udts # no need for all nested types, only a spot checked few and the largest one - execute_until_pass(session, "CREATE TABLE mytable (" - "k int PRIMARY KEY, " - "v_0 frozen, " - "v_1 frozen, " - "v_2 frozen, " - "v_3 frozen, " - "v_{0} frozen)".format(max_nesting_depth)) + execute_until_pass( + session, + "CREATE TABLE mytable (" + "k int PRIMARY KEY, " + "v_0 frozen, " + "v_1 frozen, " + "v_2 frozen, " + "v_3 frozen, " + "v_{0} frozen)".format(max_nesting_depth), + ) def nested_udt_creation_helper(self, udts, i): if i == 0: - return udts[0](42, 'Bob') + return udts[0](42, "Bob") else: return udts[i](self.nested_udt_creation_helper(udts, i - 1)) @@ -366,20 +429,28 @@ def nested_udt_verification_helper(self, session, max_nesting_depth, udts): session.execute("INSERT INTO mytable (k, v_%s) VALUES (0, %s)", [i, udt]) # verify udt was written and read correctly - result = session.execute("SELECT v_{0} FROM mytable WHERE k=0".format(i)).one() + result = session.execute( + "SELECT v_{0} FROM mytable WHERE k=0".format(i) + ).one() assert udt == result["v_{0}".format(i)] # write udt via prepared statement - insert = session.prepare("INSERT INTO mytable (k, v_{0}) VALUES (1, ?)".format(i)) + insert = session.prepare( + "INSERT INTO mytable (k, v_{0}) VALUES (1, ?)".format(i) + ) session.execute(insert, [udt]) # verify udt was written and read correctly - result = session.execute("SELECT v_{0} FROM mytable WHERE k=1".format(i)).one() + result = session.execute( + "SELECT v_{0} FROM mytable WHERE k=1".format(i) + ).one() assert udt == result["v_{0}".format(i)] def _cluster_default_dict_factory(self): return TestCluster( - execution_profiles={EXEC_PROFILE_DEFAULT: ExecutionProfile(row_factory=dict_factory)} + execution_profiles={ + EXEC_PROFILE_DEFAULT: ExecutionProfile(row_factory=dict_factory) + } ) def test_can_insert_nested_registered_udts(self): @@ -396,15 +467,17 @@ def test_can_insert_nested_registered_udts(self): # create and register the seed udt type udts = [] - udt = namedtuple('depth_0', ('age', 'name')) + udt = namedtuple("depth_0", ("age", "name")) udts.append(udt) c.register_user_type(self.keyspace_name, "depth_0", udts[0]) # create and register the nested udt types for i in range(max_nesting_depth): - udt = namedtuple('depth_{0}'.format(i + 1), ('value')) + udt = namedtuple("depth_{0}".format(i + 1), ("value")) udts.append(udt) - c.register_user_type(self.keyspace_name, "depth_{0}".format(i + 1), udts[i + 1]) + c.register_user_type( + self.keyspace_name, "depth_{0}".format(i + 1), udts[i + 1] + ) # insert udts and verify inserts with reads self.nested_udt_verification_helper(s, max_nesting_depth, udts) @@ -424,12 +497,12 @@ def test_can_insert_nested_unregistered_udts(self): # create the seed udt type udts = [] - udt = namedtuple('depth_0', ('age', 'name')) + udt = namedtuple("depth_0", ("age", "name")) udts.append(udt) # create the nested udt types for i in range(max_nesting_depth): - udt = namedtuple('depth_{0}'.format(i + 1), ('value')) + udt = namedtuple("depth_{0}".format(i + 1), ("value")) udts.append(udt) # insert udts via prepared statements and verify inserts with reads @@ -438,11 +511,15 @@ def test_can_insert_nested_unregistered_udts(self): udt = self.nested_udt_creation_helper(udts, i) # write udt - insert = s.prepare("INSERT INTO mytable (k, v_{0}) VALUES (0, ?)".format(i)) + insert = s.prepare( + "INSERT INTO mytable (k, v_{0}) VALUES (0, ?)".format(i) + ) s.execute(insert, [udt]) # verify udt was written and read correctly - result = s.execute("SELECT v_{0} FROM mytable WHERE k=0".format(i)).one() + result = s.execute( + "SELECT v_{0} FROM mytable WHERE k=0".format(i) + ).one() assert udt == result["v_{0}".format(i)] def test_can_insert_nested_registered_udts_with_different_namedtuples(self): @@ -461,15 +538,17 @@ def test_can_insert_nested_registered_udts_with_different_namedtuples(self): # create and register the seed udt type udts = [] - udt = namedtuple('level_0', ('age', 'name')) + udt = namedtuple("level_0", ("age", "name")) udts.append(udt) c.register_user_type(self.keyspace_name, "depth_0", udts[0]) # create and register the nested udt types for i in range(max_nesting_depth): - udt = namedtuple('level_{0}'.format(i + 1), ('value')) + udt = namedtuple("level_{0}".format(i + 1), ("value")) udts.append(udt) - c.register_user_type(self.keyspace_name, "depth_{0}".format(i + 1), udts[i + 1]) + c.register_user_type( + self.keyspace_name, "depth_{0}".format(i + 1), udts[i + 1] + ) # insert udts and verify inserts with reads self.nested_udt_verification_helper(s, max_nesting_depth, udts) @@ -481,7 +560,7 @@ def test_raise_error_on_nonexisting_udts(self): c = TestCluster() s = c.connect(self.keyspace_name, wait_for_all_pools=True) - User = namedtuple('user', ('age', 'name')) + User = namedtuple("user", ("age", "name")) with pytest.raises(UserTypeDoesNotExist): c.register_user_type("some_bad_keyspace", "user", User) @@ -504,21 +583,22 @@ def test_can_insert_udt_all_datatypes(self): # create UDT alpha_type_list = [] - start_index = ord('a') + start_index = ord("a") for i, datatype in enumerate(PRIMITIVE_DATATYPES): alpha_type_list.append("{0} {1}".format(chr(start_index + i), datatype)) - s.execute(""" + s.execute( + """ CREATE TYPE alldatatypes ({0}) - """.format(', '.join(alpha_type_list)) - ) + """.format(", ".join(alpha_type_list)) + ) s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b frozen)") # register UDT alphabet_list = [] - for i in range(ord('a'), ord('a') + len(PRIMITIVE_DATATYPES)): - alphabet_list.append('{0}'.format(chr(i))) + for i in range(ord("a"), ord("a") + len(PRIMITIVE_DATATYPES)): + alphabet_list.append("{0}".format(chr(i))) Alldatatypes = namedtuple("alldatatypes", alphabet_list) c.register_user_type(self.keyspace_name, "alldatatypes", Alldatatypes) @@ -549,32 +629,45 @@ def test_can_insert_udt_all_collection_datatypes(self): # create UDT alpha_type_list = [] - start_index = ord('a') + start_index = ord("a") for i, collection_type in enumerate(COLLECTION_TYPES): for j, datatype in enumerate(PRIMITIVE_DATATYPES_KEYS): if collection_type == "map": - type_string = "{0}_{1} {2}<{3}, {3}>".format(chr(start_index + i), chr(start_index + j), - collection_type, datatype) + type_string = "{0}_{1} {2}<{3}, {3}>".format( + chr(start_index + i), + chr(start_index + j), + collection_type, + datatype, + ) elif collection_type == "tuple": - type_string = "{0}_{1} frozen<{2}<{3}>>".format(chr(start_index + i), chr(start_index + j), - collection_type, datatype) + type_string = "{0}_{1} frozen<{2}<{3}>>".format( + chr(start_index + i), + chr(start_index + j), + collection_type, + datatype, + ) else: - type_string = "{0}_{1} {2}<{3}>".format(chr(start_index + i), chr(start_index + j), - collection_type, datatype) + type_string = "{0}_{1} {2}<{3}>".format( + chr(start_index + i), + chr(start_index + j), + collection_type, + datatype, + ) alpha_type_list.append(type_string) - s.execute(""" + s.execute( + """ CREATE TYPE alldatatypes ({0}) - """.format(', '.join(alpha_type_list)) + """.format(", ".join(alpha_type_list)) ) s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b frozen)") # register UDT alphabet_list = [] - for i in range(ord('a'), ord('a') + len(COLLECTION_TYPES)): - for j in range(ord('a'), ord('a') + len(PRIMITIVE_DATATYPES_KEYS)): - alphabet_list.append('{0}_{1}'.format(chr(i), chr(j))) + for i in range(ord("a"), ord("a") + len(COLLECTION_TYPES)): + for j in range(ord("a"), ord("a") + len(PRIMITIVE_DATATYPES_KEYS)): + alphabet_list.append("{0}_{1}".format(chr(i), chr(j))) Alldatatypes = namedtuple("alldatatypes", alphabet_list) c.register_user_type(self.keyspace_name, "alldatatypes", Alldatatypes) @@ -598,9 +691,13 @@ def test_can_insert_udt_all_collection_datatypes(self): c.shutdown() def insert_select_column(self, session, table_name, column_name, value): - insert = session.prepare("INSERT INTO %s (k, %s) VALUES (?, ?)" % (table_name, column_name)) + insert = session.prepare( + "INSERT INTO %s (k, %s) VALUES (?, ?)" % (table_name, column_name) + ) session.execute(insert, (0, value)) - result = session.execute("SELECT %s FROM %s WHERE k=%%s" % (column_name, table_name), (0,)).one()[0] + result = session.execute( + "SELECT %s FROM %s WHERE k=%%s" % (column_name, table_name), (0,) + ).one()[0] assert result == value def test_can_insert_nested_collections(self): @@ -609,7 +706,9 @@ def test_can_insert_nested_collections(self): """ if self.cass_version < (2, 1, 3): - raise unittest.SkipTest("Support for nested collections was introduced in Cassandra 2.1.3") + raise unittest.SkipTest( + "Support for nested collections was introduced in Cassandra 2.1.3" + ) c = TestCluster() s = c.connect(self.keyspace_name, wait_for_all_pools=True) @@ -617,22 +716,29 @@ def test_can_insert_nested_collections(self): name = self._testMethodName - s.execute(""" + s.execute( + """ CREATE TYPE %s ( m frozen>, t tuple, l frozen>, s frozen> - )""" % name) - s.execute(""" + )""" + % name + ) + s.execute( + """ CREATE TYPE %s_nested ( m frozen>, t tuple, l frozen>, s frozen>, u frozen<%s> - )""" % (name, name)) - s.execute(""" + )""" + % (name, name) + ) + s.execute( + """ CREATE TABLE %s ( k int PRIMARY KEY, map_map map>, frozen>>, @@ -640,18 +746,28 @@ def test_can_insert_nested_collections(self): map_list map>, frozen>>, map_tuple map>, frozen>>, map_udt map, frozen<%s>>, - )""" % (name, name, name)) + )""" + % (name, name, name) + ) validate = partial(self.insert_select_column, s, name) - validate('map_map', OrderedMap([({1: 1, 2: 2}, {3: 3, 4: 4}), ({5: 5, 6: 6}, {7: 7, 8: 8})])) - validate('map_set', OrderedMap([(set((1, 2)), set((3, 4))), (set((5, 6)), set((7, 8)))])) - validate('map_list', OrderedMap([([1, 2], [3, 4]), ([5, 6], [7, 8])])) - validate('map_tuple', OrderedMap([((1, 2), (3,)), ((4, 5), (6,))])) + validate( + "map_map", + OrderedMap([({1: 1, 2: 2}, {3: 3, 4: 4}), ({5: 5, 6: 6}, {7: 7, 8: 8})]), + ) + validate( + "map_set", + OrderedMap([(set((1, 2)), set((3, 4))), (set((5, 6)), set((7, 8)))]), + ) + validate("map_list", OrderedMap([([1, 2], [3, 4]), ([5, 6], [7, 8])])) + validate("map_tuple", OrderedMap([((1, 2), (3,)), ((4, 5), (6,))])) - value = nested_collection_udt({1: 'v1', 2: 'v2'}, (3, 'v3'), [4, 5, 6, 7], set((8, 9, 10))) + value = nested_collection_udt( + {1: "v1", 2: "v2"}, (3, "v3"), [4, 5, 6, 7], set((8, 9, 10)) + ) key = nested_collection_udt_nested(value.m, value.t, value.l, value.s, value) - key2 = nested_collection_udt_nested({3: 'v3'}, value.t, value.l, value.s, value) - validate('map_udt', OrderedMap([(key, value), (key2, value)])) + key2 = nested_collection_udt_nested({3: "v3"}, value.t, value.l, value.s, value) + validate("map_udt", OrderedMap([(key, value), (key2, value)])) c.shutdown() @@ -660,25 +776,33 @@ def test_non_alphanum_identifiers(self): PYTHON-413 """ s = self.session - non_alphanum_name = 'test.field@#$%@%#!' - type_name = 'type2' - s.execute('CREATE TYPE "%s" ("%s" text)' % (non_alphanum_name, non_alphanum_name)) + non_alphanum_name = "test.field@#$%@%#!" + type_name = "type2" + s.execute( + 'CREATE TYPE "%s" ("%s" text)' % (non_alphanum_name, non_alphanum_name) + ) s.execute('CREATE TYPE %s ("%s" text)' % (type_name, non_alphanum_name)) # table with types as map keys to make sure the tuple lookup works - s.execute('CREATE TABLE %s (k int PRIMARY KEY, non_alphanum_type_map map, int>, alphanum_type_map map, int>)' % (self.table_name, non_alphanum_name, type_name)) - s.execute('INSERT INTO %s (k, non_alphanum_type_map, alphanum_type_map) VALUES (%s, {{"%s": \'nonalphanum\'}: 0}, {{"%s": \'alphanum\'}: 1})' % (self.table_name, 0, non_alphanum_name, non_alphanum_name)) - row = s.execute('SELECT * FROM %s' % (self.table_name,)).one() + s.execute( + 'CREATE TABLE %s (k int PRIMARY KEY, non_alphanum_type_map map, int>, alphanum_type_map map, int>)' + % (self.table_name, non_alphanum_name, type_name) + ) + s.execute( + "INSERT INTO %s (k, non_alphanum_type_map, alphanum_type_map) VALUES (%s, {{\"%s\": 'nonalphanum'}: 0}, {{\"%s\": 'alphanum'}: 1})" + % (self.table_name, 0, non_alphanum_name, non_alphanum_name) + ) + row = s.execute("SELECT * FROM %s" % (self.table_name,)).one() k, v = row.non_alphanum_type_map.popitem() assert v == 0 assert k.__class__ == tuple - assert k[0] == 'nonalphanum' + assert k[0] == "nonalphanum" k, v = row.alphanum_type_map.popitem() assert v == 1 assert k.__class__ != tuple # should be the namedtuple type - assert k[0] == 'alphanum' - assert k.field_0_ == 'alphanum' # named tuple with positional field name + assert k[0] == "alphanum" + assert k.field_0_ == "alphanum" # named tuple with positional field name @lessthancass30 def test_type_alteration(self): @@ -687,34 +811,43 @@ def test_type_alteration(self): """ s = self.session type_name = "type_name" - assert type_name not in s.cluster.metadata.keyspaces['udttests'].user_types - s.execute('CREATE TYPE %s (v0 int)' % (type_name,)) - assert type_name in s.cluster.metadata.keyspaces['udttests'].user_types + assert type_name not in s.cluster.metadata.keyspaces["udttests"].user_types + s.execute("CREATE TYPE %s (v0 int)" % (type_name,)) + assert type_name in s.cluster.metadata.keyspaces["udttests"].user_types - s.execute('CREATE TABLE %s (k int PRIMARY KEY, v frozen<%s>)' % (self.table_name, type_name)) - s.execute('INSERT INTO %s (k, v) VALUES (0, {v0 : 1})' % (self.table_name,)) + s.execute( + "CREATE TABLE %s (k int PRIMARY KEY, v frozen<%s>)" + % (self.table_name, type_name) + ) + s.execute("INSERT INTO %s (k, v) VALUES (0, {v0 : 1})" % (self.table_name,)) - s.cluster.register_user_type('udttests', type_name, dict) + s.cluster.register_user_type("udttests", type_name, dict) - val = s.execute('SELECT v FROM %s' % self.table_name).one()[0] - assert val['v0'] == 1 + val = s.execute("SELECT v FROM %s" % self.table_name).one()[0] + assert val["v0"] == 1 # add field - s.execute('ALTER TYPE %s ADD v1 text' % (type_name,)) - val = s.execute('SELECT v FROM %s' % self.table_name).one()[0] - assert val['v0'] == 1 - assert val['v1'] is None - s.execute("INSERT INTO %s (k, v) VALUES (0, {v0 : 2, v1 : 'sometext'})" % (self.table_name,)) - val = s.execute('SELECT v FROM %s' % self.table_name).one()[0] - assert val['v0'] == 2 - assert val['v1'] == 'sometext' + s.execute("ALTER TYPE %s ADD v1 text" % (type_name,)) + val = s.execute("SELECT v FROM %s" % self.table_name).one()[0] + assert val["v0"] == 1 + assert val["v1"] is None + s.execute( + "INSERT INTO %s (k, v) VALUES (0, {v0 : 2, v1 : 'sometext'})" + % (self.table_name,) + ) + val = s.execute("SELECT v FROM %s" % self.table_name).one()[0] + assert val["v0"] == 2 + assert val["v1"] == "sometext" # alter field type - s.execute('ALTER TYPE %s ALTER v1 TYPE blob' % (type_name,)) - s.execute("INSERT INTO %s (k, v) VALUES (0, {v0 : 3, v1 : 0xdeadbeef})" % (self.table_name,)) - val = s.execute('SELECT v FROM %s' % self.table_name).one()[0] - assert val['v0'] == 3 - assert val['v1'] == b'\xde\xad\xbe\xef' + s.execute("ALTER TYPE %s ALTER v1 TYPE blob" % (type_name,)) + s.execute( + "INSERT INTO %s (k, v) VALUES (0, {v0 : 3, v1 : 0xdeadbeef})" + % (self.table_name,) + ) + val = s.execute("SELECT v FROM %s" % self.table_name).one()[0] + assert val["v0"] == 3 + assert val["v1"] == b"\xde\xad\xbe\xef" @lessthancass30 def test_alter_udt(self): @@ -731,20 +864,32 @@ def test_alter_udt(self): # Create udt ensure it has the proper column names. self.session.set_keyspace(self.keyspace_name) self.session.execute("CREATE TYPE typetoalter (a int)") - typetoalter = namedtuple('typetoalter', ('a')) - self.session.execute("CREATE TABLE {0} (pk int primary key, typetoalter frozen)".format(self.function_table_name)) - insert_statement = self.session.prepare("INSERT INTO {0} (pk, typetoalter) VALUES (?, ?)".format(self.function_table_name)) + typetoalter = namedtuple("typetoalter", ("a")) + self.session.execute( + "CREATE TABLE {0} (pk int primary key, typetoalter frozen)".format( + self.function_table_name + ) + ) + insert_statement = self.session.prepare( + "INSERT INTO {0} (pk, typetoalter) VALUES (?, ?)".format( + self.function_table_name + ) + ) self.session.execute(insert_statement, [1, typetoalter(1)]) - results = self.session.execute("SELECT * from {0}".format(self.function_table_name)) + results = self.session.execute( + "SELECT * from {0}".format(self.function_table_name) + ) for result in results: - assert hasattr(result.typetoalter, 'a') - assert not hasattr(result.typetoalter, 'b') + assert hasattr(result.typetoalter, "a") + assert not hasattr(result.typetoalter, "b") # Alter UDT and ensure the alter is honored in results self.session.execute("ALTER TYPE typetoalter add b int") - typetoalter = namedtuple('typetoalter', ('a', 'b')) + typetoalter = namedtuple("typetoalter", ("a", "b")) self.session.execute(insert_statement, [2, typetoalter(2, 2)]) - results = self.session.execute("SELECT * from {0}".format(self.function_table_name)) + results = self.session.execute( + "SELECT * from {0}".format(self.function_table_name) + ) for result in results: - assert hasattr(result.typetoalter, 'a') - assert hasattr(result.typetoalter, 'b') + assert hasattr(result.typetoalter, "a") + assert hasattr(result.typetoalter, "b") diff --git a/tests/integration/standard/test_use_keyspace.py b/tests/integration/standard/test_use_keyspace.py index 80e7cfe5f3..a8654fb6ac 100644 --- a/tests/integration/standard/test_use_keyspace.py +++ b/tests/integration/standard/test_use_keyspace.py @@ -8,7 +8,11 @@ from cassandra.connection import Connection from cassandra.cluster import Cluster -from cassandra.policies import TokenAwarePolicy, RoundRobinPolicy, ConstantReconnectionPolicy +from cassandra.policies import ( + TokenAwarePolicy, + RoundRobinPolicy, + ConstantReconnectionPolicy, +) from tests.integration import use_cluster, PROTOCOL_VERSION, local @@ -19,25 +23,28 @@ def setup_module(): global _saved_scylla_ext_opts - _saved_scylla_ext_opts = os.environ.get('SCYLLA_EXT_OPTS') - os.environ['SCYLLA_EXT_OPTS'] = "--smp 2 --memory 2048M" - use_cluster('shared_aware', [3], start=True) + _saved_scylla_ext_opts = os.environ.get("SCYLLA_EXT_OPTS") + os.environ["SCYLLA_EXT_OPTS"] = "--smp 2 --memory 2048M" + use_cluster("shared_aware", [3], start=True) def teardown_module(): if _saved_scylla_ext_opts is None: - os.environ.pop('SCYLLA_EXT_OPTS', None) + os.environ.pop("SCYLLA_EXT_OPTS", None) else: - os.environ['SCYLLA_EXT_OPTS'] = _saved_scylla_ext_opts + os.environ["SCYLLA_EXT_OPTS"] = _saved_scylla_ext_opts @local class TestUseKeyspace(unittest.TestCase): @classmethod def setup_class(cls): - cls.cluster = Cluster(contact_points=["127.0.0.1"], protocol_version=PROTOCOL_VERSION, - load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy()), - reconnection_policy=ConstantReconnectionPolicy(1)) + cls.cluster = Cluster( + contact_points=["127.0.0.1"], + protocol_version=PROTOCOL_VERSION, + load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy()), + reconnection_policy=ConstantReconnectionPolicy(1), + ) cls.session = cls.cluster.connect() LOGGER.info(cls.cluster.is_shard_aware()) LOGGER.info(cls.cluster.shard_aware_stats()) @@ -45,13 +52,13 @@ def setup_class(cls): @classmethod def teardown_class(cls): cls.cluster.shutdown() - + def test_set_keyspace_slow_connection(self): # Test that "USE keyspace" gets propagated # to all connections. # # Reproduces an issue #187 where some pending - # connections for shards would not + # connections for shards would not # receive "USE keyspace". # # Simulate that scenario by adding an artifical @@ -64,11 +71,19 @@ def patched_set_keyspace_blocking(*args, **kwargs): time.sleep(1) return original_set_keyspace_blocking(*args, **kwargs) - with patch.object(Connection, "set_keyspace_blocking", patched_set_keyspace_blocking): - self.session.execute("CREATE KEYSPACE test_set_keyspace WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1}") - self.session.execute("CREATE TABLE test_set_keyspace.set_keyspace_slow_connection(pk int, PRIMARY KEY(pk))") + with patch.object( + Connection, "set_keyspace_blocking", patched_set_keyspace_blocking + ): + self.session.execute( + "CREATE KEYSPACE test_set_keyspace WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': 1}" + ) + self.session.execute( + "CREATE TABLE test_set_keyspace.set_keyspace_slow_connection(pk int, PRIMARY KEY(pk))" + ) session2 = self.cluster.connect() session2.execute("USE test_set_keyspace") for i in range(200): - session2.execute(f"SELECT * FROM set_keyspace_slow_connection WHERE pk = 1") + session2.execute( + f"SELECT * FROM set_keyspace_slow_connection WHERE pk = 1" + ) diff --git a/tests/unit/advanced/test_metadata.py b/tests/unit/advanced/test_metadata.py index 5ccfa5e477..f2daf36347 100644 --- a/tests/unit/advanced/test_metadata.py +++ b/tests/unit/advanced/test_metadata.py @@ -15,32 +15,50 @@ import unittest from cassandra.metadata import ( - KeyspaceMetadata, TableMetadataDSE68, - VertexMetadata, EdgeMetadata, SchemaParserV22, _SchemaParser + KeyspaceMetadata, + TableMetadataDSE68, + VertexMetadata, + EdgeMetadata, + SchemaParserV22, + _SchemaParser, ) from cassandra.protocol import ResultMessage, RESULT_KIND_ROWS class GraphMetadataToCQLTests(unittest.TestCase): - - def _create_edge_metadata(self, partition_keys=['pk1'], clustering_keys=['c1']): + def _create_edge_metadata(self, partition_keys=["pk1"], clustering_keys=["c1"]): return EdgeMetadata( - 'keyspace', 'table', 'label', 'from_table', 'from_label', - partition_keys, clustering_keys, 'to_table', 'to_label', - partition_keys, clustering_keys) - - def _create_vertex_metadata(self, label_name='label'): - return VertexMetadata('keyspace', 'table', label_name) + "keyspace", + "table", + "label", + "from_table", + "from_label", + partition_keys, + clustering_keys, + "to_table", + "to_label", + partition_keys, + clustering_keys, + ) + + def _create_vertex_metadata(self, label_name="label"): + return VertexMetadata("keyspace", "table", label_name) def _create_keyspace_metadata(self, graph_engine): return KeyspaceMetadata( - 'keyspace', True, 'org.apache.cassandra.locator.SimpleStrategy', - {'replication_factor': 1}, graph_engine=graph_engine) + "keyspace", + True, + "org.apache.cassandra.locator.NetworkTopologyStrategy", + {"dc1": 1}, + graph_engine=graph_engine, + ) def _create_table_metadata(self, with_vertex=False, with_edge=False): - tm = TableMetadataDSE68('keyspace', 'table') + tm = TableMetadataDSE68("keyspace", "table") if with_vertex: - tm.vertex = self._create_vertex_metadata() if with_vertex is True else with_vertex + tm.vertex = ( + self._create_vertex_metadata() if with_vertex is True else with_vertex + ) elif with_edge: tm.edge = self._create_edge_metadata() if with_edge is True else with_edge @@ -52,7 +70,7 @@ def test_keyspace_no_graph_engine(self): assert "graph_engine" not in km.as_cql_query() def test_keyspace_with_graph_engine(self): - graph_engine = 'Core' + graph_engine = "Core" km = self._create_keyspace_metadata(graph_engine) assert km.graph_engine == graph_engine cql = km.as_cql_query() @@ -86,33 +104,34 @@ def test_table_with_edge(self): assert "TO to_label" in cql def test_vertex_with_label(self): - tm = self. _create_table_metadata(with_vertex=True) - assert tm.as_cql_query().endswith('VERTEX LABEL label') + tm = self._create_table_metadata(with_vertex=True) + assert tm.as_cql_query().endswith("VERTEX LABEL label") def test_edge_single_partition_key_and_clustering_key(self): tm = self._create_table_metadata(with_edge=True) - assert 'FROM from_label(pk1, c1)' in tm.as_cql_query() + assert "FROM from_label(pk1, c1)" in tm.as_cql_query() def test_edge_multiple_partition_keys(self): - edge = self._create_edge_metadata(partition_keys=['pk1', 'pk2']) - tm = self. _create_table_metadata(with_edge=edge) - assert 'FROM from_label((pk1, pk2), ' in tm.as_cql_query() + edge = self._create_edge_metadata(partition_keys=["pk1", "pk2"]) + tm = self._create_table_metadata(with_edge=edge) + assert "FROM from_label((pk1, pk2), " in tm.as_cql_query() def test_edge_no_clustering_keys(self): edge = self._create_edge_metadata(clustering_keys=[]) - tm = self. _create_table_metadata(with_edge=edge) - assert 'FROM from_label(pk1) ' in tm.as_cql_query() + tm = self._create_table_metadata(with_edge=edge) + assert "FROM from_label(pk1) " in tm.as_cql_query() def test_edge_multiple_clustering_keys(self): - edge = self._create_edge_metadata(clustering_keys=['c1', 'c2']) - tm = self. _create_table_metadata(with_edge=edge) - assert 'FROM from_label(pk1, c1, c2) ' in tm.as_cql_query() + edge = self._create_edge_metadata(clustering_keys=["c1", "c2"]) + tm = self._create_table_metadata(with_edge=edge) + assert "FROM from_label(pk1, c1, c2) " in tm.as_cql_query() def test_edge_multiple_partition_and_clustering_keys(self): - edge = self._create_edge_metadata(partition_keys=['pk1', 'pk2'], - clustering_keys=['c1', 'c2']) - tm = self. _create_table_metadata(with_edge=edge) - assert 'FROM from_label((pk1, pk2), c1, c2) ' in tm.as_cql_query() + edge = self._create_edge_metadata( + partition_keys=["pk1", "pk2"], clustering_keys=["c1", "c2"] + ) + tm = self._create_table_metadata(with_edge=edge) + assert "FROM from_label((pk1, pk2), c1, c2) " in tm.as_cql_query() class SchemaParsersTests(unittest.TestCase): @@ -135,19 +154,26 @@ def wait_for_responses(self, *msgs, **kwargs): p._query_all() for q in conn.queries: - assert "USING TIMEOUT" not in q.query, f"<{schemaClass.__name__}> query `{q.query}` contains `USING TIMEOUT`, while should not" + assert "USING TIMEOUT" not in q.query, ( + f"<{schemaClass.__name__}> query `{q.query}` contains `USING TIMEOUT`, while should not" + ) conn = FakeConnection() p = schemaClass(conn, 2.0, 1000, datetime.timedelta(seconds=2)) p._query_all() for q in conn.queries: - assert "USING TIMEOUT 2000ms" in q.query, f"{schemaClass.__name__} query `{q.query}` does not contain `USING TIMEOUT 2000ms`" + assert "USING TIMEOUT 2000ms" in q.query, ( + f"{schemaClass.__name__} query `{q.query}` does not contain `USING TIMEOUT 2000ms`" + ) def get_all_schema_parser_classes(cl): for child in cl.__subclasses__(): - if not child.__name__.startswith('SchemaParser') or child.__module__ != 'cassandra.metadata': + if ( + not child.__name__.startswith("SchemaParser") + or child.__module__ != "cassandra.metadata" + ): continue yield child for c in get_all_schema_parser_classes(child): diff --git a/tests/unit/test_metadata.py b/tests/unit/test_metadata.py index dcbb840447..1318ea8538 100644 --- a/tests/unit/test_metadata.py +++ b/tests/unit/test_metadata.py @@ -23,16 +23,32 @@ import cassandra from cassandra.cqltypes import strip_frozen from cassandra.marshal import uint16_unpack, uint16_pack -from cassandra.metadata import (Murmur3Token, MD5Token, - BytesToken, ReplicationStrategy, - NetworkTopologyStrategy, SimpleStrategy, - LocalStrategy, protect_name, - protect_names, protect_value, is_valid_name, - UserType, KeyspaceMetadata, get_schema_parser, - _UnknownStrategy, ColumnMetadata, TableMetadata, - IndexMetadata, Function, Aggregate, - Metadata, TokenMap, ReplicationFactor, - SchemaParserDSE68) +from cassandra.metadata import ( + Murmur3Token, + MD5Token, + BytesToken, + ReplicationStrategy, + NetworkTopologyStrategy, + SimpleStrategy, + LocalStrategy, + protect_name, + protect_names, + protect_value, + is_valid_name, + UserType, + KeyspaceMetadata, + get_schema_parser, + _UnknownStrategy, + ColumnMetadata, + TableMetadata, + IndexMetadata, + Function, + Aggregate, + Metadata, + TokenMap, + ReplicationFactor, + SchemaParserDSE68, +) from cassandra.policies import SimpleConvictionPolicy from cassandra.pool import Host from cassandra.protocol import QueryMessage @@ -44,39 +60,36 @@ class ReplicationFactorTest(unittest.TestCase): - def test_replication_factor_parsing(self): - rf = ReplicationFactor.create('3') + rf = ReplicationFactor.create("3") assert rf.all_replicas == 3 assert rf.full_replicas == 3 assert rf.transient_replicas == None - assert str(rf) == '3' + assert str(rf) == "3" - rf = ReplicationFactor.create('3/1') + rf = ReplicationFactor.create("3/1") assert rf.all_replicas == 3 assert rf.full_replicas == 2 assert rf.transient_replicas == 1 - assert str(rf) == '3/1' + assert str(rf) == "3/1" with pytest.raises(ValueError): - ReplicationFactor.create('3/') + ReplicationFactor.create("3/") with pytest.raises(ValueError): - ReplicationFactor.create('a/1') + ReplicationFactor.create("a/1") with pytest.raises(ValueError): - ReplicationFactor.create('a') + ReplicationFactor.create("a") with pytest.raises(ValueError): - ReplicationFactor.create('3/a') + ReplicationFactor.create("3/a") def test_replication_factor_equality(self): - assert ReplicationFactor.create('3/1') == ReplicationFactor.create('3/1') - assert ReplicationFactor.create('3') == ReplicationFactor.create('3') - assert ReplicationFactor.create('3') != ReplicationFactor.create('3/1') - assert ReplicationFactor.create('3') != ReplicationFactor.create('3/1') - + assert ReplicationFactor.create("3/1") == ReplicationFactor.create("3/1") + assert ReplicationFactor.create("3") == ReplicationFactor.create("3") + assert ReplicationFactor.create("3") != ReplicationFactor.create("3/1") + assert ReplicationFactor.create("3") != ReplicationFactor.create("3/1") class StrategiesTest(unittest.TestCase): - def test_replication_strategy(self): """ Basic code coverage testing that ensures different ReplicationStrategies @@ -85,27 +98,42 @@ def test_replication_strategy(self): rs = ReplicationStrategy() - assert rs.create('OldNetworkTopologyStrategy', None) == _UnknownStrategy('OldNetworkTopologyStrategy', None) - fake_options_map = {'options': 'map'} - uks = rs.create('OldNetworkTopologyStrategy', fake_options_map) - assert uks == _UnknownStrategy('OldNetworkTopologyStrategy', fake_options_map) + assert rs.create("OldNetworkTopologyStrategy", None) == _UnknownStrategy( + "OldNetworkTopologyStrategy", None + ) + fake_options_map = {"options": "map"} + uks = rs.create("OldNetworkTopologyStrategy", fake_options_map) + assert uks == _UnknownStrategy("OldNetworkTopologyStrategy", fake_options_map) assert uks.make_token_replica_map({}, []) == {} - fake_options_map = {'dc1': '3'} - assert isinstance(rs.create('NetworkTopologyStrategy', fake_options_map), NetworkTopologyStrategy) - assert rs.create('NetworkTopologyStrategy', fake_options_map).dc_replication_factors == NetworkTopologyStrategy(fake_options_map).dc_replication_factors + fake_options_map = {"dc1": "3"} + assert isinstance( + rs.create("NetworkTopologyStrategy", fake_options_map), + NetworkTopologyStrategy, + ) + assert ( + rs.create( + "NetworkTopologyStrategy", fake_options_map + ).dc_replication_factors + == NetworkTopologyStrategy(fake_options_map).dc_replication_factors + ) - fake_options_map = {'options': 'map'} - assert rs.create('SimpleStrategy', fake_options_map) is None + fake_options_map = {"options": "map"} + assert rs.create("SimpleStrategy", fake_options_map) is None - fake_options_map = {'options': 'map'} - assert isinstance(rs.create('LocalStrategy', fake_options_map), LocalStrategy) + fake_options_map = {"options": "map"} + assert isinstance(rs.create("LocalStrategy", fake_options_map), LocalStrategy) - fake_options_map = {'options': 'map', 'replication_factor': 3} - assert isinstance(rs.create('SimpleStrategy', fake_options_map), SimpleStrategy) - assert rs.create('SimpleStrategy', fake_options_map).replication_factor == SimpleStrategy(fake_options_map).replication_factor + fake_options_map = {"options": "map", "replication_factor": 3} + assert isinstance(rs.create("SimpleStrategy", fake_options_map), SimpleStrategy) + assert ( + rs.create("SimpleStrategy", fake_options_map).replication_factor + == SimpleStrategy(fake_options_map).replication_factor + ) - assert rs.create('xxxxxxxx', fake_options_map) == _UnknownStrategy('xxxxxxxx', fake_options_map) + assert rs.create("xxxxxxxx", fake_options_map) == _UnknownStrategy( + "xxxxxxxx", fake_options_map + ) with pytest.raises(NotImplementedError): rs.make_token_replica_map(None, None) @@ -113,111 +141,139 @@ def test_replication_strategy(self): rs.export_for_schema() def test_simple_replication_type_parsing(self): - """ Test equality between passing numeric and string replication factor for simple strategy """ + """Test equality between passing numeric and string replication factor for simple strategy""" rs = ReplicationStrategy() - simple_int = rs.create('SimpleStrategy', {'replication_factor': 3}) - simple_str = rs.create('SimpleStrategy', {'replication_factor': '3'}) + simple_int = rs.create("SimpleStrategy", {"replication_factor": 3}) + simple_str = rs.create("SimpleStrategy", {"replication_factor": "3"}) assert simple_int.export_for_schema() == simple_str.export_for_schema() assert simple_int == simple_str # make token replica map ring = [MD5Token(0), MD5Token(1), MD5Token(2)] - hosts = [Host('dc1.{}'.format(host), SimpleConvictionPolicy, host_id=uuid.uuid4()) for host in range(3)] + hosts = [ + Host("dc1.{}".format(host), SimpleConvictionPolicy, host_id=uuid.uuid4()) + for host in range(3) + ] token_to_host = dict(zip(ring, hosts)) - assert simple_int.make_token_replica_map(token_to_host, ring) == simple_str.make_token_replica_map(token_to_host, ring) + assert simple_int.make_token_replica_map( + token_to_host, ring + ) == simple_str.make_token_replica_map(token_to_host, ring) def test_transient_replication_parsing(self): - """ Test that we can PARSE a transient replication factor for SimpleStrategy """ + """Test that we can PARSE a transient replication factor for SimpleStrategy""" rs = ReplicationStrategy() - simple_transient = rs.create('SimpleStrategy', {'replication_factor': '3/1'}) + simple_transient = rs.create("SimpleStrategy", {"replication_factor": "3/1"}) assert simple_transient.replication_factor_info == ReplicationFactor(3, 1) assert simple_transient.replication_factor == 2 assert "'replication_factor': '3/1'" in simple_transient.export_for_schema() - simple_str = rs.create('SimpleStrategy', {'replication_factor': '2'}) + simple_str = rs.create("SimpleStrategy", {"replication_factor": "2"}) assert simple_transient != simple_str # make token replica map ring = [MD5Token(0), MD5Token(1), MD5Token(2)] - hosts = [Host('dc1.{}'.format(host), SimpleConvictionPolicy, host_id=uuid.uuid4()) for host in range(3)] + hosts = [ + Host("dc1.{}".format(host), SimpleConvictionPolicy, host_id=uuid.uuid4()) + for host in range(3) + ] token_to_host = dict(zip(ring, hosts)) - assert simple_transient.make_token_replica_map(token_to_host, ring) == simple_str.make_token_replica_map(token_to_host, ring) + assert simple_transient.make_token_replica_map( + token_to_host, ring + ) == simple_str.make_token_replica_map(token_to_host, ring) def test_nts_replication_parsing(self): - """ Test equality between passing numeric and string replication factor for NTS """ + """Test equality between passing numeric and string replication factor for NTS""" rs = ReplicationStrategy() - nts_int = rs.create('NetworkTopologyStrategy', {'dc1': 3, 'dc2': 5}) - nts_str = rs.create('NetworkTopologyStrategy', {'dc1': '3', 'dc2': '5'}) + nts_int = rs.create("NetworkTopologyStrategy", {"dc1": 3, "dc2": 5}) + nts_str = rs.create("NetworkTopologyStrategy", {"dc1": "3", "dc2": "5"}) - assert nts_int.dc_replication_factors['dc1'] == 3 - assert nts_str.dc_replication_factors['dc1'] == 3 - assert nts_int.dc_replication_factors_info['dc1'] == ReplicationFactor(3) - assert nts_str.dc_replication_factors_info['dc1'] == ReplicationFactor(3) + assert nts_int.dc_replication_factors["dc1"] == 3 + assert nts_str.dc_replication_factors["dc1"] == 3 + assert nts_int.dc_replication_factors_info["dc1"] == ReplicationFactor(3) + assert nts_str.dc_replication_factors_info["dc1"] == ReplicationFactor(3) assert nts_int.export_for_schema() == nts_str.export_for_schema() assert nts_int == nts_str # make token replica map ring = [MD5Token(0), MD5Token(1), MD5Token(2)] - hosts = [Host('dc1.{}'.format(host), SimpleConvictionPolicy, host_id=uuid.uuid4()) for host in range(3)] + hosts = [ + Host("dc1.{}".format(host), SimpleConvictionPolicy, host_id=uuid.uuid4()) + for host in range(3) + ] token_to_host = dict(zip(ring, hosts)) - assert nts_int.make_token_replica_map(token_to_host, ring) == nts_str.make_token_replica_map(token_to_host, ring) + assert nts_int.make_token_replica_map( + token_to_host, ring + ) == nts_str.make_token_replica_map(token_to_host, ring) def test_nts_transient_parsing(self): - """ Test that we can PARSE a transient replication factor for NTS """ + """Test that we can PARSE a transient replication factor for NTS""" rs = ReplicationStrategy() - nts_transient = rs.create('NetworkTopologyStrategy', {'dc1': '3/1', 'dc2': '5/1'}) - assert nts_transient.dc_replication_factors_info['dc1'] == ReplicationFactor(3, 1) - assert nts_transient.dc_replication_factors_info['dc2'] == ReplicationFactor(5, 1) - assert nts_transient.dc_replication_factors['dc1'] == 2 - assert nts_transient.dc_replication_factors['dc2'] == 4 + nts_transient = rs.create( + "NetworkTopologyStrategy", {"dc1": "3/1", "dc2": "5/1"} + ) + assert nts_transient.dc_replication_factors_info["dc1"] == ReplicationFactor( + 3, 1 + ) + assert nts_transient.dc_replication_factors_info["dc2"] == ReplicationFactor( + 5, 1 + ) + assert nts_transient.dc_replication_factors["dc1"] == 2 + assert nts_transient.dc_replication_factors["dc2"] == 4 assert "'dc1': '3/1', 'dc2': '5/1'" in nts_transient.export_for_schema() - nts_str = rs.create('NetworkTopologyStrategy', {'dc1': '3', 'dc2': '5'}) + nts_str = rs.create("NetworkTopologyStrategy", {"dc1": "3", "dc2": "5"}) assert nts_transient != nts_str # make token replica map ring = [MD5Token(0), MD5Token(1), MD5Token(2)] - hosts = [Host('dc1.{}'.format(host), SimpleConvictionPolicy, host_id=uuid.uuid4()) for host in range(3)] + hosts = [ + Host("dc1.{}".format(host), SimpleConvictionPolicy, host_id=uuid.uuid4()) + for host in range(3) + ] token_to_host = dict(zip(ring, hosts)) - assert nts_transient.make_token_replica_map(token_to_host, ring) == nts_str.make_token_replica_map(token_to_host, ring) + assert nts_transient.make_token_replica_map( + token_to_host, ring + ) == nts_str.make_token_replica_map(token_to_host, ring) def test_nts_make_token_replica_map(self): token_to_host_owner = {} - dc1_1 = Host('dc1.1', SimpleConvictionPolicy, host_id=uuid.uuid4()) - dc1_2 = Host('dc1.2', SimpleConvictionPolicy, host_id=uuid.uuid4()) - dc1_3 = Host('dc1.3', SimpleConvictionPolicy, host_id=uuid.uuid4()) + dc1_1 = Host("dc1.1", SimpleConvictionPolicy, host_id=uuid.uuid4()) + dc1_2 = Host("dc1.2", SimpleConvictionPolicy, host_id=uuid.uuid4()) + dc1_3 = Host("dc1.3", SimpleConvictionPolicy, host_id=uuid.uuid4()) for host in (dc1_1, dc1_2, dc1_3): - host.set_location_info('dc1', 'rack1') + host.set_location_info("dc1", "rack1") token_to_host_owner[MD5Token(0)] = dc1_1 token_to_host_owner[MD5Token(100)] = dc1_2 token_to_host_owner[MD5Token(200)] = dc1_3 - dc2_1 = Host('dc2.1', SimpleConvictionPolicy, host_id=uuid.uuid4()) - dc2_2 = Host('dc2.2', SimpleConvictionPolicy, host_id=uuid.uuid4()) - dc2_1.set_location_info('dc2', 'rack1') - dc2_2.set_location_info('dc2', 'rack1') + dc2_1 = Host("dc2.1", SimpleConvictionPolicy, host_id=uuid.uuid4()) + dc2_2 = Host("dc2.2", SimpleConvictionPolicy, host_id=uuid.uuid4()) + dc2_1.set_location_info("dc2", "rack1") + dc2_2.set_location_info("dc2", "rack1") token_to_host_owner[MD5Token(1)] = dc2_1 token_to_host_owner[MD5Token(101)] = dc2_2 - dc3_1 = Host('dc3.1', SimpleConvictionPolicy, host_id=uuid.uuid4()) - dc3_1.set_location_info('dc3', 'rack3') + dc3_1 = Host("dc3.1", SimpleConvictionPolicy, host_id=uuid.uuid4()) + dc3_1.set_location_info("dc3", "rack3") token_to_host_owner[MD5Token(2)] = dc3_1 - ring = [MD5Token(0), - MD5Token(1), - MD5Token(2), - MD5Token(100), - MD5Token(101), - MD5Token(200)] + ring = [ + MD5Token(0), + MD5Token(1), + MD5Token(2), + MD5Token(100), + MD5Token(101), + MD5Token(200), + ] - nts = NetworkTopologyStrategy({'dc1': 2, 'dc2': 2, 'dc3': 1}) + nts = NetworkTopologyStrategy({"dc1": 2, "dc2": 2, "dc3": 1}) replica_map = nts.make_token_replica_map(token_to_host_owner, ring) assertCountEqual(replica_map[MD5Token(0)], (dc1_1, dc1_2, dc2_1, dc2_2, dc3_1)) @@ -240,21 +296,22 @@ def test_nts_token_performance(self): current_token = 0 vnodes_per_host = 500 for i in range(dc1hostnum): - - host = Host('dc1.{0}'.format(i), SimpleConvictionPolicy, host_id=uuid.uuid4()) - host.set_location_info('dc1', "rack1") + host = Host( + "dc1.{0}".format(i), SimpleConvictionPolicy, host_id=uuid.uuid4() + ) + host.set_location_info("dc1", "rack1") for vnode_num in range(vnodes_per_host): - md5_token = MD5Token(current_token+vnode_num) + md5_token = MD5Token(current_token + vnode_num) token_to_host_owner[md5_token] = host ring.append(md5_token) current_token += 1000 - nts = NetworkTopologyStrategy({'dc1': 3}) + nts = NetworkTopologyStrategy({"dc1": 3}) start_time = timeit.default_timer() nts.make_token_replica_map(token_to_host_owner, ring) elapsed_base = timeit.default_timer() - start_time - nts = NetworkTopologyStrategy({'dc1': 1500}) + nts = NetworkTopologyStrategy({"dc1": 1500}) start_time = timeit.default_timer() nts.make_token_replica_map(token_to_host_owner, ring) elapsed_bad = timeit.default_timer() - start_time @@ -265,117 +322,141 @@ def test_nts_make_token_replica_map_multi_rack(self): token_to_host_owner = {} # (A) not enough distinct racks, first skipped is used - dc1_1 = Host('dc1.1', SimpleConvictionPolicy, host_id=uuid.uuid4()) - dc1_2 = Host('dc1.2', SimpleConvictionPolicy, host_id=uuid.uuid4()) - dc1_3 = Host('dc1.3', SimpleConvictionPolicy, host_id=uuid.uuid4()) - dc1_4 = Host('dc1.4', SimpleConvictionPolicy, host_id=uuid.uuid4()) - dc1_1.set_location_info('dc1', 'rack1') - dc1_2.set_location_info('dc1', 'rack1') - dc1_3.set_location_info('dc1', 'rack2') - dc1_4.set_location_info('dc1', 'rack2') + dc1_1 = Host("dc1.1", SimpleConvictionPolicy, host_id=uuid.uuid4()) + dc1_2 = Host("dc1.2", SimpleConvictionPolicy, host_id=uuid.uuid4()) + dc1_3 = Host("dc1.3", SimpleConvictionPolicy, host_id=uuid.uuid4()) + dc1_4 = Host("dc1.4", SimpleConvictionPolicy, host_id=uuid.uuid4()) + dc1_1.set_location_info("dc1", "rack1") + dc1_2.set_location_info("dc1", "rack1") + dc1_3.set_location_info("dc1", "rack2") + dc1_4.set_location_info("dc1", "rack2") token_to_host_owner[MD5Token(0)] = dc1_1 token_to_host_owner[MD5Token(100)] = dc1_2 token_to_host_owner[MD5Token(200)] = dc1_3 token_to_host_owner[MD5Token(300)] = dc1_4 # (B) distinct racks, but not contiguous - dc2_1 = Host('dc2.1', SimpleConvictionPolicy, host_id=uuid.uuid4()) - dc2_2 = Host('dc2.2', SimpleConvictionPolicy, host_id=uuid.uuid4()) - dc2_3 = Host('dc2.3', SimpleConvictionPolicy, host_id=uuid.uuid4()) - dc2_1.set_location_info('dc2', 'rack1') - dc2_2.set_location_info('dc2', 'rack1') - dc2_3.set_location_info('dc2', 'rack2') + dc2_1 = Host("dc2.1", SimpleConvictionPolicy, host_id=uuid.uuid4()) + dc2_2 = Host("dc2.2", SimpleConvictionPolicy, host_id=uuid.uuid4()) + dc2_3 = Host("dc2.3", SimpleConvictionPolicy, host_id=uuid.uuid4()) + dc2_1.set_location_info("dc2", "rack1") + dc2_2.set_location_info("dc2", "rack1") + dc2_3.set_location_info("dc2", "rack2") token_to_host_owner[MD5Token(1)] = dc2_1 token_to_host_owner[MD5Token(101)] = dc2_2 token_to_host_owner[MD5Token(201)] = dc2_3 - ring = [MD5Token(0), - MD5Token(1), - MD5Token(100), - MD5Token(101), - MD5Token(200), - MD5Token(201), - MD5Token(300)] + ring = [ + MD5Token(0), + MD5Token(1), + MD5Token(100), + MD5Token(101), + MD5Token(200), + MD5Token(201), + MD5Token(300), + ] - nts = NetworkTopologyStrategy({'dc1': 3, 'dc2': 2}) + nts = NetworkTopologyStrategy({"dc1": 3, "dc2": 2}) replica_map = nts.make_token_replica_map(token_to_host_owner, ring) token_replicas = replica_map[MD5Token(0)] assertCountEqual(token_replicas, (dc1_1, dc1_2, dc1_3, dc2_1, dc2_3)) def test_nts_make_token_replica_map_empty_dc(self): - host = Host('1', SimpleConvictionPolicy, host_id=uuid.uuid4()) - host.set_location_info('dc1', 'rack1') + host = Host("1", SimpleConvictionPolicy, host_id=uuid.uuid4()) + host.set_location_info("dc1", "rack1") token_to_host_owner = {MD5Token(0): host} ring = [MD5Token(0)] - nts = NetworkTopologyStrategy({'dc1': 1, 'dc2': 0}) + nts = NetworkTopologyStrategy({"dc1": 1, "dc2": 0}) replica_map = nts.make_token_replica_map(token_to_host_owner, ring) assert set(replica_map[MD5Token(0)]) == set([host]) def test_nts_export_for_schema(self): - strategy = NetworkTopologyStrategy({'dc1': '1', 'dc2': '2'}) - assert "{'class': 'NetworkTopologyStrategy', 'dc1': '1', 'dc2': '2'}" == strategy.export_for_schema() + strategy = NetworkTopologyStrategy({"dc1": "1", "dc2": "2"}) + assert ( + "{'class': 'NetworkTopologyStrategy', 'dc1': '1', 'dc2': '2'}" + == strategy.export_for_schema() + ) def test_simple_strategy_make_token_replica_map(self): - host1 = Host('1', SimpleConvictionPolicy, host_id=uuid.uuid4()) - host2 = Host('2', SimpleConvictionPolicy, host_id=uuid.uuid4()) - host3 = Host('3', SimpleConvictionPolicy, host_id=uuid.uuid4()) + host1 = Host( + "1", + SimpleConvictionPolicy, + datacenter="dc1", + rack="rack1", + host_id=uuid.uuid4(), + ) + host2 = Host( + "2", + SimpleConvictionPolicy, + datacenter="dc1", + rack="rack1", + host_id=uuid.uuid4(), + ) + host3 = Host( + "3", + SimpleConvictionPolicy, + datacenter="dc1", + rack="rack1", + host_id=uuid.uuid4(), + ) token_to_host_owner = { MD5Token(0): host1, MD5Token(100): host2, - MD5Token(200): host3 + MD5Token(200): host3, } ring = [MD5Token(0), MD5Token(100), MD5Token(200)] - rf1_replicas = SimpleStrategy({'replication_factor': '1'}).make_token_replica_map(token_to_host_owner, ring) + rf1_replicas = NetworkTopologyStrategy({"dc1": "1"}).make_token_replica_map( + token_to_host_owner, ring + ) assertCountEqual(rf1_replicas[MD5Token(0)], [host1]) assertCountEqual(rf1_replicas[MD5Token(100)], [host2]) assertCountEqual(rf1_replicas[MD5Token(200)], [host3]) - rf2_replicas = SimpleStrategy({'replication_factor': '2'}).make_token_replica_map(token_to_host_owner, ring) + rf2_replicas = NetworkTopologyStrategy({"dc1": "2"}).make_token_replica_map( + token_to_host_owner, ring + ) assertCountEqual(rf2_replicas[MD5Token(0)], [host1, host2]) assertCountEqual(rf2_replicas[MD5Token(100)], [host2, host3]) assertCountEqual(rf2_replicas[MD5Token(200)], [host3, host1]) - rf3_replicas = SimpleStrategy({'replication_factor': '3'}).make_token_replica_map(token_to_host_owner, ring) + rf3_replicas = NetworkTopologyStrategy({"dc1": "3"}).make_token_replica_map( + token_to_host_owner, ring + ) assertCountEqual(rf3_replicas[MD5Token(0)], [host1, host2, host3]) assertCountEqual(rf3_replicas[MD5Token(100)], [host2, host3, host1]) assertCountEqual(rf3_replicas[MD5Token(200)], [host3, host1, host2]) def test_ss_equals(self): - assert SimpleStrategy({'replication_factor': '1'}) != NetworkTopologyStrategy({'dc1': 2}) + assert NetworkTopologyStrategy({"dc1": "1"}) != NetworkTopologyStrategy( + {"dc1": 2} + ) class NameEscapingTest(unittest.TestCase): - def test_protect_name(self): """ Test cassandra.metadata.protect_name output """ - assert protect_name('tests') == 'tests' - assert protect_name('test\'s') == '"test\'s"' - assert protect_name('test\'s') == "\"test's\"" - assert protect_name('tests ?!@#$%^&*()') == '"tests ?!@#$%^&*()"' - assert protect_name('1') == '"1"' - assert protect_name('1test') == '"1test"' + assert protect_name("tests") == "tests" + assert protect_name("test's") == '"test\'s"' + assert protect_name("test's") == '"test\'s"' + assert protect_name("tests ?!@#$%^&*()") == '"tests ?!@#$%^&*()"' + assert protect_name("1") == '"1"' + assert protect_name("1test") == '"1test"' def test_protect_names(self): """ Test cassandra.metadata.protect_names output """ - assert protect_names(['tests']) == ['tests'] - assert protect_names( - [ - 'tests', - 'test\'s', - 'tests ?!@#$%^&*()', - '1' - ]) == [ - 'tests', - "\"test's\"", + assert protect_names(["tests"]) == ["tests"] + assert protect_names(["tests", "test's", "tests ?!@#$%^&*()", "1"]) == [ + "tests", + '"test\'s"', '"tests ?!@#$%^&*()"', - '"1"' + '"1"', ] def test_protect_value(self): @@ -384,35 +465,42 @@ def test_protect_value(self): """ assert protect_value(True) == "true" assert protect_value(False) == "false" - assert protect_value(3.14) == '3.14' - assert protect_value(3) == '3' - assert protect_value('test') == "'test'" - assert protect_value('test\'s') == "'test''s'" - assert protect_value(None) == 'NULL' + assert protect_value(3.14) == "3.14" + assert protect_value(3) == "3" + assert protect_value("test") == "'test'" + assert protect_value("test's") == "'test''s'" + assert protect_value(None) == "NULL" def test_is_valid_name(self): """ Test cassandra.metadata.is_valid_name output """ assert is_valid_name(None) == False - assert is_valid_name('test') == True - assert is_valid_name('Test') == False - assert is_valid_name('t_____1') == True - assert is_valid_name('test1') == True - assert is_valid_name('1test1') == False - - invalid_keywords = cassandra.metadata.cql_keywords - cassandra.metadata.cql_keywords_unreserved + assert is_valid_name("test") == True + assert is_valid_name("Test") == False + assert is_valid_name("t_____1") == True + assert is_valid_name("test1") == True + assert is_valid_name("1test1") == False + + invalid_keywords = ( + cassandra.metadata.cql_keywords - cassandra.metadata.cql_keywords_unreserved + ) for keyword in invalid_keywords: assert is_valid_name(keyword) == False class GetReplicasTest(unittest.TestCase): def _get_replicas(self, token_klass): - tokens = [token_klass(i) for i in range(0, (2 ** 127 - 1), 2 ** 125)] - hosts = [Host("ip%d" % i, SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(len(tokens))] + tokens = [token_klass(i) for i in range(0, (2**127 - 1), 2**125)] + hosts = [ + Host("ip%d" % i, SimpleConvictionPolicy, host_id=uuid.uuid4()) + for i in range(len(tokens)) + ] + for host in hosts: + host.set_location_info("dc1", "rack1") token_to_primary_replica = dict(zip(tokens, hosts)) - keyspace = KeyspaceMetadata("ks", True, "SimpleStrategy", {"replication_factor": "1"}) - metadata = Mock(spec=Metadata, keyspaces={'ks': keyspace}) + keyspace = KeyspaceMetadata("ks", True, "NetworkTopologyStrategy", {"dc1": "1"}) + metadata = Mock(spec=Metadata, keyspaces={"ks": keyspace}) token_map = TokenMap(token_klass, token_to_primary_replica, tokens, metadata) # tokens match node tokens exactly @@ -442,95 +530,107 @@ def test_bytes_tokens(self): class Murmur3TokensTest(unittest.TestCase): - def test_murmur3_init(self): murmur3_token = Murmur3Token(cassandra.metadata.MIN_LONG - 1) - assert str(murmur3_token) == '' + assert str(murmur3_token) == "" def test_python_vs_c(self): from cassandra.murmur3 import _murmur3 as mm3_python + try: from cassandra.cmurmur3 import murmur3 as mm3_c iterations = 100 for _ in range(iterations): - for len in range(0, 32): # zero to one block plus full range of tail lengths + for len in range( + 0, 32 + ): # zero to one block plus full range of tail lengths key = os.urandom(len) assert mm3_python(key) == mm3_c(key) except ImportError: - raise unittest.SkipTest('The cmurmur3 extension is not available') + raise unittest.SkipTest("The cmurmur3 extension is not available") def test_murmur3_python(self): from cassandra.murmur3 import _murmur3 + self._verify_hash(_murmur3) def test_murmur3_c(self): try: from cassandra.cmurmur3 import murmur3 + self._verify_hash(murmur3) except ImportError: - raise unittest.SkipTest('The cmurmur3 extension is not available') + raise unittest.SkipTest("The cmurmur3 extension is not available") def _verify_hash(self, fn): - assert fn(b'123') == -7468325962851647638 - assert fn(b'\x00\xff\x10\xfa\x99' * 10) == 5837342703291459765 - assert fn(b'\xfe' * 8) == -8927430733708461935 - assert fn(b'\x10' * 8) == 1446172840243228796 + assert fn(b"123") == -7468325962851647638 + assert fn(b"\x00\xff\x10\xfa\x99" * 10) == 5837342703291459765 + assert fn(b"\xfe" * 8) == -8927430733708461935 + assert fn(b"\x10" * 8) == 1446172840243228796 assert fn(str(cassandra.metadata.MAX_LONG).encode()) == 7162290910810015547 class MD5TokensTest(unittest.TestCase): - def test_md5_tokens(self): md5_token = MD5Token(cassandra.metadata.MIN_LONG - 1) - assert md5_token.hash_fn('123') == 42767516990368493138776584305024125808 - assert md5_token.hash_fn(str(cassandra.metadata.MAX_LONG)) == 28528976619278518853815276204542453639 - assert str(md5_token) == '' % -9223372036854775809 + assert md5_token.hash_fn("123") == 42767516990368493138776584305024125808 + assert ( + md5_token.hash_fn(str(cassandra.metadata.MAX_LONG)) + == 28528976619278518853815276204542453639 + ) + assert str(md5_token) == "" % -9223372036854775809 class BytesTokensTest(unittest.TestCase): - def test_bytes_tokens(self): - bytes_token = BytesToken(unhexlify(b'01')) - assert bytes_token.value == b'\x01' + bytes_token = BytesToken(unhexlify(b"01")) + assert bytes_token.value == b"\x01" assert str(bytes_token) == "" % bytes_token.value - assert bytes_token.hash_fn('123') == '123' + assert bytes_token.hash_fn("123") == "123" assert bytes_token.hash_fn(123) == 123 - assert bytes_token.hash_fn(str(cassandra.metadata.MAX_LONG)) == str(cassandra.metadata.MAX_LONG) + assert bytes_token.hash_fn(str(cassandra.metadata.MAX_LONG)) == str( + cassandra.metadata.MAX_LONG + ) def test_from_string(self): - from_unicode = BytesToken.from_string('0123456789abcdef') - from_bin = BytesToken.from_string(b'0123456789abcdef') + from_unicode = BytesToken.from_string("0123456789abcdef") + from_bin = BytesToken.from_string(b"0123456789abcdef") assert from_unicode == from_bin assert isinstance(from_unicode.value, bytes) assert isinstance(from_bin.value, bytes) def test_comparison(self): - tok = BytesToken.from_string('0123456789abcdef') + tok = BytesToken.from_string("0123456789abcdef") token_high_order = uint16_unpack(tok.value[0:2]) assert BytesToken(uint16_pack(token_high_order - 1)) < tok assert BytesToken(uint16_pack(token_high_order + 1)) > tok def test_comparison_unicode(self): - value = b'\'_-()"\xc2\xac' + value = b"'_-()\"\xc2\xac" t0 = BytesToken(value) - t1 = BytesToken.from_string('00') + t1 = BytesToken.from_string("00") assert t0 > t1 assert not t0 < t1 class KeyspaceMetadataTest(unittest.TestCase): - def test_export_as_string_user_types(self): - keyspace_name = 'test' - keyspace = KeyspaceMetadata(keyspace_name, True, 'SimpleStrategy', dict(replication_factor=3)) - keyspace.user_types['a'] = UserType(keyspace_name, 'a', ['one', 'two'], ['c', 'int']) - keyspace.user_types['b'] = UserType(keyspace_name, 'b', ['one', 'two', 'three'], ['d', 'int', 'a']) - keyspace.user_types['c'] = UserType(keyspace_name, 'c', ['one'], ['int']) - keyspace.user_types['d'] = UserType(keyspace_name, 'd', ['one'], ['c']) + keyspace_name = "test" + keyspace = KeyspaceMetadata( + keyspace_name, True, "NetworkTopologyStrategy", dict(dc1=3) + ) + keyspace.user_types["a"] = UserType( + keyspace_name, "a", ["one", "two"], ["c", "int"] + ) + keyspace.user_types["b"] = UserType( + keyspace_name, "b", ["one", "two", "three"], ["d", "int", "a"] + ) + keyspace.user_types["c"] = UserType(keyspace_name, "c", ["one"], ["int"]) + keyspace.user_types["d"] = UserType(keyspace_name, "d", ["one"], ["c"]) - assert """CREATE KEYSPACE test WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '3'} AND durable_writes = true; + assert """CREATE KEYSPACE test WITH replication = {'class': 'NetworkTopologyStrategy', 'dc1': '3'} AND durable_writes = true; CREATE TYPE test.c ( one int @@ -553,11 +653,13 @@ def test_export_as_string_user_types(self): class UserTypesTest(unittest.TestCase): - def test_as_cql_query(self): - field_types = ['varint', 'ascii', 'frozen>'] + field_types = ["varint", "ascii", "frozen>"] udt = UserType("ks1", "mytype", ["a", "b", "c"], field_types) - assert "CREATE TYPE ks1.mytype (a varint, b ascii, c frozen>)" == udt.as_cql_query(formatted=False) + assert ( + "CREATE TYPE ks1.mytype (a varint, b ascii, c frozen>)" + == udt.as_cql_query(formatted=False) + ) assert """CREATE TYPE ks1.mytype ( a varint, @@ -566,15 +668,27 @@ def test_as_cql_query(self): );""" == udt.export_as_string() def test_as_cql_query_name_escaping(self): - udt = UserType("MyKeyspace", "MyType", ["AbA", "keyspace"], ['ascii', 'ascii']) - assert 'CREATE TYPE "MyKeyspace"."MyType" ("AbA" ascii, "keyspace" ascii)' == udt.as_cql_query(formatted=False) + udt = UserType("MyKeyspace", "MyType", ["AbA", "keyspace"], ["ascii", "ascii"]) + assert ( + 'CREATE TYPE "MyKeyspace"."MyType" ("AbA" ascii, "keyspace" ascii)' + == udt.as_cql_query(formatted=False) + ) class UserDefinedFunctionTest(unittest.TestCase): def test_as_cql_query_removes_frozen(self): func = Function( - "ks1", "myfunction", ["frozen>"], ["a"], - "int", "java", "return 0;", True, False, False, False + "ks1", + "myfunction", + ["frozen>"], + ["a"], + "int", + "java", + "return 0;", + True, + False, + False, + False, ) expected_result = ( "CREATE FUNCTION ks1.myfunction(a tuple) " @@ -588,7 +702,17 @@ def test_as_cql_query_removes_frozen(self): class UserDefinedAggregateTest(unittest.TestCase): def test_as_cql_query_removes_frozen(self): - aggregate = Aggregate("ks1", "myaggregate", ["frozen>"], "statefunc", "frozen>", "finalfunc", "(0)", "tuple", False) + aggregate = Aggregate( + "ks1", + "myaggregate", + ["frozen>"], + "statefunc", + "frozen>", + "finalfunc", + "(0)", + "tuple", + False, + ) expected_result = ( "CREATE AGGREGATE ks1.myaggregate(tuple) " "SFUNC statefunc " @@ -600,27 +724,31 @@ def test_as_cql_query_removes_frozen(self): class IndexTest(unittest.TestCase): - def test_build_index_as_cql(self): column_meta = Mock() - column_meta.name = 'column_name_here' - column_meta.table.name = 'table_name_here' - column_meta.table.keyspace_name = 'keyspace_name_here' + column_meta.name = "column_name_here" + column_meta.table.name = "table_name_here" + column_meta.table.keyspace_name = "keyspace_name_here" column_meta.table.columns = {column_meta.name: column_meta} - parser = get_schema_parser(Mock(), '2.1.0', None, 0.1, None) + parser = get_schema_parser(Mock(), "2.1.0", None, 0.1, None) - row = {'index_name': 'index_name_here', 'index_type': 'index_type_here'} + row = {"index_name": "index_name_here", "index_type": "index_type_here"} index_meta = parser._build_index_metadata(column_meta, row) - assert index_meta.as_cql_query() == 'CREATE INDEX index_name_here ON keyspace_name_here.table_name_here (column_name_here)' + assert ( + index_meta.as_cql_query() + == "CREATE INDEX index_name_here ON keyspace_name_here.table_name_here (column_name_here)" + ) - row['index_options'] = '{ "class_name": "class_name_here" }' - row['index_type'] = 'CUSTOM' + row["index_options"] = '{ "class_name": "class_name_here" }' + row["index_type"] = "CUSTOM" index_meta = parser._build_index_metadata(column_meta, row) - assert index_meta.as_cql_query() == "CREATE CUSTOM INDEX index_name_here ON keyspace_name_here.table_name_here (column_name_here) USING 'class_name_here'" + assert ( + index_meta.as_cql_query() + == "CREATE CUSTOM INDEX index_name_here ON keyspace_name_here.table_name_here (column_name_here) USING 'class_name_here'" + ) class SchemaParserLookupTests(unittest.TestCase): - def test_reads_versions_from_system_local_when_missing(self): connection = Mock() @@ -659,10 +787,10 @@ class UnicodeIdentifiersTests(unittest.TestCase): Looking for encoding errors like PYTHON-447 """ - name = b'\'_-()"\xc2\xac'.decode('utf-8') + name = b"'_-()\"\xc2\xac".decode("utf-8") def test_keyspace_name(self): - km = KeyspaceMetadata(self.name, False, 'SimpleStrategy', {'replication_factor': 1}) + km = KeyspaceMetadata(self.name, False, "NetworkTopologyStrategy", {"dc1": 1}) km.export_as_string() def test_table_name(self): @@ -670,146 +798,181 @@ def test_table_name(self): tm.export_as_string() def test_column_name_single_partition(self): - tm = TableMetadata('ks', 'table') - cm = ColumnMetadata(tm, self.name, u'int') + tm = TableMetadata("ks", "table") + cm = ColumnMetadata(tm, self.name, "int") tm.columns[cm.name] = cm tm.partition_key.append(cm) tm.export_as_string() def test_column_name_single_partition_single_clustering(self): - tm = TableMetadata('ks', 'table') - cm = ColumnMetadata(tm, self.name, u'int') + tm = TableMetadata("ks", "table") + cm = ColumnMetadata(tm, self.name, "int") tm.columns[cm.name] = cm tm.partition_key.append(cm) - cm = ColumnMetadata(tm, self.name + 'x', u'int') + cm = ColumnMetadata(tm, self.name + "x", "int") tm.columns[cm.name] = cm tm.clustering_key.append(cm) tm.export_as_string() def test_column_name_multiple_partition(self): - tm = TableMetadata('ks', 'table') - cm = ColumnMetadata(tm, self.name, u'int') + tm = TableMetadata("ks", "table") + cm = ColumnMetadata(tm, self.name, "int") tm.columns[cm.name] = cm tm.partition_key.append(cm) - cm = ColumnMetadata(tm, self.name + 'x', u'int') + cm = ColumnMetadata(tm, self.name + "x", "int") tm.columns[cm.name] = cm tm.partition_key.append(cm) tm.export_as_string() def test_index(self): - im = IndexMetadata(self.name, self.name, self.name, kind='', index_options={'target': self.name}) + im = IndexMetadata( + self.name, + self.name, + self.name, + kind="", + index_options={"target": self.name}, + ) log.debug(im.export_as_string()) - im = IndexMetadata(self.name, self.name, self.name, kind='CUSTOM', index_options={'target': self.name, 'class_name': 'Class'}) + im = IndexMetadata( + self.name, + self.name, + self.name, + kind="CUSTOM", + index_options={"target": self.name, "class_name": "Class"}, + ) log.debug(im.export_as_string()) # PYTHON-1008 - im = IndexMetadata(self.name, self.name, self.name, kind='CUSTOM', index_options={'target': self.name, 'class_name': 'Class', 'delimiter': self.name}) + im = IndexMetadata( + self.name, + self.name, + self.name, + kind="CUSTOM", + index_options={ + "target": self.name, + "class_name": "Class", + "delimiter": self.name, + }, + ) log.debug(im.export_as_string()) def test_function(self): - fm = Function(keyspace=self.name, name=self.name, - argument_types=(u'int', u'int'), - argument_names=(u'x', u'y'), - return_type=u'int', language=u'language', - body=self.name, called_on_null_input=False, - deterministic=True, - monotonic=False, monotonic_on=(u'x',)) + fm = Function( + keyspace=self.name, + name=self.name, + argument_types=("int", "int"), + argument_names=("x", "y"), + return_type="int", + language="language", + body=self.name, + called_on_null_input=False, + deterministic=True, + monotonic=False, + monotonic_on=("x",), + ) fm.export_as_string() def test_aggregate(self): - am = Aggregate(self.name, self.name, (u'text',), self.name, u'text', self.name, self.name, u'text', True) + am = Aggregate( + self.name, + self.name, + ("text",), + self.name, + "text", + self.name, + self.name, + "text", + True, + ) am.export_as_string() def test_user_type(self): - um = UserType(self.name, self.name, [self.name, self.name], [u'int', u'text']) + um = UserType(self.name, self.name, [self.name, self.name], ["int", "text"]) um.export_as_string() class FunctionToCQLTests(unittest.TestCase): - base_vars = { - 'keyspace': 'ks_name', - 'name': 'function_name', - 'argument_types': (u'int', u'int'), - 'argument_names': (u'x', u'y'), - 'return_type': u'int', - 'language': u'language', - 'body': 'body', - 'called_on_null_input': False, - 'deterministic': True, - 'monotonic': False, - 'monotonic_on': () + "keyspace": "ks_name", + "name": "function_name", + "argument_types": ("int", "int"), + "argument_names": ("x", "y"), + "return_type": "int", + "language": "language", + "body": "body", + "called_on_null_input": False, + "deterministic": True, + "monotonic": False, + "monotonic_on": (), } def _function_with_kwargs(self, **kwargs): - return Function(**dict(self.base_vars, - **kwargs) - ) + return Function(**dict(self.base_vars, **kwargs)) def test_non_monotonic(self): - assert 'MONOTONIC' not in self._function_with_kwargs( - monotonic=False, - monotonic_on=() - ).export_as_string() + assert ( + "MONOTONIC" + not in self._function_with_kwargs( + monotonic=False, monotonic_on=() + ).export_as_string() + ) def test_monotonic_all(self): - mono_function = self._function_with_kwargs( - monotonic=True, - monotonic_on=() - ) - assert 'MONOTONIC LANG' in mono_function.as_cql_query(formatted=False) - assert 'MONOTONIC\n LANG' in mono_function.as_cql_query(formatted=True) + mono_function = self._function_with_kwargs(monotonic=True, monotonic_on=()) + assert "MONOTONIC LANG" in mono_function.as_cql_query(formatted=False) + assert "MONOTONIC\n LANG" in mono_function.as_cql_query(formatted=True) def test_monotonic_one(self): mono_on_function = self._function_with_kwargs( - monotonic=False, - monotonic_on=('x',) + monotonic=False, monotonic_on=("x",) + ) + assert "MONOTONIC ON x LANG" in mono_on_function.as_cql_query(formatted=False) + assert "MONOTONIC ON x\n LANG" in mono_on_function.as_cql_query( + formatted=True ) - assert 'MONOTONIC ON x LANG' in mono_on_function.as_cql_query(formatted=False) - assert 'MONOTONIC ON x\n LANG' in mono_on_function.as_cql_query(formatted=True) def test_nondeterministic(self): - assert 'DETERMINISTIC' not in self._function_with_kwargs( + assert "DETERMINISTIC" not in self._function_with_kwargs( deterministic=False ).as_cql_query(formatted=False) def test_deterministic(self): - assert 'DETERMINISTIC' in self._function_with_kwargs( + assert "DETERMINISTIC" in self._function_with_kwargs( deterministic=True ).as_cql_query(formatted=False) - assert 'DETERMINISTIC\n' in self._function_with_kwargs( + assert "DETERMINISTIC\n" in self._function_with_kwargs( deterministic=True ).as_cql_query(formatted=True) class AggregateToCQLTests(unittest.TestCase): base_vars = { - 'keyspace': 'ks_name', - 'name': 'function_name', - 'argument_types': (u'int', u'int'), - 'state_func': 'funcname', - 'state_type': u'int', - 'return_type': u'int', - 'final_func': None, - 'initial_condition': '0', - 'deterministic': True + "keyspace": "ks_name", + "name": "function_name", + "argument_types": ("int", "int"), + "state_func": "funcname", + "state_type": "int", + "return_type": "int", + "final_func": None, + "initial_condition": "0", + "deterministic": True, } def _aggregate_with_kwargs(self, **kwargs): - return Aggregate(**dict(self.base_vars, - **kwargs) - ) + return Aggregate(**dict(self.base_vars, **kwargs)) def test_nondeterministic(self): - assert 'DETERMINISTIC' not in self._aggregate_with_kwargs( + assert "DETERMINISTIC" not in self._aggregate_with_kwargs( deterministic=False ).as_cql_query(formatted=True) def test_deterministic(self): for formatted in (True, False): - query = self._aggregate_with_kwargs( - deterministic=True - ).as_cql_query(formatted=formatted) - assert query.endswith('DETERMINISTIC'), "'DETERMINISTIC' not found in {}".format(query) + query = self._aggregate_with_kwargs(deterministic=True).as_cql_query( + formatted=formatted + ) + assert query.endswith("DETERMINISTIC"), ( + "'DETERMINISTIC' not found in {}".format(query) + ) class HostsTests(unittest.TestCase): @@ -818,8 +981,12 @@ def test_iterate_all_hosts_and_modify(self): PYTHON-572 """ metadata = Metadata() - metadata.add_or_return_host(Host('dc1.1', SimpleConvictionPolicy, host_id=uuid.uuid4())) - metadata.add_or_return_host(Host('dc1.2', SimpleConvictionPolicy, host_id=uuid.uuid4())) + metadata.add_or_return_host( + Host("dc1.1", SimpleConvictionPolicy, host_id=uuid.uuid4()) + ) + metadata.add_or_return_host( + Host("dc1.2", SimpleConvictionPolicy, host_id=uuid.uuid4()) + ) assert len(metadata.all_hosts()) == 2 @@ -830,18 +997,27 @@ def test_iterate_all_hosts_and_modify(self): class MetadataHelpersTest(unittest.TestCase): - """ For any helper functions that need unit tests """ + """For any helper functions that need unit tests""" + def test_strip_frozen(self): self.longMessage = True argument_to_expected_results = [ - ('int', 'int'), - ('tuple', 'tuple'), - (r'map<"!@#$%^&*()[]\ frozen >>>", int>', r'map<"!@#$%^&*()[]\ frozen >>>", int>'), # A valid UDT name - ('frozen>', 'tuple'), - (r'frozen>>", int>>', r'map<"!@#$%^&*()[]\ frozen >>>", int>'), - ('frozen>, int>>, frozen>>>>>', - 'map, int>, map>>'), + ("int", "int"), + ("tuple", "tuple"), + ( + r'map<"!@#$%^&*()[]\ frozen >>>", int>', + r'map<"!@#$%^&*()[]\ frozen >>>", int>', + ), # A valid UDT name + ("frozen>", "tuple"), + ( + r'frozen>>", int>>', + r'map<"!@#$%^&*()[]\ frozen >>>", int>', + ), + ( + "frozen>, int>>, frozen>>>>>", + "map, int>, map>>", + ), ] for argument, expected_result in argument_to_expected_results: result = strip_frozen(argument)