https://github.com/Microsoft/CNTK
Raw File
Tip revision: 534049087cf08c5e8c5c150a26c9a84a2af0ee28 authored by Vadim Mazalov on 27 December 2017, 16:41:29 UTC
Cont debug
Tip revision: 5340490
ssefloat4.h
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//
// ssematrix.h -- matrix with SSE-accelerated operations
//

#pragma once

#ifdef _WIN32
#include <intrin.h> // for intrinsics
#endif
#ifdef __unix__
#if !defined(__aarch64__)
#include <x86intrin.h>
#else
#define _mm_free(p) free(p)
#define _mm_malloc(a, b) malloc(a)
#endif
#endif

namespace msra { namespace math {

// ===========================================================================
// float4 -- wrapper around the rather ugly SSE intrinsics for float[4]
//
// Do not use the intrinsics outside anymore; instead add all you need into this class.
//
// MSDN links:
// basic: http://msdn.microsoft.com/en-us/library/x5c07e2a%28v=VS.80%29.aspx
// load/store: (add this)
// newer ones: (seems no single list available)
// ===========================================================================

// The code in this file implements a float4 vector based on the SSE intrinsics available on Intel platforms.
// Since we don't have SSE on ARM64 (NEON has similar functionality but is not identical) we cannot
// use the SSE implementation on ARM64.
// TODO: In the future, we should provide a NEON based implementation instead.
#if defined(__aarch64__)
typedef struct __m128_t
{
    float f[4];
}__m128;

static __m128 ZERO_M128 = {0,0,0,0};

static __m128 _mm_setzero_ps()
{
    return ZERO_M128;
}
static void _mm_store_ss(float *a, const __m128 &b)
{
    *a = b.f[0];
}
static __m128 _mm_load1_ps(const float *a)
{
    __m128 result = {(float)*a, (float)*a, (float)*a, (float)*a};
    return result;
}
static __m128 _mm_sub_ps(const __m128 &a, const __m128 &b)
{
    __m128 result =  {
      a.f[0] - b.f[0],
      a.f[1] - b.f[1],
      a.f[2] - b.f[2],
      a.f[3] - b.f[3] };

    return result;
}
static __m128 _mm_and_ps(const __m128 &a, const __m128 &b)
{
    __m128 result =  {
      (float)((int)(a.f[0]) & (int)(b.f[0])),
      (float)((int)(a.f[1]) & (int)(b.f[1])),
      (float)((int)(a.f[2]) & (int)(b.f[2])),
      (float)((int)(a.f[3]) & (int)(b.f[3])) };

    return result;
}
static __m128 _mm_or_ps(const __m128 &a, const __m128 &b)
{
    __m128 result =  {
      (float)((int)(a.f[0]) | (int)(b.f[0])),
      (float)((int)(a.f[1]) | (int)(b.f[1])),
      (float)((int)(a.f[2]) | (int)(b.f[2])),
      (float)((int)(a.f[3]) | (int)(b.f[3])) };

    return result;
}
static __m128 _mm_add_ps(const __m128 &a, const __m128 &b)
{
    __m128 result =  {
      a.f[0] + b.f[0],
      a.f[1] + b.f[1],
      a.f[2] + b.f[2],
      a.f[3] + b.f[3] };

    return result;
}
static __m128 _mm_mul_ps(const __m128 &a, const __m128 &b)
{
    __m128 result =  {
      a.f[0] * b.f[0],
      a.f[1] * b.f[1],
      a.f[2] * b.f[2],
      a.f[3] * b.f[3] };

    return result;
}
static __m128 _mm_div_ps(const __m128 &a, const __m128 &b)
{
    __m128 result =  {
      a.f[0] / b.f[0],
      a.f[1] / b.f[1],
      a.f[2] / b.f[2],
      a.f[3] / b.f[3] };

    return result;
}
static __m128 _mm_hadd_ps(const __m128 &a, const __m128 &b)
{
    __m128 result =  {
      a.f[0] + a.f[1],
      a.f[2] + a.f[3],
      b.f[0] + b.f[1],
      b.f[2] + b.f[3] };

    return result;
}
static __m128 _mm_cmpge_ps(const __m128 &a, const __m128 &b)
{
    __m128 result =  {
      a.f[0] >= b.f[0] ? 1.0f : 0.0f,
      a.f[1] >= b.f[1] ? 1.0f : 0.0f,
      a.f[2] >= b.f[2] ? 1.0f : 0.0f,
      a.f[3] >= b.f[3] ? 1.0f : 0.0f };

    return result;
}
static __m128 _mm_cmple_ps(const __m128 &a, const __m128 &b)
{
    __m128 result =  {
      a.f[0] <= b.f[0] ? 1.0f : 0.0f,
      a.f[1] <= b.f[1] ? 1.0f : 0.0f,
      a.f[2] <= b.f[2] ? 1.0f : 0.0f,
      a.f[3] <= b.f[3] ? 1.0f : 0.0f };

    return result;
}

#define _MM_TRANSPOSE4_PS( c1, c2, c3, c4 ) \
{ \
    float4 t1, t2, t3, t4; \
 \
    t1.v.f[0] = c1.v.f[0]; \
    t1.v.f[1] = c2.v.f[0]; \
    t1.v.f[2] = c3.v.f[0]; \
    t1.v.f[3] = c4.v.f[0]; \
 \
    t2.v.f[0] = c1.v.f[1]; \
    t2.v.f[1] = c2.v.f[1]; \
    t2.v.f[2] = c3.v.f[1]; \
    t2.v.f[3] = c4.v.f[1]; \
 \
    t3.v.f[0] = c1.v.f[2]; \
    t3.v.f[1] = c2.v.f[2]; \
    t3.v.f[2] = c3.v.f[2]; \
    t3.v.f[3] = c4.v.f[2]; \
 \
    t4.v.f[0] = c1.v.f[3]; \
    t4.v.f[1] = c2.v.f[3]; \
    t4.v.f[2] = c3.v.f[3]; \
    t4.v.f[3] = c4.v.f[3]; \
 \
    c1 = t1; \
    c2 = t2; \
    c3 = t3; \
    c4 = t4; \
}

#define _mm_prefetch(a, b) 
#endif

class float4
{
    __m128 v; // value
private:
    // return the low 'float'
    float f0() const
    {
        float f;
        _mm_store_ss(&f, v);
        return f;
    }
    // construct from a __m128, assuming it is a f32 vector (needed for directly returning __m128 below)
    float4(const __m128& v)
        : v(v)
    {
    }
    // return as a __m128 --should this be a reference?
    operator __m128() const
    {
        return v;
    }
    // assign a __m128 (needed for using nested float4 objects inside this class, e.g. sum())
    float4& operator=(const __m128& other)
    {
        v = other;
        return *this;
    }

public:
    float4()
    {
    } // uninitialized
    float4(const float4& f4)
        : v(f4.v)
    {
    }
    float4& operator=(const float4& other)
    {
        v = other.v;
        return *this;
    }

