Update prediction sample to use sample_tools.

Reviewed in https://codereview.appspot.com/9325044/.
This commit is contained in:
Joe Gregorio
2013-06-28 01:30:57 -04:00
parent 1a5e30e432
commit e8391150a7

View File

@@ -37,68 +37,29 @@ To get detailed log output run:
__author__ = ('jcgregorio@google.com (Joe Gregorio), '
'marccohen@google.com (Marc Cohen)')
import apiclient.errors
import gflags
import httplib2
import logging
import argparse
import os
import pprint
import sys
import time
from apiclient.discovery import build
from oauth2client.file import Storage
from oauth2client.client import AccessTokenRefreshError
from oauth2client.client import flow_from_clientsecrets
from oauth2client.tools import run
from apiclient import discovery
from apiclient import sample_tools
from oauth2client import client
FLAGS = gflags.FLAGS
# CLIENT_SECRETS, name of a file containing the OAuth 2.0 information for this
# application, including client_id and client_secret, which are found
# on the API Access tab on the Google APIs
# Console <http://code.google.com/apis/console>
CLIENT_SECRETS = 'samples/prediction/client_secrets.json'
# Helpful message to display in the browser if the CLIENT_SECRETS file
# is missing.
MISSING_CLIENT_SECRETS_MESSAGE = """
WARNING: Please configure OAuth 2.0
To make this sample run you will need to populate the client_secrets.json file
found at:
%s
with information from the APIs Console <https://code.google.com/apis/console>.
""" % os.path.join(os.path.dirname(__file__), CLIENT_SECRETS)
# Set up a Flow object to be used if we need to authenticate.
FLOW = flow_from_clientsecrets(CLIENT_SECRETS,
scope='https://www.googleapis.com/auth/prediction',
message=MISSING_CLIENT_SECRETS_MESSAGE)
# The gflags module makes defining command-line options easy for
# applications. Run this program with the '--help' argument to see
# all the flags that it understands.
gflags.DEFINE_enum('logging_level', 'ERROR',
['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
'Set the level of logging detail.')
gflags.DEFINE_string('object_name',
None,
'Full Google Storage path of csv data (ex bucket/object)')
gflags.MarkFlagAsRequired('object_name')
gflags.DEFINE_string('id',
None,
'Model Id of your choosing to name trained model')
gflags.MarkFlagAsRequired('id')
# Time to wait (in seconds) between successive checks of training status.
SLEEP_TIME = 10
# Declare command-line flags.
argparser = argparse.ArgumentParser(add_help=False)
argparser.add_argument('object_name',
help='Full Google Storage path of csv data (ex bucket/object)')
argparser.add_argument('id',
help='Model Id of your choosing to name trained model')
def print_header(line):
'''Format and print header block sized to length of line'''
header_str = '='
@@ -107,34 +68,14 @@ def print_header(line):
print line
print header_line
def main(argv):
# Let the gflags module process the command-line arguments.
try:
argv = FLAGS(argv)
except gflags.FlagsError, e:
print '%s\\nUsage: %s ARGS\\n%s' % (e, argv[0], FLAGS)
sys.exit(1)
# Set the logging according to the command-line flag
logging.getLogger().setLevel(getattr(logging, FLAGS.logging_level))
# If the Credentials don't exist or are invalid run through the native client
# flow. The Storage object will ensure that if successful the good
# Credentials will get written back to a file.
storage = Storage('prediction.dat')
credentials = storage.get()
if credentials is None or credentials.invalid:
credentials = run(FLOW, storage)
# Create an httplib2.Http object to handle our HTTP requests and authorize it
# with our good Credentials.
http = httplib2.Http()
http = credentials.authorize(http)
service, flags = sample_tools.init(
argv, 'prediction', 'v1.5', __doc__, __file__, parents=[argparser],
scope='https://www.googleapis.com/auth/prediction')
try:
# Get access to the Prediction API.
service = build("prediction", "v1.5", http=http)
papi = service.trainedmodels()
# List models.
@@ -145,7 +86,7 @@ def main(argv):
# Start training request on a data set.
print_header('Submitting model training request')
body = {'id': FLAGS.id, 'storageDataLocation': FLAGS.object_name}
body = {'id': flags.id, 'storageDataLocation': flags.object_name}
start = papi.insert(body=body).execute()
print 'Training results:'
pprint.pprint(start)
@@ -153,7 +94,7 @@ def main(argv):
# Wait for the training to complete.
print_header('Waiting for training to complete')
while True:
status = papi.get(id=FLAGS.id).execute()
status = papi.get(id=flags.id).execute()
state = status['trainingStatus']
print 'Training state: ' + state
if state == 'DONE':
@@ -171,25 +112,26 @@ def main(argv):
# Describe model.
print_header('Fetching model description')
result = papi.analyze(id=FLAGS.id).execute()
result = papi.analyze(id=flags.id).execute()
print 'Analyze results:'
pprint.pprint(result)
# Make a prediction using the newly trained model.
print_header('Making a prediction')
body = {'input': {'csvInstance': ["mucho bueno"]}}
result = papi.predict(body=body, id=FLAGS.id).execute()
result = papi.predict(body=body, id=flags.id).execute()
print 'Prediction results...'
pprint.pprint(result)
# Delete model.
print_header('Deleting model')
result = papi.delete(id=FLAGS.id).execute()
result = papi.delete(id=flags.id).execute()
print 'Model deleted.'
except AccessTokenRefreshError:
except client.AccessTokenRefreshError:
print ("The credentials have been revoked or expired, please re-run"
"the application to re-authorize")
if __name__ == '__main__':
main(sys.argv)