#!/usr/bin/env python3





##############################################################################
# MG2 : please note this is *not* my idea of quality code
#       these scripts started out being very nice
#       then it turned out boto3 had pervasive reliability issues
#       and it became a very long, many day, experience of discovery and work-arounds
#       having finally slugged through all that, I now know enough to write
#       a decent script - but that will mean re-writing and debugging from scratch
#       and right now, I just don't have the emotional strength to go around round
#       of slugging it out with boto; and these scripts now are viable in terms of
#       functionality, and I want to get the site to beta so testers can begin
#       giving feedback, so I'm going with this for now





##############################################################################
import argparse
import boto3
import json
import pprint
import psycopg2
import psycopg2.extras
import socket
import threading
import time





##############################################################################
MICROSECONDS_PER_SECOND = 1000000





##############################################################################
def create_global_state( rc, ec2r, ec2c, region, test_name_prefix, queue_slots, queue_memory ):

  print( 'create_global_state()' )

  tag_name = 'arrp-%s' % (test_name_prefix)

  global_state = \
  {
    'vpc' : None,
    'ig'  : None,
    'sn'  : None,
    'rsn' : None,
    'rpg' : None
  }

  availability_zone = region + 'a'

  # MG2 : create a VPC (this also creates a routing table, a NACL, SG and DHCP option set

  global_state['vpc'] = ec2r.create_vpc( CidrBlock         = '172.168.0.0/24',
                                         TagSpecifications = [ { 'ResourceType':'vpc', 'Tags':[{'Key':'creator', 'Value':tag_name}] } ] )

  global_state['vpc'].wait_until_available()

  # MG2 : activate DNS support on the VPC

  ec2c.modify_vpc_attribute( EnableDnsHostnames = { 'Value':True },
                             VpcId              = global_state['vpc'].id )

  ec2c.modify_vpc_attribute( EnableDnsSupport   = { 'Value':True },
                             VpcId              = global_state['vpc'].id )

  # MG2 : create an Internet Gateweay and hook it up to the VPC

  global_state['ig'] = ec2r.create_internet_gateway( TagSpecifications = [ { 'ResourceType':'internet-gateway', 'Tags':[{'Key':'creator', 'Value':tag_name}] } ] )

  global_state['vpc'].attach_internet_gateway( InternetGatewayId = global_state['ig'].id )

  # MG2 : figure out which routing table is hooked up the VPC we made, and add a route to it
  #       next() doesn't work with the iterator, so I'm using a hack

  routing_tables_iterator = global_state['vpc'].route_tables.all()
  for routing_table in routing_tables_iterator:
    global_state['rt'] = routing_table
    break

  global_state['route'] = global_state['rt'].create_route( DestinationCidrBlock = '0.0.0.0/0',
                                                           GatewayId            = global_state['ig'].id,
                                                           RouteTableId         = global_state['rt'].id )

  # MG2 : create a subnet and hook it up to the VPC

  global_state['sn'] = ec2r.create_subnet( AvailabilityZone  = availability_zone,
                                           CidrBlock         = '172.168.0.0/24',
                                           VpcId             = global_state['vpc'].id,
                                           TagSpecifications = [ { 'ResourceType':'subnet', 'Tags':[{'Key':'creator', 'Value':tag_name}] } ] )

  global_state['rt'].associate_with_subnet( SubnetId = global_state['sn'].id )

  # MG2 : configure the subnet to auto-assign IP addresses to clusters
  #       we have to do this, because there's a default limit of seven ElasticIP addresses
  #       and it's easy to run a test which needs more than seven clusters

  ec2c.modify_subnet_attribute( SubnetId            = global_state['sn'].id,
                                MapPublicIpOnLaunch = { 'Value':True } )

  # MG2 : find and configure the security group

  security_groups_iterator = global_state['vpc'].security_groups.all()
  for security_group in security_groups_iterator:
    global_state['sg'] = security_group
    break

  global_state['sg'].authorize_ingress( CidrIp     = '0.0.0.0/0',
                                        IpProtocol = 'tcp',
                                        FromPort   = 5439,
                                        ToPort     = 5439 )

  # MG2 : make and populate a Redshift subnet group

  global_state['rsn'] = tag_name
  rc.create_cluster_subnet_group( ClusterSubnetGroupName = global_state['rsn'],
                                  Description            = 'Temporary ARRP Redshift Subnet Group',
                                  SubnetIds              = [ global_state['sn'].id ] )

  # MG2 : make and populate a Redshift parameter group

  global_state['rpg'] = tag_name
  create_redshift_parameter_group( rc, global_state['rpg'], queue_slots, queue_memory )

  return global_state





##############################################################################
def delete_global_state( rc, ec2r, ec2c, global_state ):

  print( 'delete_global_state()' )

  # MG2 : assuming here IG is attached...
  if global_state['vpc'] != None and global_state['ig'] != None:
    global_state['vpc'].detach_internet_gateway( InternetGatewayId = global_state['ig'].id )

  if global_state['ig'] != None:
    ec2c.delete_internet_gateway( InternetGatewayId = global_state['ig'].id )

  if global_state['sn'] != None:
    ec2c.delete_subnet( SubnetId = global_state['sn'].id )

  if global_state['rsn'] != None:
    rc.delete_cluster_subnet_group( ClusterSubnetGroupName = global_state['rsn'] )

  if global_state['vpc'] != None:
    ec2c.delete_vpc( VpcId = global_state['vpc'].id )

  if global_state['rpg'] != None:
    delete_redshift_parameter_group( rc, global_state['rpg'] )

  return





##############################################################################
def reconstruct_global_state( rc, ec2r, ec2c, test_name_prefix ):

  print( 'reconstruct_global_state()' )

  tag_name = 'arrp-%s' % (test_name_prefix)

  global_state = \
  {
    'vpc' : None,
    'ig'  : None,
    'sn'  : None,
    'rsn' : None,
    'rpg' : None
  }

  vpc_info = ec2c.describe_vpcs(              Filters = [ {'Name':'tag:creator', 'Values':[ tag_name ]} ] )
  ig_info  = ec2c.describe_internet_gateways( Filters = [ {'Name':'tag:creator', 'Values':[ tag_name ]} ] )
  sn_info  = ec2c.describe_subnets(           Filters = [ {'Name':'tag:creator', 'Values':[ tag_name ]} ] )

  try:
    rsn_info = rc.describe_cluster_subnet_groups( ClusterSubnetGroupName = tag_name )
  except:
    pass
  else:
    global_state['rsn'] = tag_name

  try:
    rpg_info = rc.describe_cluster_parameter_groups( ParameterGroupName = tag_name )
  except:
    pass
  else:
    global_state['rpg'] = tag_name

  if len(vpc_info['Vpcs']) == 1:
    global_state['vpc'] = ec2r.Vpc( vpc_info['Vpcs'][0]['VpcId'] )

  if len(ig_info['InternetGateways']) == 1:
    global_state['ig'] = ec2r.InternetGateway( ig_info['InternetGateways'][0]['InternetGatewayId'] )

  if len(sn_info['Subnets']) == 1:
    global_state['sn'] = ec2r.Subnet( sn_info['Subnets'][0]['SubnetId'] )

  return global_state





