Skip to content

Commit

Permalink
Remove the nearly-redundant extra scheduling layer
Browse files Browse the repository at this point in the history
  • Loading branch information
lukstafi committed May 29, 2024
1 parent 09dc6e7 commit 8ed4b35
Show file tree
Hide file tree
Showing 8 changed files with 64 additions and 91 deletions.
18 changes: 5 additions & 13 deletions arrayjit/lib/backends.ml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ module Debug_runtime = Utils.Debug_runtime

type 'context routine = {
context : 'context;
schedule : unit -> Tnode.work;
schedule : Tnode.task;
bindings : Indexing.lowered_bindings;
name : string;
}
Expand Down Expand Up @@ -126,7 +126,7 @@ let forget_printbox (module Runtime : Minidebug_runtime.PrintBox_runtime) =
module Multicore_backend (Backend : No_device_backend) : Backend = struct
module Domain = Domain [@warning "-3"]

type task_list = (Tnode.work[@sexp.opaque]) Utils.mutable_list [@@deriving sexp_of]
type task_list = (Tnode.task[@sexp.opaque]) Utils.mutable_list [@@deriving sexp_of]

type device_state = {
mutable keep_spinning : bool;
Expand Down Expand Up @@ -223,21 +223,13 @@ module Multicore_backend (Backend : No_device_backend) : Backend = struct

let link { ctx; device } code =
let task = Backend.link ctx code in
let%diagn_sexp schedule () =
[%log_result "Scheduling", task.name];
make_work device @@ task.schedule ()
in
{ task with context = { ctx = task.context; device }; schedule }
{ task with context = { ctx = task.context; device }; schedule = make_work device task.schedule }

let link_batch { ctx; device } code_batch =
Array.map (Backend.link_batch ctx code_batch)
~f:
(Option.map ~f:(fun task ->
let%diagn_sexp schedule () =
[%log_result "Scheduling from batch", task.name];
make_work device @@ task.schedule ()
in
{ task with context = { ctx = task.context; device }; schedule }))
{ task with context = { ctx = task.context; device }; schedule = make_work device task.schedule }))

let from_host ?rt context tn =
if Option.is_some rt then
Expand Down Expand Up @@ -338,7 +330,7 @@ module type Simple_backend = sig
procedure option array

val link_compiled :
context -> procedure -> context * Indexing.lowered_bindings * (unit -> Tnode.work) * string
context -> procedure -> context * Indexing.lowered_bindings * Tnode.task * string

