Add Joined combinator

Signed-off-by: Marcel Müller <neikos@neikos.email>
This commit is contained in:
Marcel Müller 2025-11-04 13:51:28 +01:00
parent 44e8280093
commit edc076efdd
2 changed files with 127 additions and 17 deletions

View file

@ -48,6 +48,10 @@ impl InternalMessage {
pub fn type_name(&self) -> &'static str { pub fn type_name(&self) -> &'static str {
self.name self.name
} }
pub fn type_id(&self) -> TypeId {
self.value.as_ref().type_id()
}
} }
pub trait Address<MB> { pub trait Address<MB> {
@ -116,6 +120,7 @@ pub struct BundleChain {
pub enum BundleOp { pub enum BundleOp {
Add(TypeId), Add(TypeId),
Remove(TypeId), Remove(TypeId),
Chain(&'static BundleChain),
} }
impl BundleChain { impl BundleChain {
@ -126,12 +131,31 @@ impl BundleChain {
} }
} }
pub const fn contains(&self, id: TypeId) -> bool {
check_is_contained(self, id)
}
pub const fn with<M: Message>(&'static self) -> BundleChain { pub const fn with<M: Message>(&'static self) -> BundleChain {
add_to_chain(self, TypeId::of::<M>()) let to_add = TypeId::of::<M>();
BundleChain {
op: BundleOp::Add(to_add),
next: Some(self),
}
} }
pub const fn without<M: Message>(&'static self) -> BundleChain { pub const fn without<M: Message>(&'static self) -> BundleChain {
remove_from_chain(self, TypeId::of::<M>()) let to_remove = TypeId::of::<M>();
BundleChain {
op: BundleOp::Remove(to_remove),
next: Some(self),
}
}
pub const fn join(&'static self, ids: &'static BundleChain) -> BundleChain {
BundleChain {
op: BundleOp::Chain(self),
next: Some(ids),
}
} }
} }
@ -165,7 +189,7 @@ where
const IS_CONTAINED: bool = check_is_contained(&MB::IDS, TypeId::of::<M>()); const IS_CONTAINED: bool = check_is_contained(&MB::IDS, TypeId::of::<M>());
} }
const fn check_is_contained(ids: &'static BundleChain, id: TypeId) -> bool { const fn check_is_contained(ids: &BundleChain, id: TypeId) -> bool {
match ids.op { match ids.op {
BundleOp::Add(added_id) => { BundleOp::Add(added_id) => {
if check_type_id_equal(added_id, id) { if check_type_id_equal(added_id, id) {
@ -177,6 +201,11 @@ const fn check_is_contained(ids: &'static BundleChain, id: TypeId) -> bool {
return false; return false;
} }
} }
BundleOp::Chain(other_chain) => {
if check_is_contained(other_chain, id) {
return true;
}
}
} }
if let Some(next) = ids.next { if let Some(next) = ids.next {
@ -190,20 +219,6 @@ const fn check_type_id_equal(left: TypeId, right: TypeId) -> bool {
left == right left == right
} }
const fn add_to_chain(prev: &'static BundleChain, to_add: TypeId) -> BundleChain {
BundleChain {
op: BundleOp::Add(to_add),
next: Some(prev),
}
}
const fn remove_from_chain(prev: &'static BundleChain, to_remove: TypeId) -> BundleChain {
BundleChain {
op: BundleOp::Remove(to_remove),
next: Some(prev),
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use macro_rules_attribute::apply; use macro_rules_attribute::apply;

View file

@ -21,6 +21,11 @@ pub trait AddressExt<MB> {
F: Fn(&M) -> U, F: Fn(&M) -> U,
M: Message + IsContainedInBundle<MB>, M: Message + IsContainedInBundle<MB>,
Self: Sized; Self: Sized;
fn join<Other>(self, o: Other) -> Joined<Self, Other>
where
Joined<Self, Other>: InternalMessageHandler,
Self: Sized;
} }
impl<MB: MessageBundle, A: InternalMessageHandler<HandledMessages = MB>> AddressExt<MB> for A { impl<MB: MessageBundle, A: InternalMessageHandler<HandledMessages = MB>> AddressExt<MB> for A {
@ -55,6 +60,52 @@ impl<MB: MessageBundle, A: InternalMessageHandler<HandledMessages = MB>> Address
_pd: PhantomData, _pd: PhantomData,
} }
} }
fn join<Other>(self, o: Other) -> Joined<Self, Other>
where
Joined<Self, Other>: InternalMessageHandler,
Self: Sized,
{
Joined {
left: self,
right: o,
}
}
}
pub struct Joined<L, R> {
left: L,
right: R,
}
pub struct JoinedMessages<L, R> {
_pd: PhantomData<fn(L, R)>,
}
impl<L, R> MessageBundle for JoinedMessages<L, R>
where
L: MessageBundle,
R: MessageBundle,
{
const IDS: tytix_core::BundleChain = <L as MessageBundle>::IDS.join(&<R as MessageBundle>::IDS);
}
impl<L, R> InternalMessageHandler for Joined<L, R>
where
L: InternalMessageHandler,
R: InternalMessageHandler,
{
type HandledMessages = JoinedMessages<L::HandledMessages, R::HandledMessages>;
async fn handle_message(&mut self, msg: InternalMessage) -> anyhow::Result<InternalMessage> {
let message_id = msg.type_id();
if L::HandledMessages::IDS.contains(message_id) {
self.left.handle_message(msg).await
} else {
self.right.handle_message(msg).await
}
}
} }
pub struct Inspect<A, F, U, M> { pub struct Inspect<A, F, U, M> {
@ -162,6 +213,12 @@ mod tests {
type Reply = (); type Reply = ();
} }
struct Zap;
impl Message for Zap {
type Reply = usize;
}
struct SimpleAddress; struct SimpleAddress;
impl InternalMessageHandler for SimpleAddress { impl InternalMessageHandler for SimpleAddress {
@ -179,6 +236,19 @@ mod tests {
} }
} }
struct ZapAddress;
impl InternalMessageHandler for ZapAddress {
type HandledMessages = (Zap,);
async fn handle_message(
&mut self,
_msg: InternalMessage,
) -> anyhow::Result<InternalMessage> {
Ok(InternalMessage::new(45usize))
}
}
#[apply(test!)] #[apply(test!)]
async fn check_mapping() { async fn check_mapping() {
static MSG: OnceLock<bool> = OnceLock::new(); static MSG: OnceLock<bool> = OnceLock::new();
@ -213,4 +283,29 @@ mod tests {
MSG.get().expect("The message was inspected!"); MSG.get().expect("The message was inspected!");
} }
#[apply(test!)]
async fn check_join() {
static MSG_SA: OnceLock<bool> = OnceLock::new();
static MSG_ZAP: OnceLock<bool> = OnceLock::new();
let sa = SimpleAddress.inspect(|_b: &Bar| {
MSG_SA.set(true).unwrap();
async {}
});
let zap = ZapAddress.inspect(|_z: &Zap| async {
MSG_ZAP.set(true).unwrap();
});
let mut joined = sa.join(zap);
joined.send(Bar).await.unwrap();
MSG_SA.get().expect("The message was not :CC inspected!");
joined.send(Zap).await.unwrap();
MSG_ZAP.get().expect("The message was NOT inspected!");
}
} }