##############################################################################
def create_redshift_parameter_group( rc, parameter_group_name, queue_slots, queue_memory ):

  print( 'create_redshift_parameter_group()' )

  rcpg = rc.create_cluster_parameter_group( ParameterGroupName   = parameter_group_name,
                                            ParameterGroupFamily = 'redshift-1.0',
                                            Description          = 'Temporary parameter group for ARRP testing' )

  # MG2 ; I'm using 'wlm_json_configuration' here not because I understand how to operate this function - I do not
  #       the docs are non-existant on how you would actually use this function in practise
  #       but I managed by trial and error to figure out the (and insane) syntax for 'wlm_json_configuration' with AWS CLI
  #       and I think I can re-use that here

  parameter_json = \
  [
    {
      "query_group"           : [],
      "query_group_wild_card" : 0,
      "user_group"            : [],
      "user_group_wild_card"  : 0,
      "auto_wlm"              : False
    },
    {
      "short_query_queue" : False
    }
  ]

  # MG2 : the first queue in the list becomes the default queue and has no name
  #       which is awkward and sadly typical of the lack of thought in Redshift design
  #       our arguments queue_slots and queue_memory describe the queues, with the first queue being the default

  for queue_id in range( 1, len(queue_slots) ):
    queue_json = {
                   "query_group"           : [ "queue_%d" % (queue_id-1) ],
                   "query_concurrency"     : queue_slots[queue_id],
                   "memory_percent_to_use" : queue_memory[queue_id]
                 }
    parameter_json.append( queue_json )

  queue_json = {
                 "query_concurrency"     : queue_slots[0],
                 "memory_percent_to_use" : queue_memory[0]
               }
  parameter_json.append( queue_json )

  rc.modify_cluster_parameter_group( ParameterGroupName = parameter_group_name,
                                     Parameters         = \
                                     [
                                       {
                                         'ApplyType'      : 'static',
                                         'ParameterName'  : 'wlm_json_configuration',
                                         'ParameterValue' : json.dumps(parameter_json),
                                       },
                                       {
                                         'ApplyType'      : 'static',
                                         'ParameterName'  : 'auto_analyze',
                                         'ParameterValue' : 'false'
                                       },
                                       {
                                         'ApplyType'      : 'static',
                                         'ParameterName'  : 'datestyle',
                                         'ParameterValue' : 'ISO'
                                       }
                                     ] )

  return rcpg





##############################################################################
def delete_redshift_parameter_group( rc, parameter_group_name ):

  print( 'delete_redshift_parameter_group()' )

  pg_info = rc.describe_cluster_parameter_groups( ParameterGroupName = parameter_group_name )

  if len(pg_info['ParameterGroups']) == 0:
    print( 'Warning : delete issued for non-existant Redshift parameter group "%s"' % (parameter_group_name) )
  else:
    rc.delete_cluster_parameter_group( ParameterGroupName = parameter_group_name )

  return





##############################################################################
def apply_redshift_parameter_group( rc, cluster_state, parameter_group_name ):

  print( "apply_redshift_parameter_group()" )

  pg_info = rc.describe_cluster_parameter_groups( ParameterGroupName = parameter_group_name )

  if len(pg_info['ParameterGroups']) == 0:
    print( 'Error : apply issued for non-existant Redshift parameter group "%s"' % (parameter_group_name) )
    exit( 1 )

  cluster_info = rc.describe_clusters( ClusterIdentifier=cluster_state['cluster_id'] )

  # print( '******* ' + cluster_info['Clusters'][0]['ClusterParameterGroups'][0]['ParameterGroupName'] + 'existing = ' + parameter_group_name )

  if cluster_info['Clusters'][0]['ClusterParameterGroups'][0]['ParameterGroupName'] != parameter_group_name:
    rc.modify_cluster( ClusterIdentifier=cluster_state['cluster_id'], ClusterParameterGroupName=parameter_group_name )

    # MG2 : we cannot move immediately to reboot; the cluster may not be ready
    #       we may also see a first 'available', which we need to ignore
    #       it may also be we don't even see the 'modifying'
    #       status changes are, frankly, a piece of junk
    #       for now it's a hack, we look for 3 available status's in a row

    cluster_wait_for_status( rc, cluster_state, 'modifying', 1 )

    cluster_wait_for_status( rc, cluster_state, 'available', 5 )

  # MG2 : need to add check here to see if a reboot is required

  rc.reboot_cluster( ClusterIdentifier=cluster_state['cluster_id'] )

  cluster_wait_for_status( rc, cluster_state, 'available', 3 )

  return





##############################################################################
def cluster_exists( rc, cluster_state ):

  print( 'cluster_exists()' )

  try:
    cluster_info = rc.describe_clusters( ClusterIdentifier=cluster_state['cluster_id'] )
  except:
    return False, None, None

  cluster_wait_for_status( rc, cluster_state, 'available', 2 )

  ip   = socket.gethostbyname( cluster_info['Clusters'][0]['Endpoint']['Address'] )
  port = cluster_info['Clusters'][0]['Endpoint']['Port']

  return True, ip, port





##########################################################################
def init_cluster_state( region, test_name_prefix, node_type, node_count ):

  print( 'init_cluster_state()' )

  cluster_id = 'arrp-%s-%s-%d' % (test_name_prefix, node_type, node_count)
  cluster_id = cluster_id.replace( '.', '-' )

  tag_name = 'arrp-%s' % (test_name_prefix)

  cluster_state = \
  {
    "cluster_id"        : cluster_id,
    "availability_zone" : region + 'a',
    "node_type"         : node_type,
    "admin_username"    : 'admin',
    "admin_password"    : 'BlackSesame2',
    "database"          : 'dev',
    "parameter_group"   : tag_name,
    "rs_subnet_group"   : tag_name,
    "node_count"        : node_count,
    "region"            : region
  }

  return cluster_state