val from_host : ?rt:(module Minidebug_runtime.Debug_runtime) -> context -> Tnode.t -> unit
val to_host : ?rt:(module Minidebug_runtime.Debug_runtime) -> context -> Tnode.t -> unit
Expand Down
49 changes: 23 additions & 26 deletions arrayjit/lib/cuda_backend.cudajit.ml
Original file line number Diff line number Diff line change
Expand Up @@ -589,10 +589,10 @@ let link old_context (code : code) =
let context = { old_context with run_module = Some run_module; global_arrays; all_arrays } in
let idx_params = Indexing.bound_symbols code.bindings in
let idx_args = List.map idx_params ~f:(fun s -> (s, ref 0)) in
let%diagn_sexp schedule () =
let%diagn_rt_sexp work () : unit =
let log_id = get_global_run_id () in
let log_id_prefix = Int.to_string log_id ^ ": " in
[%log_result "Scheduling", code.name, context.label, (log_id : int)];
[%log_result "Launching", code.name, context.label, (log_id : int)];
let module Cu = Cudajit in
let log_arg = if Utils.settings.debug_log_from_routines then [ Cu.Int log_id ] else [] in
let idx_args =
Expand Down Expand Up @@ -620,31 +620,28 @@ let link old_context (code : code) =
| Global, Some _, Some ptr -> Some (Cu.Tensor ptr)
| _ -> None)
in
let%diagn_rt_sexp work () : unit =
[%log "zeroing-out global memory"];
set_ctx context.ctx;
Map.iteri global_arrays ~f:(fun ~key ~data:ptr ->
if Hash_set.mem code.info.used_tensors key then
let node = Map.find_exn all_arrays key in
if node.zero_initialized then Cu.memset_d8 ptr Unsigned.UChar.zero ~length:node.size_in_bytes);
[%log "launching the kernel"];
(* if Utils.settings.debug_log_from_routines then Cu.ctx_set_limit CU_LIMIT_PRINTF_FIFO_SIZE 4096; *)
Cu.launch_kernel func ~grid_dim_x:1 ~block_dim_x:1 ~shared_mem_bytes:0 Cu.no_stream
@@ log_arg @ idx_args @ args;
[%log "kernel launched"];
if Utils.settings.debug_log_from_routines then
let postprocess_logs ~output =
let output = List.filter_map output ~f:(String.chop_prefix ~prefix:log_id_prefix) in
[%log_entry
context.label;
Utils.log_trace_tree _debug_runtime output]
in
context.device.physical.postprocess_queue <-
(context, postprocess_logs) :: context.device.physical.postprocess_queue
in
Tnode.Work work
[%log "zeroing-out global memory"];
set_ctx context.ctx;
Map.iteri global_arrays ~f:(fun ~key ~data:ptr ->
if Hash_set.mem code.info.used_tensors key then
let node = Map.find_exn all_arrays key in
if node.zero_initialized then Cu.memset_d8 ptr Unsigned.UChar.zero ~length:node.size_in_bytes);
[%log "launching the kernel"];
(* if Utils.settings.debug_log_from_routines then Cu.ctx_set_limit CU_LIMIT_PRINTF_FIFO_SIZE 4096; *)
Cu.launch_kernel func ~grid_dim_x:1 ~block_dim_x:1 ~shared_mem_bytes:0 Cu.no_stream
@@ log_arg @ idx_args @ args;
[%log "kernel launched"];
if Utils.settings.debug_log_from_routines then
let postprocess_logs ~output =
let output = List.filter_map output ~f:(String.chop_prefix ~prefix:log_id_prefix) in
[%log_entry
context.label;
Utils.log_trace_tree _debug_runtime output]
in
context.device.physical.postprocess_queue <-
(context, postprocess_logs) :: context.device.physical.postprocess_queue
in
(context, idx_args, schedule)
(context, idx_args, Tnode.Work work)

let link_batch old_context (code_batch : code_batch) =
let idx_params = Indexing.bound_symbols code_batch.bindings in
Expand Down
4 changes: 2 additions & 2 deletions arrayjit/lib/cuda_backend.missing.ml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ let compile ?name:_ bindings _optimized = bindings

let link (Unimplemented : context) code =
let lowered_bindings = List.map ~f:(fun s -> (s, ref 0)) @@ Indexing.bound_symbols code in
let work () = Tnode.Work (fun _debug_runtime () -> ()) in
((Unimplemented : context), lowered_bindings, work)
let task = Tnode.Work (fun _debug_runtime () -> ()) in
((Unimplemented : context), lowered_bindings, task)

let unsafe_cleanup ?unsafe_shutdown:_ () = ()
let from_host ?rt:_ _context _arr = ()
Expand Down
7 changes: 2 additions & 5 deletions arrayjit/lib/cuda_backend.mli
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,8 @@ val compile : ?name:string -> Indexing.unit_bindings -> Low_level.optimized -> c
val compile_batch :
names:string option array -> Indexing.unit_bindings -> Low_level.optimized option array -> code_batch

val link : context -> code -> context * Indexing.lowered_bindings * (unit -> Tnode.work)

val link_batch :
context -> code_batch -> context * Indexing.lowered_bindings * (unit -> Tnode.work) option array

val link : context -> code -> context * Indexing.lowered_bindings * Tnode.task
val link_batch : context -> code_batch -> context * Indexing.lowered_bindings * Tnode.task option array
val unsafe_cleanup : ?unsafe_shutdown:bool -> unit -> unit

