Skip to content
Paul Westcott edited this page Sep 28, 2016 · 7 revisions
// optimized mutation-based implementation. This code is only valid in fslib, where mutation of private
// tail cons cells is permitted in carefully written library code.
module OptimizedFilter =
    let rec private copy<'T> (cons:'T list) (finish:'T list) (current:'T list) : 'T list=
        if obj.ReferenceEquals (current, finish) then cons
        else
            match current with
            | hd :: tl ->
                let cons2 = freshConsNoTail hd
                setFreshConsTail cons cons2
                copy cons2 finish tl
            | _ -> failwith "unexpected"

    let rec private mixed<'T> (cons:'T list) (f:'T->bool) (l:'T list) (potentialTail:'T list) : unit = 
        match l with 
        | [] ->
            match potentialTail :> obj with
            | null -> setFreshConsTail cons []
            | _    -> setFreshConsTail cons potentialTail
        | h::t -> 
            if f h then
                match potentialTail :> obj with
                | null -> mixed cons f t l
                | _    -> mixed cons f t potentialTail
            else
                match potentialTail :> obj with
                | null -> mixed cons f t Unchecked.defaultof<_>
                | _ -> 
                    let cons2 = copy cons l potentialTail
                    mixed cons2 f t Unchecked.defaultof<_>

    // we attempt to reuse as much of the list as we can
    let rec filter<'T> (first:'T list) (f:'T->bool) (l:'T list) =
        match l with 
        | [] -> first
        | h::t -> 
            if f h then   
                filter first f t
            else
                match first with
                | hd::tl -> 
                    let cons = freshConsNoTail hd
                    let cons2 = copy cons l tl
                    mixed cons2 f t Unchecked.defaultof<_>
                    cons
                | _ -> failwith "unexpected"

let rec filter f l = 
    match l with 
    | [] -> []
    | h::t -> 
        if f h then
            OptimizedFilter.filter l f t 
        else
            filter f t

// optimized mutation-based implementation. This code is only valid in fslib, where mutation of private
// tail cons cells is permitted in carefully written library code.
module OptimizedFilter =
    let rec private mixed<'T> (candidateForTail:'T list) (cons:'T list) (f:'T->bool) (l:'T list) (potentialTail:'T list) : unit = 
        match l with 
        | [] ->
            match potentialTail :> obj with
            | null -> setFreshConsTail cons []
            | _    -> setFreshConsTail candidateForTail potentialTail
        | h::t -> 
            if f h then
                let newCons = freshConsNoTail h
                setFreshConsTail cons newCons
                match potentialTail :> obj with
                | null -> mixed candidateForTail newCons f t l
                | _    -> mixed candidateForTail newCons f t potentialTail
            else
                mixed cons cons f t Unchecked.defaultof<_>

    let rec private copy<'T> (cons:'T list) (finish:'T list) (current:'T list) : 'T list=
        if obj.ReferenceEquals (current, finish) then cons
        else
            match current with
            | hd :: tl ->
                let cons2 = freshConsNoTail hd
                setFreshConsTail cons cons2
                copy cons2 finish tl
            | _ -> failwith "unexpected"

    // we attempt to reuse as much of the list as we can
    let rec filter<'T> (first:'T list) (f:'T->bool) (l:'T list) =
        match l with 
        | [] -> first
        | h::t -> 
            if f h then   
                filter first f t
            else
                match first with
                | hd::tl -> 
                    let cons = freshConsNoTail hd
                    let cons2 = copy cons l tl
                    mixed cons2 cons2 f t Unchecked.defaultof<_>
                    cons
                | _ -> failwith "unexpected"

let rec filter f l = 
    match l with 
    | [] -> []
    | h::t -> 
        if f h then
            OptimizedFilter.filter l f t 
        else
            filter f t