##########################################################################
def cluster_create( rc, ec2r, ec2c, global_state, cluster_state ):

  print( 'cluster_create()' )

  success_flag = False
  attempt_count = 0

  while success_flag == False and attempt_count < 5:
    start = time.time()

    rc.create_cluster( DBName                           = cluster_state['database'],
                       ClusterIdentifier                = cluster_state['cluster_id'],
                       ClusterSubnetGroupName           = cluster_state['rs_subnet_group'],
                       ClusterType                      = 'multi-node',
                       NodeType                         = cluster_state['node_type'],
                       MasterUsername                   = cluster_state['admin_username'],
                       MasterUserPassword               = cluster_state['admin_password'], 
                       AvailabilityZone                 = cluster_state['availability_zone'],
                       ClusterParameterGroupName        = cluster_state['parameter_group'],
                       NumberOfNodes                    = cluster_state['node_count'],
                       PubliclyAccessible               = True,
                       Encrypted                        = False,
                       VpcSecurityGroupIds              = [ global_state['sg'].id ],
                       # MaintenanceTrackName             = 'trailing',
                       # IamRoles                         = [ iam_roles ],
                       AquaConfigurationStatus          = 'disabled',
                       ManualSnapshotRetentionPeriod    = 1,
                       AutomatedSnapshotRetentionPeriod = 1 if cluster_state['node_type'][0:3] == 'ra3' else 0 )

    # MG2 : one steady check is not enough
    #       as if you immediate then move to delete, you'll be told an operation is in progress

    cluster_wait_for_status( rc, cluster_state, 'available', 3 )

    end = time.time()

    print( '%s %d nodes %d seconds startup' % (cluster_state['node_type'], cluster_state['node_count'], end-start) )

    # MG2 : now, sometime clusters say they're available, but when you try to look up their IP, or connect, you fail

    cluster_info = rc.describe_clusters( ClusterIdentifier = cluster_state['cluster_id'] )

    try:
      cluster_state["ip"] = socket.gethostbyname( cluster_info['Clusters'][0]['Endpoint']['Address'] )
    except:
      ip_known_flag = False
    else:
      cluster_state["port"] = cluster_info['Clusters'][0]['Endpoint']['Port']
      ip_known_flag = True

    if ip_known_flag == True:
      # MG2 : now check we can actually connect
      #       sometimes the cluster is up, but we can't actually connect
      #       and then we need to zap the cluster and try again with a new cluster
      connection_state = cluster_connect( cluster_state, None )
      success_flag = connection_state['connection_flag']

    if ip_known_flag == False or connection_state['connection_flag'] == False:
      print( "Failed to lookup or connect to available cluster %s, deleting cluster and trying again." % (cluster_state['cluster_id']) )
      cluster_delete( rc, ec2r, ec2c, cluster_state )
      attempt_count += 1

  if success_flag == True:
    rows, row_count = issue_sql( connection_state, "select count(distinct slice) from stv_slices;" )
    cluster_state['number_slices'] = rows[0][0]

  if success_flag == False:
    print( "Unable to create and connect to cluster %s, bailing out." % (cluster_state['cluster_id']) )

  return





##########################################################################
def cluster_delete( rc, ec2r, ec2c, cluster_state ):

  print( 'cluster_delete( %s )' % (cluster_state["cluster_id"]) )

  try:
    result = rc.delete_cluster( ClusterIdentifier        = cluster_state["cluster_id"],
                                SkipFinalClusterSnapshot = True )
  except:
    pass;
  else:
    cluster_wait_for_status( rc, cluster_state, 'deleted', 3 )

  return





##########################################################################
def cluster_wait_for_status( rc, cluster_state, status, steady_count ):

  print( 'cluster_wait_for_status( %s, %s )' % (cluster_state["cluster_id"], status) )

  status_seen_count = 0

  while status_seen_count < steady_count:
    if status != 'deleted':
      cluster_info = rc.describe_clusters( ClusterIdentifier = cluster_state["cluster_id"] )
      if 'Clusters' in cluster_info:
        print( 'cluster %s status = %s' % (cluster_state["cluster_id"], cluster_info['Clusters'][0]['ClusterStatus']) )
        if cluster_info['Clusters'][0]['ClusterStatus'] == status:
          status_seen_count += 1
        else:
          status_seen_count = 0
    if status == 'deleted':
      try:
        cluster_info = rc.describe_clusters( ClusterIdentifier = cluster_state["cluster_id"] )
      except Exception as e:
        status_seen_count += 1
        print( 'cluster %s status = deleted' % (cluster_state["cluster_id"]) )
      else:
        status_seen_count = 0
        print( 'cluster %s status = %s' % (cluster_state["cluster_id"], cluster_info['Clusters'][0]['ClusterStatus']) )
    time.sleep( 5 )

  return





##########################################################################
def cluster_connect( cluster_state, database ):

  print( 'cluster_connect()' )

  connection_state = {}

  if database == None:
    db = cluster_state["database"]
  else:
    db = database

  connection_state['connection_flag'] = False

  try:
    connection_state['connection'] = psycopg2.connect( host=cluster_state["ip"], port=cluster_state["port"], database=db, user=cluster_state["admin_username"], password=cluster_state["admin_password"], connect_timeout=8 )
  except Exception as e:
    print( "Connection to %s (%s:%d) failed with error %s" % (cluster_state["cluster_id"], cluster_state["ip"], cluster_state["port"], str(e)) )
  else:
    connection_state['connection_flag'] = True

  if connection_state['connection_flag'] == True:
    connection_state['connection'].autocommit = True
    connection_state['cursor'] = connection_state['connection'].cursor( cursor_factory=psycopg2.extras.DictCursor )
    issue_sql( connection_state, 'set enable_result_cache_for_session=off;' )
    issue_sql( connection_state, 'set analyze_threshold_percent to 0;' )

  return connection_state





##########################################################################
def connection_disconnect( connection_state ):

  print( 'connection_disconnect()' )

  connection_state['cursor'].close()
  connection_state['connection'].close()

  connection_state['connection_flag'] = False

  return





##########################################################################
def cluster_status( rc, cluster_state ):

  print( 'cluster_status()' )

  cluster_info = rc.describe_clusters( ClusterIdentifier = cluster_state["cluster_id"] )
  status = cluster_info['Clusters'][0]['ClusterStatus']
  print( "cluster status = '%s'" % (status) )

  return





##############################################################################
def cluster_get_total_used_blocks( connection_state ):

  print( 'cluster_get_total_used_blocks()' )

  rows, row_count = issue_sql( connection_state, 'select sum(used) from stv_node_storage_capacity;' )

  return rows[0][0]





##############################################################################
def get_blocks_used_by_table( connection_state, schema_name, table_name ):

  print( 'get_blocks_used_by_table()' )

  sql = '''select
             pg_class.oid
           from
             pg_class,
             pg_namespace
           where
             pg_namespace.nspname  = '%s' and
             pg_class.relname      = '%s' and
             pg_class.relnamespace = pg_namespace.oid;''' % (schema_name, table_name)
  rows, row_count = issue_sql( connection_state, sql )
  table_oid = rows[0][0]

  sql = '''select
             count(*)
           from
             stv_blocklist
           where
             tbl = %d;''' % (table_oid)
  rows, row_count = issue_sql( connection_state, sql )

  return rows[0][0]





##########################################################################
def issue_insert_sql( connection_state, sql ):

  connection_state['cursor'].execute( sql )
  insert_count = connection_state['cursor'].rowcount

  return insert_count





##########################################################################
def issue_sql( connection_state, sql ):

  # print( sql )

  connection_state['cursor'].execute( sql )
  row_count = connection_state['cursor'].rowcount

  try:
    rows = connection_state['cursor'].fetchall()
  except psycopg2.ProgrammingError:
    rows = []
    row_count = -1

  return rows, row_count