val from_host : ?rt:(module Minidebug_runtime.Debug_runtime) -> context -> Tnode.t -> unit
Expand Down
46 changes: 18 additions & 28 deletions arrayjit/lib/gccjit_backend.ml
Original file line number Diff line number Diff line change
Expand Up @@ -710,47 +710,37 @@ let%track_sexp link_compiled (old_context : context) (code : procedure) : contex
wrong order but that's OK because [link] only uses them to check the number of indices. *)
link code.bindings (List.rev code.params) Ctypes.(void @-> returning void)]
in
let schedule () =
let callback = Indexing.apply run_variadic in
let%diagn_rt_sexp work () : unit =
[%log_result name];
callback ();
if Utils.settings.debug_log_from_routines then (
Utils.log_trace_tree _debug_runtime (Stdio.In_channel.read_lines log_file_name);
Stdlib.Sys.remove log_file_name)
in
Tn.Work work
let%diagn_rt_sexp work () : unit =
[%log_result name];
Indexing.apply run_variadic ();
if Utils.settings.debug_log_from_routines then (
Utils.log_trace_tree _debug_runtime (Stdio.In_channel.read_lines log_file_name);
Stdlib.Sys.remove log_file_name)
in
(context, Indexing.lowered_bindings code.bindings run_variadic, schedule, name)
(context, Indexing.lowered_bindings code.bindings run_variadic, Tn.Work work, name)

let from_host ?rt (context : context) (tn : Tn.t) : unit =
let from_host ?rt:_ (context : context) (tn : Tn.t) : unit =
Option.iter (Map.find context.arrays tn) ~f:(fun c_arr ->
match tn.Tn.array with
| (lazy (Some h_arr)) ->
Ndarray.map2 { f2 = Ndarray.A.blit } h_arr c_arr;
if Utils.settings.with_debug_level > 0 then
let module Debug_runtime =
(val Option.value_or_thunk rt ~default:(fun () ->
(module Debug_runtime : Minidebug_runtime.Debug_runtime)))
in
[%diagn_sexp
[%log_entry
"from_host " ^ Tn.unsafe_ident tn;
[%log "copied", Tn.label tn, Tn.name tn, "from host"];
if Utils.settings.with_debug_level > 1 then
[%log_printbox
let indices = Array.init (Array.length @@ Lazy.force tn.dims) ~f:(fun i -> i - 5) in
Ndarray.render_array ~indices c_arr]]]
Ndarray.map2 { f2 = Ndarray.A.blit } h_arr c_arr
(* ; if Utils.settings.with_debug_level > 0 then let module Debug_runtime = (val
Option.value_or_thunk rt ~default:(fun () -> (module Debug_runtime :
Minidebug_runtime.Debug_runtime))) in [%diagn_sexp [%log_entry "from_host " ^ Tn.unsafe_ident tn;
[%log "copied", Tn.label tn, Tn.name tn, "from host"]; if Utils.settings.with_debug_level > 1
then [%log_printbox let indices = Array.init (Array.length @@ Lazy.force tn.dims) ~f:(fun i -> i
- 5) in Ndarray.render_array ~indices c_arr]]] *)
| (lazy None) -> ())

