Scala | for-comprehension 底层转换 | withFilter 解析

Scala 中的 for-comprehension 是一种方便的语法糖,它实际上是几种操作mapflatMapfilter的组合。for-comprehension 的 EBNF 表示如下:

1
2
3
4
5
Expr1 ::= ‘for’ (‘(’ Enumerators ‘)’ | ‘{’ Enumerators ‘}’)
{nl} [‘yield’] Expr
Enumerators ::= Generator {semi Generator}
Generator ::= Pattern1 ‘<-’ Expr {[semi] Guard | semi Pattern1 ‘=’ Expr}
Guard ::= ‘if’ PostfixExpr

Scala 中的 for-comprehension 与 Haskell 中的 do-notation 类似,都是对操作组合过程的简化,操作的对象都是 Monad。这里就类比 Haskell 中的 do-notation 来总结 Scala 中的 for-comprehension 转换规则。

First step

第一步 Scala 会处理 generator 中的 refutable pattern。所谓的 refutable pattern 就是模式匹配中可能失败的情况,而 irrefutable pattern 就是模式匹配中一定会匹配成功的情况(如variables)。对于每个可能匹配失败的 generator p <- e,Scala 会将其转化为:

1
p <- e.withFilter { case p => true; case _ => false }

比如 for (1 <- List(1, 2)) "ha" 这段表达式的转化结果为:(直接在REPL里通过宏查看AST)

1
2
3
4
5
6
7
8
scala> reify( for (1 <- List(1, 2)) "ha" )
res1: reflect.runtime.universe.Expr[Unit] =
Expr[Unit](List.apply(1, 2).withFilter(((check$ifrefutable$1) => check$ifrefutable$1: @unchecked match {
case 1 => true
case _ => false
})).foreach(((x$1) => x$1: @unchecked match {
case 1 => "ha"
})))

单个generator的for-comprehension

只有一个 generator 的 for-comprehension:

1
2
for (x <- e1)
yield e2

它会被转化为

1
e1 map {x => e2}

我们通过 Quasiquotes 获取AST来验证:

1
2
3
4
5
6
7
8
scala> val e1 = List(1, 2, 3, 4)
e1: List[Int] = List(1, 2, 3, 4)
scala> def f1(x: Int) = x * 2
f1: (x: Int)Int
scala> q" for (x <- e1) yield f1 _ "
res2: reflect.runtime.universe.Tree = e1.map(((x) => (f1: (() => <empty>))))

在Haskell中原表达式等价于:

1
2
3
do
x <- e1
return e2

转换为非do-notation:

1
2
e1 >>=
\x -> return e2

根据 Monad Laws 推导出的 fmap f ma = ma >>= (return . f) 转化:

1
(\x -> e2) <$> e1

多个generator的for-comprehension

多个generator其实就是mapflatMap(fmap>>=)的组合,比如:

1
2
for (x <- e1; y <- e2)
yield e3

会转化为

1
e1.flatMap(x => for (y <- e2) yield e3)

1
2
3
e1 flatMap { x =>
e2 map { y => e3 }
}

REPL里验证:

1
2
scala> q"for(x <- e1; y <- e2) yield x + y"
res3: reflect.runtime.universe.Tree = e1.flatMap(((x) => e2.map(((y) => x.$plus(y)))))

举例(Scala):

1
2
3
4
5
6
7
8
9
val e1 = List(1, 2, 3)
val e2 = List(4, 5, 6)
val f1 = for(x <- e1; y <- e2)
yield x + y // List(5, 6, 7, 6, 7, 8, 7, 8, 9)
val f2 = e1 flatMap { x =>
e2 map { y => x + y }
} // List(5, 6, 7, 6, 7, 8, 7, 8, 9)

在Haskell中原表达式等价于:

1
2
3
4
do
x <- e1
y <- e2
return e3

转换为非do-notation:

1
2
3
e1 >>=
\x -> e2 >>=
\y -> return e3

根据 Monad Laws 推导出的 fmap f ma = ma >>= (return . f) 转化:

1
2
e1 >>=
\x -> fmap (\y -> e3) e2

举例(Haskell):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
-- result: [5,6,7,6,7,8,7,8,9]
f2 :: (Num a) => [a]
f2 = do
x <- [1, 2, 3]
y <- [4, 5, 6]
return (x + y)
f3 :: (Num a) => [a]
f3 = [1, 2, 3] >>=
\x -> [4, 5, 6] >>=
\y -> return (x + y)
f4 :: (Num a) => [a]
f4 = [1, 2, 3] >>=
\x -> fmap (\y -> x + y) [4, 5, 6]

for-loop

Scala中,for表达式也有支持side effects的版本(for-loop),比如:

1
2
for(x <- e1; y <- e2)
println(x * y)

它的转化和含yield的差不多,只不过它用含副作用的foreach操作替代了mapflatMap算子:

1
2
3
4
5
e1 foreach {
x => e2 foreach {
y => println(x * y)
}
}

含条件的for表达式

Scala支持含有条件判断(if guard)的for表达式,其中if guard对应withFilter算子。

转换规则:p <- e if g 会转换为 p <- e.withFilter(p => g)

比如:

1
2
for (x <- e1 if p)
yield e2

