Skip to content

Context

Context #

Bases: Generic[MODEL_T]

A global object representing a context for a given workflow run.

The Context object can be used to store data that needs to be available across iterations during a workflow execution, and across multiple workflow runs. Every context instance offers two type of data storage: a global one, that's shared among all the steps within a workflow, and private one, that's only accessible from a single step.

Both set and get operations on global data are governed by a lock, and considered coroutine-safe.

Source code in workflows/context/context.py
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
class Context(Generic[MODEL_T]):
    """
    A global object representing a context for a given workflow run.

    The Context object can be used to store data that needs to be available across iterations during a workflow
    execution, and across multiple workflow runs.
    Every context instance offers two type of data storage: a global one, that's shared among all the steps within a
    workflow, and private one, that's only accessible from a single step.

    Both `set` and `get` operations on global data are governed by a lock, and considered coroutine-safe.
    """

    # These keys are set by pre-built workflows and
    # are known to be unserializable in some cases.
    known_unserializable_keys = ("memory",)

    def __init__(
        self,
        workflow: "Workflow",
        stepwise: bool = False,
    ) -> None:
        self.stepwise = stepwise
        self.is_running = False
        # Store the step configs of this workflow, to be used in send_event
        self._step_configs: dict[str, StepConfig | None] = {}
        for step_name, step_func in workflow._get_steps().items():
            self._step_configs[step_name] = getattr(step_func, "__step_config", None)

        # Init broker machinery
        self._init_broker_data()

        # Global data storage
        self._lock = asyncio.Lock()
        self._state_store: InMemoryStateStore[MODEL_T] | None = None

        # instrumentation
        self._dispatcher = workflow._dispatcher

    async def _init_state_store(self, state_class: MODEL_T) -> None:
        # If a state manager already exists, ensure the requested state type is compatible
        if self._state_store is not None:
            existing_state = await self._state_store.get_state()
            if type(state_class) is not type(existing_state):
                # Existing state type differs from the requested one – this is not allowed
                raise ValueError(
                    f"Cannot initialize with state class {type(state_class)} because it already has a state class {type(existing_state)}"
                )

            # State manager already initialised and compatible – nothing to do
            return

        # First-time initialisation
        self._state_store = InMemoryStateStore(state_class)

    @property
    def store(self) -> InMemoryStateStore[MODEL_T]:
        # Default to DictState if no state manager is initialized
        if self._state_store is None:
            self._state_store = InMemoryStateStore(DictState())

        return self._state_store

    def _init_broker_data(self) -> None:
        self._queues: dict[str, asyncio.Queue] = {}
        self._tasks: set[asyncio.Task] = set()
        self._broker_log: list[Event] = []
        self._cancel_flag: asyncio.Event = asyncio.Event()
        self._step_flags: dict[str, asyncio.Event] = {}
        self._step_events_holding: list[Event] | None = None
        self._step_lock: asyncio.Lock = asyncio.Lock()
        self._step_condition: asyncio.Condition = asyncio.Condition(
            lock=self._step_lock
        )
        self._step_event_written: asyncio.Condition = asyncio.Condition(
            lock=self._step_lock
        )
        self._accepted_events: list[Tuple[str, str]] = []
        self._retval: RunResultT = None
        # Map the step names that were executed to a list of events they received.
        # This will be serialized, and is needed to resume a Workflow run passing
        # an existing context.
        self._in_progress: dict[str, list[Event]] = defaultdict(list)
        # Keep track of the steps currently running. This is only valid when a
        # workflow is running and won't be serialized. Note that a single step
        # might have multiple workers, so we keep a counter.
        self._currently_running_steps: DefaultDict[str, int] = defaultdict(int)
        # Streaming machinery
        self._streaming_queue: asyncio.Queue = asyncio.Queue()
        # Step-specific instance
        self._event_buffers: dict[str, EventBuffer] = {}

    def _serialize_queue(self, queue: asyncio.Queue, serializer: BaseSerializer) -> str:
        queue_items = list(queue._queue)  # type: ignore
        queue_objs = [serializer.serialize(obj) for obj in queue_items]
        return json.dumps(queue_objs)  # type: ignore

    def _deserialize_queue(
        self,
        queue_str: str,
        serializer: BaseSerializer,
        prefix_queue_objs: list[Any] = [],
    ) -> asyncio.Queue:
        queue_objs = json.loads(queue_str)
        queue_objs = prefix_queue_objs + queue_objs
        queue = asyncio.Queue()  # type: ignore
        for obj in queue_objs:
            event_obj = serializer.deserialize(obj)
            queue.put_nowait(event_obj)
        return queue

    def _deserialize_globals(
        self, serialized_globals: dict[str, Any], serializer: BaseSerializer
    ) -> dict[str, Any]:
        """
        DEPRECATED: Kept to support reloading a Context from an old serialized payload.

        This method is deprecated and will be removed in a future version.
        """
        deserialized_globals = {}
        for key, value in serialized_globals.items():
            try:
                deserialized_globals[key] = serializer.deserialize(value)
            except Exception as e:
                raise ValueError(f"Failed to deserialize value for key {key}: {e}")
        return deserialized_globals

    def to_dict(self, serializer: BaseSerializer | None = None) -> dict[str, Any]:
        serializer = serializer or JsonSerializer()

        # Serialize state using the state manager's method
        state_data = {}
        if self._state_store is not None:
            state_data = self._state_store.to_dict(serializer)

        return {
            "state": state_data,  # Use state manager's serialize method
            "streaming_queue": self._serialize_queue(self._streaming_queue, serializer),
            "queues": {
                k: self._serialize_queue(v, serializer) for k, v in self._queues.items()
            },
            "stepwise": self.stepwise,
            "event_buffers": {
                k: {
                    inner_k: [serializer.serialize(ev) for ev in inner_v]
                    for inner_k, inner_v in v.items()
                }
                for k, v in self._event_buffers.items()
            },
            "in_progress": {
                k: [serializer.serialize(ev) for ev in v]
                for k, v in self._in_progress.items()
            },
            "accepted_events": self._accepted_events,
            "broker_log": [serializer.serialize(ev) for ev in self._broker_log],
            "is_running": self.is_running,
        }

    @classmethod
    def from_dict(
        cls,
        workflow: "Workflow",
        data: dict[str, Any],
        serializer: BaseSerializer | None = None,
    ) -> "Context":
        serializer = serializer or JsonSerializer()

        try:
            context = cls(workflow, stepwise=data["stepwise"])

            # Deserialize state manager using the state manager's method
            if "state" in data:
                context._state_store = InMemoryStateStore.from_dict(
                    data["state"], serializer
                )
            elif "globals" in data:
                # Deserialize legacy globals for backward compatibility
                globals = context._deserialize_globals(data["globals"], serializer)
                context._state_store = InMemoryStateStore(DictState(**globals))

            context._streaming_queue = context._deserialize_queue(
                data["streaming_queue"], serializer
            )

            context._event_buffers = {}
            for buffer_id, type_events_map in data["event_buffers"].items():
                context._event_buffers[buffer_id] = {}
                for event_type, events_list in type_events_map.items():
                    context._event_buffers[buffer_id][event_type] = [
                        serializer.deserialize(ev) for ev in events_list
                    ]

            context._accepted_events = data["accepted_events"]
            context._broker_log = [
                serializer.deserialize(ev) for ev in data["broker_log"]
            ]
            context.is_running = data["is_running"]
            # load back up whatever was in the queue as well as the events whose steps
            # were in progress when the serialization of the Context took place
            context._queues = {
                k: context._deserialize_queue(
                    v, serializer, prefix_queue_objs=data["in_progress"].get(k, [])
                )
                for k, v in data["queues"].items()
            }
            context._in_progress = defaultdict(list)
            return context
        except KeyError as e:
            msg = "Error creating a Context instance: the provided payload has a wrong or old format."
            raise ContextSerdeError(msg) from e

    async def set(self, key: str, value: Any, make_private: bool = False) -> None:
        """
        Store `value` into the Context under `key`.

        DEPRECATED: Use `await ctx.store.set(key, value)` instead.
        This method is deprecated and will be removed in a future version.

        Args:
            key: A unique string to identify the value stored.
            value: The data to be stored.

        Raises:
            ValueError: When make_private is True but a key already exists in the global storage.

        """
        warnings.warn(
            "Context.set(key, value) is deprecated. Use 'await ctx.store.set(key, value)' instead.",
            DeprecationWarning,
            stacklevel=2,
        )

        if make_private:
            warnings.warn(
                "`make_private` is deprecated and will be ignored", DeprecationWarning
            )

        # Delegate to state manager
        await self.store.set(key, value)

    async def mark_in_progress(self, name: str, ev: Event) -> None:
        """
        Add input event to in_progress dict.

        Args:
            name (str): The name of the step that is in progress.
            ev (Event): The input event that kicked off this step.

        """
        async with self.lock:
            self._in_progress[name].append(ev)

    async def remove_from_in_progress(self, name: str, ev: Event) -> None:
        """
        Remove input event from active steps.

        Args:
            name (str): The name of the step that has been completed.
            ev (Event): The associated input event that kicked of this completed step.

        """
        async with self.lock:
            events = [e for e in self._in_progress[name] if e != ev]
            self._in_progress[name] = events

    async def add_running_step(self, name: str) -> None:
        async with self.lock:
            self._currently_running_steps[name] += 1

    async def remove_running_step(self, name: str) -> None:
        async with self.lock:
            self._currently_running_steps[name] -= 1
            if self._currently_running_steps[name] == 0:
                del self._currently_running_steps[name]

    async def running_steps(self) -> list[str]:
        async with self.lock:
            return list(self._currently_running_steps)

    async def get(self, key: str, default: Any | None = Ellipsis) -> Any:
        """
        Get the value corresponding to `key` from the Context.

        DEPRECATED: Use `await ctx.store.get(key)` instead.
        This method is deprecated and will be removed in a future version.

        Args:
            key: A unique string to identify the value stored.
            default: The value to return when `key` is missing instead of raising an exception.

        Raises:
            ValueError: When there's not value accessible corresponding to `key`.

        """
        warnings.warn(
            "Context.get() is deprecated. Use 'await ctx.store.get()' instead.",
            DeprecationWarning,
            stacklevel=2,
        )

        return await self.store.get(key, default=default)

    @property
    def lock(self) -> asyncio.Lock:
        """Returns a mutex to lock the Context."""
        return self._lock

    @property
    def session(self) -> "Context":  # pragma: no cover
        """This property is provided for backward compatibility."""
        msg = "`session` is deprecated, please use the Context instance directly."
        warnings.warn(msg, DeprecationWarning, stacklevel=2)
        return self

    def _get_full_path(self, ev_type: Type[Event]) -> str:
        return f"{ev_type.__module__}.{ev_type.__name__}"

    def _get_event_buffer_id(self, events: list[Type[Event]]) -> str:
        # Try getting the current task name
        try:
            current_task = asyncio.current_task()
            if current_task:
                t_name = current_task.get_name()
                # Do not use the default value 'Task'
                if t_name != "Task":
                    return t_name
        except RuntimeError:
            # This is a sync step, fallback to using events list
            pass

        # Fall back to creating a stable identifier from expected events
        return ":".join(sorted(self._get_full_path(e_type) for e_type in events))

    def collect_events(
        self, ev: Event, expected: list[Type[Event]], buffer_id: str | None = None
    ) -> list[Event] | None:
        """
        Collects events for buffering in workflows.

        This method adds the current event to the internal buffer and attempts to collect all
        expected event types. If all expected events are found, they will be returned in order.
        Otherwise, it returns None and restores any collected events back to the buffer.

        Args:
            ev (Event): The current event to add to the buffer.
            expected (list[Type[Event]]): list of expected event types to collect.
            buffer_id (str): A unique identifier for the events collected. Ideally this should be
            the step name, so to avoid any interference between different steps. If not provided,
            a stable identifier will be created using the list of expected events.

        Returns:
            list[Event] | None: list of collected events in the order of expected types if all
                                  expected events are found; otherwise None.

        """
        buffer_id = buffer_id or self._get_event_buffer_id(expected)

        if buffer_id not in self._event_buffers:
            self._event_buffers[buffer_id] = defaultdict(list)

        event_type_path = self._get_full_path(type(ev))
        self._event_buffers[buffer_id][event_type_path].append(ev)

        retval: list[Event] = []
        for e_type in expected:
            e_type_path = self._get_full_path(e_type)
            e_instance_list = self._event_buffers[buffer_id].get(e_type_path, [])
            if e_instance_list:
                retval.append(e_instance_list.pop(0))
            else:
                # We already know we don't have all the events
                break

        if len(retval) == len(expected):
            return retval

        # put back the events if unable to collect all
        for i, ev_to_restore in enumerate(retval):
            e_type = type(retval[i])
            e_type_path = self._get_full_path(e_type)
            self._event_buffers[buffer_id][e_type_path].append(ev_to_restore)

        return None

    def add_holding_event(self, event: Event) -> None:
        """
        Add an event to the list of those collected in current step.

        This is only relevant for stepwise execution.
        """
        if self.stepwise:
            if self._step_events_holding is None:
                self._step_events_holding = []

            self._step_events_holding.append(event)

    def get_holding_events(self) -> list[Event]:
        """Returns a copy of the list of events holding the stepwise execution."""
        if self._step_events_holding is None:
            return []

        return list(self._step_events_holding)

    def send_event(self, message: Event, step: str | None = None) -> None:
        """
        Sends an event to a specific step in the workflow.

        If step is None, the event is sent to all the receivers and we let
        them discard events they don't want.
        """
        self.add_holding_event(message)

        if step is None:
            for queue in self._queues.values():
                queue.put_nowait(message)
        else:
            if step not in self._step_configs:
                raise WorkflowRuntimeError(f"Step {step} does not exist")

            step_config = self._step_configs[step]
            if step_config and type(message) in step_config.accepted_events:
                self._queues[step].put_nowait(message)
            else:
                raise WorkflowRuntimeError(
                    f"Step {step} does not accept event of type {type(message)}"
                )

        self._broker_log.append(message)

    async def wait_for_event(
        self,
        event_type: Type[T],
        waiter_event: Event | None = None,
        waiter_id: str | None = None,
        requirements: dict[str, Any] | None = None,
        timeout: float | None = 2000,
    ) -> T:
        """
        Asynchronously wait for a specific event type to be received.

        If provided, `waiter_event` will be written to the event stream to let the caller know that we're waiting for a response.

        Args:
            event_type: The type of event to wait for
            waiter_event: The event to emit to the event stream to let the caller know that we're waiting for a response
            waiter_id: A unique identifier for this specific wait call. It helps ensure that we only send one `waiter_event` for each `waiter_id`.
            requirements: Optional dict of requirements the event must match
            timeout: Optional timeout in seconds. Defaults to 2000s.

        Returns:
            The event type that was requested.

        Raises:
            asyncio.TimeoutError: If the timeout is reached before receiving matching event

        """
        requirements = requirements or {}

        # Generate a unique key for the waiter
        event_str = self._get_full_path(event_type)
        requirements_str = str(requirements)
        waiter_id = waiter_id or f"waiter_{event_str}_{requirements_str}"

        if waiter_id not in self._queues:
            self._queues[waiter_id] = asyncio.Queue()

        # send the waiter event if it's not already sent
        if waiter_event is not None:
            is_waiting = await self.get(waiter_id, default=False)
            if not is_waiting:
                await self.set(waiter_id, True)
                self.write_event_to_stream(waiter_event)

        while True:
            try:
                event = await asyncio.wait_for(
                    self._queues[waiter_id].get(), timeout=timeout
                )
                if type(event) is event_type:
                    if all(
                        event.get(k, default=None) == v for k, v in requirements.items()
                    ):
                        return event
                    else:
                        continue
            finally:
                await self.set(waiter_id, False)

    def write_event_to_stream(self, ev: Event | None) -> None:
        self._streaming_queue.put_nowait(ev)

    def get_result(self) -> RunResultT:
        """Returns the result of the workflow."""
        return self._retval

    @property
    def streaming_queue(self) -> asyncio.Queue:
        return self._streaming_queue

    def clear(self) -> None:
        """Clear any data stored in the context.

        DEPRECATED: Use `await ctx.store.set(StateCLS())` instead.
        This method is deprecated and will be removed in a future version.
        """
        warnings.warn(
            "Context.clear() is deprecated. Use 'await ctx.store.set(StateCLS())' instead.",
            DeprecationWarning,
            stacklevel=2,
        )

        # Clear the user data storage
        if self._state_store is not None:
            self._state_store._state = self._state_store._state.__class__()

    async def shutdown(self) -> None:
        """
        To be called when a workflow ends.

        We clear all the tasks and set the is_running flag. Note that we
        don't clear _globals or _queues so that the context can be still
        used after the shutdown to fetch data or consume leftover events.
        """
        self.is_running = False
        # Cancel all running tasks
        for task in self._tasks:
            task.cancel()
        # Wait for all tasks to complete
        await asyncio.gather(*self._tasks, return_exceptions=True)
        self._tasks.clear()

    def add_step_worker(
        self,
        name: str,
        step: Callable,
        config: StepConfig,
        stepwise: bool,
        verbose: bool,
        checkpoint_callback: "CheckpointCallback | None",
        run_id: str,
        service_manager: ServiceManager,
        resource_manager: ResourceManager,
    ) -> None:
        self._tasks.add(
            asyncio.create_task(
                self._step_worker(
                    name=name,
                    step=step,
                    config=config,
                    stepwise=stepwise,
                    verbose=verbose,
                    checkpoint_callback=checkpoint_callback,
                    run_id=run_id,
                    service_manager=service_manager,
                    resource_manager=resource_manager,
                ),
                name=name,
            )
        )

    async def _step_worker(
        self,
        name: str,
        step: Callable,
        config: StepConfig,
        stepwise: bool,
        verbose: bool,
        checkpoint_callback: "CheckpointCallback | None",
        run_id: str,
        service_manager: ServiceManager,
        resource_manager: ResourceManager,
    ) -> None:
        while True:
            ev = await self._queues[name].get()
            if type(ev) not in config.accepted_events:
                continue

            # do we need to wait for the step flag?
            if stepwise:
                await self._step_flags[name].wait()

                # clear all flags so that we only run one step
                for flag in self._step_flags.values():
                    flag.clear()

            if verbose and name != "_done":
                print(f"Running step {name}")

            # run step
            # Initialize state manager if needed
            if self._state_store is None:
                if (
                    hasattr(config, "context_state_type")
                    and config.context_state_type is not None
                ):
                    # Instantiate the state class and initialize the state manager
                    try:
                        # Try to instantiate the state class
                        state_instance = config.context_state_type()
                        await self._init_state_store(state_instance)
                    except Exception as e:
                        raise WorkflowRuntimeError(
                            f"Failed to initialize state of type {config.context_state_type}: {e}"
                        ) from e
                else:
                    # Initialize state manager with DictState by default
                    await self._init_state_store(DictState())

            kwargs: dict[str, Any] = {}
            if config.context_parameter:
                kwargs[config.context_parameter] = self
            for service_definition in config.requested_services:
                service = service_manager.get(
                    service_definition.name, service_definition.default_value
                )
                kwargs[service_definition.name] = service
            for resource in config.resources:
                kwargs[resource.name] = await resource_manager.get(
                    resource=resource.resource
                )
            kwargs[config.event_name] = ev

            # wrap the step with instrumentation
            instrumented_step = self._dispatcher.span(step)

            # - check if its async or not
            # - if not async, run it in an executor
            if asyncio.iscoroutinefunction(step):
                retry_start_at = time.time()
                attempts = 0
                while True:
                    await self.mark_in_progress(name=name, ev=ev)
                    await self.add_running_step(name)
                    try:
                        new_ev = await instrumented_step(**kwargs)
                        kwargs.clear()
                        break  # exit the retrying loop

                    except WorkflowDone:
                        await self.remove_from_in_progress(name=name, ev=ev)
                        raise
                    except Exception as e:
                        if config.retry_policy is None:
                            raise WorkflowRuntimeError(
                                f"Error in step '{name}': {e!s}"
                            ) from e

                        delay = config.retry_policy.next(
                            retry_start_at + time.time(), attempts, e
                        )
                        if delay is None:
                            # We're done retrying
                            raise WorkflowRuntimeError(
                                f"Error in step '{name}': {e!s}"
                            ) from e

                        attempts += 1
                        if verbose:
                            print(
                                f"Step {name} produced an error, retry in {delay} seconds"
                            )
                        await asyncio.sleep(delay)
                    finally:
                        await self.remove_running_step(name)

            else:
                try:
                    run_task = functools.partial(instrumented_step, **kwargs)
                    kwargs.clear()
                    new_ev = await asyncio.get_event_loop().run_in_executor(
                        None, run_task
                    )
                except WorkflowDone:
                    await self.remove_from_in_progress(name=name, ev=ev)
                    raise
                except Exception as e:
                    raise WorkflowRuntimeError(f"Error in step '{name}': {e!s}") from e

            if verbose and name != "_done":
                if new_ev is not None:
                    print(f"Step {name} produced event {type(new_ev).__name__}")
                else:
                    print(f"Step {name} produced no event")

            # Store the accepted event for the drawing operations
            if new_ev is not None:
                self._accepted_events.append((name, type(ev).__name__))

            # Fail if the step returned something that's not an event
            if new_ev is not None and not isinstance(new_ev, Event):
                msg = f"Step function {name} returned {type(new_ev).__name__} instead of an Event instance."
                raise WorkflowRuntimeError(msg)

            if stepwise:
                async with self._step_condition:
                    await self._step_condition.wait()

                    if new_ev is not None:
                        self.add_holding_event(new_ev)
                    self._step_event_written.notify()  # shares same lock

                    await self.remove_from_in_progress(name=name, ev=ev)

                    # for stepwise Checkpoint after handler.run_step() call
                    if checkpoint_callback:
                        await checkpoint_callback(
                            run_id=run_id,
                            ctx=self,
                            last_completed_step=name,
                            input_ev=ev,
                            output_ev=new_ev,
                        )
            else:
                # for regular execution, Checkpoint just before firing the next event
                await self.remove_from_in_progress(name=name, ev=ev)
                if checkpoint_callback:
                    await checkpoint_callback(
                        run_id=run_id,
                        ctx=self,
                        last_completed_step=name,
                        input_ev=ev,
                        output_ev=new_ev,
                    )

                # InputRequiredEvent's are special case and need to be written to the stream
                # this way, the user can access and respond to the event
                if isinstance(new_ev, InputRequiredEvent):
                    self.write_event_to_stream(new_ev)
                elif new_ev is not None:
                    self.send_event(new_ev)

    def add_cancel_worker(self) -> None:
        self._tasks.add(asyncio.create_task(self._cancel_worker()))

    async def _cancel_worker(self) -> None:
        try:
            await self._cancel_flag.wait()
            raise WorkflowCancelledByUser
        except asyncio.CancelledError:
            return

