From 80eab2135793027be7c2ac4fbe652cf45a1de5f7 Mon Sep 17 00:00:00 2001 From: Abhijit Sarkar Date: Sun, 17 Dec 2023 05:48:14 -0800 Subject: [PATCH] Complete chapter 3 --- README.md | 2 +- build.sc | 11 +- chapter02/test/src/LibSpec.scala | 29 ++-- chapter03/src/Lib.scala | 12 ++ chapter03/src/List.scala | 214 ++++++++++++++++++++++++++++++ chapter03/src/Tree.scala | 36 +++++ chapter03/test/src/LibSpec.scala | 8 ++ chapter03/test/src/ListSpec.scala | 104 +++++++++++++++ chapter03/test/src/TreeSpec.scala | 20 +++ 9 files changed, 425 insertions(+), 11 deletions(-) create mode 100644 chapter03/src/Lib.scala create mode 100644 chapter03/src/List.scala create mode 100644 chapter03/src/Tree.scala create mode 100644 chapter03/test/src/LibSpec.scala create mode 100644 chapter03/test/src/ListSpec.scala create mode 100644 chapter03/test/src/TreeSpec.scala diff --git a/README.md b/README.md index 9221589..6a2f9ef 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,7 @@ Install a BSP connection file: mill mill.bsp.BSP/install ``` -Then open VSCode command palette, and select "Metals: Switch build server". +Then open VSCode command palette, and select `Metals: Switch build server`. ## References diff --git a/build.sc b/build.sc index f60e675..8f6f879 100644 --- a/build.sc +++ b/build.sc @@ -28,7 +28,16 @@ object chapter02 extends AdvancedScalaModule { override def ivyDeps = Agg( ivy"org.scalactic::scalactic:$scalatestVersion", ivy"org.scalatest::scalatest:$scalatestVersion", - ivy"org.scalatestplus::scalacheck-1-17:$scalacheckVersion", + ) + } +} + +object chapter03 extends AdvancedScalaModule { + object test extends ScalaTests with TestModule.ScalaTest { + // // use `::` for scala deps, `:` for java deps + override def ivyDeps = Agg( + ivy"org.scalactic::scalactic:$scalatestVersion", + ivy"org.scalatest::scalatest:$scalatestVersion", ) } } diff --git a/chapter02/test/src/LibSpec.scala b/chapter02/test/src/LibSpec.scala index c2a5039..9c9097d 100644 --- a/chapter02/test/src/LibSpec.scala +++ b/chapter02/test/src/LibSpec.scala @@ -1,28 +1,39 @@ import org.scalatest.funspec.AnyFunSpec import Lib.* +import org.scalatest.prop.TableDrivenPropertyChecks +import org.scalatest.prop.TableDrivenPropertyChecks.Table +import org.scalatest.matchers.should.Matchers.shouldBe -class LibSpec extends AnyFunSpec: +class LibSpec extends AnyFunSpec with TableDrivenPropertyChecks: describe("Chapter 2"): it("fib should return the nth Fibonacci number"): val fst20 = List(0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610, 987, 1597, 2584, 4181) for ((expected, n) <- fst20.zipWithIndex) - assert(fib(n) == expected, s"for n=$n") + fib(n) shouldBe expected + + val isSortedInput = + Table( + ("as", "gt", "expected"), + (Array(1, 2, 3), (_: Int) > (_: Int), true), + (Array(1, 2, 1), (_: Int) > (_: Int), false), + (Array(3, 2, 1), (_: Int) < (_: Int), true), + (Array(1, 2, 3), (_: Int) < (_: Int), false) + ) it("isSorted should check if an array is sorted"): - assert(isSorted(Array(1, 2, 3), _ > _)) - assert(!isSorted(Array(1, 2, 1), _ > _)) - assert(isSorted(Array(3, 2, 1), _ < _)) - assert(!isSorted(Array(1, 2, 3), _ < _)) + forAll(isSortedInput) { (as: Array[Int], gt: (Int, Int) => Boolean, expected: Boolean) => + isSorted(as, gt) shouldBe expected + } it("curry should convert a two-argument function into an one-argument function that returns a function"): val add = (x: Int, y: Int) => x + y - assert(curry(add)(1)(2) == 3) + curry(add)(1)(2) shouldBe 3 it("uncurry should reverse curry"): val curriedAdd = (a: Int) => (b: Int) => a + b - assert(uncurry(curriedAdd)(1, 2) == 3) + uncurry(curriedAdd)(1, 2) shouldBe 3 it("compose should compose two functions"): val plus2 = (x: Int) => x + 2 val times2 = (x: Int) => x * 2 - assert(compose(times2, plus2)(2) == 8) + compose(times2, plus2)(2) shouldBe 8 diff --git a/chapter03/src/Lib.scala b/chapter03/src/Lib.scala new file mode 100644 index 0000000..5d3c603 --- /dev/null +++ b/chapter03/src/Lib.scala @@ -0,0 +1,12 @@ +import List.* + +object Lib: + /* + Exercise 3.1: What will be the result of the following match expression? + */ + val result = List(1, 2, 3, 4, 5) match + case Cons(x, Cons(2, Cons(4, _))) => x + case Nil => 42 + case Cons(x, Cons(y, Cons(3, Cons(4, _)))) => x + y + case Cons(h, t) => h + sum(t) + case null => 101 diff --git a/chapter03/src/List.scala b/chapter03/src/List.scala new file mode 100644 index 0000000..dc9a42c --- /dev/null +++ b/chapter03/src/List.scala @@ -0,0 +1,214 @@ +import scala.annotation.tailrec +enum List[+A]: + case Nil + case Cons(head: A, tail: List[A]) + +object List: + def apply[A](as: A*): List[A] = + if as.isEmpty then Nil + else Cons(as.head, apply(as.tail*)) + + def sum(ints: List[Int]): Int = ints match + case Nil => 0 + case Cons(x, xs) => x + sum(xs) + + def head[A](xs: List[A]): A = xs match + case Nil => sys.error("empty list") + case Cons(x, _) => x + + /* + Exercise 3.2: Implement the function tail for removing the first element of a List. + */ + def tail[A](xs: List[A]): List[A] = xs match + case Nil => sys.error("empty list") + case Cons(_, ys) => ys + + /* + Exercise 3.3: Implement the function setHead for replacing the first element of a list + with a different value. + */ + def setHead[A](xs: List[A], a: A): List[A] = xs match + case Nil => Cons(a, Nil) + case Cons(_, ys) => Cons(a, ys) + + /* + Exercise 3.4: Implement the function drop, which removes the first n elements from a list. + Dropping n element from an empty list should return the empty list. + */ + def drop[A](xs: List[A], n: Int): List[A] = + if n <= 0 then xs + else + xs match + case Cons(_, ys) => drop(ys, n - 1) + case Nil => Nil + + /* + Exercise 3.5: Implement dropWhile, which removes elements from the List prefix as + long as they match a predicate. + */ + def dropWhile[A](as: List[A], f: A => Boolean): List[A] = + as match + case Cons(hd, tl) if f(hd) => dropWhile(tl, f) + case _ => as + + /* + Exercise 3.6: Implement a function, init, that returns a list containing of all + but the last element of a list. + */ + def init[A](xs: List[A]): List[A] = xs match + case Nil => sys.error("empty list") + case Cons(_, Nil) => Nil + case Cons(x, xs) => Cons(x, init(xs)) + +// def foldRight[A, B](as: List[A], acc: B, f: (A, B) => B): B = as match +// case Nil => acc +// case Cons(x, xs) => f(x, foldRight(xs, acc, f)) + + /* + Exercise 3.7: Can product, implemented using foldRight, immediately halt the + recursion and return 0.0 if it encounters a 0.0? Why or why not? Consider how + any short circuiting might work if you call foldRight with a large list. + --- + No, foldRight traverses all the way to the end of the list before invoking the function. + */ + + /* + Exercise 3.8: See what happens when you pass Nil and Cons themselves to foldRight, + like this: foldRight(List(1, 2, 3), Nil: List[Int], Cons(_, _)). + What do you think this says about the relationship between foldRight and the data + constructors of List? + -- + Nothing happens, the original list is returned. + */ + + /* + Exercise 3.9: Compute the length of a list using foldRight. + */ + def length[A](xs: List[A]): Int = + foldRight(xs, 0, (_, acc) => acc + 1) + + /* + Exercise 3.10: foldRight is not stack safe. Write another general list-recursion function, + foldLeft, that is tail recursive. Start collapsing from the leftmost start of the list. + */ + @tailrec + def foldLeft[A, B](as: List[A], acc: B, f: (B, A) => B): B = as match + case Nil => acc + case Cons(x, xs) => foldLeft(xs, f(acc, x), f) + + /* + Exercise 3.11: Write sum, product and a function to compute the length + of a list using foldLeft. + */ + def sumViaFoldLeft(xs: List[Int]): Int = + foldLeft(xs, 0, _ + _) + + def productViaFoldLeft(xs: List[Int]): Int = + foldLeft(xs, 1, _ * _) + + def lengthViaFoldLeft(xs: List[Int]): Int = + foldLeft(xs, 0, (acc, _) => acc + 1) + + /* + Exercise 3.12: Write a function that returns the reverse of a list. + */ + def reverse[A](xs: List[A]): List[A] = + foldLeft(xs, Nil: List[A], (acc, x) => Cons(x, acc)) + + /* + Exercise 3.13: Can you write foldRight in terms of foldLeft? How about the other way around? + --- + At each iteration, we create a function that remembers the current list + element and awaits a value of type B before producing a result. + Upon receiving such a value, we apply f to produce a value of type B, + which is then fed to the function from the previous step. + + The output from foldLeft is a function which is fed the zero value, + at which point the function chain is evaluated in reverse. + + Note that this implementation is not stack safe due to the creation + of an anonymous function at each step. + For a stack safe implementation, we can reverse the list and then + apply foldLeft. + */ + def foldRight[A, B](as: List[A], z: B, f: (A, B) => B): B = foldLeft( + as, + (b: B) => b, + (acc, a) => (b: B) => acc(f(a, b)) + )(z) + + /* + Exercise 3.14: Implement append in terms of either foldLeft or foldRight. + */ + def append[A](a1: List[A], a2: List[A]): List[A] = + foldRight(a1, a2, Cons(_, _)) + + /* + Exercise 3.15: Write a function that concatenates a list of lists into a + single list. Its runtime should be linear in the total length of all lists. + */ + def concat[A](xxs: List[List[A]]): List[A] = + foldRight(xxs, List[A](), append) + + /* + Exercise 3.16: Write a function that transforms a list of integers by adding + 1 to each element. + */ + def add1(xs: List[Int]): List[Int] = + foldRight(xs, List[Int](), (x, acc) => Cons(x + 1, acc)) + + /* + Exercise 3.17: Write a function that turns each value in a List[Double] into + a String. + */ + def doubleToString(xs: List[Double]): List[String] = + foldRight(xs, List[String](), (x, acc) => Cons(f"$x%2.2f", acc)) + + /* + Exercise 3.18: Write a function, map, that generalizes modifying each element + in a list while maintaining the structure of the list. + */ + def map[A, B](as: List[A], f: A => B): List[B] = + foldRight(as, List[B](), (x, acc) => Cons(f(x), acc)) + + /* + Exercise 3.19: Write a function, filter, that removes elements from a list + unless they satisfy a given predicate. + */ +// def filter[A](as: List[A], f: A => Boolean): List[A] = +// foldRight(as, List[A](), (x, acc) => if f(x) then Cons(x, acc) else acc) + + /* + Exercise 3.20: Write a function, flatMap, that works like map except that + the function given will return a list instead of a single result, ensuring + that the list is inserted into the final resulting list. + */ + def flatMap[A, B](as: List[A], f: A => List[B]): List[B] = + foldRight(as, List[B](), (x, acc) => append(f(x), acc)) + + /* + Exercise 3.21: Use flatMap to implement filter. + */ + def filter[A](as: List[A], f: A => Boolean): List[A] = + flatMap(as, x => if f(x) then List(x) else List()) + + /* + Exercise 3.22: Write a function that accepts two lists and constructs a new + list by adding corresponding elements. + + Exercise 3.23: Generalize the function you just wrote so it's not specific + to integers or addition. + */ + // Not stack safe! + def zipWith[A, B, C](xs: List[A], ys: List[B], f: (A, B) => C): List[C] = (xs, ys) match + case (Nil, _) => Nil + case (_, Nil) => Nil + case (Cons(x, xxs), Cons(y, yys)) => Cons(f(x, y), zipWith(xxs, yys, f)) + + /* + Exercise 3.24: Implement hasSubsequence to check whether a List contains + another List as a sunsequence. + */ + // 1. Does subsequence mean potentially not consecutive elements? No. + // 2. Check at every position. + def hasSubsequence[A](sup: List[A], sub: List[A]): Boolean = ??? diff --git a/chapter03/src/Tree.scala b/chapter03/src/Tree.scala new file mode 100644 index 0000000..ed6cf8c --- /dev/null +++ b/chapter03/src/Tree.scala @@ -0,0 +1,36 @@ +enum Tree[+A]: + case Leaf(value: A) + case Branch(left: Tree[A], right: Tree[A]) + + def size: Int = + fold(_ => 1, 1 + _ + _) + /* + Exercise 3.26: Write a function, depth, that returns the maximum + path length from the root to any leaf. + */ + def depth: Int = + fold(_ => 0, (d1, d2) => 1 + (d1 max d2)) + + /* + Exercise 3.27: Write a function, map, analogous to the method of + the same name on List that modifies each element in a tree with + a given function. + */ + def map[B](f: A => B): Tree[B] = + fold(a => Leaf(f(a)), Branch(_, _)) + + /* + Exercise 3.28: Generalize size, maximum, depth, and map, writing + a new function, fold, that abstracts over their similarities. + */ + def fold[B](f: A => B, g: (B, B) => B): B = this match + case Leaf(a) => f(a) + case Branch(l, r) => g(l.fold(f, g), r.fold(f, g)) + +object Tree: + /* + Exercise 3.25: Write a function, maximum, that returns the + maximum element in a Tree[Int]. + */ + def maximum(t: Tree[Int]): Int = + t.fold(x => x, (x, y) => x.max(y)) diff --git a/chapter03/test/src/LibSpec.scala b/chapter03/test/src/LibSpec.scala new file mode 100644 index 0000000..d286162 --- /dev/null +++ b/chapter03/test/src/LibSpec.scala @@ -0,0 +1,8 @@ +import org.scalatest.funspec.AnyFunSpec +import Lib.* +import org.scalatest.matchers.should.Matchers.shouldBe + +class LibSpec extends AnyFunSpec: + describe("Chapter 3"): + it("list pattern match should add first two values"): + result shouldBe 3 diff --git a/chapter03/test/src/ListSpec.scala b/chapter03/test/src/ListSpec.scala new file mode 100644 index 0000000..544687a --- /dev/null +++ b/chapter03/test/src/ListSpec.scala @@ -0,0 +1,104 @@ +import org.scalatest.funspec.AnyFunSpec +import List.* +import org.scalatest.prop.TableDrivenPropertyChecks +import org.scalatest.prop.TableDrivenPropertyChecks.Table +import org.scalatest.matchers.should.Matchers.{shouldBe, should, thrownBy, a, be} +import scala.math.Pi + +class ListSpec extends AnyFunSpec with TableDrivenPropertyChecks: + describe("List"): + it("tail should remove the first element of a non-empty list"): + val xs = List(1, 2) + tail(xs) shouldBe List(2) + + it("tail should throw on an empty list"): + a[RuntimeException] should be thrownBy tail(List()) + + it("setHead should replace the first element"): + setHead(List(1, 2), 3) shouldBe List(3, 2) + + val dropInput = + Table( + ("xs", "n", "expected"), + (List(1, 2, 3), 2, List(3)), + (List(1, 2), 2, List()), + (List(), 1, List()), + (List(1), 0, List(1)) + ) + + it("drop should remove the first n elements"): + forAll(dropInput) { (xs: List[Int], n: Int, expected: List[Int]) => + drop(xs, n) shouldBe expected + } + + val dropWhileInput = + Table( + ("xs", "f", "expected"), + (List(1, 2, 3), (_: Int) < 3, List(3)), + (List(1, 2, 3), (_: Int) > 1, List(1, 2, 3)), + (List(1, 2, 3), (_: Int) => true, List()) + ) + + it("dropWhile should remove elements as long as they match the predicate"): + forAll(dropWhileInput) { (xs: List[Int], f: Int => Boolean, expected: List[Int]) => + dropWhile(xs, f) shouldBe expected + } + + it("init should remove the last element"): + init(List(1, 2, 3, 4)) shouldBe List(1, 2, 3) + init(List(1)) shouldBe List() + + it("init should throw on an empty list"): + a[RuntimeException] should be thrownBy init(List()) + + it("length should return the length of the list"): + length(List()) shouldBe 0 + length(List(1, 2, 3)) shouldBe 3 + + it("sumViaFoldLeft should compute the sum using foldLeft"): + sumViaFoldLeft(List()) shouldBe 0 + sumViaFoldLeft(List(1, 2, 3)) shouldBe 6 + + it("productViaFoldLeft should compute the product using foldLeft"): + productViaFoldLeft(List()) shouldBe 1 + productViaFoldLeft(List(1, 2, 3)) shouldBe 6 + + it("lengthViaFoldLeft should compute the length using foldLeft"): + lengthViaFoldLeft(List()) shouldBe 0 + lengthViaFoldLeft(List(1, 2, 3)) shouldBe 3 + + it("reverse should reverse the list"): + reverse(List()) shouldBe List() + reverse(List(1, 2, 3)) shouldBe List(3, 2, 1) + + it("foldRight can be implemented using foldLeft"): + foldRight(List(8, 12, 24, 4), 2.0, _ / _) shouldBe 8.0 + + it("append should concatenate two lists"): + append(List(1, 2, 3), List(4, 5, 6)) shouldBe List(1, 2, 3, 4, 5, 6) + + it("concat should flatten a list of lists"): + concat(List(List(1, 2, 3), List(4, 5), List(6), List())) shouldBe List(1, 2, 3, 4, 5, 6) + + it("add1 should increment every integer element by 1"): + add1(List(1, 2, 3)) shouldBe List(2, 3, 4) + + it("doubleToString should convert every double element to a string"): + doubleToString(List(1, 2, Pi)) shouldBe List("1.00", "2.00", "3.14") + + it("map should transform each element using the given function"): + map(List(1, 2, 3), _ + 1) shouldBe List(2, 3, 4) + map(List(1, 2, Pi), x => f"$x%2.2f") shouldBe List("1.00", "2.00", "3.14") + + it("filter should remove the elements that don't satisfy the given predicate"): + filter(List(1, 2, 3), x => x % 2 == 0) shouldBe List(2) + + it("flatMap should work as expected"): + flatMap(List(1, 2, 3), i => List(i, i)) shouldBe List(1, 1, 2, 2, 3, 3) + + it("zipWith should work as expected"): + zipWith( + List(1, 2, 3), + List(true, false, true), + (x, y) => if y then x.toString() else y.toString() + ) shouldBe List("1", "false", "3") diff --git a/chapter03/test/src/TreeSpec.scala b/chapter03/test/src/TreeSpec.scala new file mode 100644 index 0000000..f4a2551 --- /dev/null +++ b/chapter03/test/src/TreeSpec.scala @@ -0,0 +1,20 @@ +import Tree.* +import org.scalatest.funspec.AnyFunSpec +import org.scalatest.matchers.should.Matchers.shouldBe + +class TreeSpec extends AnyFunSpec: + describe("Tree"): + val t = Branch(Branch(Leaf(0), Leaf(1)), Branch(Leaf(2), Leaf(3))) + + it("size should return the number of nodes in the tree"): + t.size shouldBe 7 + + it("depth should return the maximum path length from the root to any leaf"): + t.depth shouldBe 2 + + it("map should transform each node but retain the tree structure"): + t.map('a' + _) shouldBe + Branch(Branch(Leaf('a'), Leaf('b')), Branch(Leaf('c'), Leaf('d'))) + + it("maximum should transform each node but retain the tree structure"): + maximum(t) shouldBe 3