modular_gemm.c
/* StarPU --- Runtime system for heterogeneous multicore architectures.
*
* Copyright (C) 2013-2015,2017 Inria
* Copyright (C) 2017 CNRS
* Copyright (C) 2014,2016-2019 Université de Bordeaux
* Copyright (C) 2013 Simon Archipoff
*
* StarPU is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as published by
* the Free Software Foundation; either version 2.1 of the License, or (at
* your option) any later version.
*
* StarPU is distributed in the hope that it will be useful, but
* WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
*
* See the GNU Lesser General Public License in COPYING.LGPL for more details.
*/
/* This scheduler runs only GEMMs on GPUs, and tries to feed them with as many
* GEMMs as possible. */
#include <starpu_sched_component.h>
#include <starpu_scheduler.h>
/* Optionally, it can take memory affinity into account, to avoid too many GPU
* data transfers */
#define MEMORY_AFFINITY
struct child_data {
double expected_start;
double predicted;
double predicted_transfer;
double expected_end;
unsigned child;
};
static int compar(const void *_a, const void *_b)
{
const struct child_data *a = _a;
const struct child_data *b = _b;
if (a->expected_end < b->expected_end)
return -1;
if (a->expected_end == b->expected_end)
return 0;
return 1;
}
static int gemm_push_task(struct starpu_sched_component * component, struct starpu_task * task)
{
unsigned n = component->nchildren;
unsigned i;
/* See if it's a GEMM task */
const char *name = starpu_task_get_model_name(task);
//fprintf(stderr, "it's %s\n", name);
if (name && (!strcmp(name, "gemm") ||
!strcmp(name, "dgemm") ||
!strcmp(name, "sgemm") ||
!strcmp(name, "chol_model_22") ||
!strcmp(name, "starpu_dlu_lu_model_22") ||
!strcmp(name, "starpu_slu_lu_model_22")))
{
/* It's a GEMM, try to push to GPUs */
struct child_data child_data[n];
for (i = 0; i < n; i++)
{
child_data[i].expected_end = -1;
child_data[i].child = i;
}
/* Look at GPU availability time */
for (i = 0; i < n; i++)
{
struct starpu_sched_component *child = component->children[i];
double predicted;
if (starpu_sched_component_execute_preds(child, task, &predicted))
{
double expected_start;
child_data[i].expected_start =
expected_start = child->estimated_end(child);
child_data[i].predicted = predicted;
child_data[i].expected_end = expected_start
+ predicted;
#ifdef MEMORY_AFFINITY
double predicted_transfer;
child_data[i].predicted_transfer =
predicted_transfer = starpu_sched_component_transfer_length(child, task);
child_data[i].expected_end += predicted_transfer;
#endif
}
}
/* Sort by increasing expected end */
qsort(child_data, n, sizeof(*child_data), compar);
/* Try to push to the GPU with minimum availability time, to balance the load. */
for (i = 0; i < n; i++)
{
if (child_data[i].expected_end != -1)
{
struct starpu_sched_component *child = component->children[child_data[i].child];
/* Note it in the task so that estimated_end() has it */
task->predicted = child_data[i].predicted;
task->predicted_transfer = child_data[i].predicted_transfer;
int ret = starpu_sched_component_push_task(component,child,task);
if (!ret)
/* Ok, this GPU took it */
return 0;
}
}
}
int workerid;
/* It's not a GEMM, or no GPU wanted to take it, find somebody else */
for(workerid = starpu_bitmap_first(component->workers_in_ctx);
workerid != -1;
workerid = starpu_bitmap_next(component->workers_in_ctx, workerid))
{
int nimpl;
for(nimpl = 0; nimpl < STARPU_MAXIMPLEMENTATIONS; nimpl++)
{
if(starpu_worker_can_execute_task(workerid,task,nimpl)
|| starpu_combined_worker_can_execute_task(workerid, task, nimpl))
{
for (i = 0; i < n; i++)
{
struct starpu_sched_component *child = component->children[i];
int idworker;
for(idworker = starpu_bitmap_first(component->children[i]->workers);
idworker != -1;
idworker = starpu_bitmap_next(component->children[i]->workers, idworker))
{
if (idworker == workerid)
{
if ((starpu_cpu_worker_get_count() == 0 ||
starpu_worker_get_type(workerid) == STARPU_CPU_WORKER)
&& (starpu_worker_can_execute_task(workerid,task,nimpl)
|| starpu_combined_worker_can_execute_task(workerid, task, nimpl)))
{
int ret = starpu_sched_component_push_task(component,child,task);
if (!ret)
return 0;
}
}
}
}
}
}
}
/* FIFOs are full */
return 1;
}
struct starpu_sched_component *starpu_sched_component_gemm_create(struct starpu_sched_tree *tree, void *params STARPU_ATTRIBUTE_UNUSED)
{
struct starpu_sched_component *component = starpu_sched_component_create(tree, "gemm");
component->push_task = gemm_push_task;
return component;
}
static void initialize_gemm_center_policy(unsigned sched_ctx_id)
{
starpu_sched_component_initialize_simple_scheduler((starpu_sched_component_create_t) starpu_sched_component_gemm_create, NULL,
STARPU_SCHED_SIMPLE_DECIDE_MEMNODES |
STARPU_SCHED_SIMPLE_FIFO_ABOVE |
STARPU_SCHED_SIMPLE_FIFO_ABOVE_PRIO |
STARPU_SCHED_SIMPLE_FIFOS_BELOW |
STARPU_SCHED_SIMPLE_FIFOS_BELOW_PRIO |
STARPU_SCHED_SIMPLE_IMPL, sched_ctx_id);
}
static void deinitialize_gemm_center_policy(unsigned sched_ctx_id)
{
struct starpu_sched_tree *tree = (struct starpu_sched_tree*)starpu_sched_ctx_get_policy_data(sched_ctx_id);
starpu_sched_tree_destroy(tree);
}
struct starpu_sched_policy _starpu_sched_modular_gemm_policy =
{
.init_sched = initialize_gemm_center_policy,
.deinit_sched = deinitialize_gemm_center_policy,
.add_workers = starpu_sched_tree_add_workers,
.remove_workers = starpu_sched_tree_remove_workers,
.push_task = starpu_sched_tree_push_task,
.pop_task = starpu_sched_tree_pop_task,
.pre_exec_hook = starpu_sched_component_worker_pre_exec_hook,
.post_exec_hook = starpu_sched_component_worker_post_exec_hook,
.pop_every_task = NULL,
.policy_name = "modular-gemm",
.policy_description = "gemm modular policy",
.worker_type = STARPU_WORKER_LIST,
};