diff --git a/lib/flame/fly_backend.ex b/lib/flame/fly_backend.ex index 52c57c2..63e008b 100644 --- a/lib/flame/fly_backend.ex +++ b/lib/flame/fly_backend.ex @@ -33,6 +33,8 @@ defmodule FLAME.FlyBackend do * `:gpus` - The number of runner GPUs. Defaults to `1` if `:gpu_kind` is set. + * `:mounts` - List volumes to mount. Refer to FlyBackend.Mount for opt details + * `:boot_timeout` - The boot timeout. Defaults to `30_000`. * `:app` – The name of the otp app. Defaults to `System.get_env("FLY_APP_NAME")`, @@ -88,6 +90,8 @@ defmodule FLAME.FlyBackend do alias FLAME.FlyBackend alias FLAME.Parser.JSON + alias FLAME.FlyBackend.Mount + require Logger @derive {Inspect, @@ -99,6 +103,7 @@ defmodule FLAME.FlyBackend do :memory_mb, :gpu_kind, :gpus, + :mounts, :image, :app, :runner_id, @@ -120,6 +125,7 @@ defmodule FLAME.FlyBackend do memory_mb: nil, gpu_kind: nil, gpus: nil, + mounts: [], image: nil, services: [], metadata: %{}, @@ -149,6 +155,7 @@ defmodule FLAME.FlyBackend do :memory_mb, :gpu_kind, :gpus, + :mounts, :boot_timeout, :env, :terminator_sup, @@ -173,6 +180,7 @@ defmodule FLAME.FlyBackend do memory_mb: 4096, boot_timeout: 30_000, services: [], + mounts: [], metadata: %{}, init: %{}, log: Keyword.get(conf, :log, false) @@ -191,7 +199,11 @@ defmodule FLAME.FlyBackend do end end - state = %{state | runner_node_base: "#{state.app}-flame-#{rand_id(20)}"} + mounts = state.mounts |> List.wrap() |> Enum.map(&Mount.parse_opts/1) + + state = + Map.merge(state, %{mounts: mounts, runner_node_base: "#{state.app}-flame-#{rand_id(20)}"}) + parent_ref = make_ref() encoded_parent = @@ -250,23 +262,80 @@ defmodule FLAME.FlyBackend do {result, div(micro, 1000)} end + defp allocate_volume_ids(%FlyBackend{mounts: []}), do: [] + + defp allocate_volume_ids(%FlyBackend{mounts: mounts} = state) when is_list(mounts) do + case get_volumes(state) do + [] -> + {:error, "no Fly volumes found"} + + all_volumes -> + volume_ids_by_name = + all_volumes + |> Enum.filter(fn vol -> + vol["attached_machine_id"] == nil && vol["state"] == "created" && + vol["host_status"] == "ok" && + Map.get(state, :region, vol["region"]) == vol["region"] + end) + |> Enum.shuffle() + |> Enum.group_by(& &1["name"], & &1["id"]) + + {new_mounts, _unused_vols} = + Enum.map_reduce( + mounts, + volume_ids_by_name, + fn + %{volume: nil} = mount, leftover_vols -> + case leftover_vols[mount.name] do + [volume_id | rest] -> + {%{mount | volume: volume_id}, %{leftover_vols | mount.name => rest}} + + _ -> + raise ArgumentError, + "no available Fly volumes with the name \"#{mount.name}\" in region \"#{Map.get(state, :region)}\" found" + end + + mount, leftover_vols -> + {mount, leftover_vols} + end + ) + + Enum.map(new_mounts, &Map.from_struct/1) + end + end + + defp allocate_volume_ids(_) do + raise ArgumentError, "expected a list of mounts" + end + + defp get_volumes(%FlyBackend{} = state) do + http_request!(:get, "#{state.host}/v1/apps/#{state.app}/volumes", @retry, + headers: [ + {"Accept", "application/json"}, + {"Authorization", "Bearer #{state.token}"} + ], + connect_timeout: state.boot_timeout + ) + end + @impl true def remote_boot(%FlyBackend{parent_ref: parent_ref} = state) do {resp, req_connect_time} = with_elapsed_ms(fn -> - http_post!("#{state.host}/v1/apps/#{state.app}/machines", @retry, + http_request!(:post, "#{state.host}/v1/apps/#{state.app}/machines", @retry, content_type: "application/json", headers: [ {"Content-Type", "application/json"}, {"Authorization", "Bearer #{state.token}"} ], connect_timeout: state.boot_timeout, - body: + body: fn -> JSON.encode!(%{ name: state.runner_node_base, region: state.region, config: %{ image: state.image, + mounts: allocate_volume_ids(state), init: state.init, guest: %{ cpu_kind: state.cpu_kind, @@ -282,6 +351,7 @@ defmodule FLAME.FlyBackend do metadata: Map.put(state.metadata, :flame_parent_ip, state.local_ip) } }) + end ) end) @@ -334,32 +404,17 @@ defmodule FLAME.FlyBackend do |> binary_part(0, len) end - defp http_post!(url, remaining_tries, opts) do - Keyword.validate!(opts, [:headers, :body, :connect_timeout, :content_type]) + defp http_request!(method, url, remaining_tries, opts) when method in [:get, :post] do + validation_request_opts!(method, opts) headers = for {field, val} <- Keyword.fetch!(opts, :headers), do: {String.to_charlist(field), val} - body = Keyword.fetch!(opts, :body) - connect_timeout = Keyword.fetch!(opts, :connect_timeout) - content_type = Keyword.fetch!(opts, :content_type) + request = make_request(method, url, headers, opts) + http_opts = make_http_opts(opts) - http_opts = [ - ssl: - [ - verify: :verify_peer, - depth: 2, - customize_hostname_check: [ - match_fun: :public_key.pkix_verify_hostname_match_fun(:https) - ] - ] ++ cacerts_options(), - connect_timeout: connect_timeout - ] - - case :httpc.request(:post, {url, headers, ~c"#{content_type}", body}, http_opts, - body_format: :binary - ) do + case :httpc.request(method, request, http_opts, body_format: :binary) do {:ok, {{_, 200, _}, _, response_body}} -> JSON.decode!(response_body) @@ -370,16 +425,56 @@ defmodule FLAME.FlyBackend do {:ok, {{_, status, _}, _, _response_body}} when status in [429, 412, 409, 422] and remaining_tries > 0 -> Process.sleep(1000) - http_post!(url, remaining_tries - 1, opts) + http_request!(method, url, remaining_tries - 1, opts) {:ok, {{_, status, reason}, _, resp_body}} -> - raise "failed POST #{url} with #{inspect(status)} (#{inspect(reason)}): #{inspect(resp_body)} #{inspect(headers)}" + raise "failed #{method} #{url} with #{inspect(status)} (#{inspect(reason)}): #{inspect(resp_body)} #{inspect(headers)}" {:error, reason} -> - raise "failed POST #{url} with #{inspect(reason)} #{inspect(headers)}" + raise "failed #{method} #{url} with #{inspect(reason)} #{inspect(headers)}" end end + defp validation_request_opts!(:get, opts) do + Keyword.validate!(opts, [:headers, :connect_timeout]) + end + + defp validation_request_opts!(:post, opts) do + Keyword.validate!(opts, [:headers, :body, :connect_timeout, :content_type]) + end + + defp make_request(:get, url, headers, _opts) do + {url, headers} + end + + defp make_request(:post, url, headers, opts) do + content_type = Keyword.fetch!(opts, :content_type) + + body = + case Keyword.fetch!(opts, :body) do + body_func when is_function(body_func) -> body_func.() + body -> body + end + + {url, headers, ~c"#{content_type}", body} + end + + defp make_http_opts(opts) do + connect_timeout = Keyword.fetch!(opts, :connect_timeout) + + [ + ssl: + [ + verify: :verify_peer, + depth: 2, + customize_hostname_check: [ + match_fun: :public_key.pkix_verify_hostname_match_fun(:https) + ] + ] ++ cacerts_options(), + connect_timeout: connect_timeout + ] + end + defp cacerts_options do cond do certs = otp_cacerts() -> diff --git a/lib/flame/fly_backend/mount.ex b/lib/flame/fly_backend/mount.ex new file mode 100644 index 0000000..6bd72ac --- /dev/null +++ b/lib/flame/fly_backend/mount.ex @@ -0,0 +1,42 @@ +defmodule FLAME.FlyBackend.Mount do + # Refer to the "mount:" section most of the may down this page for how to use these keys + # https://fly.io/docs/machines/api/machines-resource/ + + alias FLAME.FlyBackend.Mount + + @derive {Inspect, + only: [ + :volume, + :path, + :name, + :extend_threshold_percent, + :add_size_gb, + :size_gb_limit + ]} + defstruct volume: nil, + path: nil, + name: nil, + extend_threshold_percent: nil, + add_size_gb: nil, + size_gb_limit: nil + + @valid_opts [:volume, :path, :name, :extend_threshold_percent, :add_size_gb, :size_gb_limit] + + @required_opts [:path, :name] + + def parse_opts(opts) do + default = %Mount{extend_threshold_percent: 0, add_size_gb: 0, size_gb_limit: 0} + + provided_opts = Keyword.validate!(opts, @valid_opts) + + %Mount{} = state = Map.merge(default, Map.new(provided_opts)) + + for key <- @required_opts do + unless Map.get(state, key) do + raise ArgumentError, "missing :#{key} config for #{inspect(__MODULE__)}" + end + end + + state + end +end