is_running instance-attribute #

is_running = False

store property #

store: InMemoryStateStore[MODEL_T]

__init__ #

__init__(
    workflow: "Workflow", stepwise: bool = False
) -> None
Source code in workflows/context/context.py
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
def __init__(
    self,
    workflow: "Workflow",
    stepwise: bool = False,
) -> None:
    self.stepwise = stepwise
    self.is_running = False
    # Store the step configs of this workflow, to be used in send_event
    self._step_configs: dict[str, StepConfig | None] = {}
    for step_name, step_func in workflow._get_steps().items():
        self._step_configs[step_name] = getattr(step_func, "__step_config", None)

    # Init broker machinery
    self._init_broker_data()

    # Global data storage
    self._lock = asyncio.Lock()
    self._state_store: InMemoryStateStore[MODEL_T] | None = None

    # instrumentation
    self._dispatcher = workflow._dispatcher

collect_events #

collect_events(
    ev: Event,
    expected: list[Type[Event]],
    buffer_id: str | None = None,
) -> list[Event] | None

Collects events for buffering in workflows.

This method adds the current event to the internal buffer and attempts to collect all expected event types. If all expected events are found, they will be returned in order. Otherwise, it returns None and restores any collected events back to the buffer.

Parameters:

Name Type Description Default
ev Event

