https://github.com/NYUCCL/psiTurk
Tip revision: c375bfc273502053608684705ece5c88ffa8516b authored by Dave Eargle on 23 August 2016, 20:24:34 UTC
bumping again for travis' sake
bumping again for travis' sake
Tip revision: c375bfc
amt_services.py
# -*- coding: utf-8 -*-
""" This module is a facade for AMT (Boto) services. """
import boto.rds
import boto.ec2
from boto.exception import EC2ResponseError
# from boto.rds import RDSConnection
from boto.mturk.connection import MTurkConnection, MTurkRequestError
from boto.mturk.question import ExternalQuestion
from boto.mturk.qualification import LocaleRequirement, \
PercentAssignmentsApprovedRequirement, Qualifications
from flask import jsonify
import re as re
from psiturk.psiturk_config import PsiturkConfig
MYSQL_RESERVED_WORDS_CAP = [
'ACCESSIBLE', 'ADD', 'ALL', 'ALTER', 'ANALYZE', 'AND', 'AS', 'ASC',
'ASENSITIVE', 'BEFORE', 'BETWEEN', 'BIGINT', 'BINARY', 'BLOB', 'BOTH',
'BY', 'CALL', 'CASCADE', 'CASE', 'CHANGE', 'CHAR', 'CHARACTER', 'CHECK',
'COLLATE', 'COLUMN', 'CONDITION', 'CONSTRAINT', 'CONTINUE', 'CONVERT',
'CREATE', 'CROSS', 'CURRENT_DATE', 'CURRENT_TIME', 'CURRENT_TIMESTAMP',
'CURRENT_USER', 'CURSOR', 'DATABASE', 'DATABASES', 'DAY_HOUR',
'DAY_MICROSECOND', 'DAY_MINUTE', 'DAY_SECOND', 'DEC', 'DECIMAL', 'DECLARE',
'DEFAULT', 'DELAYED', 'DELETE', 'DESC', 'DESCRIBE', 'DETERMINISTIC',
'DISTINCT', 'DISTINCTROW', 'DIV', 'DOUBLE', 'DROP', 'DUAL', 'EACH', 'ELSE',
'ELSEIF', 'ENCLOSED', 'ESCAPED', 'EXISTS', 'EXIT', 'EXPLAIN', 'FALSE',
'FETCH', 'FLOAT', 'FLOAT4', 'FLOAT8', 'FOR', 'FORCE', 'FOREIGN', 'FROM',
'FULLTEXT', 'GET', 'GRANT', 'GROUP', 'HAVING', 'HIGH_PRIORITY',
'HOUR_MICROSECOND', 'HOUR_MINUTE', 'HOUR_SECOND', 'IF', 'IGNORE', 'IN',
'INDEX', 'INFILE', 'INNER', 'INOUT', 'INSENSITIVE', 'INSERT', 'INT',
'INT1', 'INT2', 'INT3', 'INT4', 'INT8', 'INTEGER', 'INTERVAL', 'INTO',
'IO_AFTER_GTIDS', 'IO_BEFORE_GTIDS', 'IS', 'ITERATE', 'JOIN', 'KEY',
'KEYS', 'KILL', 'LEADING', 'LEAVE', 'LEFT', 'LIKE', 'LIMIT', 'LINEAR',
'LINES', 'LOAD', 'LOCALTIME', 'LOCALTIMESTAMP', 'LOCK', 'LONG', 'LONGBLOB',
'LONGTEXT', 'LOOP', 'LOW_PRIORITY', 'MASTER_BIND',
'MASTER_SSL_VERIFY_SERVER_CERT', 'MATCH', 'MAXVALUE', 'MEDIUMBLOB',
'MEDIUMINT', 'MEDIUMTEXT', 'MIDDLEINT', 'MINUTE_MICROSECOND',
'MINUTE_SECOND', 'MOD', 'MODIFIES', 'NATURAL', 'NOT', 'NO_WRITE_TO_BINLOG',
'NULL', 'NUMERIC', 'ON', 'OPTIMIZE', 'OPTION', 'OPTIONALLY', 'OR', 'ORDER',
'OUT', 'OUTER', 'OUTFILE', 'PARTITION', 'PRECISION', 'PRIMARY',
'PROCEDURE', 'PURGE', 'RANGE', 'READ', 'READS', 'READ_WRITE', 'REAL',
'REFERENCES', 'REGEXP', 'RELEASE', 'RENAME', 'REPEAT', 'REPLACE',
'REQUIRE', 'RESIGNAL', 'RESTRICT', 'RETURN', 'REVOKE', 'RIGHT', 'RLIKE',
'SCHEMA', 'SCHEMAS', 'SECOND_MICROSECOND', 'SELECT', 'SENSITIVE',
'SEPARATOR', 'SET', 'SHOW', 'SIGNAL', 'SMALLINT', 'SPATIAL', 'SPECIFIC',
'SQL', 'SQLEXCEPTION', 'SQLSTATE', 'SQLWARNING', 'SQL_BIG_RESULT',
'SQL_CALC_FOUND_ROWS', 'SQL_SMALL_RESULT', 'SSL', 'STARTING',
'STRAIGHT_JOIN', 'TABLE', 'TERMINATED', 'THEN', 'TINYBLOB', 'TINYINT',
'TINYTEXT', 'TO', 'TRAILING', 'TRIGGER', 'TRUE', 'UNDO', 'UNION', 'UNIQUE',
'UNLOCK', 'UNSIGNED', 'UPDATE', 'USAGE', 'USE', 'USING', 'UTC_DATE',
'UTC_TIME', 'UTC_TIMESTAMP', 'VALUES', 'VARBINARY', 'VARCHAR',
'VARCHARACTER', 'VARYING', 'WHEN', 'WHERE', 'WHILE', 'WITH', 'WRITE',
'XOR', 'YEAR_MONTH', 'ZEROFILL'
]
MYSQL_RESERVED_WORDS = [word.lower() for word in MYSQL_RESERVED_WORDS_CAP]
class MTurkHIT(object):
''' Structure for dealing with MTurk HITs '''
def __init__(self, json_options):
self.options = json_options
def __repr__(self):
for opt in self.options:
self.options[opt] = self.options[opt].encode('ascii', 'replace')
return "%s \n\tStatus: %s \n\tHITid: %s \
\n\tmax:%s/pending:%s/complete:%s/remain:%s \n\tCreated:%s \
\n\tExpires:%s\n" % (
self.options['title'],
self.options['status'],
self.options['hitid'],
self.options['max_assignments'],
self.options['number_assignments_pending'],
self.options['number_assignments_completed'],
self.options['number_assignments_available'],
self.options['creation_time'],
self.options['expiration']
)
class RDSServices(object):
''' Relational database services via AWS '''
def __init__(self, aws_access_key_id, aws_secret_access_key,
region='us-east-1', quiet=False):
self.update_credentials(aws_access_key_id, aws_secret_access_key)
self.set_region(region)
if not quiet:
self.valid_login = self.verify_aws_login()
# if not self.valid_login:
# print 'Sorry, AWS Credentials invalid.\nYou will only be able to
# '\ + 'test experiments locally until you enter\nvalid '\ +
# 'credentials in the AWS Access section of config.txt.'
def update_credentials(self, aws_access_key_id, aws_secret_access_key):
''' Update credentials '''
self.aws_access_key_id = aws_access_key_id
self.aws_secret_access_key = aws_secret_access_key
def list_regions(self):
''' List regions '''
regions = boto.rds.regions()
return [reg.name for reg in regions]
def get_region(self):
''' Get regions '''
return self.region
def set_region(self, region):
''' Set regions '''
self.region = region
def verify_aws_login(self):
'''Verify AWS login '''
if ((self.aws_access_key_id == 'YourAccessKeyId') or
(self.aws_secret_access_key == 'YourSecretAccessKey')):
return False
else:
# rdsparams = dict(
# aws_access_key_id=self.aws_access_key_id,
# aws_secret_access_key=self.aws_secret_access_key,
# region=self.region)
# self.rdsc = RDSConnection(**rdsparams)
self.rdsc = boto.rds.connect_to_region(
self.region,
aws_access_key_id=self.aws_access_key_id,
aws_secret_access_key=self.aws_secret_access_key
)
try:
self.rdsc.get_all_dbinstances()
except MTurkRequestError as exception:
print exception.error_message
return False
except AttributeError:
print "*** Unable to establish connection to AWS region %s "\
"using your access key/secret key", self.region
return False
except boto.exception.BotoServerError as e:
print "***********************************************************"
print "WARNING"
print "Unable to establish connection to AWS RDS (Amazon relational database services)."
print "See relevant psiturk docs here:"
print "\thttp://psiturk.readthedocs.io/en/latest/configure_databases.html#obtaining-a-low-cost-or-free-mysql-database-on-amazon-s-web-services-cloud"
print "***********************************************************"
return False
else:
return True
def connect_to_aws_rds(self):
''' Connec to aws rds '''
if not self.valid_login:
print 'Sorry, unable to connect to Amazon\'s RDS database server. "\
"AWS credentials invalid.'
return False
# rdsparams = dict(
# aws_access_key_id = self.aws_access_key_id,
# aws_secret_access_key = self.aws_secret_access_key,
# region=self.region)
# self.rdsc = RDSConnection(**rdsparams)
self.rdsc = boto.rds.connect_to_region(
self.region,
aws_access_key_id=self.aws_access_key_id,
aws_secret_access_key=self.aws_secret_access_key
)
return True
def get_db_instance_info(self, dbid):
''' Get DB instance info '''
if not self.connect_to_aws_rds():
return False
try:
instances = self.rdsc.get_all_dbinstances(dbid)
except:
return False
else:
myinstance = instances[0]
return myinstance
def allow_access_to_instance(self, _, ip_address):
''' Allow access to instance. '''
if not self.connect_to_aws_rds():
return False
try:
conn = boto.ec2.connect_to_region(
self.region,
aws_access_key_id=self.aws_access_key_id,
aws_secret_access_key=self.aws_secret_access_key
)
sgs = conn.get_all_security_groups('default')
default_sg = sgs[0]
default_sg.authorize(ip_protocol='tcp', from_port=3306,
to_port=3306, cidr_ip=str(ip_address)+'/32')
except EC2ResponseError, exception:
if exception.error_code == "InvalidPermission.Duplicate":
return True # ok it already exists
else:
return False
else:
return True
def get_db_instances(self):
''' DB instance '''
if not self.connect_to_aws_rds():
return False
try:
instances = self.rdsc.get_all_dbinstances()
except:
return False
else:
return instances
def delete_db_instance(self, dbid):
''' Delete DB '''
if not self.connect_to_aws_rds():
return False
try:
database = self.rdsc.delete_dbinstance(dbid,
skip_final_snapshot=True)
print database
except:
return False
else:
return True
def validate_instance_id(self, instid):
''' Validate instance ID '''
# 1-63 alphanumeric characters, first must be a letter.
if re.match('[\w-]+$', instid) is not None:
if len(instid) <= 63 and len(instid) >= 1:
if instid[0].isalpha():
return True
return "*** Error: Instance ids must be 1-63 alphanumeric characters, \
first is a letter."
def validate_instance_size(self, size):
''' integer between 5-1024 (inclusive) '''
try:
int(size)
except ValueError:
return '*** Error: size must be a whole number between 5 and 1024.'
if int(size) < 5 or int(size) > 1024:
return '*** Error: size must be between 5-1024 GB.'
return True
def validate_instance_username(self, username):
''' Validate instance username '''
# 1-16 alphanumeric characters - first character must be a letter -
# cannot be a reserved MySQL word
if re.match('[\w-]+$', username) is not None:
if len(username) <= 16 and len(username) >= 1:
if username[0].isalpha():
if username not in MYSQL_RESERVED_WORDS:
return True
return '*** Error: Usernames must be 1-16 alphanumeric chracters, \
first a letter, cannot be reserved MySQL word.'
def validate_instance_password(self, password):
''' Validate instance passwords '''
# 1-16 alphanumeric characters - first character must be a letter -
# cannot be a reserved MySQL word
if re.match('[\w-]+$', password) is not None:
if len(password) <= 41 and len(password) >= 8:
return True
return '*** Error: Passwords must be 8-41 alphanumeric characters'
def validate_instance_dbname(self, dbname):
''' Validate instance database name '''
# 1-64 alphanumeric characters, cannot be a reserved MySQL word
if re.match('[\w-]+$', dbname) is not None:
if len(dbname) <= 41 and len(dbname) >= 1:
if dbname.lower() not in MYSQL_RESERVED_WORDS:
return True
return '*** Error: Database names must be 1-64 alphanumeric characters,\
cannot be a reserved MySQL word.'
def create_db_instance(self, params):
''' Create db instance '''
if not self.connect_to_aws_rds():
return False
try:
database = self.rdsc.create_dbinstance(
id=params['id'],
allocated_storage=params['size'],
instance_class='db.t1.micro',
engine='MySQL',
master_username=params['username'],
master_password=params['password'],
db_name=params['dbname'],
multi_az=False
)
except:
return False
else:
return True
class MTurkServices(object):
''' MTurk services '''
def __init__(self, aws_access_key_id, aws_secret_access_key, is_sandbox):
self.update_credentials(aws_access_key_id, aws_secret_access_key)
self.set_sandbox(is_sandbox)
self.valid_login = self.verify_aws_login()
if not self.valid_login:
print 'WARNING *****************************'
print 'Sorry, AWS Credentials invalid.\nYou will only be able to '\
'test experiments locally until you enter\nvalid '\
'credentials in the AWS Access section of ~/.psiturkconfig\n'
def update_credentials(self, aws_access_key_id, aws_secret_access_key):
''' Update credentials '''
self.aws_access_key_id = aws_access_key_id
self.aws_secret_access_key = aws_secret_access_key
def set_sandbox(self, is_sandbox):
''' Set sandbox '''
self.is_sandbox = is_sandbox
def get_reviewable_hits(self):
''' Get reviewable HITs '''
if not self.connect_to_turk():
return False
try:
hits = self.mtc.get_all_hits()
except MTurkRequestError:
return False
reviewable_hits = [hit for hit in hits if hit.HITStatus == "Reviewable" \
or hit.HITStatus == "Reviewing"]
hits_data = [MTurkHIT({
'hitid': hit.HITId,
'title': hit.Title,
'status': hit.HITStatus,
'max_assignments': hit.MaxAssignments,
'number_assignments_completed': hit.NumberOfAssignmentsCompleted,
'number_assignments_pending': hit.NumberOfAssignmentsPending,
'number_assignments_available': hit.NumberOfAssignmentsAvailable,
'creation_time': hit.CreationTime,
'expiration': hit.Expiration
}) for hit in reviewable_hits]
return hits_data
def get_all_hits(self):
''' Get all HITs '''
if not self.connect_to_turk():
return False
try:
hits = self.mtc.get_all_hits()
except MTurkRequestError:
return False
hits_data = [MTurkHIT({
'hitid': hit.HITId,
'title': hit.Title,
'status': hit.HITStatus,
'max_assignments': hit.MaxAssignments,
'number_assignments_completed': hit.NumberOfAssignmentsCompleted,
'number_assignments_pending': hit.NumberOfAssignmentsPending,
'number_assignments_available': hit.NumberOfAssignmentsAvailable,
'creation_time': hit.CreationTime,
'expiration': hit.Expiration,
}) for hit in hits]
return hits_data
def get_active_hits(self):
''' Get active HITs '''
if not self.connect_to_turk():
return False
# hits = self.mtc.search_hits()
try:
hits = self.mtc.get_all_hits()
except MTurkRequestError:
return False
active_hits = [hit for hit in hits if not hit.expired]
hits_data = [MTurkHIT({
'hitid': hit.HITId,
'title': hit.Title,
'status': hit.HITStatus,
'max_assignments': hit.MaxAssignments,
'number_assignments_completed': hit.NumberOfAssignmentsCompleted,
'number_assignments_pending': hit.NumberOfAssignmentsPending,
'number_assignments_available': hit.NumberOfAssignmentsAvailable,
'creation_time': hit.CreationTime,
'expiration': hit.Expiration,
}) for hit in active_hits]
return hits_data
def get_workers(self, assignment_status=None):
''' Get workers '''
if not self.connect_to_turk():
return False
try:
hits = self.mtc.get_all_hits()
hit_ids = [hit.HITId for hit in hits]
workers_nested = []
page_size=100
for hit_id in hit_ids:
current_page_number=1
hit_assignments = self.mtc.get_assignments(
hit_id,
status=assignment_status,
sort_by='SubmitTime',
page_size=page_size,
page_number=current_page_number
)
totalNumResults = int(hit_assignments.TotalNumResults)
total_pages = (totalNumResults // page_size) + (totalNumResults % page_size > 0) #do integer division then round up if necessary
while current_page_number < total_pages:
current_page_number += 1
hit_assignments += self.mtc.get_assignments(
hit_id,
status=assignment_status,
sort_by='SubmitTime',
page_size=page_size,
page_number=current_page_number
)
workers_nested.append(hit_assignments)
workers = [val for subl in workers_nested for val in subl] # Flatten nested lists
except MTurkRequestError:
return False
worker_data = [{
'hitId': worker.HITId,
'assignmentId': worker.AssignmentId,
'workerId': worker.WorkerId,
'submit_time': worker.SubmitTime,
'accept_time': worker.AcceptTime,
'status': worker.AssignmentStatus
} for worker in workers]
return worker_data
def bonus_worker(self, assignment_id, amount, reason=""):
''' Bonus worker '''
if not self.connect_to_turk():
return False
try:
bonus = MTurkConnection.get_price_as_price(amount)
assignment = self.mtc.get_assignment(assignment_id)[0]
worker_id = assignment.WorkerId
self.mtc.grant_bonus(worker_id, assignment_id, bonus, reason)
return True
except MTurkRequestError as exception:
print exception
return False
def approve_worker(self, assignment_id):
''' Approve worker '''
if not self.connect_to_turk():
return False
try:
self.mtc.approve_assignment(assignment_id, feedback=None)
return True
except MTurkRequestError:
return False
def reject_worker(self, assignment_id):
''' Reject worker '''
if not self.connect_to_turk():
return False
try:
self.mtc.reject_assignment(assignment_id, feedback=None)
return True
except MTurkRequestError:
return False
def unreject_worker(self, assignment_id):
''' Unreject worker '''
if not self.connect_to_turk():
return False
try:
self.mtc.approve_rejected_assignment(assignment_id)
return True
except MTurkRequestError:
return False
def verify_aws_login(self):
''' Verify AWS login '''
if ((self.aws_access_key_id == 'YourAccessKeyId') or
(self.aws_secret_access_key == 'YourSecretAccessKey')):
return False
else:
host = 'mechanicalturk.amazonaws.com'
mturkparams = dict(
aws_access_key_id=self.aws_access_key_id,
aws_secret_access_key=self.aws_secret_access_key,
host=host)
self.mtc = MTurkConnection(**mturkparams)
try:
self.mtc.get_account_balance()
except MTurkRequestError as exception:
print exception.error_message
return False
else:
return True
def connect_to_turk(self):
''' Connect to turk '''
if not self.valid_login:
print 'Sorry, unable to connect to Amazon Mechanical Turk. AWS '\
'credentials invalid.'
return False
if self.is_sandbox:
host = 'mechanicalturk.sandbox.amazonaws.com'
else:
host = 'mechanicalturk.amazonaws.com'
mturkparams = dict(
aws_access_key_id=self.aws_access_key_id,
aws_secret_access_key=self.aws_secret_access_key,
host=host)
self.mtc = MTurkConnection(**mturkparams)
return True
def configure_hit(self, hit_config):
''' Configure HIT '''
# configure question_url based on the id
experiment_portal_url = hit_config['ad_location']
frame_height = 600
mturk_question = ExternalQuestion(experiment_portal_url, frame_height)
# Qualification:
quals = Qualifications()
approve_requirement = hit_config['approve_requirement']
quals.add(
PercentAssignmentsApprovedRequirement("GreaterThanOrEqualTo",
approve_requirement))
if hit_config['us_only']:
quals.add(LocaleRequirement("EqualTo", "US"))
# Create a HIT type for this HIT.
hit_type = self.mtc.register_hit_type(
hit_config['title'],
hit_config['description'],
hit_config['reward'],
hit_config['duration'],
keywords=hit_config['keywords'],
approval_delay=None,
qual_req=None)[0]
# Check the config file to see if notifications are wanted.
config = PsiturkConfig()
config.load_config()
try:
url = config.get('Server Parameters', 'notification_url')
all_event_types = [
"AssignmentAccepted",
"AssignmentAbandoned",
"AssignmentReturned",
"AssignmentSubmitted",
"HITReviewable",
"HITExpired",
]
self.mtc.set_rest_notification(
hit_type.HITTypeId,
url,
event_types=all_event_types)
except:
pass
# Specify all the HIT parameters
self.param_dict = dict(
hit_type=hit_type.HITTypeId,
question=mturk_question,
lifetime=hit_config['lifetime'],
max_assignments=hit_config['max_assignments'],
title=hit_config['title'],
description=hit_config['description'],
keywords=hit_config['keywords'],
reward=hit_config['reward'],
duration=hit_config['duration'],
approval_delay=None,
questions=None,
qualifications=quals,
response_groups=[
'Minimal',
'HITDetail',
'HITQuestion',
'HITAssignmentSummary'
])
def check_balance(self):
''' Check balance '''
if not self.connect_to_turk():
return '-'
return self.mtc.get_account_balance()[0]
# TODO (if valid AWS credentials haven't been provided then
# connect_to_turk() will fail, not error checking here and elsewhere)
def create_hit(self, hit_config):
''' Create HIT '''
try:
if not self.connect_to_turk():
return False
self.configure_hit(hit_config)
myhit = self.mtc.create_hit(**self.param_dict)[0]
self.hitid = myhit.HITId
except:
return False
else:
return self.hitid
# TODO(Jay): Have a wrapper around functions that serializes them.
# Default output should not be serialized.
def expire_hit(self, hitid):
''' Expire HIT '''
if not self.connect_to_turk():
return False
try:
self.mtc.expire_hit(hitid)
return True
except MTurkRequestError:
print "Failed to expire HIT. Please check the ID and try again."
return False
def dispose_hit(self, hitid):
''' Dispose HIT '''
if not self.connect_to_turk():
return False
try:
self.mtc.dispose_hit(hitid)
except Exception, e:
print "Failed to dispose of HIT %s. Make sure there are no "\
"assignments remaining to be reviewed." % hitid
def extend_hit(self, hitid, assignments_increment=None,
expiration_increment=None):
if not self.connect_to_turk():
return False
try:
self.mtc.extend_hit(hitid,
assignments_increment=int(assignments_increment
or 0))
self.mtc.extend_hit(hitid,
expiration_increment=int(expiration_increment
or 0)*60)
return True
except Exception, e:
print "Failed to extend HIT %s. Please check the ID and try again." \
% hitid
return False
def get_hit_status(self, hitid):
''' Get HIT status '''
if not self.connect_to_turk():
return False
try:
hitdata = self.mtc.get_hit(hitid)
except:
return False
return hitdata[0].HITStatus
def get_summary(self):
''' Get summary '''
try:
balance = self.check_balance()
summary = jsonify(balance=str(balance))
return summary
except MTurkRequestError as exception:
print exception.error_message
return False