##########################################################################
def issue_sql_with_status( connection_state, sql ):

  # print( sql )

  try:
    connection_state['cursor'].execute( sql )
  except psycopg2.DatabaseError as e:
    return False, [], -1

  row_count = connection_state['cursor'].rowcount

  try:
    rows = connection_state['cursor'].fetchall()
  except psycopg2.ProgrammingError:
    return False, [], -1

  return True, rows, row_count





##########################################################################
def get_column_names( connection_state ):

  print( 'get_column_names()' )

  return connection_state['cursor'].description





##############################################################################
def get_xid_qid_by_most_recent_matching_query_text( connection_state, query_text ):

  print( 'get_xid_qid_by_most_recent_matching_query_text()' )

  # MG2 : this is awkward as all hell
  #       stl_querytext contains one row for every 200 characters of query text, per query

  # MG2 : there's a bug in how STL_QUERYTEXT stores text
  #       I've not tried to pin down exactly when, but what I do know is
  #       if the row has 200 characters of text and the final character is a space
  #       the next line of text will have an extra space added at its beginning

  # prepared_query_text = query_text.replace('\n', '\\n')

  query_length = len( query_text )
  start = 0

  prepared_query_text = ''
  length = 200

  while start < query_length:
    end = start + length
    text_block = query_text[start:end]
    prepared_query_text += text_block
    start += length
    if prepared_query_text[-1:] == ' ':
      prepared_query_text += ' '
      length = 199
    else:
      length = 200

  query_length         = len( prepared_query_text )
  expected_number_rows = int( query_length / 200 )
  if query_length % 200 != 0:
    expected_number_rows += 1

  sql = '''select
             unnamed_0.trans_id,
             unnamed_0.query_id
           from
             stl_query,
             (
               select
                 xid      as trans_id,
                 query    as query_id,
                 count(*) as number_rows
               from
                 stl_querytext
               where\n''';

  for loop in range( 0, expected_number_rows ):
    start = 200 * loop
    end   = 200 * (loop+1)
    sql += "      ( sequence = %d and text = '%s' ) or\n" % (loop, prepared_query_text[start:end])

  sql = sql[:-4] + '''\n    group by
      xid,
      query
    having
      number_rows = %s
  ) as unnamed_0
where
  stl_query.query = unnamed_0.query_id
order by
  stl_query.starttime desc
limit 1;''' % (expected_number_rows)

  rows, row_count = issue_sql( connection_state, sql )

  return rows[0][0], rows[0][1]





##############################################################################
def get_run_time_for_query_id( connection_state, query_id ):

  print( 'get_run_time_for_query_id()' )

  # MG2 : get query timing
  timing_sql = '''select
                    total_exec_time::varchar::float/1000000 - (select sum(endtime - starttime)::varchar::float/1000000 from svl_compile where query = %d) as run_time
                  from
                    stl_wlm_query
                  where
                    query = %d;''' % (query_id, query_id)
  rows, row_count = issue_sql( connection_state, timing_sql )

  return rows[0][0]





##########################################################################
def generate_slice_distkey_value_lookup( cluster_state, connection_state ):

  print( 'generate_slice_distkey_value_lookup()' )

  slice_distkey_value_lookup = {}
  known_slices_count = 0

  sql = '''create table distkey_test_table
           (
             column_1  bigint not null distkey
           )
           diststyle key;'''
  issue_sql( connection_state, sql )

  finished = False
  value = 0

  while known_slices_count < cluster_state['number_slices']:
    sql = 'insert into distkey_test_table ( column_1 ) values ( %d );' % (value)
    issue_sql( connection_state, sql )
    
    sql = '''select
               slice
             from
               stv_blocklist,
               pg_class
             where
               pg_class.relname     = 'distkey_test_table'  and 
               stv_blocklist.tbl    = pg_class.oid          and
               stv_blocklist.col    = 0;'''
    rows, row_count = issue_sql( connection_state, sql )
    slice = rows[0][0]

    if slice not in slice_distkey_value_lookup:
      slice_distkey_value_lookup[slice] = value
      known_slices_count += 1

    value += 1

    issue_sql( connection_state, 'truncate table distkey_test_table;' )

  issue_sql( connection_state, 'drop table distkey_test_table;' )

  return slice_distkey_value_lookup





##############################################################################
def drop_make_and_populate_test_table( cluster_state, connection_state, slice_distkey_value_lookup, table_name, blocks_per_slice ):

  print( 'drop_make_and_populate_test_table()' )

  not_null_float8_or_bigint_number_values_per_block = 130994

  issue_sql( connection_state, "drop table if exists %s;" % (table_name) )
  ddl = '''create table %s
           (
             column_1  bigint  not null encode raw distkey
           )
           diststyle key
           compound sortkey( column_1 );''' % (table_name)
  issue_sql( connection_state, ddl )

  if blocks_per_slice > 0:

    # MG2 : populate, first with initial 48 values for each slice

    initial_values_per_slice = 48

    values_list = []

    for slice_id in slice_distkey_value_lookup:
      values_list.extend( ['%d' % (slice_distkey_value_lookup[slice_id])] * initial_values_per_slice )

    sql = 'insert into %s ( column_1 ) values ( %s );' % (table_name, '), ('.join(values_list) )
    issue_insert_sql( connection_state, sql )

    issue_sql( connection_state, "analyze %s;" % (table_name) )

    # MG2 : and now the bulk populate
    #       we now issue one thread per slice, where that thread self-joins repeatedly, on its value only
    #       and so runs on just that slice

    target_row_count_per_slice = ( not_null_float8_or_bigint_number_values_per_block * blocks_per_slice ) - initial_values_per_slice

    thread_list = []

    for slice_id in range( 0, cluster_state['number_slices'] ):
      thread = threading.Thread( target=thread_populate_test_table, args=(cluster_state, table_name, slice_id, slice_distkey_value_lookup[slice_id], target_row_count_per_slice) )
      thread_list.append( thread )

    for thread in thread_list:
      thread.start()

    for thread in thread_list:
      thread.join()

    print( 'vacuum test table...' )
    issue_sql( connection_state, "vacuum full %s to 100 percent;" % (table_name) )

  print( 'analyze test table...' )
  issue_sql( connection_state, "analyze %s;" % (table_name) )

  return





##########################################################################
def thread_populate_test_table( cluster_state, table_name, slice_id, slice_distkey_value, target_row_count_per_slice ):

  print( 'thread_populate_test_table( slice_id = %d )' % (slice_id) )

  connection_state = cluster_connect( cluster_state, None )

  rows_inserted = 0

  while rows_inserted < target_row_count_per_slice:
    limit = target_row_count_per_slice - rows_inserted
    print( 'slice %d (value %d) inserting rows (%d remaining to insert)...' % (slice_id, slice_distkey_value, limit) )

    sql = '''insert into
               %s ( column_1 )
             select
               %d
             from
                    %s as t1
               join %s as t2 on t1.column_1 = %d and t2.column_1 = %d and t1.column_1 = t2.column_1
             where
               t1.column_1 = %d and t2.column_1 = %d
             limit %d;''' % (table_name, slice_distkey_value, table_name, table_name, slice_distkey_value, slice_distkey_value, slice_distkey_value, slice_distkey_value, limit)

    insert_count = issue_insert_sql( connection_state, sql )
    rows_inserted += insert_count

  connection_disconnect( connection_state )

  return





