1use std::collections::HashMap;
8
9use axum::{
10    BoxError, Json,
11    extract::{
12        Form, FromRequest, FromRequestParts,
13        rejection::{FailedToDeserializeForm, FormRejection},
14    },
15    response::IntoResponse,
16};
17use axum_extra::typed_header::{TypedHeader, TypedHeaderRejectionReason};
18use headers::{Authorization, authorization::Basic};
19use http::{Request, StatusCode};
20use mas_data_model::{Client, JwksOrJwksUri};
21use mas_http::RequestBuilderExt;
22use mas_iana::oauth::OAuthClientAuthenticationMethod;
23use mas_jose::{jwk::PublicJsonWebKeySet, jwt::Jwt};
24use mas_keystore::Encrypter;
25use mas_storage::{RepositoryAccess, oauth2::OAuth2ClientRepository};
26use oauth2_types::errors::{ClientError, ClientErrorCode};
27use serde::{Deserialize, de::DeserializeOwned};
28use serde_json::Value;
29use thiserror::Error;
30
31use crate::record_error;
32
33static JWT_BEARER_CLIENT_ASSERTION: &str = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer";
34
35#[derive(Deserialize)]
36struct AuthorizedForm<F = ()> {
37    client_id: Option<String>,
38    client_secret: Option<String>,
39    client_assertion_type: Option<String>,
40    client_assertion: Option<String>,
41
42    #[serde(flatten)]
43    inner: F,
44}
45
46#[derive(Debug, PartialEq, Eq)]
47pub enum Credentials {
48    None {
49        client_id: String,
50    },
51    ClientSecretBasic {
52        client_id: String,
53        client_secret: String,
54    },
55    ClientSecretPost {
56        client_id: String,
57        client_secret: String,
58    },
59    ClientAssertionJwtBearer {
60        client_id: String,
61        jwt: Box<Jwt<'static, HashMap<String, serde_json::Value>>>,
62    },
63}
64
65impl Credentials {
66    #[must_use]
68    pub fn client_id(&self) -> &str {
69        match self {
70            Credentials::None { client_id }
71            | Credentials::ClientSecretBasic { client_id, .. }
72            | Credentials::ClientSecretPost { client_id, .. }
73            | Credentials::ClientAssertionJwtBearer { client_id, .. } => client_id,
74        }
75    }
76
77    pub async fn fetch<E>(
84        &self,
85        repo: &mut impl RepositoryAccess<Error = E>,
86    ) -> Result<Option<Client>, E> {
87        let client_id = match self {
88            Credentials::None { client_id }
89            | Credentials::ClientSecretBasic { client_id, .. }
90            | Credentials::ClientSecretPost { client_id, .. }
91            | Credentials::ClientAssertionJwtBearer { client_id, .. } => client_id,
92        };
93
94        repo.oauth2_client().find_by_client_id(client_id).await
95    }
96
97    #[tracing::instrument(skip_all)]
103    pub async fn verify(
104        &self,
105        http_client: &reqwest::Client,
106        encrypter: &Encrypter,
107        method: &OAuthClientAuthenticationMethod,
108        client: &Client,
109    ) -> Result<(), CredentialsVerificationError> {
110        match (self, method) {
111            (Credentials::None { .. }, OAuthClientAuthenticationMethod::None) => {}
112
113            (
114                Credentials::ClientSecretPost { client_secret, .. },
115                OAuthClientAuthenticationMethod::ClientSecretPost,
116            )
117            | (
118                Credentials::ClientSecretBasic { client_secret, .. },
119                OAuthClientAuthenticationMethod::ClientSecretBasic,
120            ) => {
121                let encrypted_client_secret = client
123                    .encrypted_client_secret
124                    .as_ref()
125                    .ok_or(CredentialsVerificationError::InvalidClientConfig)?;
126
127                let decrypted_client_secret = encrypter
128                    .decrypt_string(encrypted_client_secret)
129                    .map_err(|_e| CredentialsVerificationError::DecryptionError)?;
130
131                if client_secret.as_bytes() != decrypted_client_secret {
133                    return Err(CredentialsVerificationError::ClientSecretMismatch);
134                }
135            }
136
137            (
138                Credentials::ClientAssertionJwtBearer { jwt, .. },
139                OAuthClientAuthenticationMethod::PrivateKeyJwt,
140            ) => {
141                let jwks = client
143                    .jwks
144                    .as_ref()
145                    .ok_or(CredentialsVerificationError::InvalidClientConfig)?;
146
147                let jwks = fetch_jwks(http_client, jwks)
148                    .await
149                    .map_err(CredentialsVerificationError::JwksFetchFailed)?;
150
151                jwt.verify_with_jwks(&jwks)
152                    .map_err(|_| CredentialsVerificationError::InvalidAssertionSignature)?;
153            }
154
155            (
156                Credentials::ClientAssertionJwtBearer { jwt, .. },
157                OAuthClientAuthenticationMethod::ClientSecretJwt,
158            ) => {
159                let encrypted_client_secret = client
161                    .encrypted_client_secret
162                    .as_ref()
163                    .ok_or(CredentialsVerificationError::InvalidClientConfig)?;
164
165                let decrypted_client_secret = encrypter
166                    .decrypt_string(encrypted_client_secret)
167                    .map_err(|_e| CredentialsVerificationError::DecryptionError)?;
168
169                jwt.verify_with_shared_secret(decrypted_client_secret)
170                    .map_err(|_| CredentialsVerificationError::InvalidAssertionSignature)?;
171            }
172
173            (_, _) => {
174                return Err(CredentialsVerificationError::AuthenticationMethodMismatch);
175            }
176        }
177        Ok(())
178    }
179}
180
181async fn fetch_jwks(
182    http_client: &reqwest::Client,
183    jwks: &JwksOrJwksUri,
184) -> Result<PublicJsonWebKeySet, BoxError> {
185    let uri = match jwks {
186        JwksOrJwksUri::Jwks(j) => return Ok(j.clone()),
187        JwksOrJwksUri::JwksUri(u) => u,
188    };
189
190    let response = http_client
191        .get(uri.as_str())
192        .send_traced()
193        .await?
194        .error_for_status()?
195        .json()
196        .await?;
197
198    Ok(response)
199}
200
201#[derive(Debug, Error)]
202pub enum CredentialsVerificationError {
203    #[error("failed to decrypt client credentials")]
204    DecryptionError,
205
206    #[error("invalid client configuration")]
207    InvalidClientConfig,
208
209    #[error("client secret did not match")]
210    ClientSecretMismatch,
211
212    #[error("authentication method mismatch")]
213    AuthenticationMethodMismatch,
214
215    #[error("invalid assertion signature")]
216    InvalidAssertionSignature,
217
218    #[error("failed to fetch jwks")]
219    JwksFetchFailed(#[source] Box<dyn std::error::Error + Send + Sync + 'static>),
220}
221
222impl CredentialsVerificationError {
223    #[must_use]
225    pub fn is_internal(&self) -> bool {
226        matches!(
227            self,
228            Self::DecryptionError | Self::InvalidClientConfig | Self::JwksFetchFailed(_)
229        )
230    }
231}
232
233#[derive(Debug, PartialEq, Eq)]
234pub struct ClientAuthorization<F = ()> {
235    pub credentials: Credentials,
236    pub form: Option<F>,
237}
238
239impl<F> ClientAuthorization<F> {
240    #[must_use]
242    pub fn client_id(&self) -> &str {
243        self.credentials.client_id()
244    }
245}
246
247#[derive(Debug, Error)]
248pub enum ClientAuthorizationError {
249    #[error("Invalid Authorization header")]
250    InvalidHeader,
251
252    #[error("Could not deserialize request body")]
253    BadForm(#[source] FailedToDeserializeForm),
254
255    #[error("client_id in form ({form:?}) does not match credential ({credential:?})")]
256    ClientIdMismatch { credential: String, form: String },
257
258    #[error("Unsupported client_assertion_type: {client_assertion_type}")]
259    UnsupportedClientAssertion { client_assertion_type: String },
260
261    #[error("No credentials were presented")]
262    MissingCredentials,
263
264    #[error("Invalid request")]
265    InvalidRequest,
266
267    #[error("Invalid client_assertion")]
268    InvalidAssertion,
269
270    #[error(transparent)]
271    Internal(Box<dyn std::error::Error>),
272}
273
274impl IntoResponse for ClientAuthorizationError {
275    fn into_response(self) -> axum::response::Response {
276        let sentry_event_id = record_error!(self, Self::Internal(_));
277        match &self {
278            ClientAuthorizationError::InvalidHeader => (
279                StatusCode::BAD_REQUEST,
280                sentry_event_id,
281                Json(ClientError::new(
282                    ClientErrorCode::InvalidRequest,
283                    "Invalid Authorization header",
284                )),
285            ),
286
287            ClientAuthorizationError::BadForm(err) => (
288                StatusCode::BAD_REQUEST,
289                sentry_event_id,
290                Json(
291                    ClientError::from(ClientErrorCode::InvalidRequest)
292                        .with_description(format!("{err}")),
293                ),
294            ),
295
296            ClientAuthorizationError::ClientIdMismatch { .. } => (
297                StatusCode::BAD_REQUEST,
298                sentry_event_id,
299                Json(
300                    ClientError::from(ClientErrorCode::InvalidGrant)
301                        .with_description(format!("{self}")),
302                ),
303            ),
304
305            ClientAuthorizationError::UnsupportedClientAssertion { .. } => (
306                StatusCode::BAD_REQUEST,
307                sentry_event_id,
308                Json(
309                    ClientError::from(ClientErrorCode::InvalidRequest)
310                        .with_description(format!("{self}")),
311                ),
312            ),
313
314            ClientAuthorizationError::MissingCredentials => (
315                StatusCode::BAD_REQUEST,
316                sentry_event_id,
317                Json(ClientError::new(
318                    ClientErrorCode::InvalidRequest,
319                    "No credentials were presented",
320                )),
321            ),
322
323            ClientAuthorizationError::InvalidRequest => (
324                StatusCode::BAD_REQUEST,
325                sentry_event_id,
326                Json(ClientError::from(ClientErrorCode::InvalidRequest)),
327            ),
328
329            ClientAuthorizationError::InvalidAssertion => (
330                StatusCode::BAD_REQUEST,
331                sentry_event_id,
332                Json(ClientError::new(
333                    ClientErrorCode::InvalidRequest,
334                    "Invalid client_assertion",
335                )),
336            ),
337
338            ClientAuthorizationError::Internal(e) => (
339                StatusCode::INTERNAL_SERVER_ERROR,
340                sentry_event_id,
341                Json(
342                    ClientError::from(ClientErrorCode::ServerError)
343                        .with_description(format!("{e}")),
344                ),
345            ),
346        }
347        .into_response()
348    }
349}
350
351impl<S, F> FromRequest<S> for ClientAuthorization<F>
352where
353    F: DeserializeOwned,
354    S: Send + Sync,
355{
356    type Rejection = ClientAuthorizationError;
357
358    #[allow(clippy::too_many_lines)]
359    async fn from_request(
360        req: Request<axum::body::Body>,
361        state: &S,
362    ) -> Result<Self, Self::Rejection> {
363        let (mut parts, body) = req.into_parts();
365
366        let header =
367            TypedHeader::<Authorization<Basic>>::from_request_parts(&mut parts, state).await;
368
369        let credentials_from_header = match header {
371            Ok(header) => Some((header.username().to_owned(), header.password().to_owned())),
372            Err(err) => match err.reason() {
373                TypedHeaderRejectionReason::Missing => None,
375                _ => return Err(ClientAuthorizationError::InvalidHeader),
377            },
378        };
379
380        let req = Request::from_parts(parts, body);
382
383        let (
385            client_id_from_form,
386            client_secret_from_form,
387            client_assertion_type,
388            client_assertion,
389            form,
390        ) = match Form::<AuthorizedForm<F>>::from_request(req, state).await {
391            Ok(Form(form)) => (
392                form.client_id,
393                form.client_secret,
394                form.client_assertion_type,
395                form.client_assertion,
396                Some(form.inner),
397            ),
398            Err(FormRejection::InvalidFormContentType(_err)) => (None, None, None, None, None),
400            Err(FormRejection::FailedToDeserializeForm(err)) => {
402                return Err(ClientAuthorizationError::BadForm(err));
403            }
404            Err(e) => return Err(ClientAuthorizationError::Internal(Box::new(e))),
406        };
407
408        let credentials = match (
410            credentials_from_header,
411            client_id_from_form,
412            client_secret_from_form,
413            client_assertion_type,
414            client_assertion,
415        ) {
416            (Some((client_id, client_secret)), client_id_from_form, None, None, None) => {
417                if let Some(client_id_from_form) = client_id_from_form {
418                    if client_id != client_id_from_form {
420                        return Err(ClientAuthorizationError::ClientIdMismatch {
421                            credential: client_id,
422                            form: client_id_from_form,
423                        });
424                    }
425                }
426
427                Credentials::ClientSecretBasic {
428                    client_id,
429                    client_secret,
430                }
431            }
432
433            (None, Some(client_id), Some(client_secret), None, None) => {
434                Credentials::ClientSecretPost {
436                    client_id,
437                    client_secret,
438                }
439            }
440
441            (None, Some(client_id), None, None, None) => {
442                Credentials::None { client_id }
444            }
445
446            (
447                None,
448                client_id_from_form,
449                None,
450                Some(client_assertion_type),
451                Some(client_assertion),
452            ) if client_assertion_type == JWT_BEARER_CLIENT_ASSERTION => {
453                let jwt: Jwt<'static, HashMap<String, Value>> = Jwt::try_from(client_assertion)
455                    .map_err(|_| ClientAuthorizationError::InvalidAssertion)?;
456
457                let client_id = if let Some(Value::String(client_id)) = jwt.payload().get("sub") {
458                    client_id.clone()
459                } else {
460                    return Err(ClientAuthorizationError::InvalidAssertion);
461                };
462
463                if let Some(client_id_from_form) = client_id_from_form {
464                    if client_id != client_id_from_form {
466                        return Err(ClientAuthorizationError::ClientIdMismatch {
467                            credential: client_id,
468                            form: client_id_from_form,
469                        });
470                    }
471                }
472
473                Credentials::ClientAssertionJwtBearer {
474                    client_id,
475                    jwt: Box::new(jwt),
476                }
477            }
478
479            (None, None, None, Some(client_assertion_type), Some(_client_assertion)) => {
480                return Err(ClientAuthorizationError::UnsupportedClientAssertion {
482                    client_assertion_type,
483                });
484            }
485
486            (None, None, None, None, None) => {
487                return Err(ClientAuthorizationError::MissingCredentials);
489            }
490
491            _ => {
492                return Err(ClientAuthorizationError::InvalidRequest);
494            }
495        };
496
497        Ok(ClientAuthorization { credentials, form })
498    }
499}
500
501#[cfg(test)]
502mod tests {
503    use axum::body::Body;
504    use http::{Method, Request};
505
506    use super::*;
507
508    #[tokio::test]
509    async fn none_test() {
510        let req = Request::builder()
511            .method(Method::POST)
512            .header(
513                http::header::CONTENT_TYPE,
514                mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
515            )
516            .body(Body::new("client_id=client-id&foo=bar".to_owned()))
517            .unwrap();
518
519        assert_eq!(
520            ClientAuthorization::<serde_json::Value>::from_request(req, &())
521                .await
522                .unwrap(),
523            ClientAuthorization {
524                credentials: Credentials::None {
525                    client_id: "client-id".to_owned(),
526                },
527                form: Some(serde_json::json!({"foo": "bar"})),
528            }
529        );
530    }
531
532    #[tokio::test]
533    async fn client_secret_basic_test() {
534        let req = Request::builder()
535            .method(Method::POST)
536            .header(
537                http::header::CONTENT_TYPE,
538                mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
539            )
540            .header(
541                http::header::AUTHORIZATION,
542                "Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=",
543            )
544            .body(Body::new("foo=bar".to_owned()))
545            .unwrap();
546
547        assert_eq!(
548            ClientAuthorization::<serde_json::Value>::from_request(req, &())
549                .await
550                .unwrap(),
551            ClientAuthorization {
552                credentials: Credentials::ClientSecretBasic {
553                    client_id: "client-id".to_owned(),
554                    client_secret: "client-secret".to_owned(),
555                },
556                form: Some(serde_json::json!({"foo": "bar"})),
557            }
558        );
559
560        let req = Request::builder()
562            .method(Method::POST)
563            .header(
564                http::header::CONTENT_TYPE,
565                mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
566            )
567            .header(
568                http::header::AUTHORIZATION,
569                "Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=",
570            )
571            .body(Body::new("client_id=client-id&foo=bar".to_owned()))
572            .unwrap();
573
574        assert_eq!(
575            ClientAuthorization::<serde_json::Value>::from_request(req, &())
576                .await
577                .unwrap(),
578            ClientAuthorization {
579                credentials: Credentials::ClientSecretBasic {
580                    client_id: "client-id".to_owned(),
581                    client_secret: "client-secret".to_owned(),
582                },
583                form: Some(serde_json::json!({"foo": "bar"})),
584            }
585        );
586
587        let req = Request::builder()
589            .method(Method::POST)
590            .header(
591                http::header::CONTENT_TYPE,
592                mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
593            )
594            .header(
595                http::header::AUTHORIZATION,
596                "Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=",
597            )
598            .body(Body::new("client_id=mismatch-id&foo=bar".to_owned()))
599            .unwrap();
600
601        assert!(matches!(
602            ClientAuthorization::<serde_json::Value>::from_request(req, &()).await,
603            Err(ClientAuthorizationError::ClientIdMismatch { .. }),
604        ));
605
606        let req = Request::builder()
608            .method(Method::POST)
609            .header(
610                http::header::CONTENT_TYPE,
611                mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
612            )
613            .header(http::header::AUTHORIZATION, "Basic invalid")
614            .body(Body::new("foo=bar".to_owned()))
615            .unwrap();
616
617        assert!(matches!(
618            ClientAuthorization::<serde_json::Value>::from_request(req, &()).await,
619            Err(ClientAuthorizationError::InvalidHeader),
620        ));
621    }
622
623    #[tokio::test]
624    async fn client_secret_post_test() {
625        let req = Request::builder()
626            .method(Method::POST)
627            .header(
628                http::header::CONTENT_TYPE,
629                mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
630            )
631            .body(Body::new(
632                "client_id=client-id&client_secret=client-secret&foo=bar".to_owned(),
633            ))
634            .unwrap();
635
636        assert_eq!(
637            ClientAuthorization::<serde_json::Value>::from_request(req, &())
638                .await
639                .unwrap(),
640            ClientAuthorization {
641                credentials: Credentials::ClientSecretPost {
642                    client_id: "client-id".to_owned(),
643                    client_secret: "client-secret".to_owned(),
644                },
645                form: Some(serde_json::json!({"foo": "bar"})),
646            }
647        );
648    }
649
650    #[tokio::test]
651    async fn client_assertion_test() {
652        let jwt = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJjbGllbnQtaWQiLCJzdWIiOiJjbGllbnQtaWQiLCJhdWQiOiJodHRwczovL2V4YW1wbGUuY29tL29hdXRoMi9pbnRyb3NwZWN0IiwianRpIjoiYWFiYmNjIiwiZXhwIjoxNTE2MjM5MzIyLCJpYXQiOjE1MTYyMzkwMjJ9.XTaACG_Rww0GPecSZvkbem-AczNy9LLNBueCLCiQajU";
654        let body = Body::new(format!(
655            "client_assertion_type={JWT_BEARER_CLIENT_ASSERTION}&client_assertion={jwt}&foo=bar",
656        ));
657
658        let req = Request::builder()
659            .method(Method::POST)
660            .header(
661                http::header::CONTENT_TYPE,
662                mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
663            )
664            .body(body)
665            .unwrap();
666
667        let authz = ClientAuthorization::<serde_json::Value>::from_request(req, &())
668            .await
669            .unwrap();
670        assert_eq!(authz.form, Some(serde_json::json!({"foo": "bar"})));
671
672        let Credentials::ClientAssertionJwtBearer { client_id, jwt } = authz.credentials else {
673            panic!("expected a JWT client_assertion");
674        };
675
676        assert_eq!(client_id, "client-id");
677        jwt.verify_with_shared_secret(b"client-secret".to_vec())
678            .unwrap();
679    }
680}