https://github.com/halide/Halide
Raw File
Tip revision: ef4b2de8778cbeb3f68bf3148c0fbc28415f6e9a authored by Andrew Adams on 18 April 2024, 17:59:21 UTC
Elaborate on why we treat NaNs as equal
Tip revision: ef4b2de
Parameter.h
#ifndef HALIDE_PARAMETER_H
#define HALIDE_PARAMETER_H

/** \file
 * Defines the internal representation of parameters to halide piplines
 */
#include <optional>
#include <string>

#include "Buffer.h"
#include "IntrusivePtr.h"
#include "Type.h"
#include "Util.h"                   // for HALIDE_NO_USER_CODE_INLINE
#include "runtime/HalideRuntime.h"  // for HALIDE_ALWAYS_INLINE

namespace Halide {

struct ArgumentEstimates;
struct Expr;
struct Type;
enum class MemoryType;

struct BufferConstraint {
    Expr min, extent, stride;
    Expr min_estimate, extent_estimate;
};

namespace Internal {

#ifdef WITH_SERIALIZATION
class Deserializer;
class Serializer;
#endif
struct ParameterContents;

}  // namespace Internal

/** A reference-counted handle to a parameter to a halide
 * pipeline. May be a scalar parameter or a buffer */
class Parameter {
    void check_defined() const;
    void check_is_buffer() const;
    void check_is_scalar() const;
    void check_dim_ok(int dim) const;
    void check_type(const Type &t) const;

protected:
    Internal::IntrusivePtr<Internal::ParameterContents> contents;

#ifdef WITH_SERIALIZATION
    friend class Internal::Deserializer;  //< for scalar_data()
    friend class Internal::Serializer;    //< for scalar_data()
#endif
    friend class Pipeline;  //< for read_only_scalar_address()

    /** Get the raw currently-bound buffer. null if unbound */
    const halide_buffer_t *raw_buffer() const;

    /** Get the pointer to the current value of the scalar
     * parameter. For a given parameter, this address will never
     * change. Note that this can only be used to *read* from -- it must
     * not be written to, so don't cast away the constness. Only relevant when jitting. */
    const void *read_only_scalar_address() const;

    /** If the Parameter is a scalar, and the scalar data is valid, return
     * the scalar data. Otherwise, return nullopt. */
    std::optional<halide_scalar_value_t> scalar_data() const;

    /** If the Parameter is a scalar and has a valid scalar value, return it.
     * Otherwise, assert-fail. */
    halide_scalar_value_t scalar_data_checked() const;

    /** If the Parameter is a scalar *of the given type* and has a valid scalar value, return it.
     * Otherwise, assert-fail. */
    halide_scalar_value_t scalar_data_checked(const Type &val_type) const;

    /** Construct a new buffer parameter via deserialization. */
    Parameter(const Type &t, int dimensions, const std::string &name,
              const Buffer<void> &buffer, int host_alignment, const std::vector<BufferConstraint> &buffer_constraints,
              MemoryType memory_type);

    /** Construct a new scalar parameter via deserialization. */
    Parameter(const Type &t, int dimensions, const std::string &name,
              const std::optional<halide_scalar_value_t> &scalar_data, const Expr &scalar_default, const Expr &scalar_min,
              const Expr &scalar_max, const Expr &scalar_estimate);

public:
    /** Construct a new undefined handle */
    Parameter() = default;

    /** Construct a new parameter of the given type. If the second
     * argument is true, this is a buffer parameter of the given
     * dimensionality, otherwise, it is a scalar parameter (and the
     * dimensionality should be zero). The parameter will be given a
     * unique auto-generated name. */
    Parameter(const Type &t, bool is_buffer, int dimensions);

    /** Construct a new parameter of the given type with name given by
     * the third argument. If the second argument is true, this is a
     * buffer parameter, otherwise, it is a scalar parameter. The
     * third argument gives the dimensionality of the buffer
     * parameter. It should be zero for scalar parameters. If the
     * fifth argument is true, the the name being passed in was
     * explicitly specified (as opposed to autogenerated). */
    Parameter(const Type &t, bool is_buffer, int dimensions, const std::string &name);