The current event to add to the buffer.

required
expected list[Type[Event]]

list of expected event types to collect.

required
buffer_id str

A unique identifier for the events collected. Ideally this should be

None

Returns:

Type Description
list[Event] | None

list[Event] | None: list of collected events in the order of expected types if all expected events are found; otherwise None.

Source code in workflows/context/context.py
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
def collect_events(
    self, ev: Event, expected: list[Type[Event]], buffer_id: str | None = None
) -> list[Event] | None:
    """
    Collects events for buffering in workflows.

    This method adds the current event to the internal buffer and attempts to collect all
    expected event types. If all expected events are found, they will be returned in order.
    Otherwise, it returns None and restores any collected events back to the buffer.

    Args:
        ev (Event): The current event to add to the buffer.
        expected (list[Type[Event]]): list of expected event types to collect.
        buffer_id (str): A unique identifier for the events collected. Ideally this should be
        the step name, so to avoid any interference between different steps. If not provided,
        a stable identifier will be created using the list of expected events.

    Returns:
        list[Event] | None: list of collected events in the order of expected types if all
                              expected events are found; otherwise None.

    """
    buffer_id = buffer_id or self._get_event_buffer_id(expected)

    if buffer_id not in self._event_buffers:
        self._event_buffers[buffer_id] = defaultdict(list)

    event_type_path = self._get_full_path(type(ev))
    self._event_buffers[buffer_id][event_type_path].append(ev)

    retval: list[Event] = []
    for e_type in expected:
        e_type_path = self._get_full_path(e_type)
        e_instance_list = self._event_buffers[buffer_id].get(e_type_path, [])
        if e_instance_list:
            retval.append(e_instance_list.pop(0))
        else:
            # We already know we don't have all the events
            break

    if len(retval) == len(expected):
        return retval

    # put back the events if unable to collect all
    for i, ev_to_restore in enumerate(retval):
        e_type = type(retval[i])
        e_type_path = self._get_full_path(e_type)
        self._event_buffers[buffer_id][e_type_path].append(ev_to_restore)

    return None