// optimized mutation-based implementation. This code is only valid in fslib, where mutation of private
// tail cons cells is permitted in carefully written library code.
module OptimizedFilter =
    let private mixed<'T> (cons:'T list) (f:'T->bool) (l:'T list)  : unit = 
        let mutable cons = cons
        let mutable current = l
        let mutable candidateForTail = cons
        let mutable potentialTail = Unchecked.defaultof<_>
        let mutable tltl = current.(::).1
        while not (obj.ReferenceEquals (tltl, null)) do
            let hd = current.(::).0
            if f hd then
                let newCons = freshConsNoTail hd
                cons.(::).1 <- newCons
                cons <- newCons
                if obj.ReferenceEquals (potentialTail, null) then
                    potentialTail <- current
            else
                candidateForTail <- cons
                potentialTail <- Unchecked.defaultof<_>
            current <- tltl
            tltl <- current.(::).1

        if obj.ReferenceEquals(potentialTail, null) then
            setFreshConsTail cons []
        else
            setFreshConsTail candidateForTail potentialTail

    let private copy<'T> (cons:'T list) (finish:'T list) (current:'T list) : 'T list=
        let mutable current = current
        let mutable cons = cons
        while not (obj.ReferenceEquals (current, finish)) do
            let consNew = freshConsNoTail (current.(::).0)
            setFreshConsTail cons consNew
            cons <- consNew
            current <- current.(::).1
        cons

    // we attempt to reuse as much of the list as we can
    let rec whileTrue<'T> (first:'T list) (f:'T->bool) (l:'T list) =
        match l with 
        | [] -> first
        | h::t -> 
            if f h then   
                whileTrue first f t
            else
                let cons = freshConsNoTail (first.(::).0)
                let consNew = copy cons l (first.(::).1)
                mixed consNew f t
                cons

let rec filter f l = 
    match l with 
    | [] -> []
    | h::t -> 
        if f h then
            OptimizedFilter.whileTrue l f t 
        else
            filter f t

open System.Collections
open System.Collections.Generic
open System

module SeqAssistant =
    let inline avoidTailCall x =
        match x with
        | true -> true
        | false -> false

type ISeqDoNext<'T,'U> =
    abstract DoNext : 'T * byref<'U> -> bool
    abstract AddDoNext : ISeqDoNext<'U,'V> -> ISeqDoNext<'T,'V>

type Factory =
    static member Filter f g = Filter (fun x -> f x && g x)
    static member Map f g = Map (f >> g)

and Map<'T,'U> (map:'T->'U) =
    interface ISeqDoNext<'T,'U> with
        member this.AddDoNext (next:ISeqDoNext<'U,'V>) : ISeqDoNext<'T,'V> =
            match next with 
            | :? Map<'U,'V> as mapU2V -> upcast (Factory.Map this.Map mapU2V.Map)
            | :? Filter<'U> as filterU -> unbox (MapFilter (this, filterU))
            | _ -> upcast Composed (this, next)
    
        member __.DoNext (input:'T, output:byref<'U>) : bool = 
            output <- map input
            true

    member __.Map = map

and Filter<'T> (filter:'T->bool) =
    interface ISeqDoNext<'T,'T> with
        member this.AddDoNext (next:ISeqDoNext<'T,'V>) : ISeqDoNext<'T,'V> = 
            match next with
            | :? Map<'T,'V> as mapTV -> upcast FilterMap (this, mapTV)
            | :? Filter<'T> as filterT2 -> unbox (Factory.Filter this.Filter filterT2.Filter)
            | _ -> upcast Composed (this, next)

        member __.DoNext (input:'T, output:byref<'T>) : bool = 
            if filter input then
                output <- input
                true
            else
                false

    member __.Filter = filter

and Composed<'T,'U,'V> (stage1:ISeqDoNext<'T,'U>, stage2:ISeqDoNext<'U,'V>) =
    interface ISeqDoNext<'T,'V> with
        member __.DoNext (input:'T, output:byref<'V>) :bool = 
            let mutable temp = Unchecked.defaultof<'U>
            if stage1.DoNext (input, &temp) then
                SeqAssistant.avoidTailCall (stage2.DoNext (temp, &output))
            else
                false

        member __.AddDoNext (next:ISeqDoNext<'V,'W>):ISeqDoNext<'T,'W> = 
            upcast Composed (stage1, stage2.AddDoNext next)

    member __.Stage1 = stage1
    member __.Stage2 = stage2