##############################################################################
def cluster_get_version_string( cluster_state ):

  print( 'cluster_get_version_string()' )

  connection_state = cluster_connect( cluster_state, None )

  rows, row_count = issue_sql( connection_state, "select version();" )

  connection_disconnect( connection_state )

  return rows[0][0]





##########################################################################
def validate_and_get_cluster_version_strings( results ):

  print( 'validate_and_get_cluster_version_strings()' )

  # MG2 : PostgreSQL 8.0.2 on i686-pc-linux-gnu, compiled by GCC gcc (GCC) 3.4.2 20041017 (Red Hat 3.4.2-6.fc3), Redshift 1.0.28422

  version_strings = {}

  prev_node_count = None

  for node_type in results:
    for node_count in results[node_type]:
      if prev_node_count == None:
        prev_node_count = node_count

      if results[node_type][node_count] != results[node_type][prev_node_count]:
        print( "Uh-oh : disparate cluster version strings across node counts" )
        print( "1. %s %d nodes %s" % (node_type, prev_node_count, results[node_type][prev_node_count]) )
        print( "2. %s %d nodes %s" % (node_type, node_count,      results[node_type][node_count]) )
        exit( 1 )

      prev_node_count = node_count

    final_space = results[node_type][node_count].rfind( ' ' )
    version_strings[node_type] = results[node_type][node_count][final_space+1:]

  return version_strings





##########################################################################
def one_shot( rc, ec2r, ec2c, region, test_name_prefix, node_types, node_counts, test_init_function, test_function, test_cleanup_function, results ):

  print( "one_shot()" )

  test_init_function( rc, ec2r, ec2c, region, test_name_prefix, node_types, node_counts, results )

  global_state = create_global_state( rc, ec2r, ec2c, region, test_name_prefix, [8], [100] )

  thread_list = []

  for node_type in node_types:
    for node_count in node_counts:
      thread = threading.Thread( target=thread_oneshot, args=(rc, ec2r, ec2c, region, global_state, test_name_prefix, node_type, node_count, test_function, results) )
      thread_list.append( thread )

  for thread in thread_list:
    thread.start()

  for thread in thread_list:
    thread.join()

  test_cleanup_function( rc, ec2r, ec2c, region, test_name_prefix, node_types, node_counts, results )

  delete_global_state( rc, ec2r, ec2c, global_state )

  version_strings = validate_and_get_cluster_version_strings( results['versions'] )

  display_results_function( node_types, node_counts, version_strings, results )

  return





##########################################################################
def thread_oneshot( rc, ec2r, ec2c, region, global_state, test_name_prefix, node_type, node_count, test_function, results ):

  print( 'thread_oneshot()' )

  cluster_state = init_cluster_state( region, test_name_prefix, node_type, node_count )

  cluster_create( rc, ec2r, ec2c, global_state, cluster_state )

  results['versions'][node_type][node_count] = cluster_get_version_string( cluster_state )

  test_function( rc, ec2r, ec2c, region, cluster_state, test_name_prefix, node_type, node_count, results )

  cluster_delete( rc, ec2r, ec2c, cluster_state )

  return





##########################################################################
def init( rc, ec2r, ec2c, region, test_name_prefix, node_types, node_counts, test_init_function, results ):

  print( 'init()' )

  test_init_function( rc, ec2r, ec2c, region, test_name_prefix, node_types, node_counts, results )

  global_state = create_global_state( rc, ec2r, ec2c, region, test_name_prefix, [8], [100] )

  thread_list = []

  for node_type in node_types:
    for node_count in node_counts:
      thread = threading.Thread( target=thread_init, args=(rc, ec2r, ec2c, region, global_state, test_name_prefix, node_type, node_count) )
      thread_list.append( thread )

  for thread in thread_list:
    thread.start()

  for thread in thread_list:
    thread.join()

  return





##########################################################################
def thread_init( rc, ec2r, ec2c, region, global_state, test_name_prefix, node_type, node_count ):

  print( 'thread_init()' )

  cluster_state = init_cluster_state( region, test_name_prefix, node_type, node_count )

  exist_flag, ip, port = cluster_exists( rc, cluster_state )

  if exist_flag == True:
    cluster_state['ip']   = ip
    cluster_state['port'] = port

  if exist_flag == False:
    cluster_create( rc, ec2r, ec2c, global_state, cluster_state )

  print( '%d x %s : %s' % (node_count, node_type, cluster_state['ip']) )

  return





##########################################################################
def run( rc, ec2r, ec2c, region, test_name_prefix, node_types, node_counts, test_function, display_results_function, results ):

  print( 'run()' )

  global_state = reconstruct_global_state( rc, ec2r, ec2c, test_name_prefix )

  thread_list = []

  for node_type in node_types:
    for node_count in node_counts:
      thread = threading.Thread( target=thread_run, args=(rc, ec2r, ec2c, region, global_state, test_name_prefix, node_type, node_count, test_function, results) )
      thread_list.append( thread )

  for thread in thread_list:
    thread.start()

  for thread in thread_list:
    thread.join()

  # MG2 : confirm the versions strings for different node counts of a given node type are all identical
  #       if not, bail out

  version_strings = validate_and_get_cluster_version_strings( results['versions'] )

  display_results_function( node_types, node_counts, version_strings, results )

  return





##########################################################################
def thread_run( rc, ec2r, ec2c, region, global_state, test_name_prefix, node_type, node_count, test_function, results ):

  print( 'thread_run()' )

  cluster_state = init_cluster_state( region, test_name_prefix, node_type, node_count )

  exist_flag, ip, port = cluster_exists( rc, cluster_state )

  if exist_flag == True:
    cluster_state['ip']   = ip
    cluster_state['port'] = port

  if exist_flag == False:
    print( "Missing cluster." )
    exit( 1 )

  results['versions'][node_type][node_count] = cluster_get_version_string( cluster_state )

  # MG2 : this is a hack; I can't set number_sices by a query in init_cluster_state() because the cluster may not be up yet
  connection_state = cluster_connect( cluster_state, None )
  rows, row_count = issue_sql( connection_state, "select count(distinct slice) from stv_slices;" )
  cluster_state['number_slices'] = rows[0][0]
  connection_disconnect( connection_state )

  test_function( rc, ec2r, ec2c, region, cluster_state, test_name_prefix, node_type, node_count, results )

  return





