https://github.com/NVIDIA/DIGITS
Revision 88666fa7061dd0ade9f6a0d6f13f7078a3507f75 authored by Luke Yeager on 26 September 2016, 21:07:46 UTC, committed by GitHub on 26 September 2016, 21:07:46 UTC
Move image gradient extensions to plug-in area
Tip revision: 88666fa7061dd0ade9f6a0d6f13f7078a3507f75 authored by Luke Yeager on 26 September 2016, 21:07:46 UTC
Merge pull request #947 from gheinrich/dev/gradients-plugin
Merge pull request #947 from gheinrich/dev/gradients-plugin
Tip revision: 88666fa
digits-walkthrough
#!/usr/bin/env python2
# Copyright (c) 2014-2016, NVIDIA CORPORATION. All rights reserved.
import argparse
import json
import os
import requests
import socket
import sys
import time
from urlparse import urlparse
from selenium import webdriver
from selenium.webdriver.common.action_chains import ActionChains
from selenium.webdriver.common.keys import Keys
wait_time = 2
def wait(s=wait_time):
time.sleep(s)
def get_page(driver, url):
driver.get(url)
wait()
def create_dataset(driver, name, folder):
dropdown_elements = driver.find_elements_by_class_name('dropdown-toggle')
dropdown_dataset = dropdown_elements[0]
dropdown_dataset.click()
wait()
dropdown_menu = driver.find_elements_by_class_name('dropdown-menu')
dropdown_menu_dataset = dropdown_menu[0]
dataset_link = dropdown_menu_dataset.find_element_by_tag_name('a')
dataset_link.click()
wait()
folder_train_tooltip = driver.find_element_by_name('folder_train_explanation')
folder_train_tooltip.click()
wait()
folder_train = driver.find_element_by_name('folder_train')
folder_train.send_keys(folder)
wait()
resize_channels_tooltip = driver.find_element_by_name('resize_channels_explanation')
resize_channels_tooltip.click()
wait()
image_type = driver.find_element_by_name('resize_channels')
image_type.click()
wait()
for option in image_type.find_elements_by_tag_name('option'):
if option.text == 'Grayscale':
image_type.click()
option.click()
break
wait()
resize_width_tooltip = driver.find_element_by_name('resize_dims_explanation')
resize_width_tooltip.click()
wait()
resize_width = driver.find_element_by_name('resize_width')
resize_width.clear()
resize_width.send_keys('28')
wait()
resize_height = driver.find_element_by_name('resize_height')
resize_height.clear()
resize_height.send_keys('28')
wait()
dataset_name = driver.find_element_by_name('dataset_name')
dataset_name.click()
dataset_name.send_keys(name)
wait()
create_button = driver.find_element_by_name('create-dataset')
create_button.click()
job_url = driver.current_url.replace('datasets', 'jobs')
status_url = job_url + '/status'
done = False
while not done:
r = requests.get(status_url)
status = json.loads(r.content)
done = status['status'] == 'Done'
wait()
wait()
def create_model(driver, name, dataset_name, test_image):
dropdown_elements = driver.find_elements_by_class_name('dropdown-toggle')
dropdown_model = dropdown_elements[1]
dropdown_model.click()
wait()
dropdown_menu = driver.find_elements_by_class_name('dropdown-menu')
dropdown_menu_model = dropdown_menu[1]
model_link = dropdown_menu_model.find_element_by_tag_name('a')
model_link.click()
# move to 0,0 so we don't accidentally select a hover element
body = driver.find_element_by_css_selector('body')
body.click()
wait()
dataset_tooltip = driver.find_element_by_name('dataset_explanation')
dataset_tooltip.click()
wait()
datasets = driver.find_element_by_name('dataset')
for option in datasets.find_elements_by_tag_name('option'):
if option.text == dataset_name:
option.click()
break
wait()
standard_networks = driver.find_elements_by_name('standard_networks')
lenet = standard_networks[0]
lenet.click()
wait()
model_name = driver.find_element_by_name('model_name')
model_name.click()
model_name.send_keys(name)
wait()
create_button = driver.find_element_by_name('create-model')
create_button.click()
job_url = driver.current_url.replace('models', 'jobs')
status_url = job_url + '/status'
done = False
while not done:
r = requests.get(status_url)
status = json.loads(r.content)
done = status['status'] == 'Done'
wait()
#driver.refresh()
# test image
print 'Testing...'
image_path = driver.find_element_by_name('image_url')
image_path.send_keys(test_image)
wait()
show_visualizations_tooltip = driver.find_element_by_name('show_visualizations_explanation')
show_visualizations_tooltip.click()
wait()
show_visualizations = driver.find_element_by_name('show_visualizations')
show_visualizations.click()
wait()
test_button = driver.find_element_by_name('classify-one-btn')
test_button.click()
# Opens in a new window - switch to it
driver.close()
driver.switch_to_window(driver.window_handles[-1])
def main(argv):
parser = argparse.ArgumentParser(description='Run a Selenium demo of DIGITS')
### Positional arguments
parser.add_argument('mnist_image_folder',
type=str,
help='Path to the MNIST dataset folder')
parser.add_argument('test_image',
type=str,
help='Image to test with')
### Optional arguments
parser.add_argument('-p', '--port',
type=int,
default=80,
help='Port the server is running on (default 80)')
args = vars(parser.parse_args())
home_page = 'http://0.0.0.0:%d/' % args['port']
dataset_path = args['mnist_image_folder']
dataset_name = 'MNIST Dataset'
model_name = 'MNIST Model'
test_image = args['test_image']
r = requests.get(home_page)
assert r.status_code == requests.codes.ok, 'page "%s" does not exist - are you looking on the wrong port?' % home_page
# Start selenium driver.
driver = webdriver.Firefox()
print 'Firefox webdriver started.'
mouse = webdriver.ActionChains(driver)
try:
driver.maximize_window()
get_page(driver, home_page)
print 'Creating dataset...'
create_dataset(driver, dataset_name, dataset_path)
get_page(driver, home_page)
print 'Creating model...'
create_model(driver, model_name, dataset_name, test_image)
print 'Done.'
#display an alert message
get_page(driver, "javascript:alert('Completed Walkthrough!');void(0);")
wait()
driver.switch_to_alert().accept()
# wait until the window is closed
while True:
try:
get_page(driver, "javascript:console.log('Waiting...');void(0);")
except Exception as e:
print e
break
wait()
except KeyboardInterrupt:
pass
finally:
try:
driver.quit()
except socket.error:
pass
if __name__ == '__main__':
main(sys.argv)
![swh spinner](/static/img/swh-spinner.gif)
Computing file changes ...