and MapFilter<'T,'U> (map:Map<'T,'U>, filter:Filter<'U>) =
    inherit Composed<'T,'U,'U>(map, filter)

    interface ISeqDoNext<'T,'U> with
        member __.DoNext (input:'T, output:byref<'U>) :bool = 
            output <- map.Map input
            SeqAssistant.avoidTailCall (filter.Filter output)

        member this.AddDoNext (next:ISeqDoNext<'U,'V>):ISeqDoNext<'T,'V> = 
            match next with
            | :? Filter<'U> as filterU -> unbox (MapFilter(map, Factory.Filter filter.Filter filterU.Filter))
            | _ -> upcast Composed (this, next)

and FilterMap<'T,'U> (filter:Filter<'T>, map:Map<'T,'U>) =
    inherit Composed<'T,'T,'U>(filter, map)

    interface ISeqDoNext<'T,'U> with
        member __.DoNext (input:'T, output:byref<'U>) : bool = 
            if filter.Filter input then
                output <- map.Map input
                true
            else
                false

        member this.AddDoNext (next:ISeqDoNext<'U,'V>) : ISeqDoNext<'T,'V> = 
            match next with
            | :? Map<'U,'V> as filterU -> upcast FilterMap(filter, Factory.Map map.Map filterU.Map)
            | _ -> upcast Composed(this, next)

[<AbstractClass>]
type SeqDoNextBase<'T> () =
    abstract member AddSeqDoNext : (ISeqDoNext<'T,'U>) -> IEnumerable<'U>

type SeqDoNext<'T,'U>(source:IEnumerator<'T>, t2u:ISeqDoNext<'T,'U>) =
    inherit SeqDoNextBase<'U>()

    let mutable current = Unchecked.defaultof<_>

    let rec moveNext () =
        if source.MoveNext () then
            if t2u.DoNext (source.Current, &current) then
                true
            else
                moveNext ()
        else
            false

    override __.AddSeqDoNext (u2v:ISeqDoNext<'U,'V>) =
        new SeqDoNext<'T,'V>(source, t2u.AddDoNext u2v) :> IEnumerable<'V>

    interface IDisposable with
        member x.Dispose():unit = 
            ()

    interface IEnumerator with
        member __.Current : obj = box current
        member __.MoveNext () = moveNext ()
        member __.Reset () : unit = failwith "Not implemented yet"

    interface IEnumerator<'U> with
        member x.Current = current

    interface IEnumerable with
        member x.GetEnumerator () : IEnumerator = upcast x
    
    interface IEnumerable<'U> with
        member x.GetEnumerator () : IEnumerator<'U> = upcast  x
      
let smap<'a,'b> (f:'a->'b) (x:seq<'a>) : seq<'b> =  
    match x with
    | :? SeqDoNextBase<'a> as s -> s.AddSeqDoNext (Map f)
    | _ -> upcast (new SeqDoNext<_,_>(x.GetEnumerator(), Map f))

let sfilter<'a> (f:'a->bool) (x:seq<'a>) : seq<'a> =  
    match x with
    | :? SeqDoNextBase<'a> as s -> s.AddSeqDoNext (Filter f)
    | _ -> upcast (new SeqDoNext<_,_>(x.GetEnumerator(), Filter f))

module SeqComposer

    open System
    open System.Collections
    open System.Collections.Generic

    let noReset ()         = failwith "noReset"
    let notStarted ()      = failwith "notStarted"
    let alreadyFinished () = failwith "alreadyFinished"

    let checkNonNull t o   = if isNull o then failwith t

    module Helpers =
        // used for performance reasons; these are not recursive calls, so should be safe
        let inline avoidTailCall x =
            match x with
            | true -> true
            | false -> false

    type Factory =
        static member Filter f g = Filter (fun x -> f x && g x)
        static member Map f g = Map (f >> g)
    
    and [<AbstractClass>] SeqComponent<'T,'U> () =
        abstract ProcessNext : 'T * byref<'U> -> bool

        abstract Composer : SeqComponent<'U,'V> -> SeqComponent<'T,'V>

        abstract ComposeMap<'S>       : Map<'S,'T>       -> SeqComponent<'S,'U>
        abstract ComposeFilter        : Filter<'T>       -> SeqComponent<'T,'U>
        abstract ComposeFilterMap<'S> : FilterMap<'S,'T> -> SeqComponent<'S,'U>
        abstract ComposeMapFilter<'S> : MapFilter<'S,'T> -> SeqComponent<'S,'U>

        override first.Composer (second:SeqComponent<'U,'V>) : SeqComponent<'T,'V> = upcast Composed (first, second)

        default second.ComposeMap<'S>       (first:Map<'S,'T>)       : SeqComponent<'S,'U> = upcast Composed (first, second)
        default second.ComposeFilter        (first:Filter<'T>)       : SeqComponent<'T,'U> = upcast Composed (first, second)
        default second.ComposeFilterMap<'S> (first:FilterMap<'S,'T>) : SeqComponent<'S,'U> = upcast Composed (first, second)
        default second.ComposeMapFilter<'S> (first:MapFilter<'S,'T>) : SeqComponent<'S,'U> = upcast Composed (first, second)

    and Composed<'T,'U,'V> (first:SeqComponent<'T,'U>, second:SeqComponent<'U,'V>) =
        inherit SeqComponent<'T,'V>()

        override __.ProcessNext (input:'T, output:byref<'V>) :bool = 
            let mutable temp = Unchecked.defaultof<'U>
            if first.ProcessNext (input, &temp) then
                Helpers.avoidTailCall (second.ProcessNext (temp, &output))
            else
                false

        override __.Composer (next:SeqComponent<'V,'W>) : SeqComponent<'T,'W> = 
            upcast Composed (first, second.Composer next)

        member __.First = first
        member __.Second = second

    and Map<'T,'U> (map:'T->'U) =
        inherit SeqComponent<'T,'U>()

        override first.Composer (second:SeqComponent<'U,'V>) : SeqComponent<'T,'V> =
            second.ComposeMap first

        override second.ComposeMap<'S> (first:Map<'S,'T>) : SeqComponent<'S,'U> =
            upcast Factory.Map first.Map second.Map
    
        override second.ComposeFilter (first:Filter<'T>) : SeqComponent<'T,'U> =
            upcast FilterMap (first, second)

        override second.ComposeFilterMap<'S> (first:FilterMap<'S,'T>) : SeqComponent<'S,'U> =
            upcast FilterMap (first.Filter, Factory.Map first.Map.Map second.Map)

        override __.ProcessNext (input:'T, output:byref<'U>) : bool = 
            output <- map input
            true

        member __.Map :'T->'U = map

    and Filter<'T> (filter:'T->bool) =
        inherit SeqComponent<'T,'T>()

        override this.Composer (next:SeqComponent<'T,'V>) : SeqComponent<'T,'V> =
            next.ComposeFilter this

        override second.ComposeMap<'S> (first:Map<'S,'T>) : SeqComponent<'S,'T> =
            upcast MapFilter (first, second)

        override second.ComposeFilter (first:Filter<'T>) : SeqComponent<'T,'T> =
            upcast Factory.Filter first.Filter second.Filter

        override second.ComposeMapFilter<'S> (first:MapFilter<'S,'T>) : SeqComponent<'S,'T> =
            upcast MapFilter(first.Map, Factory.Filter first.Filter.Filter second.Filter)

        override __.ProcessNext (input:'T, output:byref<'T>) : bool = 
            if filter input then
                output <- input
                true
            else
                false

        member __.Filter :'T->bool = filter

    and MapFilter<'T,'U> (map:Map<'T,'U>, filter:Filter<'U>) =
        inherit SeqComponent<'T,'U>()

        override __.ProcessNext (input:'T, output:byref<'U>) :bool = 
            output <- map.Map input
            Helpers.avoidTailCall (filter.Filter output)

        override first.Composer (second:SeqComponent<'U,'V>):SeqComponent<'T,'V> =
            second.ComposeMapFilter first

        member __.Map    : Map<'T,'U> = map
        member __.Filter : Filter<'U> = filter

    and FilterMap<'T,'U> (filter:Filter<'T>, map:Map<'T,'U>) =
        inherit SeqComponent<'T,'U>()

        override __.ProcessNext (input:'T, output:byref<'U>) : bool = 
            if filter.Filter input then
                output <- map.Map input
                true
            else
                false

        override this.Composer (next:SeqComponent<'U,'V>) : SeqComponent<'T,'V> = next.ComposeFilterMap this

        member __.Filter : Filter<'T> = filter
        member __.Map    : Map<'T,'U> = map

    and Pairwise<'T> () =
        inherit SeqComponent<'T,'T*'T>()

        let mutable isFirst = true
        let mutable lastValue = Unchecked.defaultof<'T>

        override __.ProcessNext (input:'T, output:byref<'T*'T>) : bool = 
            if isFirst then
                lastValue <- input
                isFirst <- false
                false
            else
                output <- lastValue, input
                lastValue <- input
                true

    and Skip<'T> (skipCount:int) =
        inherit SeqComponent<'T,'T>()

        let mutable count = 0

        override __.ProcessNext (input:'T, output:byref<'T>) : bool = 
            if count < skipCount then
                count <- count + 1
                false
            else
                output <- input
                true

    type SeqProcessNextStates =
        | NotStarted = 1
        | Finished = 2
        | InProcess = 3

    type SeqEnumerator<'T,'U>(enumerator:IEnumerator<'T>, t2u:SeqComponent<'T,'U>) =
        let mutable source = enumerator
        let mutable state = SeqProcessNextStates.NotStarted
        let mutable current = Unchecked.defaultof<_>

        let rec moveNext () =
            if source.MoveNext () then
                if t2u.ProcessNext (source.Current, &current) then
                    true
                else
                    moveNext ()
            else
                state <- SeqProcessNextStates.Finished
                false

        interface IDisposable with
            member x.Dispose():unit =
                match source with
                | null -> ()
                | _ ->
                    source.Dispose ()
                    source <- Unchecked.defaultof<_>

        interface IEnumerator with
            member this.Current : obj = box (this:>IEnumerator<'U>).Current
            member __.MoveNext () =
                state <- SeqProcessNextStates.InProcess
                moveNext ()
            member __.Reset () : unit = noReset ()

        interface IEnumerator<'U> with
            member x.Current =
                match state with
                | SeqProcessNextStates.NotStarted -> notStarted()
                | SeqProcessNextStates.Finished -> alreadyFinished()
                | _ -> ()
                current

    [<AbstractClass>]
    type SeqEnumerableBase<'T> () =
        abstract member Compose<'U> : SeqComponent<'T,'U> -> IEnumerable<'U>

    type SeqEnumerable<'T,'U>(generator:IEnumerable<'T>, t2u:SeqComponent<'T,'U>) =
        inherit SeqEnumerableBase<'U>()

        let getEnumerator () : IEnumerator<'U> =
            upcast (new SeqEnumerator<'T,'U>(generator.GetEnumerator(), t2u))

        interface IEnumerable with
            member this.GetEnumerator () : IEnumerator = upcast (getEnumerator ())
    
        interface IEnumerable<'U> with
            member this.GetEnumerator () : IEnumerator<'U> = getEnumerator ()

        override __.Compose (u2v:SeqComponent<'U,'V>) =
            new SeqEnumerable<'T,'V>(generator, t2u.Composer u2v) :> IEnumerable<'V>

    [<CompiledName("Filter")>]
    let filter<'T> (f:'T->bool) (source:seq<'T>) : seq<'T> =
        checkNonNull "source" source
        match source with
        | :? SeqEnumerableBase<'T> as s -> s.Compose (Filter f)
        | _ -> upcast (new SeqEnumerable<_,_>(source, Filter f))

    [<CompiledName("Map")>]
    let map<'T,'U> (f:'T->'U) (source:seq<'T>) : seq<'U> =
        checkNonNull "source" source
        match source with
        | :? SeqEnumerableBase<'T> as s -> s.Compose (Map f)
        | _ -> upcast (new SeqEnumerable<_,_>(source, Map f))

    [<CompiledName("Pairwise")>]
    let pairwise<'T> (source:seq<'T>) : seq<'T*'T> =
        checkNonNull "source" source
        match source with
        | :? SeqEnumerableBase<'T> as s -> s.Compose (Pairwise ())
        | _ -> upcast (new SeqEnumerable<_,_>(source, Pairwise ()))

    [<CompiledName("Skip")>]
    let skip<'T> (count:int) (source:seq<'T>) : seq<'T> =
        checkNonNull "source" source
        match source with
        | :? SeqEnumerableBase<'T> as s -> s.Compose (Skip count)
        | _ -> upcast (new SeqEnumerable<_,_>(source, Skip count))
Clone this wiki locally