##########################################################################
def cleanup( rc, ec2r, ec2c, region, test_name_prefix, node_types, node_counts, test_cleanup_function, results ):

  print( 'cleanup()' )

  global_state = reconstruct_global_state( rc, ec2r, ec2c, test_name_prefix )

  thread_list = []

  for node_type in node_types:
    for node_count in node_counts:
      thread = threading.Thread( target=thread_cleanup, args=(rc, ec2r, ec2c, region, global_state, test_name_prefix, node_type, node_count) )
      thread_list.append( thread )

  for thread in thread_list:
    thread.start()

  for thread in thread_list:
    thread.join()

  delete_global_state( rc, ec2r, ec2c, global_state )

  test_cleanup_function( rc, ec2r, ec2c, region, test_name_prefix, node_types, node_counts, results )

  return





##########################################################################
def thread_cleanup( rc, ec2r, ec2c, region, global_state, test_name_prefix, node_type, node_count ):

  print( 'thread_cleanup()' )

  cluster_state = init_cluster_state( region, test_name_prefix, node_type, node_count )

  cluster_delete( rc, ec2r, ec2c, cluster_state )

  return





##########################################################################
def command_line( test_name_prefix, help_text, region, node_types, node_counts, test_init_function, test_function, test_cleanup_function, display_results_function ):

  print( "command_line()" )

  start = time.time()

  ap = argparse.ArgumentParser( description=help_text )
  sp = ap.add_subparsers( title='command', description='Tells the script what to do.', dest='command', required=True )
  pr_ct = sp.add_parser( 'one-shot', help='Issues init, run and then cleanup.' )
  pr_ct = sp.add_parser( 'init', help='Initializes VPC config and starts up clusters.' )
  pr_ct = sp.add_parser( 'run', help='Runs the test; init must have been performed.' )
  pr_ct = sp.add_parser( 'cleanup', help='Shuts down clusters, deletes VPC config.' )
  pa = ap.parse_args()

  results = \
  {
    'proofs'   : {},
    'tests'    : {},
    'versions' : {}
  }

  for node_type in node_types:
    results['proofs'][node_type] = {}
    results['tests'][node_type] = {}
    results['versions'][node_type] = {}
    for node_count in node_counts:
      results['proofs'][node_type][node_count] = {}
      results['tests'][node_type][node_count] = {}
      results['versions'][node_type][node_count] = 'unset'

  rc   = boto3.client( 'redshift', region )
  ec2r = boto3.resource( 'ec2', region )
  ec2c = boto3.client( 'ec2', region )

  if pa.command == 'one-shot':
    one_shot( rc, ec2r, ec2c, region, test_name_prefix, node_types, node_counts, test_init_function, test_function, test_cleanup_function, results )

  if pa.command == 'init':
    init( rc, ec2r, ec2c, region, test_name_prefix, node_types, node_counts, test_init_function, results )

  if pa.command == 'run':
    run( rc, ec2r, ec2c, region, test_name_prefix, node_types, node_counts, test_function, display_results_function, results )

  if pa.command == 'cleanup':
    cleanup( rc, ec2r, ec2c, region, test_name_prefix, node_types, node_counts, test_cleanup_function, results )

  end = time.time()

  print( 'Duration : %d seconds.' % (end-start) )

  return





##############################################################################
import datetime
import pprint
import statistics
import sys
import threading
import time





##############################################################################
def test_init_function( rc, ec2r, ec2c, region, test_name_prefix, node_types, node_counts, results ):

  print( 'test_init_function()' )

  rpg_name = 'arrp-%s-2' % (test_name_prefix)

  create_redshift_parameter_group( rc, rpg_name, [1,1,1], [2,49,49] )

  return





##############################################################################
def test_function( rc, ec2r, ec2c, region, cluster_state, test_name_prefix, node_type, node_count, results ):

  print( 'test_function()' )

  performance_test( rc, ec2r, ec2c, cluster_state, test_name_prefix, node_type, node_count, results )

  general_proofs( rc, ec2r, ec2c, cluster_state, test_name_prefix, node_type, node_count, results )

  return





##############################################################################
def test_cleanup_function( rc, ec2r, ec2c, region, test_name_prefix, node_types, node_counts, results ):

  print( 'test_cleanup_function()' )

  rpg_name = 'arrp-%s-2' % (test_name_prefix)

  # MG2 : hack - I should check it exists

  try:
    delete_redshift_parameter_group( rc, rpg_name )
  except:
    pass;

  return





##############################################################################
def performance_test( rc, ec2r, ec2c, cluster_state, test_name_prefix, node_type, node_count, results ):

  print( 'performance_test()' )

  # MG2 : we iterate over
  #       . number rows in table
  #       . number of rows being added
  #       . run each combination five times

  table_per_slice_block_counts = [ 1, 10, 100 ]
  insert_per_slice_block_counts = [ 1, 10, 100, 200 ]

  # MG2 : proof function swaps parameter group, so we need to make sure we're using the right group

  tag_name = 'arrp-%s' % (test_name_prefix)
  apply_redshift_parameter_group( rc, cluster_state, tag_name )

  connection_state = cluster_connect( cluster_state, None )

  slice_distkey_value_lookup = generate_slice_distkey_value_lookup( cluster_state, connection_state )

  for insert_per_slice_block_count in insert_per_slice_block_counts:
    drop_make_and_populate_test_table( cluster_state, connection_state, slice_distkey_value_lookup, 'source_%d' % (insert_per_slice_block_count), insert_per_slice_block_count )

  results['tests'][node_type][node_count]['number_slices'] = cluster_state['number_slices']

  for table_per_slice_block_count in table_per_slice_block_counts:
    print( 'table_per_slice_block_count = %d' % (table_per_slice_block_count) )
    results['tests'][node_type][node_count][table_per_slice_block_count] = {}

    for insert_per_slice_block_count in insert_per_slice_block_counts:
      print( 'insert_per_slice_block_count = %d' % (insert_per_slice_block_count) )
      results['tests'][node_type][node_count][table_per_slice_block_count][insert_per_slice_block_count] = []

      for iteration in range( 0, 5 ):
        print( 'iteration = %d' % (iteration) )

        drop_make_and_populate_test_table( cluster_state, connection_state, slice_distkey_value_lookup, 'table_1', table_per_slice_block_count )

        sql = 'insert into table_1( column_1 ) select column_1 from source_%d;' % (insert_per_slice_block_count)
        rows, row_count = issue_sql( connection_state, sql )
        transaction_id, query_id = get_xid_qid_by_most_recent_matching_query_text( connection_state, sql )

        # MG2 : now get the commit duration

        sql = '''select
                   node,
                   case when node = -1 then 'leader' else 'worker#' || node::varchar end as node_name,
                   startqueue,
                   startwork,
                   endflush,
                   endstage,
                   endlocal,
                   startglobal,
                   endtime,
                   permblocks,
                   newblocks,
                   dirtyblocks,
                   headers
                 from
                   stl_commit_stats
                 where
                   xid = %s
                 order by
                   node asc;''' % (transaction_id)
        rows, row_count = issue_sql( connection_state, sql )

        results['tests'][node_type][node_count][table_per_slice_block_count][insert_per_slice_block_count].append( rows )

  for insert_per_slice_block_count in insert_per_slice_block_counts:
    issue_sql( connection_state, 'drop table source_%d;' % (insert_per_slice_block_count) )

  connection_disconnect( connection_state )

  return





