Skip to content

Commit ce2108f

Browse files
committed
initial code commit
1 parent d472336 commit ce2108f

17 files changed

+3709
-0
lines changed

.gitignore

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
*.class
2+
*.log
3+
*~
4+
*.scala~
5+
*.swp
6+
*.iml
7+
.idea/
8+
.settings
9+
.cache
10+
/build/
11+
work/
12+
out/
13+
.DS_Store
14+
conf/java-opts
15+
conf/spark-env.sh
16+
docs/_site
17+
docs/api
18+
target/
19+
reports/
20+
.project
21+
.classpath
22+
lib_managed/
23+
src_managed/
24+
project/boot/
25+
project/plugins/project/build.properties
26+
project/build/target/
27+
project/plugins/target/
28+
project/plugins/lib_managed/
29+
project/plugins/src_managed/
30+
logs/
31+
log/
32+
spark-tests.log
33+
34+
# sbt specific
35+
dist/*
36+
target/
37+
lib_managed/
38+
src_managed/
39+
project/boot/
40+
project/plugins/project/
41+
42+
# Scala-IDE specific
43+
.idea*
44+
.scala_dependencies
45+
46+
# Jars
47+
*.jar

build.sbt

+51
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import AssemblyKeys._
2+
3+
assemblySettings
4+
5+
name := "cocoa"
6+
7+
version := "0.1"
8+
9+
organization := "edu.berkeley.cs.amplab"
10+
11+
scalaVersion := "2.10.4"
12+
13+
parallelExecution in Test := false
14+
15+
libraryDependencies ++= Seq(
16+
"org.slf4j" % "slf4j-api" % "1.7.2",
17+
"org.slf4j" % "slf4j-log4j12" % "1.7.2",
18+
"org.scalatest" %% "scalatest" % "1.9.1" % "test",
19+
"org.apache.spark" % "spark-core_2.10" % "1.1.1",
20+
"org.apache.commons" % "commons-compress" % "1.7",
21+
"commons-io" % "commons-io" % "2.4",
22+
"org.jblas" % "jblas" % "1.2.3"
23+
)
24+
25+
{
26+
val defaultHadoopVersion = "1.0.4"
27+
val hadoopVersion =
28+
scala.util.Properties.envOrElse("SPARK_HADOOP_VERSION",
29+
defaultHadoopVersion)
30+
libraryDependencies += "org.apache.hadoop" % "hadoop-client" % hadoopVersion
31+
}
32+
33+
resolvers ++= Seq(
34+
"Typesafe" at "http://repo.typesafe.com/typesafe/releases",
35+
"Spray" at "http://repo.spray.cc"
36+
)
37+
38+
mergeStrategy in assembly <<= (mergeStrategy in assembly) { (old) =>
39+
{
40+
case PathList("javax", "servlet", xs @ _*) => MergeStrategy.first
41+
case PathList(ps @ _*) if ps.last endsWith ".html" => MergeStrategy.first
42+
case "application.conf" => MergeStrategy.concat
43+
case "reference.conf" => MergeStrategy.concat
44+
case "log4j.properties" => MergeStrategy.discard
45+
case m if m.toLowerCase.endsWith("manifest.mf") => MergeStrategy.discard
46+
case m if m.toLowerCase.matches("meta-inf.*\\.sf$") => MergeStrategy.discard
47+
case _ => MergeStrategy.first
48+
}
49+
}
50+
51+
test in assembly := {}

conf/log4j.properties

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Set everything to be logged to the console
2+
log4j.rootCategory=WARN, console
3+
log4j.appender.console=org.apache.log4j.ConsoleAppender
4+
log4j.appender.console.layout=org.apache.log4j.PatternLayout
5+
log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n
6+
7+
# Ignore messages below warning level from Jetty, because it's a bit verbose
8+
log4j.logger.org.eclipse.jetty=WARN
9+
10+
log4j.logger.org.apache.spark=WARN

data/small_test.dat

+600
Large diffs are not rendered by default.

data/small_train.dat

+2,000
Large diffs are not rendered by default.

local-helper.sh

+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
#!/usr/bin/env bash
2+
3+
SCALA_VERSION=2.10
4+
5+
# Figure out where the Scala framework is installed
6+
FWDIR="$(cd `dirname $0`; pwd)"
7+
8+
if [ -z "$1" ]; then
9+
echo "Usage: local-helper.sh <class> [<args>]" >&2
10+
exit 1
11+
fi
12+
13+
ASSEMBLY_DEPS_JAR="" #cocoa-assembly-0.1-deps.jar
14+
if [ -e "$FWDIR"/target/scala-$SCALA_VERSION/cocoa-assembly-*-deps.jar ]; then
15+
export ASSEMBLY_DEPS_JAR=`ls "$FWDIR"/target/scala-$SCALA_VERSION/cocoa-assembly*-deps.jar`
16+
fi
17+
18+
if [[ -z $ASSEMBLY_DEPS_JAR ]]; then
19+
ASSEMBLY_JAR=""
20+
if [ -e "$FWDIR"/target/scala-$SCALA_VERSION/cocoa-assembly-*.jar ]; then
21+
export ASSEMBLY_JAR=`ls "$FWDIR"/target/scala-$SCALA_VERSION/cocoa-assembly*.jar`
22+
fi
23+
24+
if [[ -z $ASSEMBLY_JAR ]]; then
25+
echo "Failed to find assembly JAR in $FWDIR/target" >&2
26+
echo "You need to run sbt/sbt assembly or sbt/sbt assembly-package-dependency before running this program" >&2
27+
exit 1
28+
fi
29+
CLASSPATH="$FWDIR/conf:$ASSEMBLY_JAR"
30+
else
31+
CLASSPATH="$FWDIR/conf:$ASSEMBLY_DEPS_JAR"
32+
CLASSPATH="$CLASSPATH:$FWDIR/target/scala-$SCALA_VERSION/classes"
33+
fi
34+
35+
# Find java binary
36+
if [ -n "${JAVA_HOME}" ]; then
37+
RUNNER="${JAVA_HOME}/bin/java"
38+
else
39+
if [ `command -v java` ]; then
40+
RUNNER="java"
41+
else
42+
echo "JAVA_HOME is not set" >&2
43+
exit 1
44+
fi
45+
fi
46+
47+
# Set SPARK_MEM if it isn't already set since we also use it for this process
48+
SPARK_MEM=${SPARK_MEM:-512m}
49+
export SPARK_MEM
50+
51+
JAVA_OPTS="$JAVA_OPTS -Xms$SPARK_MEM -Xmx$SPARK_MEM ""$SPARK_JAVA_OPTS"
52+
53+
exec "$RUNNER" -Djava.library.path=$FWDIR/lib -cp "$CLASSPATH" $JAVA_OPTS "$@"

project/build.properties

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
2+
sbt.version=0.13.5

project/plugins.sbt

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
resolvers += "Sonatype snapshots" at "http://oss.sonatype.org/content/repositories/snapshots/"
2+
3+
addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.11.2")
4+
5+
addSbtPlugin("com.github.mpeltonen" % "sbt-idea" % "1.6.0")

run-demo-cluster.sh

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
#!/bin/bash
2+
3+
/root/spark/bin/spark-submit \
4+
--master `cat /root/spark-ec2/cluster-url` \
5+
--class "distopt.driver" \
6+
--driver-memory 80423M \
7+
--driver-java-options "-Dspark.local.dir=/mnt/spark,/mnt2/spark -XX:+UseG1GC" \
8+
target/scala-2.10/cocoa-assembly-0.1.jar \
9+
"$@"
10+

run-demo-local.sh

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
./local-helper.sh distopt.driver \
2+
--trainFile=data/small_train.dat \
3+
--testFile=data/small_test.dat \
4+
--numFeatures=9947 \
5+
--numRounds=100 \
6+
--localIterFrac=0.1 \
7+
--numSplits=4 \
8+
--lambda=.001 \
9+
--justCoCoA=false

sbt/sbt

+54
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
#!/bin/bash
2+
3+
#
4+
# Licensed to the Apache Software Foundation (ASF) under one or more
5+
# contributor license agreements. See the NOTICE file distributed with
6+
# this work for additional information regarding copyright ownership.
7+
# The ASF licenses this file to You under the Apache License, Version 2.0
8+
# (the "License"); you may not use this file except in compliance with
9+
# the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing, software
14+
# distributed under the License is distributed on an "AS IS" BASIS,
15+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
# See the License for the specific language governing permissions and
17+
# limitations under the License.
18+
#
19+
20+
# This script launches sbt for this project. If present it uses the system
21+
# version of sbt. If there is no system version of sbt it attempts to download
22+
# sbt locally.
23+
SBT_VERSION=`awk -F "=" '/sbt\\.version/ {print $2}' ./project/build.properties`
24+
URL1=http://typesafe.artifactoryonline.com/typesafe/ivy-releases/org.scala-sbt/sbt-launch/${SBT_VERSION}/sbt-launch.jar
25+
URL2=http://repo.typesafe.com/typesafe/ivy-releases/org.scala-sbt/sbt-launch/${SBT_VERSION}/sbt-launch.jar
26+
JAR=sbt/sbt-launch-${SBT_VERSION}.jar
27+
28+
# Download sbt launch jar if it hasn't been downloaded yet
29+
if [ ! -f ${JAR} ]; then
30+
# Download
31+
printf "Attempting to fetch sbt\n"
32+
if hash curl 2>/dev/null; then
33+
curl --progress-bar ${URL1} > ${JAR} || curl --progress-bar ${URL2} > ${JAR}
34+
elif hash wget 2>/dev/null; then
35+
wget --progress=bar ${URL1} -O ${JAR} || wget --progress=bar ${URL2} -O ${JAR}
36+
else
37+
printf "You do not have curl or wget installed, please install sbt manually from http://www.scala-sbt.org/\n"
38+
exit -1
39+
fi
40+
fi
41+
if [ ! -f ${JAR} ]; then
42+
# We failed to download
43+
printf "Our attempt to download sbt locally to ${JAR} failed. Please install sbt manually from http://www.scala-sbt.org/\n"
44+
exit -1
45+
fi
46+
printf "Launching sbt from ${JAR}\n"
47+
48+
FWDIR="$(cd `dirname $0`/..; pwd)"
49+
50+
java \
51+
-Djava.library.path="$FWDIR/lib" \
52+
-Xmx2400m -XX:MaxPermSize=350m -XX:ReservedCodeCacheSize=256m \
53+
-jar ${JAR} \
54+
"$@"

src/main/scala/driver.scala

+97
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
package distopt
2+
3+
import org.apache.spark.{SparkContext, SparkConf}
4+
import distopt.utils._
5+
import scala.collection.immutable.SortedMap
6+
import distopt.solvers._
7+
8+
object driver {
9+
10+
def main(args: Array[String]) {
11+
12+
val options = args.map { arg =>
13+
arg.dropWhile(_ == '-').split('=') match {
14+
case Array(opt, v) => (opt -> v)
15+
case Array(opt) => (opt -> "true")
16+
case _ => throw new IllegalArgumentException("Invalid argument: " + arg)
17+
}
18+
}.toMap
19+
20+
// read in inputs
21+
val master = options.getOrElse("master", "local[4]")
22+
val trainFile = options.getOrElse("trainFile", "")
23+
val numFeatures = options.getOrElse("numFeatures", "0").toInt
24+
val numSplits = options.getOrElse("numSplits","1").toInt
25+
val chkptDir = options.getOrElse("chkptDir","");
26+
var chkptIter = options.getOrElse("chkptIter","100").toInt
27+
val testFile = options.getOrElse("testFile", "")
28+
val justCoCoA = options.getOrElse("justCoCoA", "true").toBoolean // set to false to compare different methods
29+
30+
// algorithm-specific inputs
31+
val lambda = options.getOrElse("lambda", "0.01").toDouble // regularization parameter
32+
val numRounds = options.getOrElse("numRounds", "200").toInt // number of outer iterations, called T in the paper
33+
val localIterFrac = options.getOrElse("localIterFrac","1.0").toDouble; // fraction of local points to be processed per round, H = localIterFrac * n
34+
val beta = options.getOrElse("beta","1.0").toDouble; // scaling parameter when combining the updates of the workers (1=averaging)
35+
val debugIter = options.getOrElse("debugIter","10").toInt // set to -1 to turn off debugging output
36+
37+
// print out inputs
38+
println("master: " + master); println("trainFile: " + trainFile);
39+
println("numFeatures: " + numFeatures); println("numSplits: " + numSplits);
40+
println("chkptDir: " + chkptDir); println("chkptIter " + chkptIter);
41+
println("testfile: " + testFile); println("justCoCoA " + justCoCoA);
42+
println("lambda: " + lambda); println("numRounds: " + numRounds);
43+
println("localIterFrac:" + localIterFrac); println("beta " + beta);
44+
println("debugIter " + debugIter);
45+
46+
// start spark context
47+
val conf = new SparkConf().setMaster(master)
48+
.setAppName("demoCoCoA")
49+
.setJars(SparkContext.jarOfObject(this).toSeq)
50+
val sc = new SparkContext(conf)
51+
if (chkptDir != "") {
52+
sc.setCheckpointDir(chkptDir)
53+
} else {
54+
chkptIter = numRounds + 1
55+
}
56+
57+
// read in data
58+
val data = OptUtils.loadLIBSVMData(sc,trainFile,numSplits,numFeatures).cache()
59+
val n = data.count().toInt // number of data examples
60+
val testData = {
61+
if (testFile != ""){ OptUtils.loadLIBSVMData(sc,testFile,numSplits,numFeatures).cache() }
62+
else { null }
63+
}
64+
65+
// compute H, # of local iterations
66+
var localIters = (localIterFrac * n / data.partitions.size).toInt
67+
localIters = Math.max(localIters,1)
68+
69+
// for the primal-dual algorithms to run correctly, the initial primal vector has to be zero
70+
// (corresponding to dual alphas being zero)
71+
val wInit = Array.fill(numFeatures)(0.0)
72+
73+
74+
// run CoCoA
75+
val (finalwCoCoA, finalalphaCoCoA) = CoCoA.runCoCoA(sc, data, n, wInit, numRounds, localIters, lambda, beta, chkptIter, testData, debugIter)
76+
OptUtils.printSummaryStatsPrimalDual("CoCoA", data, finalwCoCoA, finalalphaCoCoA, lambda, testData)
77+
78+
// optionally run other methods for comparison
79+
if(!justCoCoA) {
80+
81+
// run Mini-batch CD
82+
val (finalwMbCD, finalalphaMbCD) = MinibatchCD.runMbCD(sc, data, n, wInit, numRounds, localIters, lambda, beta, chkptIter, testData, debugIter)
83+
OptUtils.printSummaryStatsPrimalDual("Mini-batch CD", data, finalwMbCD, finalalphaMbCD, lambda, testData)
84+
85+
// run Mini-batch SGD
86+
val finalwMbSGD = SGD.runSGD(sc, data, n, wInit, numRounds, localIters, lambda, local=false, beta, chkptIter, testData, debugIter)
87+
OptUtils.printSummaryStats("Mini-batch SGD", data, finalwMbSGD, lambda, testData)
88+
89+
// run Local SGD
90+
val finalwLocalSGD = SGD.runSGD(sc, data, n, wInit, numRounds, localIters, lambda, local=true, beta, chkptIter, testData, debugIter)
91+
OptUtils.printSummaryStats("Local SGD", data, finalwLocalSGD, lambda, testData)
92+
93+
}
94+
95+
sc.stop()
96+
}
97+
}

0 commit comments

Comments
 (0)