Add Joined combinator

Signed-off-by: Marcel Müller <neikos@neikos.email>
This commit is contained in:
Marcel Müller 2025-11-04 13:51:28 +01:00
parent 44e8280093
commit edc076efdd
2 changed files with 127 additions and 17 deletions

View file

@ -48,6 +48,10 @@ impl InternalMessage {
pub fn type_name(&self) -> &'static str {
self.name
}
pub fn type_id(&self) -> TypeId {
self.value.as_ref().type_id()
}
}
pub trait Address<MB> {
@ -116,6 +120,7 @@ pub struct BundleChain {
pub enum BundleOp {
Add(TypeId),
Remove(TypeId),
Chain(&'static BundleChain),
}
impl BundleChain {
@ -126,12 +131,31 @@ impl BundleChain {
}
}
pub const fn contains(&self, id: TypeId) -> bool {
check_is_contained(self, id)
}
pub const fn with<M: Message>(&'static self) -> BundleChain {
add_to_chain(self, TypeId::of::<M>())
let to_add = TypeId::of::<M>();
BundleChain {
op: BundleOp::Add(to_add),
next: Some(self),
}
}
pub const fn without<M: Message>(&'static self) -> BundleChain {
remove_from_chain(self, TypeId::of::<M>())
let to_remove = TypeId::of::<M>();
BundleChain {
op: BundleOp::Remove(to_remove),
next: Some(self),
}
}
pub const fn join(&'static self, ids: &'static BundleChain) -> BundleChain {
BundleChain {
op: BundleOp::Chain(self),
next: Some(ids),
}
}
}
@ -165,7 +189,7 @@ where
const IS_CONTAINED: bool = check_is_contained(&MB::IDS, TypeId::of::<M>());
}
const fn check_is_contained(ids: &'static BundleChain, id: TypeId) -> bool {
const fn check_is_contained(ids: &BundleChain, id: TypeId) -> bool {
match ids.op {
BundleOp::Add(added_id) => {
if check_type_id_equal(added_id, id) {
@ -177,6 +201,11 @@ const fn check_is_contained(ids: &'static BundleChain, id: TypeId) -> bool {
return false;
}
}
BundleOp::Chain(other_chain) => {
if check_is_contained(other_chain, id) {
return true;
}
}
}
if let Some(next) = ids.next {
@ -190,20 +219,6 @@ const fn check_type_id_equal(left: TypeId, right: TypeId) -> bool {
left == right
}
const fn add_to_chain(prev: &'static BundleChain, to_add: TypeId) -> BundleChain {
BundleChain {
op: BundleOp::Add(to_add),
next: Some(prev),
}
}
const fn remove_from_chain(prev: &'static BundleChain, to_remove: TypeId) -> BundleChain {
BundleChain {
op: BundleOp::Remove(to_remove),
next: Some(prev),
}
}
#[cfg(test)]
mod tests {
use macro_rules_attribute::apply;

View file

@ -21,6 +21,11 @@ pub trait AddressExt<MB> {
F: Fn(&M) -> U,
M: Message + IsContainedInBundle<MB>,
Self: Sized;
fn join<Other>(self, o: Other) -> Joined<Self, Other>
where
Joined<Self, Other>: InternalMessageHandler,
Self: Sized;
}
impl<MB: MessageBundle, A: InternalMessageHandler<HandledMessages = MB>> AddressExt<MB> for A {
@ -55,6 +60,52 @@ impl<MB: MessageBundle, A: InternalMessageHandler<HandledMessages = MB>> Address
_pd: PhantomData,
}
}
fn join<Other>(self, o: Other) -> Joined<Self, Other>
where
Joined<Self, Other>: InternalMessageHandler,
Self: Sized,
{
Joined {
left: self,
right: o,
}
}
}
pub struct Joined<L, R> {
left: L,
right: R,
}
pub struct JoinedMessages<L, R> {
_pd: PhantomData<fn(L, R)>,
}
impl<L, R> MessageBundle for JoinedMessages<L, R>
where
L: MessageBundle,
R: MessageBundle,
{
const IDS: tytix_core::BundleChain = <L as MessageBundle>::IDS.join(&<R as MessageBundle>::IDS);
}
impl<L, R> InternalMessageHandler for Joined<L, R>
where
L: InternalMessageHandler,
R: InternalMessageHandler,
{
type HandledMessages = JoinedMessages<L::HandledMessages, R::HandledMessages>;
async fn handle_message(&mut self, msg: InternalMessage) -> anyhow::Result<InternalMessage> {
let message_id = msg.type_id();
if L::HandledMessages::IDS.contains(message_id) {
self.left.handle_message(msg).await
} else {
self.right.handle_message(msg).await
}
}
}
pub struct Inspect<A, F, U, M> {
@ -162,6 +213,12 @@ mod tests {
type Reply = ();
}
struct Zap;
impl Message for Zap {
type Reply = usize;
}
struct SimpleAddress;
impl InternalMessageHandler for SimpleAddress {
@ -179,6 +236,19 @@ mod tests {
}
}
struct ZapAddress;
impl InternalMessageHandler for ZapAddress {
type HandledMessages = (Zap,);
async fn handle_message(
&mut self,
_msg: InternalMessage,
) -> anyhow::Result<InternalMessage> {
Ok(InternalMessage::new(45usize))
}
}
#[apply(test!)]
async fn check_mapping() {
static MSG: OnceLock<bool> = OnceLock::new();
@ -213,4 +283,29 @@ mod tests {
MSG.get().expect("The message was inspected!");
}
#[apply(test!)]
async fn check_join() {
static MSG_SA: OnceLock<bool> = OnceLock::new();
static MSG_ZAP: OnceLock<bool> = OnceLock::new();
let sa = SimpleAddress.inspect(|_b: &Bar| {
MSG_SA.set(true).unwrap();
async {}
});
let zap = ZapAddress.inspect(|_z: &Zap| async {
MSG_ZAP.set(true).unwrap();
});
let mut joined = sa.join(zap);
joined.send(Bar).await.unwrap();
MSG_SA.get().expect("The message was not :CC inspected!");
joined.send(Zap).await.unwrap();
MSG_ZAP.get().expect("The message was NOT inspected!");
}
}