https://github.com/NVIDIA/DIGITS
Revision 88666fa7061dd0ade9f6a0d6f13f7078a3507f75 authored by Luke Yeager on 26 September 2016, 21:07:46 UTC, committed by GitHub on 26 September 2016, 21:07:46 UTC
Move image gradient extensions to plug-in area
2 parent s ba2bbf6 + aab35bd
Raw File
Tip revision: 88666fa7061dd0ade9f6a0d6f13f7078a3507f75 authored by Luke Yeager on 26 September 2016, 21:07:46 UTC
Merge pull request #947 from gheinrich/dev/gradients-plugin
Tip revision: 88666fa
digits-walkthrough
#!/usr/bin/env python2
# Copyright (c) 2014-2016, NVIDIA CORPORATION.  All rights reserved.

import argparse
import json
import os
import requests
import socket
import sys
import time
from urlparse import urlparse

from selenium import webdriver
from selenium.webdriver.common.action_chains import ActionChains
from selenium.webdriver.common.keys import Keys

wait_time = 2
def wait(s=wait_time):
    time.sleep(s)

def get_page(driver, url):
    driver.get(url)
    wait()

def create_dataset(driver, name, folder):
    dropdown_elements = driver.find_elements_by_class_name('dropdown-toggle')
    dropdown_dataset = dropdown_elements[0]
    dropdown_dataset.click()
    wait()

    dropdown_menu = driver.find_elements_by_class_name('dropdown-menu')
    dropdown_menu_dataset = dropdown_menu[0]
    dataset_link = dropdown_menu_dataset.find_element_by_tag_name('a')
    dataset_link.click()
    wait()

    folder_train_tooltip = driver.find_element_by_name('folder_train_explanation')
    folder_train_tooltip.click()
    wait()

    folder_train = driver.find_element_by_name('folder_train')
    folder_train.send_keys(folder)
    wait()

    resize_channels_tooltip = driver.find_element_by_name('resize_channels_explanation')
    resize_channels_tooltip.click()
    wait()

    image_type = driver.find_element_by_name('resize_channels')
    image_type.click()
    wait()
    for option in image_type.find_elements_by_tag_name('option'):
        if option.text == 'Grayscale':
            image_type.click()
            option.click()
            break
    wait()

    resize_width_tooltip = driver.find_element_by_name('resize_dims_explanation')
    resize_width_tooltip.click()
    wait()

    resize_width = driver.find_element_by_name('resize_width')
    resize_width.clear()
    resize_width.send_keys('28')
    wait()

    resize_height = driver.find_element_by_name('resize_height')
    resize_height.clear()
    resize_height.send_keys('28')
    wait()

    dataset_name = driver.find_element_by_name('dataset_name')
    dataset_name.click()
    dataset_name.send_keys(name)
    wait()

    create_button = driver.find_element_by_name('create-dataset')
    create_button.click()

    job_url = driver.current_url.replace('datasets', 'jobs')
    status_url = job_url + '/status'
    done = False
    while not done:
        r = requests.get(status_url)
        status = json.loads(r.content)
        done = status['status'] == 'Done'
        wait()
    wait()

def create_model(driver, name, dataset_name, test_image):
    dropdown_elements = driver.find_elements_by_class_name('dropdown-toggle')
    dropdown_model = dropdown_elements[1]
    dropdown_model.click()
    wait()

    dropdown_menu = driver.find_elements_by_class_name('dropdown-menu')
    dropdown_menu_model = dropdown_menu[1]
    model_link = dropdown_menu_model.find_element_by_tag_name('a')
    model_link.click()
    # move to 0,0 so we don't accidentally select a hover element
    body = driver.find_element_by_css_selector('body')
    body.click()
    wait()

    dataset_tooltip = driver.find_element_by_name('dataset_explanation')
    dataset_tooltip.click()
    wait()

    datasets = driver.find_element_by_name('dataset')
    for option in datasets.find_elements_by_tag_name('option'):
        if option.text == dataset_name:
            option.click()
            break
    wait()

    standard_networks = driver.find_elements_by_name('standard_networks')
    lenet = standard_networks[0]
    lenet.click()
    wait()

    model_name = driver.find_element_by_name('model_name')
    model_name.click()
    model_name.send_keys(name)
    wait()

    create_button = driver.find_element_by_name('create-model')
    create_button.click()

    job_url = driver.current_url.replace('models', 'jobs')
    status_url = job_url + '/status'
    done = False
    while not done:
        r = requests.get(status_url)
        status = json.loads(r.content)
        done = status['status'] == 'Done'
        wait()
        #driver.refresh()

    # test image
    print 'Testing...'
    image_path = driver.find_element_by_name('image_url')
    image_path.send_keys(test_image)
    wait()

    show_visualizations_tooltip = driver.find_element_by_name('show_visualizations_explanation')
    show_visualizations_tooltip.click()
    wait()

    show_visualizations = driver.find_element_by_name('show_visualizations')
    show_visualizations.click()
    wait()

    test_button = driver.find_element_by_name('classify-one-btn')
    test_button.click()

    # Opens in a new window - switch to it
    driver.close()
    driver.switch_to_window(driver.window_handles[-1])


def main(argv):
    parser = argparse.ArgumentParser(description='Run a Selenium demo of DIGITS')

    ### Positional arguments

    parser.add_argument('mnist_image_folder',
            type=str,
            help='Path to the MNIST dataset folder')
    parser.add_argument('test_image',
            type=str,
            help='Image to test with')

    ### Optional arguments

    parser.add_argument('-p', '--port',
            type=int,
            default=80,
            help='Port the server is running on (default 80)')

    args = vars(parser.parse_args())

    home_page = 'http://0.0.0.0:%d/' % args['port']
    dataset_path = args['mnist_image_folder']
    dataset_name = 'MNIST Dataset'
    model_name = 'MNIST Model'
    test_image = args['test_image']

    r = requests.get(home_page)
    assert r.status_code == requests.codes.ok, 'page "%s" does not exist - are you looking on the wrong port?' % home_page

    # Start selenium driver.
    driver = webdriver.Firefox()
    print 'Firefox webdriver started.'
    mouse = webdriver.ActionChains(driver)

    try:
        driver.maximize_window()

        get_page(driver, home_page)

        print 'Creating dataset...'
        create_dataset(driver, dataset_name, dataset_path)

        get_page(driver, home_page)

        print 'Creating model...'
        create_model(driver, model_name, dataset_name, test_image)

        print 'Done.'

        #display an alert message
        get_page(driver, "javascript:alert('Completed Walkthrough!');void(0);")
        wait()
        driver.switch_to_alert().accept()

        # wait until the window is closed
        while True:
            try:
                get_page(driver, "javascript:console.log('Waiting...');void(0);")
            except Exception as e:
                print e
                break
            wait()

    except KeyboardInterrupt:
        pass
    finally:
        try:
            driver.quit()
        except socket.error:
            pass

if __name__ == '__main__':
    main(sys.argv)

back to top