from_dict classmethod #

from_dict(
    workflow: "Workflow",
    data: dict[str, Any],
    serializer: BaseSerializer | None = None,
) -> "Context"
Source code in workflows/context/context.py
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
@classmethod
def from_dict(
    cls,
    workflow: "Workflow",
    data: dict[str, Any],
    serializer: BaseSerializer | None = None,
) -> "Context":
    serializer = serializer or JsonSerializer()

    try:
        context = cls(workflow, stepwise=data["stepwise"])

        # Deserialize state manager using the state manager's method
        if "state" in data:
            context._state_store = InMemoryStateStore.from_dict(
                data["state"], serializer
            )
        elif "globals" in data:
            # Deserialize legacy globals for backward compatibility
            globals = context._deserialize_globals(data["globals"], serializer)
            context._state_store = InMemoryStateStore(DictState(**globals))

        context._streaming_queue = context._deserialize_queue(
            data["streaming_queue"], serializer
        )

        context._event_buffers = {}
        for buffer_id, type_events_map in data["event_buffers"].items():
            context._event_buffers[buffer_id] = {}
            for event_type, events_list in type_events_map.items():
                context._event_buffers[buffer_id][event_type] = [
                    serializer.deserialize(ev) for ev in events_list
                ]

        context._accepted_events = data["accepted_events"]
        context._broker_log = [
            serializer.deserialize(ev) for ev in data["broker_log"]
        ]
        context.is_running = data["is_running"]
        # load back up whatever was in the queue as well as the events whose steps
        # were in progress when the serialization of the Context took place
        context._queues = {
            k: context._deserialize_queue(
                v, serializer, prefix_queue_objs=data["in_progress"].get(k, [])
            )
            for k, v in data["queues"].items()
        }
        context._in_progress = defaultdict(list)
        return context
    except KeyError as e:
        msg = "Error creating a Context instance: the provided payload has a wrong or old format."
        raise ContextSerdeError(msg) from e

