Implement DAV Push

This commit is contained in:
Lennart
2025-06-14 20:24:50 +02:00
parent 0c48507f0c
commit 03ae492483
23 changed files with 882 additions and 308 deletions

View File

@@ -0,0 +1,23 @@
use axum::{
Router,
extract::{Path, State},
response::{IntoResponse, Response},
routing::delete,
};
use http::StatusCode;
use rustical_store::SubscriptionStore;
use std::sync::Arc;
async fn handle_delete<S: SubscriptionStore>(
State(store): State<Arc<S>>,
Path(id): Path<String>,
) -> Result<Response, rustical_store::Error> {
store.delete_subscription(&id).await?;
Ok((StatusCode::NO_CONTENT, "Unregistered").into_response())
}
pub fn subscription_service<S: SubscriptionStore>(sub_store: Arc<S>) -> Router {
Router::new()
.route("/push_subscription/{id}", delete(handle_delete::<S>))
.with_state(sub_store)
}

View File

@@ -1,14 +1,41 @@
mod extension;
pub mod notifier;
mod prop;
pub mod register;
use base64::Engine;
use derive_more::Constructor;
pub use extension::*;
use http::{HeaderValue, Method, header};
pub use prop::*;
use rustical_store::{CollectionOperation, SubscriptionStore};
use std::sync::Arc;
use reqwest::{Body, Url};
use rustical_store::{
CollectionOperation, CollectionOperationInfo, Subscription, SubscriptionStore,
};
use rustical_xml::{XmlRootTag, XmlSerialize, XmlSerializeRoot};
use std::{collections::HashMap, sync::Arc, time::Duration};
use tokio::sync::mpsc::Receiver;
use tracing::error;
use tracing::{error, warn};
mod endpoints;
pub use endpoints::subscription_service;
#[derive(XmlSerialize, Debug)]
pub struct ContentUpdate {
#[xml(ns = "rustical_dav::namespace::NS_DAV")]
sync_token: Option<String>,
}
#[derive(XmlSerialize, XmlRootTag, Debug)]
#[xml(root = b"push-message", ns = "rustical_dav::namespace::NS_DAVPUSH")]
#[xml(ns_prefix(
rustical_dav::namespace::NS_DAVPUSH = b"",
rustical_dav::namespace::NS_DAV = b"D",
))]
struct PushMessage {
#[xml(ns = "rustical_dav::namespace::NS_DAVPUSH")]
topic: String,
#[xml(ns = "rustical_dav::namespace::NS_DAVPUSH")]
content_update: Option<ContentUpdate>,
}
#[derive(Debug, Constructor)]
pub struct DavPushController<S: SubscriptionStore> {
@@ -18,14 +45,176 @@ pub struct DavPushController<S: SubscriptionStore> {
impl<S: SubscriptionStore> DavPushController<S> {
pub async fn notifier(&self, mut recv: Receiver<CollectionOperation>) {
while let Some(message) = recv.recv().await {
let subscribers = match self.sub_store.get_subscriptions(&message.topic).await {
Ok(subs) => subs,
Err(err) => {
error!("{err}");
continue;
loop {
// Make sure we don't flood the subscribers
tokio::time::sleep(Duration::from_secs(10)).await;
let mut messages = vec![];
recv.recv_many(&mut messages, 100).await;
// Right now we just have to show the latest content update by topic
// This might become more complicated in the future depending on what kind of updates
// we add
let mut latest_messages = HashMap::new();
for message in messages {
if matches!(message.data, CollectionOperationInfo::Content { .. }) {
latest_messages.insert(message.topic.to_string(), message);
}
};
}
let messages = latest_messages.into_values();
for message in messages {
self.send_message(message).await;
}
}
}
async fn send_message(&self, message: CollectionOperation) {
let subscriptions = match self.sub_store.get_subscriptions(&message.topic).await {
Ok(subs) => subs,
Err(err) => {
error!("{err}");
return;
}
};
if subscriptions.is_empty() {
return;
}
if matches!(message.data, CollectionOperationInfo::Delete) {
// Collection has been deleted, but we cannot handle that
return;
}
let content_update = if let CollectionOperationInfo::Content { sync_token } = message.data {
Some(ContentUpdate {
sync_token: Some(sync_token),
})
} else {
None
};
let push_message = PushMessage {
topic: message.topic,
content_update,
};
let mut output: Vec<_> = b"<?xml version=\"1.0\" encoding=\"utf-8\"?>\n".into();
let mut writer = quick_xml::Writer::new_with_indent(&mut output, b' ', 4);
if let Err(err) = push_message.serialize_root(&mut writer) {
error!("Could not serialize push message: {}", err);
return;
}
let payload = String::from_utf8(output).unwrap();
for subsciption in subscriptions {
if let Some(allowed_push_servers) = &self.allowed_push_servers {
if let Ok(url) = Url::parse(&subsciption.push_resource) {
let origin = url.origin().unicode_serialization();
if !allowed_push_servers.contains(&origin) {
warn!(
"Deleting subscription {} on topic {} because the endpoint is not in the list of allowed push servers",
subsciption.id, subsciption.topic
);
self.try_delete_subscription(&subsciption.id).await;
}
} else {
warn!(
"Deleting subscription {} on topic {} because of invalid URL",
subsciption.id, subsciption.topic
);
self.try_delete_subscription(&subsciption.id).await;
};
}
if let Err(err) = self.send_payload(&payload, &subsciption).await {
error!("An error occured sending out a push notification: {err}");
if err.is_permament_error() {
warn!(
"Deleting subscription {} on topic {}",
subsciption.id, subsciption.topic
);
self.try_delete_subscription(&subsciption.id).await;
}
}
}
}
async fn try_delete_subscription(&self, sub_id: &str) {
if let Err(err) = self.sub_store.delete_subscription(sub_id).await {
error!("Error deleting subsciption: {err}");
}
}
async fn send_payload(
&self,
payload: &str,
subsciption: &Subscription,
) -> Result<(), NotifierError> {
if subsciption.public_key_type != "p256dh" {
return Err(NotifierError::InvalidPublicKeyType(
subsciption.public_key_type.to_string(),
));
}
let endpoint = subsciption.push_resource.parse().map_err(|_| {
NotifierError::InvalidEndpointUrl(subsciption.push_resource.to_string())
})?;
let ua_public = base64::engine::general_purpose::URL_SAFE_NO_PAD
.decode(&subsciption.public_key)
.map_err(|_| NotifierError::InvalidKeyEncoding)?;
let auth_secret = base64::engine::general_purpose::URL_SAFE_NO_PAD
.decode(&subsciption.auth_secret)
.map_err(|_| NotifierError::InvalidKeyEncoding)?;
let client = reqwest::ClientBuilder::new()
.build()
.map_err(NotifierError::from)?;
let payload = ece::encrypt(&ua_public, &auth_secret, payload.as_bytes())?;
let mut request = reqwest::Request::new(Method::POST, endpoint);
*request.body_mut() = Some(Body::from(payload));
let hdrs = request.headers_mut();
hdrs.insert(
header::CONTENT_ENCODING,
HeaderValue::from_static("aes128gcm"),
);
hdrs.insert(
header::CONTENT_TYPE,
HeaderValue::from_static("application/octet-stream"),
);
client.execute(request).await?;
Ok(())
}
}
#[derive(Debug, thiserror::Error)]
enum NotifierError {
#[error("Invalid public key type: {0}")]
InvalidPublicKeyType(String),
#[error("Invalid endpoint URL: {0}")]
InvalidEndpointUrl(String),
#[error("Invalid key encoding")]
InvalidKeyEncoding,
#[error(transparent)]
EceError(#[from] ece::Error),
#[error(transparent)]
ReqwestError(#[from] reqwest::Error),
}
impl NotifierError {
// Decide whether the error should cause the subscription to be removed
pub fn is_permament_error(&self) -> bool {
match self {
Self::InvalidPublicKeyType(_)
| Self::InvalidEndpointUrl(_)
| Self::InvalidKeyEncoding => true,
Self::EceError(err) => matches!(
err,
ece::Error::InvalidAuthSecret | ece::Error::InvalidKeyLength
),
Self::ReqwestError(_) => false,
}
}
}

View File

@@ -1,147 +0,0 @@
use http::StatusCode;
use reqwest::{
Method, Request,
header::{self, HeaderName, HeaderValue},
};
use rustical_dav::xml::multistatus::PropstatElement;
use rustical_store::{CollectionOperation, CollectionOperationType, SubscriptionStore};
use rustical_xml::{XmlRootTag, XmlSerialize, XmlSerializeRoot};
use std::{str::FromStr, sync::Arc};
use tokio::sync::mpsc::Receiver;
use tracing::{error, info, warn};
// use web_push::{SubscriptionInfo, WebPushMessage, WebPushMessageBuilder};
#[derive(XmlSerialize, Debug)]
struct PushMessageProp {
#[xml(ns = "rustical_dav::namespace::NS_DAV")]
topic: String,
#[xml(ns = "rustical_dav::namespace::NS_DAV")]
sync_token: Option<String>,
}
#[derive(XmlSerialize, XmlRootTag, Debug)]
#[xml(root = b"push-message", ns = "rustical_dav::namespace::NS_DAVPUSH")]
#[xml(ns_prefix(
rustical_dav::namespace::NS_DAVPUSH = b"",
rustical_dav::namespace::NS_DAV = b"D",
))]
struct PushMessage {
#[xml(ns = "rustical_dav::namespace::NS_DAV")]
propstat: PropstatElement<PushMessageProp>,
}
// pub fn build_request(message: WebPushMessage) -> Request {
// // A little janky :)
// let url = reqwest::Url::from_str(&message.endpoint.to_string()).unwrap();
// let mut builder = Request::new(Method::POST, url);
//
// if let Some(topic) = message.topic {
// builder
// .headers_mut()
// .insert("Topic", HeaderValue::from_str(topic.as_str()).unwrap());
// }
//
// if let Some(payload) = message.payload {
// builder.headers_mut().insert(
// header::CONTENT_ENCODING,
// HeaderValue::from_static(payload.content_encoding.to_str()),
// );
// builder.headers_mut().insert(
// header::CONTENT_TYPE,
// HeaderValue::from_static("application/octet-stream"),
// );
//
// for (k, v) in payload.crypto_headers.into_iter() {
// let v: &str = v.as_ref();
// builder.headers_mut().insert(
// HeaderName::from_static(k),
// HeaderValue::from_str(&v).unwrap(),
// );
// }
//
// *builder.body_mut() = Some(reqwest::Body::from(payload.content));
// }
// builder
// }
pub async fn push_notifier(
allowed_push_servers: Option<Vec<String>>,
mut recv: Receiver<CollectionOperation>,
sub_store: Arc<impl SubscriptionStore>,
) {
let client = reqwest::Client::new();
while let Some(message) = recv.recv().await {
let subscribers = match sub_store.get_subscriptions(&message.topic).await {
Ok(subs) => subs,
Err(err) => {
error!("{err}");
continue;
}
};
let status = match message.r#type {
CollectionOperationType::Object => StatusCode::OK,
CollectionOperationType::Delete => StatusCode::NOT_FOUND,
};
let push_message = PushMessage {
propstat: PropstatElement {
prop: PushMessageProp {
topic: message.topic,
sync_token: message.sync_token,
},
status,
},
};
let mut output: Vec<_> = b"<?xml version=\"1.0\" encoding=\"utf-8\"?>\n".into();
let mut writer = quick_xml::Writer::new_with_indent(&mut output, b' ', 4);
if let Err(err) = push_message.serialize_root(&mut writer) {
error!("Could not serialize push message: {}", err);
continue;
}
let payload = String::from_utf8(output).unwrap();
// for subscriber in subscribers {
// let push_resource = subscriber.push_resource;
//
// let sub_info = SubscriptionInfo {
// endpoint: push_resource.to_owned(),
// keys: web_push::SubscriptionKeys {
// p256dh: subscriber.public_key,
// auth: subscriber.auth_secret,
// },
// };
// let mut builder = WebPushMessageBuilder::new(&sub_info);
// builder.set_payload(web_push::ContentEncoding::Aes128Gcm, payload.as_bytes());
// let push_message = builder.build().unwrap();
// let request = build_request(push_message);
//
// let allowed = if let Some(allowed_push_servers) = &allowed_push_servers {
// if let Ok(resource_url) = reqwest::Url::parse(&push_resource) {
// let origin = resource_url.origin().ascii_serialization();
// allowed_push_servers
// .iter()
// .any(|allowed_push_server| allowed_push_server == &origin)
// } else {
// warn!("Invalid push url: {push_resource}");
// false
// }
// } else {
// true
// };
//
// if allowed {
// info!("Sending a push message to {}: {}", push_resource, payload);
// if let Err(err) = client.execute(request).await {
// error!("{err}");
// }
// } else {
// warn!(
// "Not sending a push notification to {} since it's not allowed in dav_push::allowed_push_servers",
// push_resource
// );
// }
// }
}
}