https://github.com/zfrenchee/NN-SVG
Raw File
Tip revision: e65a5b1747d227857192f8c3e424660e489f6887 authored by Johan Pauwels on 12 February 2024, 01:27:00 UTC
Correct width and height label placement in AlexNet (#60)
Tip revision: e65a5b1
AlexNet.js

function AlexNet() {

    // /////////////////////////////////////////////////////////////////////////////
    //                       ///////    Variables    ///////
    // /////////////////////////////////////////////////////////////////////////////

    var w = window.innerWidth;
    var h = window.innerHeight;

    var color1 = '#eeeeee';
    var color2 = '#99ddff';
    var color3 = '#ffbbbb';

    var rectOpacity = 0.4;
    var filterOpacity = 0.4;
    var fontScale = 1;

    var line_material = new THREE.LineBasicMaterial( { 'color':0x000000 } );
    var box_material = new THREE.MeshBasicMaterial( {'color':color1, 'side':THREE.DoubleSide, 'transparent':true, 'opacity':rectOpacity, 'depthWrite':false, 'needsUpdate':true} );
    var conv_material = new THREE.MeshBasicMaterial( {'color':color2, 'side':THREE.DoubleSide, 'transparent':true, 'opacity':filterOpacity, 'depthWrite':false, 'needsUpdate':true} );
    var pyra_material = new THREE.MeshBasicMaterial( {'color':color3, 'side':THREE.DoubleSide, 'transparent':true, 'opacity':filterOpacity, 'depthWrite':false, 'needsUpdate':true} );

    var architecture = [];
    var architecture2 = [];
    var betweenLayers = 20;

    var logDepth = true;
    var depthScale = 10;
    var logWidth = true;
    var widthScale = 10;
    var logConvSize = false;
    var convScale = 1;

    var showDims = false;
    var showConvDims = false;

    let depthFn = (depth) => logDepth ? (Math.log(depth) * depthScale) : (depth * depthScale);
    let widthFn = (width) => logWidth ? (Math.log(width) * widthScale) : (width * widthScale);
    let convFn = (conv) => logConvSize ? (Math.log(conv) * convScale) : (conv * convScale);

    function wf(layer) { return widthFn(layer['width']); }
    function hf(layer) { return widthFn(layer['height']); }

    var layers = new THREE.Group();
    var convs = new THREE.Group();
    var pyramids = new THREE.Group();
    var sprites = new THREE.Group();


    var scene = new THREE.Scene();
    scene.background = new THREE.Color( 0xffffff );

    // var camera = new THREE.PerspectiveCamera( 75, window.innerWidth/window.innerHeight, 0.1, 100000 );
    var camera = new THREE.OrthographicCamera( w / - 2, w / 2, h / 2, h / - 2, -10000000, 10000000 );
    camera.position.set(-219, 92, 84);

    var renderer;
    var rendererType = 'webgl';

    var controls;


    // /////////////////////////////////////////////////////////////////////////////
    //                       ///////    Methods    ///////
    // /////////////////////////////////////////////////////////////////////////////

    function restartRenderer({rendererType_=rendererType}={}) {

        rendererType = rendererType_;

        clearThree(scene);

        if (rendererType === 'webgl') { renderer = new THREE.WebGLRenderer( { 'alpha':true } ); }
        else if (rendererType === 'svg') { renderer = new THREE.SVGRenderer(); }

        renderer.setPixelRatio(window.devicePixelRatio || 1);
        renderer.setSize( window.innerWidth, window.innerHeight );

        graph_container = document.getElementById('graph-container')
        while (graph_container.firstChild) { graph_container.removeChild(graph_container.firstChild); }
        graph_container.appendChild( renderer.domElement );

        if (controls) { controls.dispose(); }
        controls = new THREE.OrbitControls( camera, renderer.domElement );

        animate();

    }

    function animate() {
        requestAnimationFrame( animate );
        renderer.render(scene, camera);
    };

    restartRenderer();

    function redraw({architecture_=architecture,
                     architecture2_=architecture2,
                     betweenLayers_=betweenLayers,
                     logDepth_=logDepth,
                     depthScale_=depthScale,
                     logWidth_=logWidth,
                     widthScale_=widthScale,
                     logConvSize_=logConvSize,
                     convScale_=convScale,
                     showDims_=showDims,
                     showConvDims_=showConvDims}={}) {

        architecture = architecture_;
        architecture2 = architecture2_;
        betweenLayers = betweenLayers_;
        logDepth = logDepth_;
        depthScale = depthScale_;
        logWidth = logWidth_;
        widthScale = widthScale_;
        logConvSize = logConvSize_;
        convScale = convScale_;
        showDims = showDims_;
        showConvDims = showConvDims_;

        clearThree(scene);

        z_offset = -(sum(architecture.map(layer => depthFn(layer['depth']))) + (betweenLayers * (architecture.length - 1))) / 3;
        layer_offsets = pairWise(architecture).reduce((offsets, layers) => offsets.concat([offsets.last() + depthFn(layers[0]['depth'])/2 + betweenLayers + depthFn(layers[1]['depth'])/2]), [z_offset]);
        layer_offsets = layer_offsets.concat(architecture2.reduce((offsets, layer) => offsets.concat([offsets.last() + widthFn(2) + betweenLayers]), [layer_offsets.last() + depthFn(architecture.last()['depth'])/2 + betweenLayers + widthFn(2)]));

        architecture.forEach( function( layer, index ) {

            // Layer
            layer_geometry = new THREE.BoxGeometry( wf(layer), hf(layer), depthFn(layer['depth']) );
            layer_object = new THREE.Mesh( layer_geometry, box_material );
            layer_object.position.set(0, 0, layer_offsets[index]);
            layers.add( layer_object );

            layer_edges_geometry = new THREE.EdgesGeometry( layer_geometry );
            layer_edges_object = new THREE.LineSegments( layer_edges_geometry, line_material );
            layer_edges_object.position.set(0, 0, layer_offsets[index]);
            layers.add( layer_edges_object );

            if (index < architecture.length - 1) {

                // Conv
                conv_geometry = new THREE.BoxGeometry( convFn(layer['filterWidth']), convFn(layer['filterHeight']), depthFn(layer['depth']) );
                conv_object = new THREE.Mesh( conv_geometry, conv_material );
                conv_object.position.set(layer['rel_x'] * wf(layer), layer['rel_y'] * hf(layer), layer_offsets[index]);
                convs.add( conv_object );

                conv_edges_geometry = new THREE.EdgesGeometry( conv_geometry );
                conv_edges_object = new THREE.LineSegments( conv_edges_geometry, line_material );
                conv_edges_object.position.set(layer['rel_x'] * wf(layer), layer['rel_y'] * hf(layer), layer_offsets[index]);
                convs.add( conv_edges_object );

                // Pyramid
                pyramid_geometry = new THREE.Geometry();

                base_z = layer_offsets[index] + (depthFn(layer['depth']) / 2);
                summit_z = layer_offsets[index] + (depthFn(layer['depth']) / 2) + betweenLayers;
                next_layer_wh = widthFn(architecture[index+1]['width'])

                pyramid_geometry.vertices = [
                    new THREE.Vector3( (layer['rel_x'] * wf(layer)) + (convFn(layer['filterWidth'])/2), (layer['rel_y'] * hf(layer)) + (convFn(layer['filterHeight'])/2), base_z ),  // base
                    new THREE.Vector3( (layer['rel_x'] * wf(layer)) + (convFn(layer['filterWidth'])/2), (layer['rel_y'] * hf(layer)) - (convFn(layer['filterHeight'])/2), base_z ),  // base
                    new THREE.Vector3( (layer['rel_x'] * wf(layer)) - (convFn(layer['filterWidth'])/2), (layer['rel_y'] * hf(layer)) - (convFn(layer['filterHeight'])/2), base_z ),  // base
                    new THREE.Vector3( (layer['rel_x'] * wf(layer)) - (convFn(layer['filterWidth'])/2), (layer['rel_y'] * hf(layer)) + (convFn(layer['filterHeight'])/2), base_z ),  // base
                    new THREE.Vector3( (layer['rel_x'] * next_layer_wh),                           (layer['rel_y'] * next_layer_wh),                           summit_z)  // summit
                ];
                pyramid_geometry.faces = [new THREE.Face3(0,1,2),new THREE.Face3(0,2,3),new THREE.Face3(1,0,4),new THREE.Face3(2,1,4),new THREE.Face3(3,2,4),new THREE.Face3(0,3,4)];

                pyramid_object = new THREE.Mesh( pyramid_geometry, pyra_material );
                pyramids.add( pyramid_object );

                pyramid_edges_geometry = new THREE.EdgesGeometry( pyramid_geometry );
                pyramid_edges_object = new THREE.LineSegments( pyramid_edges_geometry, line_material );
                pyramids.add( pyramid_edges_object );

            }

            if (showDims) {

                // Dims
                sprite = makeTextSprite(layer['depth'].toString());
                sprite.position.copy( layer_object.position ).sub( new THREE.Vector3( wf(layer)/2 + 2, hf(layer)/2 + 2, 0 ) );
                sprites.add( sprite );

                sprite = makeTextSprite(layer['height'].toString());
                sprite.position.copy( layer_object.position ).sub( new THREE.Vector3( wf(layer)/2 + 3, 0, depthFn(layer['depth'])/2 + 3 ) );
                sprites.add( sprite );

                sprite = makeTextSprite(layer['width'].toString());
                sprite.position.copy( layer_object.position ).sub( new THREE.Vector3( 0, -hf(layer)/2 - 3, depthFn(layer['depth'])/2 + 3 ) );
                sprites.add( sprite );

            }

            if (showConvDims && index < architecture.length - 1) {

                // Conv Dims
                sprite = makeTextSprite(layer['filterHeight'].toString());
                sprite.position.copy( conv_object.position ).sub( new THREE.Vector3( convFn(layer['filterWidth'])/2, -3, depthFn(layer['depth'])/2 ) );
                sprites.add( sprite );

                sprite = makeTextSprite(layer['filterWidth'].toString());
                sprite.position.copy( conv_object.position ).sub( new THREE.Vector3( -1, convFn(layer['filterHeight'])/2, depthFn(layer['depth'])/2 ) );
                sprites.add( sprite );

            }

        });

        architecture2.forEach( function( layer, index ) {

            // Dense
            layer_geometry = new THREE.BoxGeometry( widthFn(2), depthFn(layer), widthFn(2) );
            layer_object = new THREE.Mesh( layer_geometry, box_material );
            layer_object.position.set(0, 0, layer_offsets[architecture.length + index]);
            layers.add( layer_object );

            layer_edges_geometry = new THREE.EdgesGeometry( layer_geometry );
            layer_edges_object = new THREE.LineSegments( layer_edges_geometry, line_material );
            layer_edges_object.position.set(0, 0, layer_offsets[architecture.length + index]);
            layers.add( layer_edges_object );

            direction = new THREE.Vector3( 0, 0, 1 );
            origin = new THREE.Vector3( 0, 0, layer_offsets[architecture.length + index] - betweenLayers - widthFn(2)/2 + 1 );
            length = betweenLayers - 2;
            headLength = betweenLayers/3;
            headWidth = 5;
            arrow = new THREE.ArrowHelper( direction, origin, length, 0x000000, headLength, headWidth );
            pyramids.add( arrow );

            if (showDims) {

                // Dims
                sprite = makeTextSprite(layer.toString());
                sprite.position.copy( layer_object.position ).sub( new THREE.Vector3( 3, depthFn(layer)/2 + 3, 3 ) );
                sprites.add( sprite );

            }


        });

        scene.add( layers );
        scene.add( convs );
        scene.add( pyramids );
        scene.add( sprites );

    }

    function clearThree(obj) {

        while(obj.children.length > 0) {
            clearThree( obj.children[0] )
            obj.remove( obj.children[0] );
        }

        if ( obj.geometry ) { obj.geometry.dispose(); }
        if ( obj.material ) { obj.material.dispose(); }
        if ( obj.texture ) { obj.texture.dispose(); }
    }


    function makeTextSprite(message, opts) {
        var parameters = opts || {};
        var fontface = parameters.fontface || 'Helvetica';
        var fontsize = parameters.fontsize || 120;
        var canvas = document.createElement('canvas');
        var context = canvas.getContext('2d');
        context.font = fontsize + "px " + fontface;

        // get size data (height depends only on font size)
        var metrics = context.measureText(message);
        var textWidth = metrics.width;

        // text color
        context.fillStyle = 'rgba(0, 0, 0, 1.0)';
        context.fillText(message, 0, fontsize);

        // canvas contents will be used for a texture
        var texture = new THREE.Texture(canvas)
        texture.minFilter = THREE.LinearFilter;
        texture.needsUpdate = true;

        var spriteMaterial = new THREE.SpriteMaterial({ map: texture });
        var sprite = new THREE.Sprite( spriteMaterial );
        sprite.scale.set( 10 * fontScale, 5* fontScale, 1.0 );
        sprite.center.set( 0,1 );
        return sprite;
    }

    function style({color1_=color1,
                    color2_=color2,
                    color3_=color3,
                    rectOpacity_=rectOpacity,
                    filterOpacity_=filterOpacity,
                    fontScale_ =fontScale,
                }={}) {
        color1        = color1_;
        color2        = color2_;
        color3        = color3_;
        rectOpacity   = rectOpacity_;
        filterOpacity = filterOpacity_;
        fontScale = fontScale_;

        box_material.color = new THREE.Color(color1);
        conv_material.color = new THREE.Color(color2);
        pyra_material.color = new THREE.Color(color3);

        box_material.opacity = rectOpacity;

        conv_material.opacity = filterOpacity;
        pyra_material.opacity = filterOpacity;
    }

    // /////////////////////////////////////////////////////////////////////////////
    //                  ///////    Window Resize    ///////
    // /////////////////////////////////////////////////////////////////////////////

    function onWindowResize() {

        renderer.setSize(window.innerWidth, window.innerHeight);

        camFactor = window.devicePixelRatio || 1;
        camera.left = -window.innerWidth / camFactor;
        camera.right = window.innerWidth / camFactor;
        camera.top = window.innerHeight / camFactor;
        camera.bottom = -window.innerHeight / camFactor;
        camera.updateProjectionMatrix();

    }

    window.addEventListener('resize', onWindowResize, false);


    /////////////////////////////////////////////////////////////////////////////
                          ///////    Return    ///////
    /////////////////////////////////////////////////////////////////////////////

    return {
        'redraw'           : redraw,
        'restartRenderer'  : restartRenderer,
        'style'            : style,

    }

}
back to top