Update prediction sample to use sample_tools.
Reviewed in https://codereview.appspot.com/9325044/.
This commit is contained in:
@@ -37,68 +37,29 @@ To get detailed log output run:
|
||||
__author__ = ('jcgregorio@google.com (Joe Gregorio), '
|
||||
'marccohen@google.com (Marc Cohen)')
|
||||
|
||||
import apiclient.errors
|
||||
import gflags
|
||||
import httplib2
|
||||
import logging
|
||||
import argparse
|
||||
import os
|
||||
import pprint
|
||||
import sys
|
||||
import time
|
||||
|
||||
from apiclient.discovery import build
|
||||
from oauth2client.file import Storage
|
||||
from oauth2client.client import AccessTokenRefreshError
|
||||
from oauth2client.client import flow_from_clientsecrets
|
||||
from oauth2client.tools import run
|
||||
from apiclient import discovery
|
||||
from apiclient import sample_tools
|
||||
from oauth2client import client
|
||||
|
||||
FLAGS = gflags.FLAGS
|
||||
|
||||
# CLIENT_SECRETS, name of a file containing the OAuth 2.0 information for this
|
||||
# application, including client_id and client_secret, which are found
|
||||
# on the API Access tab on the Google APIs
|
||||
# Console <http://code.google.com/apis/console>
|
||||
CLIENT_SECRETS = 'samples/prediction/client_secrets.json'
|
||||
|
||||
# Helpful message to display in the browser if the CLIENT_SECRETS file
|
||||
# is missing.
|
||||
MISSING_CLIENT_SECRETS_MESSAGE = """
|
||||
WARNING: Please configure OAuth 2.0
|
||||
|
||||
To make this sample run you will need to populate the client_secrets.json file
|
||||
found at:
|
||||
|
||||
%s
|
||||
|
||||
with information from the APIs Console <https://code.google.com/apis/console>.
|
||||
|
||||
""" % os.path.join(os.path.dirname(__file__), CLIENT_SECRETS)
|
||||
|
||||
# Set up a Flow object to be used if we need to authenticate.
|
||||
FLOW = flow_from_clientsecrets(CLIENT_SECRETS,
|
||||
scope='https://www.googleapis.com/auth/prediction',
|
||||
message=MISSING_CLIENT_SECRETS_MESSAGE)
|
||||
|
||||
# The gflags module makes defining command-line options easy for
|
||||
# applications. Run this program with the '--help' argument to see
|
||||
# all the flags that it understands.
|
||||
gflags.DEFINE_enum('logging_level', 'ERROR',
|
||||
['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
|
||||
'Set the level of logging detail.')
|
||||
|
||||
gflags.DEFINE_string('object_name',
|
||||
None,
|
||||
'Full Google Storage path of csv data (ex bucket/object)')
|
||||
gflags.MarkFlagAsRequired('object_name')
|
||||
|
||||
gflags.DEFINE_string('id',
|
||||
None,
|
||||
'Model Id of your choosing to name trained model')
|
||||
gflags.MarkFlagAsRequired('id')
|
||||
|
||||
# Time to wait (in seconds) between successive checks of training status.
|
||||
SLEEP_TIME = 10
|
||||
|
||||
|
||||
# Declare command-line flags.
|
||||
argparser = argparse.ArgumentParser(add_help=False)
|
||||
argparser.add_argument('object_name',
|
||||
help='Full Google Storage path of csv data (ex bucket/object)')
|
||||
argparser.add_argument('id',
|
||||
help='Model Id of your choosing to name trained model')
|
||||
|
||||
|
||||
def print_header(line):
|
||||
'''Format and print header block sized to length of line'''
|
||||
header_str = '='
|
||||
@@ -107,34 +68,14 @@ def print_header(line):
|
||||
print line
|
||||
print header_line
|
||||
|
||||
|
||||
def main(argv):
|
||||
# Let the gflags module process the command-line arguments.
|
||||
try:
|
||||
argv = FLAGS(argv)
|
||||
except gflags.FlagsError, e:
|
||||
print '%s\\nUsage: %s ARGS\\n%s' % (e, argv[0], FLAGS)
|
||||
sys.exit(1)
|
||||
|
||||
# Set the logging according to the command-line flag
|
||||
logging.getLogger().setLevel(getattr(logging, FLAGS.logging_level))
|
||||
|
||||
# If the Credentials don't exist or are invalid run through the native client
|
||||
# flow. The Storage object will ensure that if successful the good
|
||||
# Credentials will get written back to a file.
|
||||
storage = Storage('prediction.dat')
|
||||
credentials = storage.get()
|
||||
if credentials is None or credentials.invalid:
|
||||
credentials = run(FLOW, storage)
|
||||
|
||||
# Create an httplib2.Http object to handle our HTTP requests and authorize it
|
||||
# with our good Credentials.
|
||||
http = httplib2.Http()
|
||||
http = credentials.authorize(http)
|
||||
service, flags = sample_tools.init(
|
||||
argv, 'prediction', 'v1.5', __doc__, __file__, parents=[argparser],
|
||||
scope='https://www.googleapis.com/auth/prediction')
|
||||
|
||||
try:
|
||||
|
||||
# Get access to the Prediction API.
|
||||
service = build("prediction", "v1.5", http=http)
|
||||
papi = service.trainedmodels()
|
||||
|
||||
# List models.
|
||||
@@ -145,7 +86,7 @@ def main(argv):
|
||||
|
||||
# Start training request on a data set.
|
||||
print_header('Submitting model training request')
|
||||
body = {'id': FLAGS.id, 'storageDataLocation': FLAGS.object_name}
|
||||
body = {'id': flags.id, 'storageDataLocation': flags.object_name}
|
||||
start = papi.insert(body=body).execute()
|
||||
print 'Training results:'
|
||||
pprint.pprint(start)
|
||||
@@ -153,7 +94,7 @@ def main(argv):
|
||||
# Wait for the training to complete.
|
||||
print_header('Waiting for training to complete')
|
||||
while True:
|
||||
status = papi.get(id=FLAGS.id).execute()
|
||||
status = papi.get(id=flags.id).execute()
|
||||
state = status['trainingStatus']
|
||||
print 'Training state: ' + state
|
||||
if state == 'DONE':
|
||||
@@ -171,25 +112,26 @@ def main(argv):
|
||||
|
||||
# Describe model.
|
||||
print_header('Fetching model description')
|
||||
result = papi.analyze(id=FLAGS.id).execute()
|
||||
result = papi.analyze(id=flags.id).execute()
|
||||
print 'Analyze results:'
|
||||
pprint.pprint(result)
|
||||
|
||||
# Make a prediction using the newly trained model.
|
||||
print_header('Making a prediction')
|
||||
body = {'input': {'csvInstance': ["mucho bueno"]}}
|
||||
result = papi.predict(body=body, id=FLAGS.id).execute()
|
||||
result = papi.predict(body=body, id=flags.id).execute()
|
||||
print 'Prediction results...'
|
||||
pprint.pprint(result)
|
||||
|
||||
# Delete model.
|
||||
print_header('Deleting model')
|
||||
result = papi.delete(id=FLAGS.id).execute()
|
||||
result = papi.delete(id=flags.id).execute()
|
||||
print 'Model deleted.'
|
||||
|
||||
except AccessTokenRefreshError:
|
||||
except client.AccessTokenRefreshError:
|
||||
print ("The credentials have been revoked or expired, please re-run"
|
||||
"the application to re-authorize")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main(sys.argv)
|
||||
|
||||
Reference in New Issue
Block a user