get_result #

get_result() -> RunResultT

Returns the result of the workflow.

Source code in workflows/context/context.py
544
545
546
def get_result(self) -> RunResultT:
    """Returns the result of the workflow."""
    return self._retval

send_event #

send_event(message: Event, step: str | None = None) -> None

Sends an event to a specific step in the workflow.

If step is None, the event is sent to all the receivers and we let them discard events they don't want.

Source code in workflows/context/context.py
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
def send_event(self, message: Event, step: str | None = None) -> None:
    """
    Sends an event to a specific step in the workflow.

    If step is None, the event is sent to all the receivers and we let
    them discard events they don't want.
    """
    self.add_holding_event(message)

    if step is None:
        for queue in self._queues.values():
            queue.put_nowait(message)
    else:
        if step not in self._step_configs:
            raise WorkflowRuntimeError(f"Step {step} does not exist")

        step_config = self._step_configs[step]
        if step_config and type(message) in step_config.accepted_events:
            self._queues[step].put_nowait(message)
        else:
            raise WorkflowRuntimeError(
                f"Step {step} does not accept event of type {type(message)}"
            )

    self._broker_log.append(message)

to_dict #

to_dict(
    serializer: BaseSerializer | None = None,
) -> dict[str, Any]
Source code in workflows/context/context.py
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
def to_dict(self, serializer: BaseSerializer | None = None) -> dict[str, Any]:
    serializer = serializer or JsonSerializer()

    # Serialize state using the state manager's method
    state_data = {}
    if self._state_store is not None:
        state_data = self._state_store.to_dict(serializer)

    return {
        "state": state_data,  # Use state manager's serialize method
        "streaming_queue": self._serialize_queue(self._streaming_queue, serializer),
        "queues": {
            k: self._serialize_queue(v, serializer) for k, v in self._queues.items()
        },
        "stepwise": self.stepwise,
        "event_buffers": {
            k: {
                inner_k: [serializer.serialize(ev) for ev in inner_v]
                for inner_k, inner_v in v.items()
            }
            for k, v in self._event_buffers.items()
        },
        "in_progress": {
            k: [serializer.serialize(ev) for ev in v]
            for k, v in self._in_progress.items()
        },
        "accepted_events": self._accepted_events,
        "broker_log": [serializer.serialize(ev) for ev in self._broker_log],
        "is_running": self.is_running,
    }

wait_for_event async #

wait_for_event(
    event_type: Type[T],
    waiter_event: Event | None = None,
    waiter_id: str | None = None,
    requirements: dict[str, Any] | None = None,
    timeout: float | None = 2000,
) -> T

Asynchronously wait for a specific event type to be received.

If provided, waiter_event will be written to the event stream to let the caller know that we're waiting for a response.

Parameters:

Name Type Description Default
event_type Type[T]

The type of event to wait for

required
waiter_event Event | None

The event to emit to the event stream to let the caller know that we're waiting for a response

None
waiter_id str | None

A unique identifier for this specific wait call. It helps ensure that we only send one waiter_event for each waiter_id.

None
requirements dict[str, Any] | None

Optional dict of requirements the event must match

None
timeout float | None

Optional timeout in seconds. Defaults to 2000s.

2000

Returns:

Type Description
T

The event type that was requested.

Raises:

Type Description
TimeoutError

If the timeout is reached before receiving matching event