##############################################################################
def general_proofs( rc, ec2r, ec2c, cluster_state, test_name_prefix, node_type, node_count, results ):

  print( 'general_proofs()' )

  # MG2 : we look now to find out if there is a single commit queue, or many
  #       now the basic problem we face is that it's not possible to know when a query has entered the commit phase
  #       so we have two test tables
  #       in one queue, we write 100000000 rows per slice to the first table, and let it run till it completes
  #       while it's running, we issue short (1000 row inserts) in the other queue to other table, and we run these continually
  #       (remember, writes to tables are serialized, and there could in theory be a commit queue per WLM queue - so two queues, two tables)
  #       once the long commit query is complete, we let the other thread finish its current query, then we examine STL_COMMIT_STATS
  #       if we see that no queries committed concurrently, we know there's only one commit queue

  connection_state = cluster_connect( cluster_state, None )

  slice_distkey_value_lookup = generate_slice_distkey_value_lookup( cluster_state, connection_state )

  drop_make_and_populate_test_table( cluster_state, connection_state, slice_distkey_value_lookup, 'source_1', 1 )
  drop_make_and_populate_test_table( cluster_state, connection_state, slice_distkey_value_lookup, 'source_1000', 1000 )

  drop_make_and_populate_test_table( cluster_state, connection_state, slice_distkey_value_lookup, 'table_slow', 0 )
  drop_make_and_populate_test_table( cluster_state, connection_state, slice_distkey_value_lookup, 'table_fast', 0 )

  connection_disconnect( connection_state )

  rpg_name = 'arrp-%s-2' % (test_name_prefix)
  apply_redshift_parameter_group( rc, cluster_state, rpg_name )

  connection_state = cluster_connect( cluster_state, None )

  finish_dict = { 'finish' : False }

  slow_thread = threading.Thread( target=thread_proofs_slow, args=(cluster_state, cluster_state['number_slices']) )
  fast_thread = threading.Thread( target=thread_proofs_fast, args=(cluster_state, cluster_state['number_slices'], finish_dict) )

  slow_thread.start()
  fast_thread.start()

  slow_thread.join()
  finish_dict['finish'] = True
  fast_thread.join()

  # MG2 : now examine STL_COMMIT_STATS
  #       do we ever have overlapping queries?
  #       turns out commits can overlap, fractionally, on a per-node basis
  #       a given commit can be ongoing, a given node completes, and it can begin working on the next commit
  #       before the other nodes have finished the first commit
  #       I think though it cannot proceed to global commit

  sql = '''select
             case when node_id = -1 then 'leader' else 'worker#' || node_id::varchar end  as node_name,
             max( concurrent_commits )                                                    as max_concurrent_commits
           from
             (
               select
                 node_id                                                                                                                as node_id,
                 sum(query_delta) over (partition by node_id order by event_time asc rows between unbounded preceding and current row)  as concurrent_commits
               from
                 (
                   select
                     node       as node_id,
                     startwork  as event_time,
                     1          as query_delta
                   from
                     stl_commit_stats

                   union all

                   select
                     node     as node_id,
                     endtime  as event_time,
                     -1       as query_delta
                   from
                     stl_commit_stats
                 )
              )
              group by
                node_id
              order by
                node_id;'''
  rows, row_count = issue_sql( connection_state, sql )

  results['proofs'][node_type][node_count] = rows

  issue_sql( connection_state, 'drop table source_1;' )
  issue_sql( connection_state, 'drop table source_1000;' )
  issue_sql( connection_state, 'drop table table_slow;' )
  issue_sql( connection_state, 'drop table table_fast;' )

  connection_disconnect( connection_state )

  return





##############################################################################
def thread_proofs_slow( cluster_state, number_slices ):

  print( 'thread_proofs_slow()' )

  connection_state = cluster_connect( cluster_state, None )

  issue_sql( connection_state, "set query_group to 'queue_0';" )

  print( 'slow start' )
  sql = 'insert into table_slow( column_1 ) select column_1 from source_1000;'
  issue_insert_sql( connection_state, sql )
  print( 'slow end' )

  connection_disconnect( connection_state )

  return





##############################################################################
def thread_proofs_fast( cluster_state, number_slices, finish_dict ):

  print( 'thread_proofs_fast()' )

  connection_state = cluster_connect( cluster_state, None )

  issue_sql( connection_state, "set query_group to 'queue_1';" )

  count = 0

  while finish_dict['finish'] == False:
    print( 'fast start %d' %(count) )
    sql = 'insert into table_fast( column_1 ) select column_1 from source_1;'
    issue_insert_sql( connection_state, sql )
    print( 'fast end' )
    count += 1

  # MG2 : one more, so we definitely have one started after the slow query
  sql = 'insert into table_fast( column_1 ) select column_1 from source_1;'
  issue_insert_sql( connection_state, sql )

  connection_disconnect( connection_state )

  return





