diff --git a/.rubocop.yml b/.rubocop.yml index 5c614cd..3932dfe 100644 --- a/.rubocop.yml +++ b/.rubocop.yml @@ -9,6 +9,14 @@ AllCops: Lint/MissingSuper: Enabled: false +Metrics/MethodLength: + Enabled: true + Exclude: + - "lib/mars/workflows/parallel.rb" + +RSpec/ExampleLength: + Enabled: false + Style/Documentation: Enabled: false diff --git a/examples/parallel_workflow/diagram.md b/examples/parallel_workflow/diagram.md index 24a4829..0efc8cb 100644 --- a/examples/parallel_workflow/diagram.md +++ b/examples/parallel_workflow/diagram.md @@ -2,15 +2,15 @@ flowchart LR in((In)) out((Out)) -parallel_workflow_aggregator[Parallel workflow Aggregator] +aggregator[Aggregator] llm_1[LLM 1] llm_2[LLM 2] llm_3[LLM 3] in --> llm_1 in --> llm_2 in --> llm_3 -llm_1 --> parallel_workflow_aggregator -parallel_workflow_aggregator --> out -llm_2 --> parallel_workflow_aggregator -llm_3 --> parallel_workflow_aggregator +llm_1 --> aggregator +aggregator --> out +llm_2 --> aggregator +llm_3 --> aggregator ``` diff --git a/examples/parallel_workflow/examples.md b/examples/parallel_workflow/examples.md deleted file mode 100644 index 3399381..0000000 --- a/examples/parallel_workflow/examples.md +++ /dev/null @@ -1,26 +0,0 @@ -```mermaid -flowchart LR -In(("In")) --> -LLM_1["LLM 1"] & LLM_2["LLM 2"] & LLM_3["LLM 3"] -LLM_1 --> LLM_4["LLM_4"] -LLM_4 --> AGGREGATOR -LLM_2 --> AGGREGATOR -LLM_3 --> AGGREGATOR -AGGREGATOR --> Out(("Out")) -``` -# - - -```mermaid -flowchart LR -In(("In")) --> LLM_1["LLM 1"] -In(("In")) --> LLM_2["LLM 2"] -In(("In")) --> LLM_3["LLM 3"] - -LLM_1 --> LLM_4["LLM_4"] -LLM_4 --> AGGREGATOR -LLM_2 --> AGGREGATOR -LLM_3 --> AGGREGATOR -AGGREGATOR --> Out(("Out")) -``` -# diff --git a/examples/parallel_workflow/generator.rb b/examples/parallel_workflow/generator.rb index 069b542..106a385 100755 --- a/examples/parallel_workflow/generator.rb +++ b/examples/parallel_workflow/generator.rb @@ -10,10 +10,13 @@ llm3 = Mars::Agent.new(name: "LLM 3") +aggregator = Mars::Aggregator.new("Aggregator", operation: lambda(&:sum)) + # Create the parallel workflow (LLM 1, LLM 2, LLM 3) parallel_workflow = Mars::Workflows::Parallel.new( "Parallel workflow", - steps: [llm1, llm2, llm3] + steps: [llm1, llm2, llm3], + aggregator: aggregator ) # Generate and save the diagram diff --git a/lib/mars.rb b/lib/mars.rb index 831d67c..06cce5d 100644 --- a/lib/mars.rb +++ b/lib/mars.rb @@ -1,6 +1,7 @@ # frozen_string_literal: true require "zeitwerk" +require "async" loader = Zeitwerk::Loader.for_gem loader.setup diff --git a/lib/mars/aggregator.rb b/lib/mars/aggregator.rb index b660b76..a4a7225 100644 --- a/lib/mars/aggregator.rb +++ b/lib/mars/aggregator.rb @@ -2,16 +2,15 @@ module Mars class Aggregator < Runnable - attr_reader :name + attr_reader :name, :operation - def initialize(name = "Aggregator") + def initialize(name = "Aggregator", operation: nil) @name = name + @operation = operation || ->(inputs) { inputs.join("\n") } end def run(inputs) - return yield if block_given? - - inputs.join("\n") + operation.call(inputs) end end end diff --git a/lib/mars/workflows/aggregate_error.rb b/lib/mars/workflows/aggregate_error.rb new file mode 100644 index 0000000..092545e --- /dev/null +++ b/lib/mars/workflows/aggregate_error.rb @@ -0,0 +1,14 @@ +# frozen_string_literal: true + +module Mars + module Workflows + class AggregateError < StandardError + attr_reader :errors + + def initialize(errors) + @errors = errors + super(errors.map { |error| "#{error[:step_name]}: #{error[:error].message}" }.join("\n")) + end + end + end +end diff --git a/lib/mars/workflows/parallel.rb b/lib/mars/workflows/parallel.rb index 6332827..41b3c34 100644 --- a/lib/mars/workflows/parallel.rb +++ b/lib/mars/workflows/parallel.rb @@ -12,11 +12,22 @@ def initialize(name, steps:, aggregator: nil) end def run(input) - inputs = @steps.map do |step| - step.run(input) - end + errors = [] + results = Async do |workflow| + tasks = @steps.map do |step| + workflow.async do + step.run(input) + rescue StandardError => e + errors << { error: e, step_name: step.name } + end + end - aggregator.run(inputs) + tasks.map(&:wait) + end.result + + raise AggregateError, errors if errors.any? + + aggregator.run(results) end private diff --git a/mars.gemspec b/mars.gemspec index 6317bac..3d4bd97 100644 --- a/mars.gemspec +++ b/mars.gemspec @@ -35,6 +35,7 @@ Gem::Specification.new do |spec| spec.require_paths = ["lib"] # Uncomment to register a new dependency of your gem + spec.add_dependency "async", "~> 2.34" spec.add_dependency "ruby_llm", "~> 1.0" spec.add_dependency "zeitwerk", "~> 2.7" diff --git a/spec/mars/workflows/parallel_spec.rb b/spec/mars/workflows/parallel_spec.rb new file mode 100644 index 0000000..37ac7be --- /dev/null +++ b/spec/mars/workflows/parallel_spec.rb @@ -0,0 +1,103 @@ +# frozen_string_literal: true + +RSpec.describe Mars::Workflows::Parallel do + let(:add_step_class) do + Class.new do + def initialize(value) + @value = value + end + + def run(input) + sleep 0.1 + puts "add step: #{input}" + input + @value + end + end + end + + let(:multiply_step_class) do + Class.new do + def initialize(multiplier) + @multiplier = multiplier + end + + def run(input) + puts "multiply step: #{input}" + input * @multiplier + end + end + end + + let(:error_step_class) do + Class.new do + attr_reader :name + + def initialize(message, name) + @message = message + @name = name + end + + def run(_input) + puts "error step: #{@name}" + raise StandardError, @message + end + end + end + + describe "#run" do + it "executes steps in parallel" do + add_five = add_step_class.new(5) + multiply_three = multiply_step_class.new(3) + add_two = add_step_class.new(2) + + workflow = described_class.new("math_workflow", steps: [add_five, multiply_three, add_two]) + + # 10 + 5 = 15, 10 * 3 = 30, 10 + 2 = 12 + expect(workflow.run(10)).to eq("15\n30\n12") + end + + it "executes steps in parallel with a custom aggregator" do + add_five = add_step_class.new(5) + multiply_three = multiply_step_class.new(3) + add_two = add_step_class.new(2) + aggregator = Mars::Aggregator.new("Custom Aggregator", operation: lambda(&:sum)) + workflow = described_class.new("math_workflow", steps: [add_five, multiply_three, add_two], + aggregator: aggregator) + + expect(workflow.run(10)).to eq(57) + end + + it "handles single step" do + multiply_step = multiply_step_class.new(7) + workflow = described_class.new("single_step", steps: [multiply_step]) + + expect(workflow.run(6)).to eq("42") + end + + it "returns input unchanged when no steps" do + workflow = described_class.new("empty", steps: []) + + expect(workflow.run(42)).to eq("") + end + + it "propagates errors from steps" do + add_step = add_step_class.new(5) + error_step = error_step_class.new("Step failed", "error_step_one") + error_step_two = error_step_class.new("Step failed two", "error_step_two") + + workflow = described_class.new("error_workflow", steps: [add_step, error_step, error_step_two]) + + expect { workflow.run(10) }.to raise_error( + Mars::Workflows::AggregateError, + "error_step_one: Step failed\nerror_step_two: Step failed two" + ) + end + end + + describe "inheritance" do + it "inherits from Mars::Runnable" do + workflow = described_class.new("test", steps: []) + expect(workflow).to be_a(Mars::Runnable) + end + end +end