Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Add onEnter and onExit events to states
  • Loading branch information
Gray-Wind committed Mar 8, 2022
commit 578b1c9223d8a69dba95cb71d7d294e7be0ce30e
65 changes: 57 additions & 8 deletions Swift/Sources/StateMachine/StateMachine.swift
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,33 @@ open class StateMachine<State: StateMachineHashable, Event: StateMachineHashable
private let states: States
private var observers: [Observer] = []

private typealias EnterExitAction = (State) throws -> Void

private var onEnterActions: [State.HashableIdentifier: EnterExitAction]
private var onExitActions: [State.HashableIdentifier: EnterExitAction]

private var isNotifying: Bool = false

public init(@DefinitionBuilder build: () -> Definition) {
let definition: Definition = build()
state = definition.initialState.state
states = definition.states.reduce(into: States()) {
$0[$1.state] = $1.events.reduce(into: Events()) {
$0[$1.event] = $1.action
var enterActions: [State.HashableIdentifier: EnterExitAction] = [:]
var exitActions: [State.HashableIdentifier: EnterExitAction] = [:]
states = definition.states.reduce(into: States()) { result, tuple in
let (state, events) = tuple
result[state] = events.reduce(into: Events()) {
switch $1.eventType {
case .onEnter(let action):
enterActions[state] = action
case .onExit(let action):
exitActions[state] = action
case .normal(let event, let action):
$0[event] = action
}
}
}
onEnterActions = enterActions
onExitActions = exitActions
observers = definition.callbacks.map {
Observer(object: self, callback: $0)
}
Expand Down Expand Up @@ -104,10 +121,18 @@ open class StateMachine<State: StateMachineHashable, Event: StateMachineHashable
event: event,
toState: action.toState ?? state,
sideEffects: action.sideEffects)
let fromState = state
if let toState: State = action.toState {
state = toState
}

result = .success(transition)

// if not `dontTransition`
if action.toState != nil {
try? onExitActions[stateIdentifier]?(fromState)
try? onEnterActions[state.hashableIdentifier]?(state)
}
} else {
result = .failure(Transition.Invalid())
}
Expand Down Expand Up @@ -172,25 +197,41 @@ extension StateMachineBuilder {
.state(state: state, events: build())
}

public static func onEnter(_ perform: @escaping (State) throws -> Void) -> [EventHandler] {
[EventHandler(eventType: .onEnter(perform))]
}

public static func onExit(_ perform: @escaping (State) throws -> Void) -> [EventHandler] {
[EventHandler(eventType: .onExit(perform))]
}

public static func onEnter(_ perform: @escaping () throws -> Void) -> [EventHandler] {
[EventHandler(eventType: .onEnter({ _ in try perform() }))]
}

public static func onExit(_ perform: @escaping () throws -> Void) -> [EventHandler] {
[EventHandler(eventType: .onExit({ _ in try perform() }))]
}

public static func on(
_ event: Event.HashableIdentifier,
perform: @escaping (State, Event) throws -> Action
) -> [EventHandler] {
[EventHandler(event: event, action: perform)]
[EventHandler(eventType: .normal(event, perform))]
}

public static func on(
_ event: Event.HashableIdentifier,
perform: @escaping (State) throws -> Action
) -> [EventHandler] {
[EventHandler(event: event) { state, _ in try perform(state) }]
[EventHandler(eventType: .normal(event, { state, _ in try perform(state) }))]
}

public static func on(
_ event: Event.HashableIdentifier,
perform: @escaping () throws -> Action
) -> [EventHandler] {
[EventHandler(event: event) { _, _ in try perform() }]
[EventHandler(eventType: .normal(event, { _, _ in try perform() }))]
}

public static func transition(
Expand Down Expand Up @@ -277,8 +318,16 @@ public enum StateMachineTypes {

public struct EventHandler<State: StateMachineHashable, Event: StateMachineHashable, SideEffect> {

fileprivate let event: Event.HashableIdentifier
fileprivate let action: Action<State, Event, SideEffect>.Factory
fileprivate var eventType: EventType<State, Event, SideEffect>

fileprivate enum EventType<State: StateMachineHashable, Event: StateMachineHashable, SideEffect> {

fileprivate typealias EnterExitAction = (State) throws -> Void

case normal(Event.HashableIdentifier, Action<State, Event, SideEffect>.Factory)
case onEnter(EnterExitAction)
case onExit(EnterExitAction)
}
}

public struct Action<State: StateMachineHashable, Event: StateMachineHashable, SideEffect> {
Expand Down
10 changes: 10 additions & 0 deletions Swift/Tests/StateMachineTests/StateMachineTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -224,3 +224,13 @@ func log(_ expectedMessages: String...) -> Predicate<Logger> {
return PredicateResult(bool: actualMessages == expectedMessages, message: message)
}
}

func noLog() -> Predicate<Logger> {
return Predicate {
let actualMessages: [String]? = try $0.evaluate()?.messages
let actualString: String = stringify(actualMessages?.joined(separator: "\\n"))
let message: ExpectationMessage = .expectedCustomValueTo("no logs",
actual: "<\(actualString)>")
return PredicateResult(bool: actualString.count == 0, message: message)
}
}
52 changes: 38 additions & 14 deletions Swift/Tests/StateMachineTests/StateMachine_Matter_Tests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -28,23 +28,41 @@ final class StateMachine_Matter_Tests: XCTestCase, StateMachineBuilder {
typealias ValidTransition = MatterStateMachine.Transition.Valid
typealias InvalidTransition = MatterStateMachine.Transition.Invalid

enum Message {

static let melted: String = "I melted"
static let frozen: String = "I froze"
static let vaporized: String = "I vaporized"
static let condensed: String = "I condensed"
enum Message: String {

case melted = "I melted"
case frozen = "I froze"
case vaporized = "I vaporized"
case condensed = "I condensed"
case enteredSolid
case exitedSolid
case enteredLiquid
case exitedLiquid
case enteredGas
case exitedGas
}

static func matterStateMachine(withInitialState _state: State, logger: Logger) -> MatterStateMachine {
MatterStateMachine {
initialState(_state)
state(.solid) {
onEnter { _ in
logger.log(Message.enteredSolid.rawValue)
}
onExit { _ in
logger.log(Message.exitedSolid.rawValue)
}
on(.melt) {
transition(to: .liquid, emit: .logMelted)
}
}
state(.liquid) {
onEnter { _ in
logger.log(Message.enteredLiquid.rawValue)
}
onExit { _ in
logger.log(Message.exitedLiquid.rawValue)
}
on(.freeze) {
transition(to: .solid, emit: .logFrozen)
}
Expand All @@ -53,6 +71,12 @@ final class StateMachine_Matter_Tests: XCTestCase, StateMachineBuilder {
}
}
state(.gas) {
onEnter { _ in
logger.log(Message.enteredGas.rawValue)
}
onExit { _ in
logger.log(Message.exitedGas.rawValue)
}
on(.condense) {
transition(to: .liquid, emit: .logCondensed)
}
Expand All @@ -61,10 +85,10 @@ final class StateMachine_Matter_Tests: XCTestCase, StateMachineBuilder {
guard case let .success(transition) = $0 else { return }
transition.sideEffects.forEach { sideEffect in
switch sideEffect {
case .logMelted: logger.log(Message.melted)
case .logFrozen: logger.log(Message.frozen)
case .logVaporized: logger.log(Message.vaporized)
case .logCondensed: logger.log(Message.condensed)
case .logMelted: logger.log(Message.melted.rawValue)
case .logFrozen: logger.log(Message.frozen.rawValue)
case .logVaporized: logger.log(Message.vaporized.rawValue)
case .logCondensed: logger.log(Message.condensed.rawValue)
}
}
}
Expand Down Expand Up @@ -103,7 +127,7 @@ final class StateMachine_Matter_Tests: XCTestCase, StateMachineBuilder {
event: .melt,
toState: .liquid,
sideEffects: [.logMelted])))
expect(self.logger).to(log(Message.melted))
expect(self.logger).to(log(Message.exitedSolid.rawValue, Message.enteredLiquid.rawValue, Message.melted.rawValue))
}

func test_givenStateIsSolid_whenFrozen_shouldThrowInvalidTransitionError() throws {
Expand Down Expand Up @@ -136,7 +160,7 @@ final class StateMachine_Matter_Tests: XCTestCase, StateMachineBuilder {
event: .freeze,
toState: .solid,
sideEffects: [.logFrozen])))
expect(self.logger).to(log(Message.frozen))
expect(self.logger).to(log(Message.exitedLiquid.rawValue, Message.enteredSolid.rawValue, Message.frozen.rawValue))
}

func test_givenStateIsLiquid_whenVaporized_shouldTransitionToGasState() throws {
Expand All @@ -153,7 +177,7 @@ final class StateMachine_Matter_Tests: XCTestCase, StateMachineBuilder {
event: .vaporize,
toState: .gas,
sideEffects: [.logVaporized])))
expect(self.logger).to(log(Message.vaporized))
expect(self.logger).to(log(Message.exitedLiquid.rawValue, Message.enteredGas.rawValue, Message.vaporized.rawValue))
}

func test_givenStateIsGas_whenCondensed_shouldTransitionToLiquidState() throws {
Expand All @@ -170,6 +194,6 @@ final class StateMachine_Matter_Tests: XCTestCase, StateMachineBuilder {
event: .condense,
toState: .liquid,
sideEffects: [.logCondensed])))
expect(self.logger).to(log(Message.condensed))
expect(self.logger).to(log(Message.exitedGas.rawValue, Message.enteredLiquid.rawValue, Message.condensed.rawValue))
}
}
34 changes: 34 additions & 0 deletions Swift/Tests/StateMachineTests/StateMachine_Turnstile_Tests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,25 @@ final class StateMachine_Turnstile_Tests: XCTestCase, StateMachineBuilder {
typealias TurnstileStateMachine = StateMachine<State, Event, SideEffect>
typealias ValidTransition = TurnstileStateMachine.Transition.Valid

enum Message: String {
case enteredLocked
case exitedLocked
case enteredUnlocked
case exitedUnlocked
case enteredBroken
case exitedBroken
}

static func turnstileStateMachine(withInitialState _state: State, logger: Logger) -> TurnstileStateMachine {
TurnstileStateMachine {
initialState(_state)
state(.locked) {
onEnter { state in
logger.log("\(Message.enteredLocked.rawValue) \(try state.credit() as Int)")
}
onExit {
logger.log(Message.exitedLocked.rawValue)
}
on(.insertCoin) { locked, insertCoin in
let newCredit: Int = try locked.credit() + insertCoin.value()
if newCredit >= Constant.farePrice {
Expand All @@ -52,11 +67,23 @@ final class StateMachine_Turnstile_Tests: XCTestCase, StateMachineBuilder {
}
}
state(.unlocked) {
onEnter {
logger.log(Message.enteredUnlocked.rawValue)
}
onExit {
logger.log(Message.exitedUnlocked.rawValue)
}
on(.admitPerson) {
transition(to: .locked(credit: 0), emit: .closeDoors)
}
}
state(.broken) {
onEnter {
logger.log(Message.enteredBroken.rawValue)
}
onExit {
logger.log(Message.exitedBroken.rawValue)
}
on(.machineRepairDidComplete) { broken in
transition(to: try broken.oldState())
}
Expand Down Expand Up @@ -96,6 +123,7 @@ final class StateMachine_Turnstile_Tests: XCTestCase, StateMachineBuilder {
event: .insertCoin(10),
toState: .locked(credit: 10),
sideEffects: [])))
expect(self.logger).to(log(Message.exitedLocked.rawValue, "\(Message.enteredLocked.rawValue) 10"))
}

func test_givenStateIsLocked_whenInsertCoin_andCreditEqualsFarePrice_shouldTransitionToUnlockedStateAndOpenDoors() throws {
Expand All @@ -112,6 +140,7 @@ final class StateMachine_Turnstile_Tests: XCTestCase, StateMachineBuilder {
event: .insertCoin(15),
toState: .unlocked,
sideEffects: [.openDoors])))
expect(self.logger).to(log(Message.exitedLocked.rawValue, Message.enteredUnlocked.rawValue))
}

func test_givenStateIsLocked_whenInsertCoin_andCreditMoreThanFarePrice_shouldTransitionToUnlockedStateAndOpenDoors() throws {
Expand All @@ -128,6 +157,7 @@ final class StateMachine_Turnstile_Tests: XCTestCase, StateMachineBuilder {
event: .insertCoin(20),
toState: .unlocked,
sideEffects: [.openDoors])))
expect(self.logger).to(log(Message.exitedLocked.rawValue, Message.enteredUnlocked.rawValue))
}