    // construct from a single float, copy to all components
    float4(float f)
        : v(_mm_load1_ps(&f))
    {
    }
    // float4 (float f) : v (_mm_set_ss (f)) {}  // code seems more complex than _mm_load1_ps()

    // basic math
    float4 operator-() const
    {
        return _mm_sub_ps(_mm_setzero_ps(), v);
    } // UNTESTED; setzero is a composite

    float4 operator&(const float4& other) const
    {
        return _mm_and_ps(v, other);
    }
    float4 operator|(const float4& other) const
    {
        return _mm_or_ps(v, other);
    }
    float4 operator+(const float4& other) const
    {
        return _mm_add_ps(v, other);
    }
    float4 operator-(const float4& other) const
    {
        return _mm_sub_ps(v, other);
    }
    float4 operator*(const float4& other) const
    {
        return _mm_mul_ps(v, other);
    }
    float4 operator/(const float4& other) const
    {
        return _mm_div_ps(v, other);
    }

    float4& operator&=(const float4& other)
    {
        v = _mm_and_ps(v, other);
        return *this;
    }
    float4& operator|=(const float4& other)
    {
        v = _mm_or_ps(v, other);
        return *this;
    }
    float4& operator+=(const float4& other)
    {
        v = _mm_add_ps(v, other);
        return *this;
    }
    float4& operator-=(const float4& other)
    {
        v = _mm_sub_ps(v, other);
        return *this;
    }
    float4& operator*=(const float4& other)
    {
        v = _mm_mul_ps(v, other);
        return *this;
    }
    float4& operator/=(const float4& other)
    {
        v = _mm_div_ps(v, other);
        return *this;
    }

    float4 operator>=(const float4& other) const
    {
        return _mm_cmpge_ps(v, other);
    }
    float4 operator<=(const float4& other) const
    {
        return _mm_cmple_ps(v, other);
    }

    // not yet implemented binary arithmetic ops: sqrt, rcp (reciprocal), rqsrt, min, max

    // other goodies I came across (intrin.h):
    //  - _mm_prefetch
    //  - _mm_stream_ps --store without polluting cache
    //  - unknown: _mm_addsub_ps, _mm_hsub_ps, _mm_movehdup_ps, _mm_moveldup_ps, _mm_blend_ps, _mm_blendv_ps, _mm_insert_ps, _mm_extract_ps, _mm_round_ps
    //  - _mm_dp_ps dot product! http://msdn.microsoft.com/en-us/library/bb514054.aspx
    //    Not so interesting for long vectors, we get better numerical precision with parallel adds and hadd at the end

    // prefetch a float4 from an address
    static void prefetch(const float4* p)
    {
        _mm_prefetch((const char*) const_cast<float4*>(p), _MM_HINT_T0);
    }

    // transpose a 4x4 matrix
    // Passing input as const ref to ensure aligned-ness
    static void transpose(const float4& col0, const float4& col1, const float4& col2, const float4& col3,
                          float4& row0, float4& row1, float4& row2, float4& row3)
    { // note: the temp variable here gets completely eliminated by optimization
        float4 m0 = col0;
        float4 m1 = col1;
        float4 m2 = col2;
        float4 m3 = col3;
        _MM_TRANSPOSE4_PS(m0, m1, m2, m3); // 8 instructions for 16 elements
        row0 = m0;
        row1 = m1;
        row2 = m2;
        row3 = m3;
    }

    // save a float4 to RAM bypassing the cache ('without polluting the cache')
    void storewithoutcache(float4& r4) const
    {
        // _mm_stream_ps ((float*) &r4, v);
        r4 = v;
    }

#if 0
    // save a float4 to RAM bypassing the cache ('without polluting the cache')
    void storewithoutcache (float4 * p4) const
    {
        // _mm_stream_ps ((float*) p4, v);
        *p4 = v;
    }

    // save a float to RAM bypassing the cache ('without polluting the cache')
    void storewithoutcache (float & r) const
    {
        _mm_stream_ss (&r, v);
    }
#endif

    // return the horizontal sum of all 4 components
    // ... return float4, use another mechanism to store the low word
    float sum() const
    {
        float4 hsum = _mm_hadd_ps(v, v);
        hsum = _mm_hadd_ps(hsum, hsum);
        return hsum.f0();
    }

    // please add anything else you might need HERE
};
};
};
back to top