##############################################################################
def display_results_function( node_types, node_counts, version_strings, results ):

  print( 'display_results_function()' )

  print( '\n# Raw Results\n' )

  pprint.pprint( results )

  print( '\n## Performance Test\n' )

  # MG2 : needed if we manually re-use results from Appendix A (the column names are lost)
  column_node        = 0
  column_node_name   = 1
  column_startqueue  = 2
  column_startwork   = 3
  column_endflush    = 4
  column_endstage    = 5
  column_endlocal    = 6
  column_startglobal = 7
  column_endtime     = 8

  table_per_slice_block_counts = [ 1, 10, 100 ]
  insert_per_slice_block_counts = [ 1, 10, 100, 200 ]

  # MG2 : leader node always has the earlier start time
  #       leader node never has first finish time
  #       turns out slices commit serially!
  #       compute slowest and fastest total times, remove those two runs
  #       of those which remain, compute total time (leader node start to last slice finish) and take mean of that

  dervied_results = {}
  stats = {}

  for node_type in node_types:
    dervied_results[node_type] = {}
    stats[node_type] = {}
    for node_count in node_counts:
      dervied_results[node_type][node_count] = {}
      stats[node_type][node_count] = {}
      for table_per_slice_block_count in table_per_slice_block_counts:
        dervied_results[node_type][node_count][table_per_slice_block_count] = {}
        stats[node_type][node_count][table_per_slice_block_count] = {}
        for insert_per_slice_block_count in insert_per_slice_block_counts:
          dervied_results[node_type][node_count][table_per_slice_block_count][insert_per_slice_block_count] = {}
          stats[node_type][node_count][table_per_slice_block_count][insert_per_slice_block_count] = {}

  for node_type in node_types:
    for node_count in node_counts:
      for table_per_slice_block_count in table_per_slice_block_counts:
        for insert_per_slice_block_count in insert_per_slice_block_counts:
          dervied_results[node_type][node_count][table_per_slice_block_count][insert_per_slice_block_count] = []
          for iteration in range( 0, 5 ):
            start_time = results['tests'][node_type][node_count][table_per_slice_block_count][insert_per_slice_block_count][iteration][0][3]
            end_time   = results['tests'][node_type][node_count][table_per_slice_block_count][insert_per_slice_block_count][iteration][0][8]
            for row in results['tests'][node_type][node_count][table_per_slice_block_count][insert_per_slice_block_count][iteration]:
              if row[column_startwork] < start_time:
                start_time = row[column_startwork]
              if row[column_endtime] > end_time:
                end_time = row[column_endtime]
            duration = end_time - start_time
            dervied_results[node_type][node_count][table_per_slice_block_count][insert_per_slice_block_count].append( duration )

  for node_type in node_types:
    for node_count in node_counts:
      for table_per_slice_block_count in table_per_slice_block_counts:
        for insert_per_slice_block_count in insert_per_slice_block_counts:
          dervied_results[node_type][node_count][table_per_slice_block_count][insert_per_slice_block_count].sort()

  for node_type in node_types:
    for node_count in node_counts:
      for table_per_slice_block_count in table_per_slice_block_counts:
        for insert_per_slice_block_count in insert_per_slice_block_counts:
          microsecond_results = []
          for loop in range( 1, 4 ):
            microsecond_results.append( int( dervied_results[node_type][node_count][table_per_slice_block_count][insert_per_slice_block_count][loop] / datetime.timedelta(microseconds=1) ) )

          stats[node_type][node_count][table_per_slice_block_count][insert_per_slice_block_count]['mean'] = statistics.mean( microsecond_results ) / MICROSECONDS_PER_SECOND
          stats[node_type][node_count][table_per_slice_block_count][insert_per_slice_block_count]['sd']   = statistics.pstdev( microsecond_results ) / MICROSECONDS_PER_SECOND

  for node_type in node_types:
    for node_count in node_counts:
      print( '### %s, %d nodes (%d slices) (%s)\n' % (node_type, node_count, results['tests'][node_type][node_count]['number_slices'], version_strings[node_type]) )
      sys.stdout.write( '| |' )
      for table_per_slice_block_count in table_per_slice_block_counts:
        sys.stdout.write( ' %d |' % (table_per_slice_block_count) )
      sys.stdout.write( '\n' )
      print( '| ---: ' * (len(table_per_slice_block_counts)+1) + '|' )
      for insert_per_slice_block_count in insert_per_slice_block_counts:
        sys.stdout.write( '| %d |' % (insert_per_slice_block_count) )
        for table_per_slice_block_count in table_per_slice_block_counts:
          sys.stdout.write( ' %.2f/%.2f |' % (stats[node_type][node_count][table_per_slice_block_count][insert_per_slice_block_count]['mean'], stats[node_type][node_count][table_per_slice_block_count][insert_per_slice_block_count]['sd']) )
        sys.stdout.write( '\n' )
      sys.stdout.write( '\n' )

  print( '## Performance Test (Detailed Results)' )

  # MG2 : show third iteration for largest clusters for dc2 and ra3, but show all clusters for ds2

  for node_type in node_types:
    node_count = node_counts[-1]
    table_per_slice_block_count  = table_per_slice_block_counts[-1]
    insert_per_slice_block_count = insert_per_slice_block_counts[-1]

    print( '\n### %s, %d nodes (%d slices) (%s)' % (node_type, node_count, results['tests'][node_type][node_count]['number_slices'], version_strings[node_type]) )

    iteration = 3

    if node_type == 'ds2.xlarge':
      start_node_count = node_counts[0]
      end_node_count   = node_counts[-1] + 1
    else:
      start_node_count = node_counts[-1]
      end_node_count   = node_counts[-1] + 1

    for node_count in range( start_node_count, end_node_count ):
      print( '\n| Node | Start Queue | Start Work | End Flush | End Staging | End Local | Start Global | End Time |' )
      print( '| :--- | ---: | ---: | ---: | ---: | ---: | ---: | ---: |' )

      # MG2 : find the earliest time in all rows, this is the base time

      start_time = results['tests'][node_type][node_count][table_per_slice_block_count][insert_per_slice_block_count][iteration][0][2]
      for row in results['tests'][node_type][node_count][table_per_slice_block_count][insert_per_slice_block_count][iteration]:
        for column in range( 2, 9 ):
          if row[column] != datetime.datetime(2000, 1, 1, 0, 0):
            if row[column] < start_time:
              start_time = row[column]

      for row in results['tests'][node_type][node_count][table_per_slice_block_count][insert_per_slice_block_count][iteration]:
        if row[0] == -1:
          print( '| %s | %f | %f | | | | %f | %f |' % (row[column_node_name], (row[column_startqueue]  - start_time) / datetime.timedelta(microseconds=1) / MICROSECONDS_PER_SECOND,
                                                                              (row[column_startwork]   - start_time) / datetime.timedelta(microseconds=1) / MICROSECONDS_PER_SECOND,
                                                                              (row[column_startglobal] - start_time) / datetime.timedelta(microseconds=1) / MICROSECONDS_PER_SECOND,
                                                                              (row[column_endtime]     - start_time) / datetime.timedelta(microseconds=1) / MICROSECONDS_PER_SECOND ) )

        if row[0] != -1:
          print( '| %s | | %f | %f | %f | %f | %f | %f |' % (row[column_node_name], (row[column_startwork]   - start_time) / datetime.timedelta(microseconds=1) / MICROSECONDS_PER_SECOND,
                                                                                    (row[column_endflush]    - start_time) / datetime.timedelta(microseconds=1) / MICROSECONDS_PER_SECOND,
                                                                                    (row[column_endstage]    - start_time) / datetime.timedelta(microseconds=1) / MICROSECONDS_PER_SECOND,
                                                                                    (row[column_endlocal]    - start_time) / datetime.timedelta(microseconds=1) / MICROSECONDS_PER_SECOND,
                                                                                    (row[column_startglobal] - start_time) / datetime.timedelta(microseconds=1) / MICROSECONDS_PER_SECOND,
                                                                                    (row[column_endtime]     - start_time) / datetime.timedelta(microseconds=1) / MICROSECONDS_PER_SECOND ) )

  print( '\n## General Proofs' )

  print( '\n### Maximum Number Concurrent Commits' )

  for node_type in node_types:
    for node_count in node_counts:
      print( '\n#### %s, %d nodes (%d slices) (%s)\n' % (node_type, node_count, results['tests'][node_type][node_count]['number_slices'], version_strings[node_type]) )
      print( '```' )
      for row in results['proofs'][node_type][node_count]:
        print( "%s : %d" % (row[0], row[1]) )
      print( '```' )

  return





##############################################################################
test_name_prefix = 'cmtperf'
help_text        = "Commit Performance"
region           = 'eu-central-1'
node_types       = [ 'dc2.large', 'ds2.xlarge', 'ra3.xlplus' ]
node_counts      = [ 2, 3, 4 ]

command_line( test_name_prefix, help_text, region, node_types, node_counts, test_init_function, test_function, test_cleanup_function, display_results_function )