    Parameter(const Parameter &) = default;
    Parameter &operator=(const Parameter &) = default;
    Parameter(Parameter &&) = default;
    Parameter &operator=(Parameter &&) = default;

    /** Get the type of this parameter */
    Type type() const;

    /** Get the dimensionality of this parameter. Zero for scalars. */
    int dimensions() const;

    /** Get the name of this parameter */
    const std::string &name() const;

    /** Does this parameter refer to a buffer/image? */
    bool is_buffer() const;

    /** If the parameter is a scalar parameter, get its currently
     * bound value. Only relevant when jitting */
    template<typename T>
    HALIDE_NO_USER_CODE_INLINE T scalar() const {
        static_assert(sizeof(T) <= sizeof(halide_scalar_value_t));
        const auto sv = scalar_data_checked(type_of<T>());
        T t;
        memcpy(&t, &sv.u.u64, sizeof(t));
        return t;
    }

    /** This returns the current value of scalar<type()>() as an Expr.
     * If the Parameter is not scalar, or its scalar data is not valid, this will assert-fail. */
    Expr scalar_expr() const;

    /** This returns true if scalar_expr() would return a valid Expr,
     * false if not. */
    bool has_scalar_value() const;

    /** If the parameter is a scalar parameter, set its current
     * value. Only relevant when jitting */
    template<typename T>
    HALIDE_NO_USER_CODE_INLINE void set_scalar(T val) {
        halide_scalar_value_t sv;
        memcpy(&sv.u.u64, &val, sizeof(val));
        set_scalar(type_of<T>(), sv);
    }

    /** If the parameter is a scalar parameter, set its current
     * value. Only relevant when jitting */
    void set_scalar(const Type &val_type, halide_scalar_value_t val);

    /** If the parameter is a buffer parameter, get its currently
     * bound buffer. Only relevant when jitting */
    Buffer<void> buffer() const;

    /** If the parameter is a buffer parameter, set its current
     * value. Only relevant when jitting */
    void set_buffer(const Buffer<void> &b);

    /** Tests if this handle is the same as another handle */
    bool same_as(const Parameter &other) const;

    /** Tests if this handle is non-nullptr */
    bool defined() const;

    /** Get and set constraints for the min, extent, stride, and estimates on
     * the min/extent. */
    //@{
    void set_min_constraint(int dim, const Expr &e);
    void set_extent_constraint(int dim, const Expr &e);
    void set_stride_constraint(int dim, const Expr &e);
    void set_min_constraint_estimate(int dim, const Expr &min);
    void set_extent_constraint_estimate(int dim, const Expr &extent);
    void set_host_alignment(int bytes);
    Expr min_constraint(int dim) const;
    Expr extent_constraint(int dim) const;
    Expr stride_constraint(int dim) const;
    Expr min_constraint_estimate(int dim) const;
    Expr extent_constraint_estimate(int dim) const;
    int host_alignment() const;
    //@}

    /** Get buffer constraints for all dimensions,
     *  only relevant when serializing. */
    const std::vector<BufferConstraint> &buffer_constraints() const;

    /** Get and set constraints for scalar parameters. These are used
     * directly by Param, so they must be exported. */
    // @{
    void set_min_value(const Expr &e);
    Expr min_value() const;
    void set_max_value(const Expr &e);
    Expr max_value() const;
    void set_estimate(Expr e);
    Expr estimate() const;
    // @}

    /** Get and set the default values for scalar parameters. At present, these
     * are used only to emit the default values in the metadata. There isn't
     * yet a public API in Param<> for them (this is used internally by the
     * Generator code). */
    // @{
    void set_default_value(const Expr &e);
    Expr default_value() const;
    // @}

    /** Order Parameters by their IntrusivePtr so they can be used
     * to index maps. */
    bool operator<(const Parameter &other) const {
        return contents < other.contents;
    }

    /** Get the ArgumentEstimates appropriate for this Parameter. */
    ArgumentEstimates get_argument_estimates() const;

    void store_in(MemoryType memory_type);
    MemoryType memory_type() const;
};

namespace Internal {

/** Validate arguments to a call to a func, image or imageparam. */
void check_call_arg_types(const std::string &name, std::vector<Expr> *args, int dims);

}  // namespace Internal
}  // namespace Halide

#endif
back to top