Skip to content

Commit

Permalink
Custom transforms (#1246)
Browse files Browse the repository at this point in the history
  • Loading branch information
angeloskath authored Jul 11, 2024
1 parent a3c2873 commit 5c1fa64
Show file tree
Hide file tree
Showing 16 changed files with 734 additions and 39 deletions.
1 change: 1 addition & 0 deletions docs/src/python/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Transforms

eval
compile
custom_function
disable_compile
enable_compile
grad
Expand Down
6 changes: 5 additions & 1 deletion mlx/array.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ bool in_tracing() {
return detail::InTracing::in_tracing();
}

bool retain_graph() {
return detail::RetainGraph::retain_graph();
}

} // namespace

array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
Expand Down Expand Up @@ -102,7 +106,7 @@ void array::eval() {
}

bool array::is_tracer() const {
return array_desc_->is_tracer && in_tracing();
return array_desc_->is_tracer && in_tracing() || retain_graph();
}

void array::set_data(allocator::Buffer buffer, deleter_t d) {
Expand Down
2 changes: 1 addition & 1 deletion mlx/backend/accelerate/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ DEFAULT(Ceil)
DEFAULT(Concatenate)
DEFAULT(Conjugate)
DEFAULT(Copy)
DEFAULT_MULTI(CustomVJP)
DEFAULT_MULTI(CustomTransforms)
DEFAULT_MULTI(Depends)
DEFAULT_MULTI(DivMod)
DEFAULT(NumberOfElements)
Expand Down
2 changes: 1 addition & 1 deletion mlx/backend/common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ void Copy::eval(const std::vector<array>& inputs, array& out) {
out.copy_shared_buffer(inputs[0]);
}

void CustomVJP::eval(
void CustomTransforms::eval(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() > outputs.size());
Expand Down
2 changes: 1 addition & 1 deletion mlx/backend/common/default_primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ DEFAULT(Convolution)
DEFAULT(Copy)
DEFAULT(Cos)
DEFAULT(Cosh)
DEFAULT_MULTI(CustomVJP)
DEFAULT_MULTI(CustomTransforms)
DEFAULT_MULTI(Depends)
DEFAULT(Divide)
DEFAULT(NumberOfElements)
Expand Down
2 changes: 1 addition & 1 deletion mlx/backend/metal/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ void Copy::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}

void CustomVJP::eval_gpu(
void CustomTransforms::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
eval(inputs, outputs);
Expand Down
2 changes: 1 addition & 1 deletion mlx/backend/no_cpu/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ NO_CPU(Convolution)
NO_CPU(Copy)
NO_CPU(Cos)
NO_CPU(Cosh)
NO_CPU_MULTI(CustomVJP)
NO_CPU_MULTI(CustomTransforms)
NO_CPU_MULTI(Depends)
NO_CPU(Divide)
NO_CPU_MULTI(DivMod)
Expand Down
2 changes: 1 addition & 1 deletion mlx/backend/no_metal/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ NO_GPU(Convolution)
NO_GPU(Copy)
NO_GPU(Cos)
NO_GPU(Cosh)
NO_GPU_MULTI(CustomVJP)
NO_GPU_MULTI(CustomTransforms)
NO_GPU_MULTI(Depends)
NO_GPU(Divide)
NO_GPU_MULTI(DivMod)
Expand Down
17 changes: 9 additions & 8 deletions mlx/fast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ std::vector<array> Custom::vjp(
auto [_, vjps] = mlx::core::vjp(fallback_, primals, cotangents);
std::vector<array> vjp_outs;
for (int i = 0, j = 0; i < vjps.size(); ++i) {
if (i < argnums.size() && i == argnums[j]) {
if (j < argnums.size() && i == argnums[j]) {
vjp_outs.push_back(vjps[i]);
j++;
}
Expand All @@ -30,15 +30,16 @@ std::vector<array> Custom::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
auto [_, jvps] = mlx::core::jvp(fallback_, primals, tangents);
std::vector<array> jvp_outs;
for (int i = 0, j = 0; i < jvps.size(); ++i) {
if (i < argnums.size() && i == argnums[j]) {
jvp_outs.push_back(jvps[i]);
j++;
std::vector<array> all_tangents;
for (int i = 0, j = 0; i < primals.size(); i++) {
if (j < argnums.size() && i == argnums[j]) {
all_tangents.emplace_back(tangents[j++]);
} else {
all_tangents.emplace_back(zeros_like(primals[i]));
}
}
return jvp_outs;
auto [_, jvps] = mlx::core::jvp(fallback_, primals, all_tangents);
return jvps;
}

std::pair<std::vector<array>, std::vector<int>> Custom::vmap(
Expand Down
28 changes: 26 additions & 2 deletions mlx/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1113,17 +1113,21 @@ std::pair<std::vector<array>, std::vector<int>> Cosh::vmap(
return {{cosh(inputs[0], stream())}, axes};
}

std::vector<array> CustomVJP::vjp(
std::vector<array> CustomTransforms::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) {
std::vector<array> inputs(primals.begin(), primals.end() - outputs.size());
// Extract the inputs to the VJP function
std::vector<array> inputs(primals.begin(), primals.end() - num_outputs_);

// Compute all the vjps
auto all_vjps = vjp_fun_(inputs, cotangents, outputs);
for (const auto& cot : cotangents) {
all_vjps.emplace_back(cot);
}

// Select the vjps requested
std::vector<array> vjps;
vjps.reserve(argnums.size());
for (auto arg : argnums) {
Expand All @@ -1133,6 +1137,26 @@ std::vector<array> CustomVJP::vjp(
return vjps;
}

std::vector<array> CustomTransforms::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
// Extract the inputs to the JVP function
std::vector<array> inputs(primals.begin(), primals.end() - num_outputs_);

// Compute the jvps
return jvp_fun_(inputs, tangents, argnums);
}

std::pair<std::vector<array>, std::vector<int>> CustomTransforms::vmap(
const std::vector<array>& inputs_,
const std::vector<int>& axes_) {
// Extract the inputs to the vmap function
std::vector<array> inputs(inputs_.begin(), inputs_.end() - num_outputs_);
std::vector<int> axes(axes_.begin(), axes_.end() - num_outputs_);
return vmap_fun_(inputs, axes);
}

std::vector<array> Depends::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
Expand Down
41 changes: 30 additions & 11 deletions mlx/primitives.h
Original file line number Diff line number Diff line change
Expand Up @@ -729,37 +729,56 @@ class Cosh : public UnaryPrimitive {
void eval(const std::vector<array>& inputs, array& out);
};

class CustomVJP : public Primitive {
class CustomTransforms : public Primitive {
public:
explicit CustomVJP(
explicit CustomTransforms(
Stream stream,
int num_outputs,
std::function<std::vector<array>(
const std::vector<array>&,
const std::vector<array>&,
const std::vector<array>&)> fun)
: Primitive(stream), vjp_fun_(std::move(fun)) {}
const std::vector<array>&)> vjp,
std::function<std::vector<array>(
const std::vector<array>&,
const std::vector<array>&,
const std::vector<int>&)> jvp,
std::function<std::pair<std::vector<array>, std::vector<int>>(
const std::vector<array>&,
const std::vector<int>&)> vmap)
: Primitive(stream),
num_outputs_(num_outputs),
vjp_fun_(std::move(vjp)),
jvp_fun_(std::move(jvp)),
vmap_fun_(std::move(vmap)) {}

void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;

std::vector<array> vjp(
const std::vector<array>& primals,
const std::vector<array>& cotan,
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;

DEFINE_PRINT(CustomVJP);
DEFINE_GRADS();
DEFINE_VMAP();
DEFINE_PRINT(CustomTransforms);

private:
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);

int num_outputs_;

std::function<std::vector<array>(
const std::vector<array>&,
const std::vector<array>&,
const std::vector<array>&)>
vjp_fun_;
std::function<std::vector<array>(
const std::vector<array>&,
const std::vector<array>&,
const std::vector<int>&)>
jvp_fun_;
std::function<std::pair<std::vector<array>, std::vector<int>>(
const std::vector<array>&,
const std::vector<int>&)>
vmap_fun_;
};

class Depends : public Primitive {
Expand Down
86 changes: 79 additions & 7 deletions mlx/transforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,11 @@ class Synchronizer : public Primitive {
// Initialize the static tracing counter from transforms_impl.h .
//
// This is used to implement the in_tracing() function the returns true if we
// are currently under a function transformation.
// are currently under a function transformation and the retain_graph()
// function which returns true if we are forced to retain the graph during
// evaluation.
int detail::InTracing::tracing_counter{0};
int detail::RetainGraph::tracing_counter{0};

array eval_impl(std::vector<array> outputs, bool async) {
std::queue<array> tape;
Expand Down Expand Up @@ -331,7 +334,11 @@ std::pair<std::vector<array>, std::vector<array>> vjp(
}
}

auto vjps = a.primitive().vjp(a.inputs(), cotangents, argnums, outputs);
std::vector<array> vjps;
{
detail::RetainGraph retain;
vjps = a.primitive().vjp(a.inputs(), cotangents, argnums, outputs);
}
// Accumulate the vector-jacobian products for each input
for (int i = 0; i < argnums.size(); ++i) {
auto in_id = a.inputs()[argnums[i]].id();
Expand Down Expand Up @@ -778,14 +785,27 @@ std::function<array(const array&)> vmap(
return [vfun](const array& a) { return vfun({a})[0]; };
}

std::function<std::vector<array>(const std::vector<array>&)> custom_vjp(
std::function<std::vector<array>(const std::vector<array>&)> custom_function(
std::function<std::vector<array>(const std::vector<array>&)> fun,
std::function<std::vector<array>(
std::optional<std::function<std::vector<array>(
const std::vector<array>&,
const std::vector<array>&,
const std::vector<array>&)> fun_vjp) {
const std::vector<array>&)>> fun_vjp /* = std::nullopt */,
std::optional<std::function<std::vector<array>(
const std::vector<array>&,
const std::vector<array>&,
const std::vector<int>&)>> fun_jvp /* = std::nullopt */,
std::optional<std::function<std::pair<std::vector<array>, std::vector<int>>(
const std::vector<array>&,
const std::vector<int>&)>> fun_vmap /* = std::nullopt */) {
if (!fun_vjp.has_value() && !fun_jvp.has_value() && !fun_vmap.has_value()) {
return fun;
}

return [fun = std::move(fun),
fun_vjp = std::move(fun_vjp)](const std::vector<array>& args) {
fun_vjp = std::move(fun_vjp),
fun_jvp = std::move(fun_jvp),
fun_vmap = std::move(fun_vmap)](const std::vector<array>& args) {
// Compute the outputs
auto outputs = fun(args);
for (auto& out : outputs) {
Expand Down Expand Up @@ -814,11 +834,63 @@ std::function<std::vector<array>(const std::vector<array>&)> custom_vjp(
return array::make_arrays(
std::move(shapes),
dtypes,
std::make_shared<CustomVJP>(to_stream(s), fun_vjp),
std::make_shared<CustomTransforms>(
to_stream(s),
outputs.size(),

// We use the passed vjp function or compute it from the inputs and
// passed cotangents. Note that this may be less efficient than
// using `fun` directly because we may not be able to fully reuse
// the outputs of the forward pass.
fun_vjp.value_or(
[fun](auto primals, auto cotangents, auto outputs) {
auto [__, vjps] = vjp(fun, primals, cotangents);
return vjps;
}),

// We use the passed jvp function or compute it from the primals
// and tangents. Similarly we can't take full advantage of the
// argnums so it is best to use `fun` directly if we don't need a
// custom transform.
//
// TODO: Use stop_gradient to make full use of argnums and not
// waste computation.
fun_jvp.value_or([fun](auto primals, auto tangents, auto argnums) {
std::vector<array> all_tangents;
for (int i = 0, j = 0; i < primals.size(); i++) {
if (j < argnums.size() && i == argnums[j]) {
all_tangents.emplace_back(tangents[j++]);
} else {
all_tangents.emplace_back(zeros_like(primals[i]));
}
}
auto [__, jvps] = jvp(fun, primals, all_tangents);
return jvps;
}),

// Same as above, we use the passed vmap function or we compute it
// from `fun`. The output axes is selected to be all 0s which again
// may be suboptimal but the only thing we can do without any
// information for `fun`.
fun_vmap.value_or(
[fun, out_size = outputs.size()](auto inputs, auto in_axes)
-> std::pair<std::vector<array>, std::vector<int>> {
std::vector<int> out_axes(out_size, 0);
return {vmap(fun, in_axes, out_axes)(inputs), out_axes};
})),
inputs);
};
}

std::function<std::vector<array>(const std::vector<array>&)> custom_vjp(
std::function<std::vector<array>(const std::vector<array>&)> fun,
std::function<std::vector<array>(
const std::vector<array>&,
const std::vector<array>&,
const std::vector<array>&)> fun_vjp) {
return custom_function(fun, fun_vjp, std::nullopt, std::nullopt);
}

std::function<std::vector<array>(const std::vector<array>&)> checkpoint(
std::function<std::vector<array>(const std::vector<array>&)> fun) {
auto vjp_fun = [fun](
Expand Down
29 changes: 27 additions & 2 deletions mlx/transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

#pragma once

#include <optional>

#include "mlx/array.h"

namespace mlx::core {
Expand Down Expand Up @@ -179,8 +181,31 @@ std::function<std::vector<array>(const std::vector<array>&)> vmap(
const std::vector<int>& out_axes = {});

/**
* Return the results of calling fun with args but if their vjp is computed it
* will be computed by fun_vjp.
* Redefine the transformations of `fun` according to the provided functions.
*
* Namely when calling the vjp of `fun` then `fun_vjp` will be called,
* `fun_jvp` for the jvp and `fun_vmap` for vmap.
*
* If any transformation is not provided, then a default one is created by
* calling `vjp`, `jvp` and `vmap` on the function directly.
*/
std::function<std::vector<array>(const std::vector<array>&)> custom_function(
std::function<std::vector<array>(const std::vector<array>&)> fun,
std::optional<std::function<std::vector<array>(
const std::vector<array>&,
const std::vector<array>&,
const std::vector<array>&)>> fun_vjp = std::nullopt,
std::optional<std::function<std::vector<array>(
const std::vector<array>&,
const std::vector<array>&,
const std::vector<int>&)>> fun_jvp = std::nullopt,
std::optional<std::function<std::pair<std::vector<array>, std::vector<int>>(
const std::vector<array>&,
const std::vector<int>&)>> fun_vmap = std::nullopt);

/**
* Return a function that behaves exactly like `fun` but if the vjp of the
* results is computed `fun_vjp` will be used instead of `vjp(fun, ...)` .
*/
std::function<std::vector<array>(const std::vector<array>&)> custom_vjp(
std::function<std::vector<array>(const std::vector<array>&)> fun,
Expand Down
Loading

0 comments on commit 5c1fa64

Please sign in to comment.