Revision c22889e76cb9b7fd8a4710d9bf53e827aaa907e4 authored by Shuhei Kadowaki on 24 August 2021, 04:21:29 UTC, committed by Shuhei Kadowaki on 26 October 2021, 14:45:33 UTC
Currently our constant-prop' heuristics work in the following way:
1. `const_prop_entry_heuristic`
2. `const_prop_argument_heuristic` & `const_prop_rettype_heuristic`
3. `force_const_prop` custom heuristic & `!const_prop_function_heuristic`
4. `MethodInstance` specialization and `const_prop_methodinstance_heuristic`

This PR changes it so that the step 1. now works like:

1. `force_const_prop` custom heuristic & `const_prop_entry_heuristic`

and the steps 2., 3. and 4. don't change

This change particularly allows us to more forcibly constant-propagate
for `getproperty` and `setproperty!`, and inline them more, e.g.:
```julia
mutable struct Foo
    val
    _::Int
end

function setter(xs)
    for x in xs
        x.val = nothing # `setproperty!` can be inlined with this PR
    end
end
```

It might be useful because now we can intervene into the constant-prop'
heuristic in a more reliable way with the `aggressive_constprop` interface.

I did the simple benchmark below, and it looks like this change doesn't
cause the latency problem for this particular example:
```zsh
~/julia master aviatesk@amdci2 6s
❯ ./usr/bin/julia -e '@time using Plots; @time plot(rand(10,3))'
  3.708500 seconds (7.28 M allocations: 506.128 MiB, 3.45% gc time, 1.13% compilation time)
  2.817794 seconds (3.45 M allocations: 195.127 MiB, 7.84% gc time, 53.76% compilation time)

~/julia avi/forceconstantprop aviatesk@amdci2 6s
❯ ./usr/bin/julia -e '@time using Plots; @time plot(rand(10,3))'
  3.622109 seconds (7.02 M allocations: 481.710 MiB, 4.19% gc time, 1.17% compilation time)
  2.863419 seconds (3.44 M allocations: 194.210 MiB, 8.02% gc time, 53.53% compilation time)
```
1 parent da71d29
Raw File
symbol.c
// This file is a part of Julia. License is MIT: https://julialang.org/license

/*
  Symbol table
*/

#include <stdlib.h>
#include <string.h>
#include <stdarg.h>
#include "julia.h"
#include "julia_internal.h"
#include "julia_assert.h"