Source code in workflows/context/context.py
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
async def wait_for_event(
    self,
    event_type: Type[T],
    waiter_event: Event | None = None,
    waiter_id: str | None = None,
    requirements: dict[str, Any] | None = None,
    timeout: float | None = 2000,
) -> T:
    """
    Asynchronously wait for a specific event type to be received.

    If provided, `waiter_event` will be written to the event stream to let the caller know that we're waiting for a response.

    Args:
        event_type: The type of event to wait for
        waiter_event: The event to emit to the event stream to let the caller know that we're waiting for a response
        waiter_id: A unique identifier for this specific wait call. It helps ensure that we only send one `waiter_event` for each `waiter_id`.
        requirements: Optional dict of requirements the event must match
        timeout: Optional timeout in seconds. Defaults to 2000s.

    Returns:
        The event type that was requested.

    Raises:
        asyncio.TimeoutError: If the timeout is reached before receiving matching event

    """
    requirements = requirements or {}

    # Generate a unique key for the waiter
    event_str = self._get_full_path(event_type)
    requirements_str = str(requirements)
    waiter_id = waiter_id or f"waiter_{event_str}_{requirements_str}"

    if waiter_id not in self._queues:
        self._queues[waiter_id] = asyncio.Queue()

    # send the waiter event if it's not already sent
    if waiter_event is not None:
        is_waiting = await self.get(waiter_id, default=False)
        if not is_waiting:
            await self.set(waiter_id, True)
            self.write_event_to_stream(waiter_event)

    while True:
        try:
            event = await asyncio.wait_for(
                self._queues[waiter_id].get(), timeout=timeout
            )
            if type(event) is event_type:
                if all(
                    event.get(k, default=None) == v for k, v in requirements.items()
                ):
                    return event
                else:
                    continue
        finally:
            await self.set(waiter_id, False)

write_event_to_stream #

write_event_to_stream(ev: Event | None) -> None
Source code in workflows/context/context.py
541
542
def write_event_to_stream(self, ev: Event | None) -> None:
    self._streaming_queue.put_nowait(ev)

DictState #

Bases: Event

A dynamic state class that behaves like a dictionary.

This is used as the default state type when no specific state class is provided. It allows storing arbitrary key-value pairs while still being a Pydantic model.

Source code in workflows/context/state_store.py
21
22
23
24
25
26
27
28
29
class DictState(Event):
    """
    A dynamic state class that behaves like a dictionary.

    This is used as the default state type when no specific state class is provided.
    It allows storing arbitrary key-value pairs while still being a Pydantic model.
    """

    pass

InMemoryStateStore #

Bases: Generic[MODEL_T]

State manager for a workflow that provides type-safe state management.

By using Context[StateType] as the parameter type annotation, the state manager is automatically initialized with the correct type, providing full type safety and IDE autocompletion.

When no state type is specified (just Context), it defaults to DictState which behaves like a regular dictionary.

Example with typed state:

from pydantic import BaseModel
from workflows import Workflow, Context, step
from workflows.events import StartEvent, StopEvent

class MyState(BaseModel):
    name: str = "Unknown"
    age: int = 0

class MyWorkflow(Workflow):
    @step
    async def step_1(self, ctx: Context[MyState], ev: StartEvent) -> StopEvent:
        # ctx._state.get() is now properly typed as MyState
        state = await ctx._state.get()
        state.name = "John"  # Type-safe: IDE knows this is a string
        state.age = 30       # Type-safe: IDE knows this is an int
        await ctx._state.set(state)
        return StopEvent()

Example with untyped dict-like state:

class MyWorkflow(Workflow):
    @step
    async def step_1(self, ctx: Context, ev: StartEvent) -> StopEvent:
        # ctx._state behaves like a dict
        state = await ctx._state.get()
        state.name = "John"     # Works like a dict
        state.age = 30          # Dynamic assignment
        await ctx._state.set(state)
        return StopEvent()

The state manager provides: - Type-safe access to state properties with full IDE support (when typed) - Dict-like behavior for dynamic state management (when untyped) - Automatic state initialization based on the generic type parameter - Thread-safe state access with async locking - Deep path-based state access and modification