func test_givenStateIsLocked_whenAdmitPerson_shouldTransitionToLockedStateAndSoundAlarm() throws {
Expand All @@ -144,6 +174,7 @@ final class StateMachine_Turnstile_Tests: XCTestCase, StateMachineBuilder {
event: .admitPerson,
toState: .locked(credit: 35),
sideEffects: [.soundAlarm])))
expect(self.logger).to(noLog())
}

func test_givenStateIsLocked_whenMachineDidFail_shouldTransitionToBrokenStateAndOrderRepair() throws {
Expand All @@ -160,6 +191,7 @@ final class StateMachine_Turnstile_Tests: XCTestCase, StateMachineBuilder {
event: .machineDidFail,
toState: .broken(oldState: .locked(credit: 15)),
sideEffects: [.orderRepair])))
expect(self.logger).to(log(Message.exitedLocked.rawValue, Message.enteredBroken.rawValue))
}

func test_givenStateIsUnlocked_whenAdmitPerson_shouldTransitionToLockedStateAndCloseDoors() throws {
Expand All @@ -176,6 +208,7 @@ final class StateMachine_Turnstile_Tests: XCTestCase, StateMachineBuilder {
event: .admitPerson,
toState: .locked(credit: 0),
sideEffects: [.closeDoors])))
expect(self.logger).to(log(Message.exitedUnlocked.rawValue, "\(Message.enteredLocked.rawValue) 0"))
}

func test_givenStateIsBroken_whenMachineRepairDidComplete_shouldTransitionToLockedState() throws {
Expand All @@ -192,6 +225,7 @@ final class StateMachine_Turnstile_Tests: XCTestCase, StateMachineBuilder {
event: .machineRepairDidComplete,
toState: .locked(credit: 15),
sideEffects: [])))
expect(self.logger).to(log(Message.exitedBroken.rawValue, "\(Message.enteredLocked.rawValue) 15"))
}
}

Expand Down