#ifdef __cplusplus
extern "C" {
#endif

static jl_sym_t *symtab = NULL;

#define MAX_SYM_LEN ((size_t)INTPTR_MAX - sizeof(jl_taggedvalue_t) - sizeof(jl_sym_t) - 1)

static uintptr_t hash_symbol(const char *str, size_t len) JL_NOTSAFEPOINT
{
    uintptr_t oid = memhash(str, len) ^ ~(uintptr_t)0/3*2;
    // compute the same hash value as v1.6 and earlier, which used `hash_uint(3h - objectid(sym))`
    return inthash(-oid);
}

static size_t symbol_nbytes(size_t len) JL_NOTSAFEPOINT
{
    return (sizeof(jl_taggedvalue_t) + sizeof(jl_sym_t) + len + 1 + 7) & -8;
}

static jl_sym_t *mk_symbol(const char *str, size_t len) JL_NOTSAFEPOINT
{
    jl_sym_t *sym;
    size_t nb = symbol_nbytes(len);
    assert(jl_symbol_type && "not initialized");

    jl_taggedvalue_t *tag = (jl_taggedvalue_t*)jl_gc_perm_alloc_nolock(nb, 0, sizeof(void*), 0);
    sym = (jl_sym_t*)jl_valueof(tag);
    // set to old marked so that we won't look at it in the GC or write barrier.
    tag->header = ((uintptr_t)jl_symbol_type) | GC_OLD_MARKED;
    sym->left = sym->right = NULL;
    sym->hash = hash_symbol(str, len);
    memcpy(jl_symbol_name(sym), str, len);
    jl_symbol_name(sym)[len] = 0;
    return sym;
}

static jl_sym_t *symtab_lookup(jl_sym_t **ptree, const char *str, size_t len, jl_sym_t ***slot) JL_NOTSAFEPOINT
{
    jl_sym_t *node = jl_atomic_load_acquire(ptree); // consume
    uintptr_t h = hash_symbol(str, len);

    // Tree nodes sorted by major key of (int(hash)) and minor key of (str).
    while (node != NULL) {
        intptr_t x = (intptr_t)(h - node->hash);
        if (x == 0) {
            x = strncmp(str, jl_symbol_name(node), len);
            if (x == 0 && jl_symbol_name(node)[len] == 0) {
                if (slot != NULL)
                    *slot = ptree;
                return node;
            }
        }
        if (x < 0)
            ptree = &node->left;
        else
            ptree = &node->right;
        node = jl_atomic_load_acquire(ptree); // consume
    }
    if (slot != NULL)
        *slot = ptree;
    return node;
}

jl_sym_t *_jl_symbol(const char *str, size_t len) JL_NOTSAFEPOINT // (or throw)
{
#ifndef __clang_analyzer__
    // Hide the error throwing from the analyser since there isn't a way to express
    // "safepoint only when throwing error" currently.
    if (len > MAX_SYM_LEN)
        jl_exceptionf(jl_argumenterror_type, "Symbol name too long");
#endif
    assert(!memchr(str, 0, len));
    jl_sym_t **slot;
    jl_sym_t *node = symtab_lookup(&symtab, str, len, &slot);
    if (node == NULL) {
        JL_LOCK_NOGC(&gc_perm_lock);
        // Someone might have updated it, check and look up again
        if (*slot != NULL && (node = symtab_lookup(slot, str, len, &slot))) {
            JL_UNLOCK_NOGC(&gc_perm_lock);
            return node;
        }
        node = mk_symbol(str, len);
        jl_atomic_store_release(slot, node);
        JL_UNLOCK_NOGC(&gc_perm_lock);
    }
    return node;
}

JL_DLLEXPORT jl_sym_t *jl_symbol(const char *str) JL_NOTSAFEPOINT // (or throw)
{
    return _jl_symbol(str, strlen(str));
}

JL_DLLEXPORT jl_sym_t *jl_symbol_lookup(const char *str) JL_NOTSAFEPOINT
{
    return symtab_lookup(&symtab, str, strlen(str), NULL);
}

JL_DLLEXPORT jl_sym_t *jl_symbol_n(const char *str, size_t len)
{
    if (memchr(str, 0, len))
        jl_exceptionf(jl_argumenterror_type, "Symbol name may not contain \\0");
    return _jl_symbol(str, len);
}

JL_DLLEXPORT jl_sym_t *jl_get_root_symbol(void)
{
    return symtab;
}

static uint32_t gs_ctr = 0;  // TODO: per-thread
uint32_t jl_get_gs_ctr(void) { return gs_ctr; }
void jl_set_gs_ctr(uint32_t ctr) { gs_ctr = ctr; }

JL_DLLEXPORT jl_sym_t *jl_gensym(void)
{
    char name[16];
    char *n;
    uint32_t ctr = jl_atomic_fetch_add(&gs_ctr, 1);
    n = uint2str(&name[2], sizeof(name)-2, ctr, 10);
    *(--n) = '#'; *(--n) = '#';
    return jl_symbol(n);
}

JL_DLLEXPORT jl_sym_t *jl_tagged_gensym(const char *str, size_t len)
{
    if (len == (size_t)-1) {
        len = strlen(str);
    }
    else if (memchr(str, 0, len)) {
        jl_exceptionf(jl_argumenterror_type, "Symbol name may not contain \\0");
    }
    char gs_name[14];
    size_t alloc_len = sizeof(gs_name) + len + 3;
    if (len > MAX_SYM_LEN || alloc_len > MAX_SYM_LEN)
        jl_exceptionf(jl_argumenterror_type, "Symbol name too long");
    char *name = (char*)(len >= 256 ? malloc_s(alloc_len) : alloca(alloc_len));
    char *n;
    name[0] = '#';
    name[1] = '#';
    name[2 + len] = '#';
    memcpy(name + 2, str, len);
    uint32_t ctr = jl_atomic_fetch_add(&gs_ctr, 1);
    n = uint2str(gs_name, sizeof(gs_name), ctr, 10);
    memcpy(name + 3 + len, n, sizeof(gs_name) - (n - gs_name));
    jl_sym_t *sym = _jl_symbol(name, alloc_len - (n - gs_name)- 1);
    if (len >= 256)
        free(name);
    return sym;
}

#ifdef __cplusplus
}
#endif
back to top