Leetcode in Scala3 (Dotty): Repeated String Match with Knuth-Morris-Prath Algorithm

I wrote how to solve how to solve “Repeated String Match” problem from Leetcode for O(a_len * (a_len + b_len).

I told how to imlement Knuth-Morris-Prath algorithm in Scala. Having a faster search, the “Repeated String Problem” can be solved for O(a_len + b_len) as follows:

class KMP private(pattern: String, /* ... */)
def search(text: String): Option[Int] = // ...

object KMP
def apply(pattern: String): KMP =
// ...
new KMP(pattern, /* ... */)

given intOps: (v: Int) extended with
def divRoundUp(d: Int): Int = (v + d - 1) / d

object Solution
// time: O(a_len + b_len)
// space: O(a_len + b_len)
def repeatedStringMatch(a: String, b: String): Int =
val repeatCount = b.length.divRoundUp(a.length)
val aRepeated = a * repeatCount
val kmp = KMP(pattern = b)

if kmp.search(text = aRepeated).isDefined
repeatCount
else if kmp.search(text = aRepeated + a).isDefined
repeatCount + 1
else
-1

def main(args: Array[String]): Unit =
assert(repeatedStringMatch("abcd", "cdabcdab") == 3)


I used the KMP instance twice to search the pattern in two different texts. But there is a way to use a single run for both texts. The solution is to build prefix function and use it to search for a pattern in the string: pattern + '#' + text. The text is input string repeated ceil(b / a) times. The solution in total is as follows:

opaque type PrefixFunction = Array[Int]

object PrefixFunction
def apply(pattern: String): PrefixFunction =
val pi = Array.ofDim[Int](pattern.length)
var i = 1
while (i < pattern.size)
var j = pi(i - 1)
while (j > 0 && pattern(i) != pattern(j))
j = pi(j - 1)
if pattern(i) == pattern(j)
j += 1
pi(i) = j
i += 1
pi

given prefixFunctionOps: (x: PrefixFunction) extended with
def apply(idx: Int) = x(idx)

given intOps: (v: Int) extended with
def divRoundUp(d: Int): Int = (v + d - 1) / d

object Solution
// time: O(a_len + b_len)
// space: O(a_len + b_len)
def repeatedStringMatch(a: String, b: String): Int =
val repeatCount = b.length.divRoundUp(a.length)
val str = b + "#" + (a * (repeatCount + 1))
val pf = PrefixFunction(pattern = str)

var idx = b.size + 1
var found = false
while (!found && idx < str.size)
if pf(idx) == b.size
found = true
else
idx += 1

if !found
-1
else
(idx - b.size).divRoundUp(a.length)

def main(args: Array[String]): Unit =
assert(repeatedStringMatch("abcd", "cdabcdab") == 3)