diff --git a/src/main/scala/Bits.scala b/src/main/scala/Bits.scala index aa31bb87..e14ca8e4 100644 --- a/src/main/scala/Bits.scala +++ b/src/main/scala/Bits.scala @@ -71,6 +71,7 @@ abstract class Bits extends Data with proc { this assign node ; setIsTypeNode ; this } + def getLit( x : BigInt ) : Bits = { Lit(x, getWidth()) { UInt() } } def fromInt(x: Int): this.type def toSInt(): SInt = chiselCast(this){SInt()} def toUInt(): UInt = chiselCast(this){UInt()} diff --git a/src/main/scala/Bool.scala b/src/main/scala/Bool.scala index 65af4756..8133b310 100644 --- a/src/main/scala/Bool.scala +++ b/src/main/scala/Bool.scala @@ -60,6 +60,8 @@ class Bool extends UInt { Bool(x > 0).asInstanceOf[this.type] } + override def getLit( x : BigInt ) : Bool = { Lit(x, getWidth()) { Bool() } } + /** Implementation of := operator, assigns value to this Bool */ override protected def colonEquals(src: Bits): Unit = src match { case _: Bool => super.colonEquals(src(0)) diff --git a/src/main/scala/FP.scala b/src/main/scala/FP.scala index 4b4ec4d4..9a188a8d 100644 --- a/src/main/scala/FP.scala +++ b/src/main/scala/FP.scala @@ -73,6 +73,8 @@ class Flo extends Bits with Num[Flo] { case _ => illegalAssignment(that) } + override def getLit( x : BigInt ) : Flo = { Lit(x, getWidth()) { Flo() } } + /** Get Flo as an instance of T */ def gen[T <: Bits](): T = Flo().asInstanceOf[T]; @@ -150,6 +152,8 @@ class Dbl extends Bits with Num[Dbl] { case _ => illegalAssignment(that) } + override def getLit( x : BigInt ) : Dbl = { Lit(x, getWidth()) { Dbl() } } + def gen[T <: Bits](): T = Dbl().asInstanceOf[T]; def unary_-(): Dbl = newUnaryOp("d-") diff --git a/src/main/scala/Fixed.scala b/src/main/scala/Fixed.scala index da1d57a0..92b194fe 100644 --- a/src/main/scala/Fixed.scala +++ b/src/main/scala/Fixed.scala @@ -131,6 +131,8 @@ class Fixed(var fractionalWidth : Int = 0) extends Bits with Num[Fixed] { case _ => illegalAssignment(that) } + override def getLit( x : BigInt ) : Fixed = { Fixed(x, getWidth(), getFractionalWidth()) } + def getFractionalWidth() : Int = this.fractionalWidth private def truncate(f : Fixed, truncateAmount : Int) : Fixed = fromSInt(f.toSInt >> UInt(truncateAmount)) diff --git a/src/main/scala/SInt.scala b/src/main/scala/SInt.scala index c4cff659..9211f45f 100644 --- a/src/main/scala/SInt.scala +++ b/src/main/scala/SInt.scala @@ -99,6 +99,8 @@ class SInt extends Bits with Num[SInt] { case _ => super.colonEquals(that) } + override def getLit( x : BigInt ) : SInt = { Lit(x, getWidth()) { SInt() } } + def gen[T <: Bits](): T = SInt().asInstanceOf[T] // arithmetic operators diff --git a/src/main/scala/UInt.scala b/src/main/scala/UInt.scala index fc199a9a..a158e3e4 100644 --- a/src/main/scala/UInt.scala +++ b/src/main/scala/UInt.scala @@ -131,6 +131,8 @@ class UInt extends Bits with Num[UInt] { override def toBits: UInt = this + override def getLit( x : BigInt ) : UInt = { Lit(x, getWidth()) { UInt() } } + // to support implicit conversions def ===(b: UInt): Bool = LogicalOp(this, b, "===") diff --git a/src/test/scala/GenTypeTest.scala b/src/test/scala/GenTypeTest.scala new file mode 100644 index 00000000..32daaf79 --- /dev/null +++ b/src/test/scala/GenTypeTest.scala @@ -0,0 +1,141 @@ + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.ListBuffer +import org.junit.Assert._ +import org.junit.Ignore +import org.junit.Test + +import Chisel._ + +class GenTypeTest extends TestSuite { + class SubMod[T <: Bits]( genType : T ) extends Module { + val io = new Bundle { + val in = genType.cloneType.asInput + val out = genType.cloneType.asOutput + } + val ZERO = genType.getLit(BigInt(0)) + val reg = Reg(init = ZERO) + reg := io.in + io.out := reg + } + + @Test def testUInt() { + class UserModUInt extends Module { + val io = new Bundle { + val in = UInt(INPUT, 4) + val out = UInt(OUTPUT, 4) + } + val sub = Module(new SubMod[UInt](UInt(0, 4))) + sub.io <> io + } + + class UIntTest(c : UserModUInt) extends Tester(c) { + poke(c.io.in, BigInt(3)) + expect(c.io.out, BigInt(0)) + step(1) + expect(c.io.out, BigInt(3)) + } + + launchCppTester( (c : UserModUInt) => { new UIntTest(c) }) + } + + @Test def testBool() { + class UserModBool extends Module { + val io = new Bundle { + val in = Bool(INPUT) + val out = Bool(OUTPUT) + } + val sub = Module(new SubMod[Bool](Bool(false))) + sub.io <> io + } + + class BoolTest(c : UserModBool) extends Tester(c) { + poke(c.io.in, BigInt(1)) + expect(c.io.out, BigInt(0)) + step(1) + expect(c.io.out, BigInt(1)) + } + + launchCppTester( (c : UserModBool) => { new BoolTest(c) }) + } + + @Test def testSInt() { + class UserModSInt extends Module { + val io = new Bundle { + val in = SInt(INPUT, 4) + val out = SInt(OUTPUT, 4) + } + val sub = Module(new SubMod[SInt](SInt(0, 4))) + sub.io <> io + } + + class SIntTest(c : UserModSInt) extends Tester(c) { + poke(c.io.in, BigInt(-1)) + expect(c.io.out, BigInt(0)) + step(1) + expect(c.io.out, BigInt(-1)) + } + + launchCppTester( (c : UserModSInt) => { new SIntTest(c) }) + } + + @Test def testFlo() { + class UserModFlo extends Module { + val io = new Bundle { + val in = Flo(INPUT) + val out = Flo(OUTPUT) + } + val sub = Module(new SubMod[Flo](Flo(0))) + sub.io <> io + } + + class FloTest(c : UserModFlo) extends Tester(c) { + poke(c.io.in, (1.5).toFloat) + expect(c.io.out, (0).toFloat) + step(1) + expect(c.io.out, (1.5).toFloat) + } + + launchCppTester( (c : UserModFlo) => { new FloTest(c) }) + } + + @Test def testDbl() { + class UserModDbl extends Module { + val io = new Bundle { + val in = Dbl(INPUT) + val out = Dbl(OUTPUT) + } + val sub = Module(new SubMod[Dbl](Dbl(0))) + sub.io <> io + } + + class DblTest(c : UserModDbl) extends Tester(c) { + poke(c.io.in, 1.5) + expect(c.io.out, 0.0) + step(1) + expect(c.io.out, 1.5) + } + + launchCppTester( (c : UserModDbl) => { new DblTest(c) }) + } + + @Test def testFixed() { + class UserModFixed extends Module { + val io = new Bundle { + val in = Fixed(INPUT, 18, 10) + val out = Fixed(OUTPUT, 18, 10) + } + val sub = Module(new SubMod[Fixed](Fixed(0, 18, 10))) + sub.io <> io + } + + class FixedTest(c : UserModFixed) extends Tester(c) { + poke(c.io.in, BigInt(10)) + expect(c.io.out, BigInt(0)) + step(1) + expect(c.io.out, BigInt(10)) + } + + launchCppTester( (c : UserModFixed) => { new FixedTest(c) }) + } +}