Source code in workflows/context/state_store.py
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
class InMemoryStateStore(Generic[MODEL_T]):
    """
    State manager for a workflow that provides type-safe state management.

    By using Context[StateType] as the parameter type annotation, the state manager
    is automatically initialized with the correct type, providing full type safety
    and IDE autocompletion.

    When no state type is specified (just Context), it defaults to DictState which
    behaves like a regular dictionary.

    Example with typed state:
    ```python
    from pydantic import BaseModel
    from workflows import Workflow, Context, step
    from workflows.events import StartEvent, StopEvent

    class MyState(BaseModel):
        name: str = "Unknown"
        age: int = 0

    class MyWorkflow(Workflow):
        @step
        async def step_1(self, ctx: Context[MyState], ev: StartEvent) -> StopEvent:
            # ctx._state.get() is now properly typed as MyState
            state = await ctx._state.get()
            state.name = "John"  # Type-safe: IDE knows this is a string
            state.age = 30       # Type-safe: IDE knows this is an int
            await ctx._state.set(state)
            return StopEvent()
    ```

    Example with untyped dict-like state:
    ```python
    class MyWorkflow(Workflow):
        @step
        async def step_1(self, ctx: Context, ev: StartEvent) -> StopEvent:
            # ctx._state behaves like a dict
            state = await ctx._state.get()
            state.name = "John"     # Works like a dict
            state.age = 30          # Dynamic assignment
            await ctx._state.set(state)
            return StopEvent()
    ```

    The state manager provides:
    - Type-safe access to state properties with full IDE support (when typed)
    - Dict-like behavior for dynamic state management (when untyped)
    - Automatic state initialization based on the generic type parameter
    - Thread-safe state access with async locking
    - Deep path-based state access and modification
    """

    # These keys are set by pre-built workflows and
    # are known to be unserializable in some cases.
    known_unserializable_keys = ("memory",)

    def __init__(self, initial_state: MODEL_T):
        self._state = initial_state
        self._lock = asyncio.Lock()

    async def get_state(self) -> MODEL_T:
        """Get a copy of the current state."""
        return self._state.model_copy()

    async def set_state(self, state: MODEL_T) -> None:
        """Set the current state."""
        if not isinstance(state, type(self._state)):
            raise ValueError(f"State must be of type {type(self._state)}")

        async with self._lock:
            self._state = state

    def to_dict(self, serializer: "BaseSerializer") -> dict[str, Any]:
        """
        Serialize the state manager's state.

        For DictState, uses the BaseSerializer for individual items since they can be arbitrary types.
        For other Pydantic models, leverages Pydantic's serialization but uses BaseSerializer for complex types.
        """
        # Special handling for DictState - serialize each item in _data
        if isinstance(self._state, DictState):
            serialized_data = {}
            for key, value in self._state.items():
                try:
                    serialized_data[key] = serializer.serialize(value)
                except Exception as e:
                    if key in self.known_unserializable_keys:
                        warnings.warn(
                            f"Skipping serialization of known unserializable key: {key} -- "
                            "This is expected but will require this item to be set manually after deserialization.",
                            category=UnserializableKeyWarning,
                        )
                        continue
                    raise ValueError(
                        f"Failed to serialize state value for key {key}: {e}"
                    )

            return {
                "state_data": {"_data": serialized_data},
                "state_type": type(self._state).__name__,
                "state_module": type(self._state).__module__,
            }
        else:
            # For regular Pydantic models, rely on pydantic's serialization
            serialized_state = serializer.serialize(self._state)

            return {
                "state_data": serialized_state,
                "state_type": type(self._state).__name__,
                "state_module": type(self._state).__module__,
            }

    @classmethod
    def from_dict(
        cls, serialized_state: dict[str, Any], serializer: "BaseSerializer"
    ) -> "InMemoryStateStore[MODEL_T]":
        """
        Deserialize and restore a state manager.
        """
        if not serialized_state:
            # Return a default DictState manager
            return cls(DictState())  # type: ignore

        state_data = serialized_state.get("state_data", {})
        state_type = serialized_state.get("state_type", "DictState")

        # Deserialize the state data
        if state_type == "DictState":
            # Special handling for DictState - deserialize each item in _data
            _data_serialized = state_data.get("_data", {})
            deserialized_data = {}
            for key, value in _data_serialized.items():
                try:
                    deserialized_data[key] = serializer.deserialize(value)
                except Exception as e:
                    raise ValueError(
                        f"Failed to deserialize state value for key {key}: {e}"
                    )

            state_instance = DictState(_data=deserialized_data)
        else:
            state_instance = serializer.deserialize(state_data)

        return cls(state_instance)  # type: ignore

    async def get(self, path: str, default: Optional[Any] = Ellipsis) -> Any:
        """
        Return a value from *path*, where path is a dot-separated string.
        Example: await sm.get("user.profile.name")
        """
        segments = path.split(".") if path else []
        if len(segments) > MAX_DEPTH:
            raise ValueError(f"Path length exceeds {MAX_DEPTH} segments")

        async with self._lock:
            try:
                value: Any = self._state
                for segment in segments:
                    value = self._traverse_step(value, segment)
            except Exception:
                if default is not Ellipsis:
                    return default

                msg = f"Path '{path}' not found in state"
                raise ValueError(msg)

        return value

    async def set(self, path: str, value: Any) -> None:
        """Set *value* at the location designated by *path* (dot-separated)."""
        if not path:
            raise ValueError("Path cannot be empty")

        segments = path.split(".")
        if len(segments) > MAX_DEPTH:
            raise ValueError(f"Path length exceeds {MAX_DEPTH} segments")

        async with self._lock:
            current = self._state

            # Navigate/create intermediate segments
            for segment in segments[:-1]:
                try:
                    current = self._traverse_step(current, segment)
                except (KeyError, AttributeError, IndexError, TypeError):
                    # Create intermediate object and assign it
                    intermediate: Any = {}
                    self._assign_step(current, segment, intermediate)
                    current = intermediate

            # Assign the final value
            self._assign_step(current, segments[-1], value)

    def _traverse_step(self, obj: Any, segment: str) -> Any:
        """Follow one segment into *obj* (dict key, list index, or attribute)."""
        if isinstance(obj, dict):
            return obj[segment]

        # attempt list/tuple index
        try:
            idx = int(segment)
            return obj[idx]
        except (ValueError, TypeError, IndexError):
            pass

        # fallback to attribute access (Pydantic models, normal objects)
        return getattr(obj, segment)

    def _assign_step(self, obj: Any, segment: str, value: Any) -> None:
        """Assign *value* to *segment* of *obj* (dict key, list index, or attribute)."""
        if isinstance(obj, dict):
            obj[segment] = value
            return

        # attempt list/tuple index assignment
        try:
            idx = int(segment)
            obj[idx] = value
            return
        except (ValueError, TypeError, IndexError):
            pass

        # fallback to attribute assignment
        setattr(obj, segment, value)

get_state async #

get_state() -> MODEL_T

Get a copy of the current state.

Source code in workflows/context/state_store.py
93
94
95
async def get_state(self) -> MODEL_T:
    """Get a copy of the current state."""
    return self._state.model_copy()

set_state async #

set_state(state: MODEL_T) -> None

Set the current state.

Source code in workflows/context/state_store.py
 97
 98
 99
100
101
102
103
async def set_state(self, state: MODEL_T) -> None:
    """Set the current state."""
    if not isinstance(state, type(self._state)):
        raise ValueError(f"State must be of type {type(self._state)}")

    async with self._lock:
        self._state = state

to_dict #

to_dict(serializer: BaseSerializer) -> dict[str, Any]

Serialize the state manager's state.

For DictState, uses the BaseSerializer for individual items since they can be arbitrary types. For other Pydantic models, leverages Pydantic's serialization but uses BaseSerializer for complex types.

Source code in workflows/context/state_store.py
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
def to_dict(self, serializer: "BaseSerializer") -> dict[str, Any]:
    """
    Serialize the state manager's state.

    For DictState, uses the BaseSerializer for individual items since they can be arbitrary types.
    For other Pydantic models, leverages Pydantic's serialization but uses BaseSerializer for complex types.
    """
    # Special handling for DictState - serialize each item in _data
    if isinstance(self._state, DictState):
        serialized_data = {}
        for key, value in self._state.items():
            try:
                serialized_data[key] = serializer.serialize(value)
            except Exception as e:
                if key in self.known_unserializable_keys:
                    warnings.warn(
                        f"Skipping serialization of known unserializable key: {key} -- "
                        "This is expected but will require this item to be set manually after deserialization.",
                        category=UnserializableKeyWarning,
                    )
                    continue
                raise ValueError(
                    f"Failed to serialize state value for key {key}: {e}"
                )

        return {
            "state_data": {"_data": serialized_data},
            "state_type": type(self._state).__name__,
            "state_module": type(self._state).__module__,
        }
    else:
        # For regular Pydantic models, rely on pydantic's serialization
        serialized_state = serializer.serialize(self._state)

        return {
            "state_data": serialized_state,
            "state_type": type(self._state).__name__,
            "state_module": type(self._state).__module__,
        }

from_dict classmethod #

from_dict(
    serialized_state: dict[str, Any],
    serializer: BaseSerializer,
) -> InMemoryStateStore[MODEL_T]

Deserialize and restore a state manager.