会转化为:

1
2
for (x <- e1 withFilter {x => p})
yield e2

即:

1
e1 withFilter {x => f} map {x => e2}

REPL里验证:

1
2
3
4
5
6
7
8
scala> reify( for(x <- e1 if x > 2) yield f1 _ )
res20: reflect.runtime.universe.Expr[List[Int => Int]] =
Expr[List[Int => Int]]($read.e1.withFilter(((x) => x.$greater(2))).map(((x) => {
((x) => $read.f1(x))
}))(List.canBuildFrom))
scala> q" for(x <- e1 if x > 2) yield f1 _ "
res21: reflect.runtime.universe.Tree = e1.withFilter(((x) => x.$greater(2))).map(((x) => (f1: (() => <empty>))))

含有value definition的for表达式

这种情况下generator中含有value definition,比如:

1
2
p <- e
p1 = e1

这种转换要稍微啰嗦一点。对于 p <- e; p1 = e1 这样的generator,Scala会将其转换为:

1
2
3
4
(p, p1) <- for (x@p<- e) yield {
val x1@p1 = e1
(x, x1)
}

可以看到展开的结果是多了一次for-comprehension,也就是多了一层map,这可能会带来一些效率问题。

在REPL里验证:

1
2
3
4
5
6
7
8
9
10
11
12
scala> val list = List("+", "1", "s")
list: List[String] = List(+, 1, s)
scala> reify ( for(p <- list; x = p; y = p) yield y )
res29: reflect.runtime.universe.Expr[List[String]] =
Expr[List[String]]($read.list.map(((p) => {
val x = p;
val y = p;
Tuple3.apply(p, x, y)
}))(List.canBuildFrom).map(((x$1) => x$1: @unchecked match {
case Tuple3((p @ _), (x @ _), (y @ _)) => y
}))(List.canBuildFrom))

withFilter

最后再来谈一下上面出现的withFilter函数,它于Scala 2.8引入,是filter的lazy版本。那么为什么要引入withFilter呢?为何不能直接用filter呢?我们先来看两段代码:

1
2
3
var found = false
for (x <- List.range(1, 10); if x % 2 == 1 && !found)
if (x == 5) found = true else println(x)
1
2
3
var found = false
for (x <- Stream.range(1, 10); if x % 2 == 1 && !found)
if (x == 5) found = true else println(x)

其中,StreamList的lazy版本,只在需要的时候求值。按照上面总结的for-comprehension转换规则,我们可以将上面的代码转换为:

1
2
var found = false
List.range(1,10).f(_ % 2 == 1 && !found).foreach(x => if (x == 5) found = true else println(x))
1
2
var found = false
Stream.range(1,10).f(_ % 2 == 1 && !found).foreach(x => if (x == 5) found = true else println(x))

这里我们暂时用f来代表某个filter函数。如果令f = filter的话,上面两段程序的运行结果分别是:

1
2
3
4
5
// List
1
3
7
9
1
2
3
// Stream
1
3

可以看到Stream在对每个元素filter的时候,都会重新计算filter对应的predicate,所以上面的代码中found改变会对filter有影响;而List在filter元素的时候,对应的predicate是已经计算好的,不会再变更,因此found改变对filter没有影响。想必大家已经看出问题了,我们在使用for-comprehension的时候,总是希望if guard里的条件是按需求值的(on-demand),而不是一开始就计算好的,因此把if guard转换成filter函数的话语义会有问题。所以,为了保持filter的语义不变,同时确保for-comprehension语义正确,Scala 2.8引入了withFilter函数作为filter的lazy实现,它的predicate是on-demand的。这样,for-comprehension中的guard就可以转换成withFilter函数,从而实现正确的语义。

withFilter的实现也非常简单,既然需要on-demand evaluation,那么就把predicate函数保存下来,到需要的时候再调用。withFilter函数会生成一个WithFilter对象:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def withFilter(p: A => Boolean): FilterMonadic[A, Repr] = new WithFilter(p)
/** A class supporting filtered operations. Instances of this class are
* returned by method `withFilter`.
*/
class WithFilter(p: A => Boolean) extends FilterMonadic[A, Repr] {
def map[B, That](f: A => B)(implicit bf: CanBuildFrom[Repr, B, That]): That = {
val b = bf(repr)
for (x <- self)
if (p(x)) b += f(x)
b.result
}
def flatMap[B, That](f: A => GenTraversableOnce[B])(implicit bf: CanBuildFrom[Repr, B, That]): That = {
val b = bf(repr)
for (x <- self)
if (p(x)) b ++= f(x).seq
b.result
}
def foreach[U](f: A => U): Unit =
for (x <- self)
if (p(x)) f(x)
def withFilter(q: A => Boolean): WithFilter =
new WithFilter(x => p(x) && q(x))
}

可以看到WithFilter里只允许map, flatMap, foreachwithFilter这四种操作,其中mapflatMap会得到原来的集合类型。


References

文章目录
  1. 1. First step
  2. 2. 单个generator的for-comprehension
  3. 3. 多个generator的for-comprehension
  4. 4. for-loop
  5. 5. 含条件的for表达式
  6. 6. 含有value definition的for表达式
  7. 7. withFilter
  8. 8. References