let to_host ?(rt : (module Minidebug_runtime.Debug_runtime) option) (context : context) (tn : Tn.t) : unit =
let to_host ?rt (context : context) (tn : Tn.t) : unit =
Option.iter (Map.find context.arrays tn) ~f:(fun c_arr ->
match tn.Tn.array with
| (lazy (Some h_arr)) ->
Ndarray.map2 { f2 = Ndarray.A.blit } c_arr h_arr;
if Utils.settings.with_debug_level > 0 then
let module Debug_runtime =
(val Option.value_or_thunk rt ~default:(fun () -> (module Debug_runtime)))
(val Option.value_or_thunk rt ~default:(fun () ->
(module Debug_runtime : Minidebug_runtime.Debug_runtime)))
in
[%diagn_sexp
[%log_entry
Expand Down
2 changes: 1 addition & 1 deletion arrayjit/lib/tnode.ml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ module Debug_runtime = Utils.Debug_runtime
[%%global_debug_log_level Nothing]
[%%global_debug_log_level_from_env_var "OCANNL_LOG_LEVEL"]

type work = Work of ((module Minidebug_runtime.Debug_runtime) -> unit -> unit)
type task = Work of ((module Minidebug_runtime.Debug_runtime) -> unit -> unit)
[@@unboxed] [@@deriving sexp_of]

let[@inline] run debug_runtime (Work work) = work debug_runtime ()
Expand Down
10 changes: 5 additions & 5 deletions arrayjit/lib/writing_a_backend.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ Currently, OCANNL integrates new backends via code in [backends.ml](backends.ml)
```ocaml
type lowered_bindings = (static_symbol, int ref) List.Assoc.t (* in indexing.ml *)
type work = Work of ((module Debug_runtime) -> unit -> unit) (* in tnode.ml *)
type task = Work of ((module Debug_runtime) -> unit -> unit) (* in tnode.ml *)
type 'context routine = {
context : 'context;
schedule : unit -> work;
schedule : task;
bindings : lowered_bindings;
name : string;
}
Expand Down Expand Up @@ -48,7 +48,7 @@ end

where `Simple_backend` implements a `No_device_backend` functionality, but only needs to deal with `Low_level.optimized` and its compilation result type `procedure`.

`No_device_backend`s do not themselves deal with the device abstraction. There's the functor `Multicore_backend (Backend : No_device_backend) : Backend` that assigns a device to a domain, and manages the given `No_device_backend` on the domain-based devices. Scheduling a work captures the state of the bindings at that point. Scheduling should never ever block. Running a work on a `No_device_backend` _should block_ (till execution finishes), but it _should not block_ for a proper `Backend` -- it should just put the work on the device's queue.
`No_device_backend`s do not themselves deal with the device abstraction. There's the functor `Multicore_backend (Backend : No_device_backend) : Backend` that assigns a device to a domain, and manages the given `No_device_backend` on the domain-based devices. Running `schedule` on a `No_device_backend` _should block_ (till execution finishes), but it _should not block_ for a proper `Backend` -- it should just put the work on the device's queue.

```ocaml
module type Backend = sig
Expand Down Expand Up @@ -149,7 +149,7 @@ module type Simple_backend = sig
...
val link_compiled :
context -> procedure -> context * Indexing.lowered_bindings * (unit -> Tnode.work) * string
context -> procedure -> context * Indexing.lowered_bindings * Tnode.task * string
...
end
Expand Down Expand Up @@ -217,7 +217,7 @@ Currently, OCANNL expects backends to implement a FIFO queue scheduling mechanis

Since this is significantly simpler than what other frameworks do, it might evolve in the future. (In particular, scheduling in `tinygrad` expresses tensor graph dependencies with arbitrary queueing.) A natural next step would be to add "acknowledge" events that indirectly keep track of (and signal) which tasks a device has already executed.

Scheduling a routine has two steps: the first step (invoking `schedule`) captures the state of the routine's static indexing bindings and returns its task, then, running the task for a device puts the task on the device's queue. Usually these steps are performed together. Besides routines, calling `from_host`, `to_host`, `device_to_device` from a backend puts the corresponding tasks on the device's queue. Implementations of `No_device_backend` and `Simple_backend` (i.e. CPU backends) should run the tasks by executing them directly.
Besides routines, calling `from_host`, `to_host`, `device_to_device` from a backend puts the corresponding tasks on the device's queue. Implementations of `No_device_backend` and `Simple_backend` (i.e. CPU backends) should run the tasks by executing them directly.

### Data transfers

Expand Down
19 changes: 8 additions & 11 deletions lib/train.ml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ module IDX = struct
end

let debug_rt = (module Debug_runtime : Minidebug_runtime.Debug_runtime)
let run jitted = Tn.run debug_rt @@ jitted.Arrayjit.Backends.schedule ()
let run jitted = Tn.run debug_rt jitted.Arrayjit.Backends.schedule

(** Reinitializes a backend selected via a global [backend] flag. *)
let fresh_backend ?backend_name ?(config = `Physical_devices_only) () =
Expand Down Expand Up @@ -267,12 +267,11 @@ let%debug_sexp all_device_to_host (type context) (module Backend : Backend_type
let%track_sexp sync_run ?looping (type context) (module Backend : Backend_type with type context = context)
(routine : Backend.routine) t =
all_host_to_device (module Backend) routine.context t;
let work = routine.schedule () in
(match looping with
| None -> Tn.run debug_rt work
| None -> Tn.run debug_rt routine.schedule
| Some then_ ->
let f () =
Tn.run debug_rt work;
Tn.run debug_rt routine.schedule;
then_ ()
in
sequential_loop ~f routine.bindings);
Expand Down Expand Up @@ -421,8 +420,7 @@ let%track_sexp parallel_update (type context) (module Backend : Backend_type wit
let grad_merges =
Array.init num_devices ~f:(fun (to_ : int) ->
Array.init num_devices ~f:(fun (from : int) ->
(* It is safe to cache scheduling, because merging does not use static indices. *)
List.map grad_merges.(from) ~f:(fun c -> (Backend.link ctxs.(to_) c).schedule ())))
List.map grad_merges.(from) ~f:(fun c -> (Backend.link ctxs.(to_) c).schedule)))
in
(* We can cache scheduling, because merging and copying does not depend on static indexing. *)
let name_prefixes = Array.create ~len:num_devices "loss_merge" in
Expand All @@ -433,10 +431,9 @@ let%track_sexp parallel_update (type context) (module Backend : Backend_type wit
let loss_merges =
Array.init num_devices ~f:(fun (to_ : int) ->
Array.init num_devices ~f:(fun (from : int) ->
(* It is safe to cache scheduling, because merging does not use static indices. *)
match loss_merges.(from) with
| [] -> None
| [ c ] -> Some ((Backend.link ctxs.(to_) c).schedule ())
| [ c ] -> Some (Backend.link ctxs.(to_) c).schedule
| _ -> assert false))
in
let merge ~(from : int) ~(to_ : int) : unit =
Expand All @@ -460,11 +457,11 @@ let%track_sexp parallel_update (type context) (module Backend : Backend_type wit
in
let copies =
Array.init (num_devices - 1) ~f:(fun (to_m_1 : int) ->
List.map copies ~f:(fun c -> (Backend.link ctxs.(to_m_1 + 1) c).schedule ()))
List.map copies ~f:(fun c -> (Backend.link ctxs.(to_m_1 + 1) c).schedule))
in
let%track_sexp sync (devices_to_sync : int) : unit =
Arrayjit.Utils.parallel_merge merge devices_to_sync;
Tn.run debug_rt @@ sgd_update.schedule ();
Tn.run debug_rt sgd_update.schedule;
(* We need to wait, because copying happens on other devices. *)
Set.iter !needed_on_host ~f:(fun p -> Backend.to_host sgd_update.context p);
Backend.(await @@ get_ctx_device sgd_update.context);
Expand All @@ -475,7 +472,7 @@ let%track_sexp parallel_update (type context) (module Backend : Backend_type wit
post_sync ~num_synced_devices:devices_to_sync
in
let lowered_bindings = [%debug_notrace Array.map grad_updates ~f:(fun upd -> upd.bindings)] in
let fs = [%debug_notrace Array.map grad_updates ~f:(fun upd () -> Tn.run debug_rt @@ upd.schedule ())] in
let fs = [%debug_notrace Array.map grad_updates ~f:(fun upd () -> Tn.run debug_rt upd.schedule)] in
fun () -> round_robin fs lowered_bindings sgd_update.bindings ~sync

let debug_name t = Tn.(debug_name ~id:t.Tensor.value.id ~label:t.value.label)
Expand Down

0 comments on commit 8ed4b35

Please sign in to comment.