Source code in workflows/context/state_store.py
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
@classmethod
def from_dict(
    cls, serialized_state: dict[str, Any], serializer: "BaseSerializer"
) -> "InMemoryStateStore[MODEL_T]":
    """
    Deserialize and restore a state manager.
    """
    if not serialized_state:
        # Return a default DictState manager
        return cls(DictState())  # type: ignore

    state_data = serialized_state.get("state_data", {})
    state_type = serialized_state.get("state_type", "DictState")

    # Deserialize the state data
    if state_type == "DictState":
        # Special handling for DictState - deserialize each item in _data
        _data_serialized = state_data.get("_data", {})
        deserialized_data = {}
        for key, value in _data_serialized.items():
            try:
                deserialized_data[key] = serializer.deserialize(value)
            except Exception as e:
                raise ValueError(
                    f"Failed to deserialize state value for key {key}: {e}"
                )

        state_instance = DictState(_data=deserialized_data)
    else:
        state_instance = serializer.deserialize(state_data)

    return cls(state_instance)  # type: ignore

get async #

get(path: str, default: Optional[Any] = Ellipsis) -> Any

Return a value from path, where path is a dot-separated string. Example: await sm.get("user.profile.name")

Source code in workflows/context/state_store.py
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
async def get(self, path: str, default: Optional[Any] = Ellipsis) -> Any:
    """
    Return a value from *path*, where path is a dot-separated string.
    Example: await sm.get("user.profile.name")
    """
    segments = path.split(".") if path else []
    if len(segments) > MAX_DEPTH:
        raise ValueError(f"Path length exceeds {MAX_DEPTH} segments")

    async with self._lock:
        try:
            value: Any = self._state
            for segment in segments:
                value = self._traverse_step(value, segment)
        except Exception:
            if default is not Ellipsis:
                return default

            msg = f"Path '{path}' not found in state"
            raise ValueError(msg)

    return value

set async #

set(path: str, value: Any) -> None

Set value at the location designated by path (dot-separated).

Source code in workflows/context/state_store.py
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
async def set(self, path: str, value: Any) -> None:
    """Set *value* at the location designated by *path* (dot-separated)."""
    if not path:
        raise ValueError("Path cannot be empty")

    segments = path.split(".")
    if len(segments) > MAX_DEPTH:
        raise ValueError(f"Path length exceeds {MAX_DEPTH} segments")

    async with self._lock:
        current = self._state

        # Navigate/create intermediate segments
        for segment in segments[:-1]:
            try:
                current = self._traverse_step(current, segment)
            except (KeyError, AttributeError, IndexError, TypeError):
                # Create intermediate object and assign it
                intermediate: Any = {}
                self._assign_step(current, segment, intermediate)
                current = intermediate

        # Assign the final value
        self._assign_step(current, segments[-1], value)

BaseSerializer #

Bases: ABC

Source code in workflows/context/serializers.py
17
18
19
20
21
22
class BaseSerializer(ABC):
    @abstractmethod
    def serialize(self, value: Any) -> str: ...

    @abstractmethod
    def deserialize(self, value: str) -> Any: ...

JsonSerializer #

Bases: BaseSerializer

Source code in workflows/context/serializers.py
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
class JsonSerializer(BaseSerializer):
    def _serialize_value(self, value: Any) -> Any:
        """Helper to serialize a single value."""
        # Note: to avoid circular dependencies we cannot import BaseComponent from llama_index.core
        # if we want to use isinstance(value, BaseComponent) instead of guessing type from the presence
        # of class_name, we need to move BaseComponent out of core
        if hasattr(value, "class_name"):
            retval = {
                "__is_component": True,
                "value": value.to_dict(),
                "qualified_name": get_qualified_name(value),
            }
            return retval

        if isinstance(value, BaseModel):
            return {
                "__is_pydantic": True,
                "value": value.model_dump(mode="json"),
                "qualified_name": get_qualified_name(value),
            }

        if isinstance(value, dict):
            return {k: self._serialize_value(v) for k, v in value.items()}

        if isinstance(value, list):
            return [self._serialize_value(item) for item in value]

        return value

    def serialize(self, value: Any) -> str:
        try:
            serialized_value = self._serialize_value(value)
            return json.dumps(serialized_value)
        except Exception:
            raise ValueError(f"Failed to serialize value: {type(value)}: {value!s}")

    def _deserialize_value(self, data: Any) -> Any:
        """Helper to deserialize a single value."""
        if isinstance(data, dict):
            if data.get("__is_pydantic") and data.get("qualified_name"):
                module_class = import_module_from_qualified_name(data["qualified_name"])
                return module_class.model_validate(data["value"])
            elif data.get("__is_component") and data.get("qualified_name"):
                module_class = import_module_from_qualified_name(data["qualified_name"])
                return module_class.from_dict(data["value"])
            return {k: self._deserialize_value(v) for k, v in data.items()}
        elif isinstance(data, list):
            return [self._deserialize_value(item) for item in data]
        return data

    def deserialize(self, value: str) -> Any:
        data = json.loads(value)
        return self._deserialize_value(data)

PickleSerializer #

Bases: JsonSerializer

Source code in workflows/context/serializers.py
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
class PickleSerializer(JsonSerializer):
    def serialize(self, value: Any) -> str:
        """Serialize while prioritizing JSON, falling back to Pickle."""
        try:
            return super().serialize(value)
        except Exception:
            return base64.b64encode(pickle.dumps(value)).decode("utf-8")

    def deserialize(self, value: str) -> Any:
        """
        Deserialize while prioritizing Pickle, falling back to JSON.
        To avoid malicious exploits of the deserialization, deserialize objects
        only when you deem it safe to do so.
        """
        try:
            return pickle.loads(base64.b64decode(value))
        except Exception:
            return super().deserialize(value)

serialize #

serialize(value: Any) -> str

Serialize while prioritizing JSON, falling back to Pickle.

Source code in workflows/context/serializers.py
81
82
83
84
85
86
def serialize(self, value: Any) -> str:
    """Serialize while prioritizing JSON, falling back to Pickle."""
    try:
        return super().serialize(value)
    except Exception:
        return base64.b64encode(pickle.dumps(value)).decode("utf-8")

deserialize #

deserialize(value: str) -> Any

Deserialize while prioritizing Pickle, falling back to JSON. To avoid malicious exploits of the deserialization, deserialize objects only when you deem it safe to do so.

Source code in workflows/context/serializers.py
88
89
90
91
92
93
94
95
96
97
def deserialize(self, value: str) -> Any:
    """
    Deserialize while prioritizing Pickle, falling back to JSON.
    To avoid malicious exploits of the deserialization, deserialize objects
    only when you deem it safe to do so.
    """
    try:
        return pickle.loads(base64.b64decode(value))
    except Exception:
        return super().deserialize(value)