From 23e87270df4203deedfff427ec97b2e7d048431b Mon Sep 17 00:00:00 2001 From: Matty Date: Fri, 9 Aug 2024 10:19:44 -0400 Subject: [PATCH] =?UTF-8?q?A=20Curve=20trait=20for=20general=20interoperat?= =?UTF-8?q?ion=20=E2=80=94=20Part=20I=20(#14630)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Objective This PR implements part of the [Curve RFC](https://github.com/bevyengine/rfcs/blob/main/rfcs/80-curve-trait.md). See that document for motivation, objectives, etc. ## Solution For purposes of reviewability, this PR excludes the entire part of the RFC related to taking multiple samples, resampling, and interpolation generally. (This means the entire `cores` submodule is also excluded.) On the other hand, the entire `Interval` type and all of the functional `Curve` adaptors are included. ## Testing Test modules are included and can be run locally (but they are also included in CI). --------- Co-authored-by: Alice Cecile --- crates/bevy_math/src/curve/interval.rs | 373 ++++++++++++++ crates/bevy_math/src/curve/mod.rs | 685 +++++++++++++++++++++++++ crates/bevy_math/src/lib.rs | 1 + 3 files changed, 1059 insertions(+) create mode 100644 crates/bevy_math/src/curve/interval.rs create mode 100644 crates/bevy_math/src/curve/mod.rs diff --git a/crates/bevy_math/src/curve/interval.rs b/crates/bevy_math/src/curve/interval.rs new file mode 100644 index 0000000000..dd263b3e68 --- /dev/null +++ b/crates/bevy_math/src/curve/interval.rs @@ -0,0 +1,373 @@ +//! The [`Interval`] type for nonempty intervals used by the [`Curve`](super::Curve) trait. + +use itertools::Either; +use std::{ + cmp::{max_by, min_by}, + ops::RangeInclusive, +}; +use thiserror::Error; + +#[cfg(feature = "bevy_reflect")] +use bevy_reflect::Reflect; +#[cfg(all(feature = "serialize", feature = "bevy_reflect"))] +use bevy_reflect::{ReflectDeserialize, ReflectSerialize}; + +/// A nonempty closed interval, possibly unbounded in either direction. +/// +/// In other words, the interval may stretch all the way to positive or negative infinity, but it +/// will always have some nonempty interior. +#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "bevy_reflect", derive(Reflect), reflect(Debug, PartialEq))] +#[cfg_attr( + all(feature = "serialize", feature = "bevy_reflect"), + reflect(Serialize, Deserialize) +)] +pub struct Interval { + start: f32, + end: f32, +} + +/// An error that indicates that an operation would have returned an invalid [`Interval`]. +#[derive(Debug, Error)] +#[error("The resulting interval would be invalid (empty or with a NaN endpoint)")] +pub struct InvalidIntervalError; + +/// An error indicating that spaced points could not be extracted from an unbounded interval. +#[derive(Debug, Error)] +#[error("Cannot extract spaced points from an unbounded interval")] +pub struct SpacedPointsError; + +/// An error indicating that a linear map between intervals could not be constructed because of +/// unboundedness. +#[derive(Debug, Error)] +#[error("Could not construct linear function to map between intervals")] +pub(super) enum LinearMapError { + /// The source interval being mapped out of was unbounded. + #[error("The source interval is unbounded")] + SourceUnbounded, + + /// The target interval being mapped into was unbounded. + #[error("The target interval is unbounded")] + TargetUnbounded, +} + +impl Interval { + /// Create a new [`Interval`] with the specified `start` and `end`. The interval can be unbounded + /// but cannot be empty (so `start` must be less than `end`) and neither endpoint can be NaN; invalid + /// parameters will result in an error. + #[inline] + pub fn new(start: f32, end: f32) -> Result { + if start >= end || start.is_nan() || end.is_nan() { + Err(InvalidIntervalError) + } else { + Ok(Self { start, end }) + } + } + + /// An interval which stretches across the entire real line from negative infinity to infinity. + pub const EVERYWHERE: Self = Self { + start: f32::NEG_INFINITY, + end: f32::INFINITY, + }; + + /// Get the start of this interval. + #[inline] + pub fn start(self) -> f32 { + self.start + } + + /// Get the end of this interval. + #[inline] + pub fn end(self) -> f32 { + self.end + } + + /// Create an [`Interval`] by intersecting this interval with another. Returns an error if the + /// intersection would be empty (hence an invalid interval). + pub fn intersect(self, other: Interval) -> Result { + let lower = max_by(self.start, other.start, f32::total_cmp); + let upper = min_by(self.end, other.end, f32::total_cmp); + Self::new(lower, upper) + } + + /// Get the length of this interval. Note that the result may be infinite (`f32::INFINITY`). + #[inline] + pub fn length(self) -> f32 { + self.end - self.start + } + + /// Returns `true` if this interval is bounded — that is, if both its start and end are finite. + /// + /// Equivalently, an interval is bounded if its length is finite. + #[inline] + pub fn is_bounded(self) -> bool { + self.length().is_finite() + } + + /// Returns `true` if this interval has a finite start. + #[inline] + pub fn has_finite_start(self) -> bool { + self.start.is_finite() + } + + /// Returns `true` if this interval has a finite end. + #[inline] + pub fn has_finite_end(self) -> bool { + self.end.is_finite() + } + + /// Returns `true` if `item` is contained in this interval. + #[inline] + pub fn contains(self, item: f32) -> bool { + (self.start..=self.end).contains(&item) + } + + /// Returns `true` if the other interval is contained in this interval. + /// + /// This is non-strict: each interval will contain itself. + #[inline] + pub fn contains_interval(self, other: Self) -> bool { + self.start <= other.start && self.end >= other.end + } + + /// Clamp the given `value` to lie within this interval. + #[inline] + pub fn clamp(self, value: f32) -> f32 { + value.clamp(self.start, self.end) + } + + /// Get an iterator over equally-spaced points from this interval in increasing order. + /// If `points` is 1, the start of this interval is returned. If `points` is 0, an empty + /// iterator is returned. An error is returned if the interval is unbounded. + #[inline] + pub fn spaced_points( + self, + points: usize, + ) -> Result, SpacedPointsError> { + if !self.is_bounded() { + return Err(SpacedPointsError); + } + if points < 2 { + // If `points` is 1, this is `Some(self.start)` as an iterator, and if `points` is 0, + // then this is `None` as an iterator. This is written this way to avoid having to + // introduce a ternary disjunction of iterators. + let iter = (points == 1).then_some(self.start).into_iter(); + return Ok(Either::Left(iter)); + } + let step = self.length() / (points - 1) as f32; + let iter = (0..points).map(move |x| self.start + x as f32 * step); + Ok(Either::Right(iter)) + } + + /// Get the linear function which maps this interval onto the `other` one. Returns an error if either + /// interval is unbounded. + #[inline] + pub(super) fn linear_map_to(self, other: Self) -> Result f32, LinearMapError> { + if !self.is_bounded() { + return Err(LinearMapError::SourceUnbounded); + } + + if !other.is_bounded() { + return Err(LinearMapError::TargetUnbounded); + } + + let scale = other.length() / self.length(); + Ok(move |x| (x - self.start) * scale + other.start) + } +} + +impl TryFrom> for Interval { + type Error = InvalidIntervalError; + fn try_from(range: RangeInclusive) -> Result { + Interval::new(*range.start(), *range.end()) + } +} + +/// Create an [`Interval`] with a given `start` and `end`. Alias of [`Interval::new`]. +#[inline] +pub fn interval(start: f32, end: f32) -> Result { + Interval::new(start, end) +} + +#[cfg(test)] +mod tests { + use super::*; + use approx::{assert_abs_diff_eq, AbsDiffEq}; + + #[test] + fn make_intervals() { + let ivl = Interval::new(2.0, -1.0); + assert!(ivl.is_err()); + + let ivl = Interval::new(-0.0, 0.0); + assert!(ivl.is_err()); + + let ivl = Interval::new(f32::NEG_INFINITY, 15.5); + assert!(ivl.is_ok()); + + let ivl = Interval::new(-2.0, f32::INFINITY); + assert!(ivl.is_ok()); + + let ivl = Interval::new(f32::NEG_INFINITY, f32::INFINITY); + assert!(ivl.is_ok()); + + let ivl = Interval::new(f32::INFINITY, f32::NEG_INFINITY); + assert!(ivl.is_err()); + + let ivl = Interval::new(-1.0, f32::NAN); + assert!(ivl.is_err()); + + let ivl = Interval::new(f32::NAN, -42.0); + assert!(ivl.is_err()); + + let ivl = Interval::new(f32::NAN, f32::NAN); + assert!(ivl.is_err()); + + let ivl = Interval::new(0.0, 1.0); + assert!(ivl.is_ok()); + } + + #[test] + fn lengths() { + let ivl = interval(-5.0, 10.0).unwrap(); + assert!((ivl.length() - 15.0).abs() <= f32::EPSILON); + + let ivl = interval(5.0, 100.0).unwrap(); + assert!((ivl.length() - 95.0).abs() <= f32::EPSILON); + + let ivl = interval(0.0, f32::INFINITY).unwrap(); + assert_eq!(ivl.length(), f32::INFINITY); + + let ivl = interval(f32::NEG_INFINITY, 0.0).unwrap(); + assert_eq!(ivl.length(), f32::INFINITY); + + let ivl = Interval::EVERYWHERE; + assert_eq!(ivl.length(), f32::INFINITY); + } + + #[test] + fn intersections() { + let ivl1 = interval(-1.0, 1.0).unwrap(); + let ivl2 = interval(0.0, 2.0).unwrap(); + let ivl3 = interval(-3.0, 0.0).unwrap(); + let ivl4 = interval(0.0, f32::INFINITY).unwrap(); + let ivl5 = interval(f32::NEG_INFINITY, 0.0).unwrap(); + let ivl6 = Interval::EVERYWHERE; + + assert!(ivl1 + .intersect(ivl2) + .is_ok_and(|ivl| ivl == interval(0.0, 1.0).unwrap())); + assert!(ivl1 + .intersect(ivl3) + .is_ok_and(|ivl| ivl == interval(-1.0, 0.0).unwrap())); + assert!(ivl2.intersect(ivl3).is_err()); + assert!(ivl1 + .intersect(ivl4) + .is_ok_and(|ivl| ivl == interval(0.0, 1.0).unwrap())); + assert!(ivl1 + .intersect(ivl5) + .is_ok_and(|ivl| ivl == interval(-1.0, 0.0).unwrap())); + assert!(ivl4.intersect(ivl5).is_err()); + assert_eq!(ivl1.intersect(ivl6).unwrap(), ivl1); + assert_eq!(ivl4.intersect(ivl6).unwrap(), ivl4); + assert_eq!(ivl5.intersect(ivl6).unwrap(), ivl5); + } + + #[test] + fn containment() { + let ivl = interval(0.0, 1.0).unwrap(); + assert!(ivl.contains(0.0)); + assert!(ivl.contains(1.0)); + assert!(ivl.contains(0.5)); + assert!(!ivl.contains(-0.1)); + assert!(!ivl.contains(1.1)); + assert!(!ivl.contains(f32::NAN)); + + let ivl = interval(3.0, f32::INFINITY).unwrap(); + assert!(ivl.contains(3.0)); + assert!(ivl.contains(2.0e5)); + assert!(ivl.contains(3.5e6)); + assert!(!ivl.contains(2.5)); + assert!(!ivl.contains(-1e5)); + assert!(!ivl.contains(f32::NAN)); + } + + #[test] + fn interval_containment() { + let ivl = interval(0.0, 1.0).unwrap(); + assert!(ivl.contains_interval(interval(-0.0, 0.5).unwrap())); + assert!(ivl.contains_interval(interval(0.5, 1.0).unwrap())); + assert!(ivl.contains_interval(interval(0.25, 0.75).unwrap())); + assert!(!ivl.contains_interval(interval(-0.25, 0.5).unwrap())); + assert!(!ivl.contains_interval(interval(0.5, 1.25).unwrap())); + assert!(!ivl.contains_interval(interval(0.25, f32::INFINITY).unwrap())); + assert!(!ivl.contains_interval(interval(f32::NEG_INFINITY, 0.75).unwrap())); + + let big_ivl = interval(0.0, f32::INFINITY).unwrap(); + assert!(big_ivl.contains_interval(interval(0.0, 5.0).unwrap())); + assert!(big_ivl.contains_interval(interval(0.0, f32::INFINITY).unwrap())); + assert!(big_ivl.contains_interval(interval(1.0, 5.0).unwrap())); + assert!(!big_ivl.contains_interval(interval(-1.0, f32::INFINITY).unwrap())); + assert!(!big_ivl.contains_interval(interval(-2.0, 5.0).unwrap())); + } + + #[test] + fn boundedness() { + assert!(!Interval::EVERYWHERE.is_bounded()); + assert!(interval(0.0, 3.5e5).unwrap().is_bounded()); + assert!(!interval(-2.0, f32::INFINITY).unwrap().is_bounded()); + assert!(!interval(f32::NEG_INFINITY, 5.0).unwrap().is_bounded()); + } + + #[test] + fn linear_maps() { + let ivl1 = interval(-3.0, 5.0).unwrap(); + let ivl2 = interval(0.0, 1.0).unwrap(); + let map = ivl1.linear_map_to(ivl2); + assert!(map.is_ok_and(|f| f(-3.0).abs_diff_eq(&0.0, f32::EPSILON) + && f(5.0).abs_diff_eq(&1.0, f32::EPSILON) + && f(1.0).abs_diff_eq(&0.5, f32::EPSILON))); + + let ivl1 = interval(0.0, 1.0).unwrap(); + let ivl2 = Interval::EVERYWHERE; + assert!(ivl1.linear_map_to(ivl2).is_err()); + + let ivl1 = interval(f32::NEG_INFINITY, -4.0).unwrap(); + let ivl2 = interval(0.0, 1.0).unwrap(); + assert!(ivl1.linear_map_to(ivl2).is_err()); + } + + #[test] + fn spaced_points() { + let ivl = interval(0.0, 50.0).unwrap(); + let points_iter: Vec = ivl.spaced_points(1).unwrap().collect(); + assert_abs_diff_eq!(points_iter[0], 0.0); + assert_eq!(points_iter.len(), 1); + let points_iter: Vec = ivl.spaced_points(2).unwrap().collect(); + assert_abs_diff_eq!(points_iter[0], 0.0); + assert_abs_diff_eq!(points_iter[1], 50.0); + let points_iter = ivl.spaced_points(21).unwrap(); + let step = ivl.length() / 20.0; + for (index, point) in points_iter.enumerate() { + let expected = ivl.start() + step * index as f32; + assert_abs_diff_eq!(point, expected); + } + + let ivl = interval(-21.0, 79.0).unwrap(); + let points_iter = ivl.spaced_points(10000).unwrap(); + let step = ivl.length() / 9999.0; + for (index, point) in points_iter.enumerate() { + let expected = ivl.start() + step * index as f32; + assert_abs_diff_eq!(point, expected); + } + + let ivl = interval(-1.0, f32::INFINITY).unwrap(); + let points_iter = ivl.spaced_points(25); + assert!(points_iter.is_err()); + + let ivl = interval(f32::NEG_INFINITY, -25.0).unwrap(); + let points_iter = ivl.spaced_points(9); + assert!(points_iter.is_err()); + } +} diff --git a/crates/bevy_math/src/curve/mod.rs b/crates/bevy_math/src/curve/mod.rs new file mode 100644 index 0000000000..30d5fee9b3 --- /dev/null +++ b/crates/bevy_math/src/curve/mod.rs @@ -0,0 +1,685 @@ +//! The [`Curve`] trait, used to describe curves in a number of different domains. This module also +//! contains the [`Interval`] type, along with a selection of core data structures used to back +//! curves that are interpolated from samples. + +pub mod interval; + +pub use interval::{interval, Interval}; + +use interval::InvalidIntervalError; +use std::{marker::PhantomData, ops::Deref}; +use thiserror::Error; + +#[cfg(feature = "bevy_reflect")] +use bevy_reflect::Reflect; + +/// A trait for a type that can represent values of type `T` parametrized over a fixed interval. +/// Typical examples of this are actual geometric curves where `T: VectorSpace`, but other kinds +/// of output data can be represented as well. +pub trait Curve { + /// The interval over which this curve is parametrized. + /// + /// This is the range of values of `t` where we can sample the curve and receive valid output. + fn domain(&self) -> Interval; + + /// Sample a point on this curve at the parameter value `t`, extracting the associated value. + /// This is the unchecked version of sampling, which should only be used if the sample time `t` + /// is already known to lie within the curve's domain. + /// + /// Values sampled from outside of a curve's domain are generally considered invalid; data which + /// is nonsensical or otherwise useless may be returned in such a circumstance, and extrapolation + /// beyond a curve's domain should not be relied upon. + fn sample_unchecked(&self, t: f32) -> T; + + /// Sample a point on this curve at the parameter value `t`, returning `None` if the point is + /// outside of the curve's domain. + fn sample(&self, t: f32) -> Option { + match self.domain().contains(t) { + true => Some(self.sample_unchecked(t)), + false => None, + } + } + + /// Sample a point on this curve at the parameter value `t`, clamping `t` to lie inside the + /// domain of the curve. + fn sample_clamped(&self, t: f32) -> T { + let t = self.domain().clamp(t); + self.sample_unchecked(t) + } + + /// Create a new curve by mapping the values of this curve via a function `f`; i.e., if the + /// sample at time `t` for this curve is `x`, the value at time `t` on the new curve will be + /// `f(x)`. + fn map(self, f: F) -> MapCurve + where + Self: Sized, + F: Fn(T) -> S, + { + MapCurve { + preimage: self, + f, + _phantom: PhantomData, + } + } + + /// Create a new [`Curve`] whose parameter space is related to the parameter space of this curve + /// by `f`. For each time `t`, the sample from the new curve at time `t` is the sample from + /// this curve at time `f(t)`. The given `domain` will be the domain of the new curve. The + /// function `f` is expected to take `domain` into `self.domain()`. + /// + /// Note that this is the opposite of what one might expect intuitively; for example, if this + /// curve has a parameter domain of `[0, 1]`, then stretching the parameter domain to + /// `[0, 2]` would be performed as follows, dividing by what might be perceived as the scaling + /// factor rather than multiplying: + /// ``` + /// # use bevy_math::curve::*; + /// let my_curve = constant_curve(interval(0.0, 1.0).unwrap(), 1.0); + /// let scaled_curve = my_curve.reparametrize(interval(0.0, 2.0).unwrap(), |t| t / 2.0); + /// ``` + /// This kind of linear remapping is provided by the convenience method + /// [`Curve::reparametrize_linear`], which requires only the desired domain for the new curve. + /// + /// # Examples + /// ``` + /// // Reverse a curve: + /// # use bevy_math::curve::*; + /// # use bevy_math::vec2; + /// let my_curve = constant_curve(interval(0.0, 1.0).unwrap(), 1.0); + /// let domain = my_curve.domain(); + /// let reversed_curve = my_curve.reparametrize(domain, |t| domain.end() - t); + /// + /// // Take a segment of a curve: + /// # let my_curve = constant_curve(interval(0.0, 1.0).unwrap(), 1.0); + /// let curve_segment = my_curve.reparametrize(interval(0.0, 0.5).unwrap(), |t| 0.5 + t); + /// + /// // Reparametrize by an easing curve: + /// # let my_curve = constant_curve(interval(0.0, 1.0).unwrap(), 1.0); + /// # let easing_curve = constant_curve(interval(0.0, 1.0).unwrap(), vec2(1.0, 1.0)); + /// let domain = my_curve.domain(); + /// let eased_curve = my_curve.reparametrize(domain, |t| easing_curve.sample_unchecked(t).y); + /// ``` + fn reparametrize(self, domain: Interval, f: F) -> ReparamCurve + where + Self: Sized, + F: Fn(f32) -> f32, + { + ReparamCurve { + domain, + base: self, + f, + _phantom: PhantomData, + } + } + + /// Linearly reparametrize this [`Curve`], producing a new curve whose domain is the given + /// `domain` instead of the current one. This operation is only valid for curves with bounded + /// domains; if either this curve's domain or the given `domain` is unbounded, an error is + /// returned. + fn reparametrize_linear( + self, + domain: Interval, + ) -> Result, LinearReparamError> + where + Self: Sized, + { + if !self.domain().is_bounded() { + return Err(LinearReparamError::SourceCurveUnbounded); + } + + if !domain.is_bounded() { + return Err(LinearReparamError::TargetIntervalUnbounded); + } + + Ok(LinearReparamCurve { + base: self, + new_domain: domain, + _phantom: PhantomData, + }) + } + + /// Reparametrize this [`Curve`] by sampling from another curve. + /// + /// The resulting curve samples at time `t` by first sampling `other` at time `t`, which produces + /// another sample time `s` which is then used to sample this curve. The domain of the resulting + /// curve is the domain of `other`. + fn reparametrize_by_curve(self, other: C) -> CurveReparamCurve + where + Self: Sized, + C: Curve, + { + CurveReparamCurve { + base: self, + reparam_curve: other, + _phantom: PhantomData, + } + } + + /// Create a new [`Curve`] which is the graph of this one; that is, its output echoes the sample + /// time as part of a tuple. + /// + /// For example, if this curve outputs `x` at time `t`, then the produced curve will produce + /// `(t, x)` at time `t`. In particular, if this curve is a `Curve`, the output of this method + /// is a `Curve<(f32, T)>`. + fn graph(self) -> GraphCurve + where + Self: Sized, + { + GraphCurve { + base: self, + _phantom: PhantomData, + } + } + + /// Create a new [`Curve`] by zipping this curve together with another. + /// + /// The sample at time `t` in the new curve is `(x, y)`, where `x` is the sample of `self` at + /// time `t` and `y` is the sample of `other` at time `t`. The domain of the new curve is the + /// intersection of the domains of its constituents. If the domain intersection would be empty, + /// an error is returned. + fn zip(self, other: C) -> Result, InvalidIntervalError> + where + Self: Sized, + C: Curve + Sized, + { + let domain = self.domain().intersect(other.domain())?; + Ok(ProductCurve { + domain, + first: self, + second: other, + _phantom: PhantomData, + }) + } + + /// Create a new [`Curve`] by composing this curve end-to-end with another, producing another curve + /// with outputs of the same type. The domain of the other curve is translated so that its start + /// coincides with where this curve ends. A [`ChainError`] is returned if this curve's domain + /// doesn't have a finite end or if `other`'s domain doesn't have a finite start. + fn chain(self, other: C) -> Result, ChainError> + where + Self: Sized, + C: Curve, + { + if !self.domain().has_finite_end() { + return Err(ChainError::FirstEndInfinite); + } + if !other.domain().has_finite_start() { + return Err(ChainError::SecondStartInfinite); + } + Ok(ChainCurve { + first: self, + second: other, + _phantom: PhantomData, + }) + } + + /// Borrow this curve rather than taking ownership of it. This is essentially an alias for a + /// prefix `&`; the point is that intermediate operations can be performed while retaining + /// access to the original curve. + /// + /// # Example + /// ```ignore + /// # use bevy_math::curve::*; + /// let my_curve = function_curve(interval(0.0, 1.0).unwrap(), |t| t * t + 1.0); + /// // Borrow `my_curve` long enough to resample a mapped version. Note that `map` takes + /// // ownership of its input. + /// let samples = my_curve.by_ref().map(|x| x * 2.0).resample_auto(100).unwrap(); + /// // Do something else with `my_curve` since we retained ownership: + /// let new_curve = my_curve.reparametrize_linear(interval(-1.0, 1.0).unwrap()).unwrap(); + /// ``` + fn by_ref(&self) -> &Self + where + Self: Sized, + { + self + } + + /// Flip this curve so that its tuple output is arranged the other way. + fn flip(self) -> impl Curve<(V, U)> + where + Self: Sized + Curve<(U, V)>, + { + self.map(|(u, v)| (v, u)) + } +} + +impl Curve for D +where + C: Curve + ?Sized, + D: Deref, +{ + fn domain(&self) -> Interval { + >::domain(self) + } + + fn sample_unchecked(&self, t: f32) -> T { + >::sample_unchecked(self, t) + } +} + +/// An error indicating that a linear reparametrization couldn't be performed because of +/// malformed inputs. +#[derive(Debug, Error)] +#[error("Could not build a linear function to reparametrize this curve")] +pub enum LinearReparamError { + /// The source curve that was to be reparametrized had unbounded domain. + #[error("This curve has unbounded domain")] + SourceCurveUnbounded, + + /// The target interval for reparametrization was unbounded. + #[error("The target interval for reparametrization is unbounded")] + TargetIntervalUnbounded, +} + +/// An error indicating that an end-to-end composition couldn't be performed because of +/// malformed inputs. +#[derive(Debug, Error)] +#[error("Could not compose these curves together")] +pub enum ChainError { + /// The right endpoint of the first curve was infinite. + #[error("The first curve's domain has an infinite end")] + FirstEndInfinite, + + /// The left endpoint of the second curve was infinite. + #[error("The second curve's domain has an infinite start")] + SecondStartInfinite, +} + +/// A curve with a constant value over its domain. +/// +/// This is a curve that holds an inner value and always produces a clone of that value when sampled. +#[derive(Clone, Copy, Debug)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "bevy_reflect", derive(Reflect))] +pub struct ConstantCurve { + domain: Interval, + value: T, +} + +impl ConstantCurve +where + T: Clone, +{ + /// Create a constant curve, which has the given `domain` and always produces the given `value` + /// when sampled. + pub fn new(domain: Interval, value: T) -> Self { + Self { domain, value } + } +} + +impl Curve for ConstantCurve +where + T: Clone, +{ + #[inline] + fn domain(&self) -> Interval { + self.domain + } + + #[inline] + fn sample_unchecked(&self, _t: f32) -> T { + self.value.clone() + } +} + +/// A curve defined by a function together with a fixed domain. +/// +/// This is a curve that holds an inner function `f` which takes numbers (`f32`) as input and produces +/// output of type `T`. The value of this curve when sampled at time `t` is just `f(t)`. +#[derive(Clone, Debug)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "bevy_reflect", derive(Reflect))] +pub struct FunctionCurve { + domain: Interval, + f: F, + _phantom: PhantomData, +} + +impl FunctionCurve +where + F: Fn(f32) -> T, +{ + /// Create a new curve with the given `domain` from the given `function`. When sampled, the + /// `function` is evaluated at the sample time to compute the output. + pub fn new(domain: Interval, function: F) -> Self { + FunctionCurve { + domain, + f: function, + _phantom: PhantomData, + } + } +} + +impl Curve for FunctionCurve +where + F: Fn(f32) -> T, +{ + #[inline] + fn domain(&self) -> Interval { + self.domain + } + + #[inline] + fn sample_unchecked(&self, t: f32) -> T { + (self.f)(t) + } +} + +/// A curve whose samples are defined by mapping samples from another curve through a +/// given function. Curves of this type are produced by [`Curve::map`]. +#[derive(Clone, Debug)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "bevy_reflect", derive(Reflect))] +pub struct MapCurve { + preimage: C, + f: F, + _phantom: PhantomData<(S, T)>, +} + +impl Curve for MapCurve +where + C: Curve, + F: Fn(S) -> T, +{ + #[inline] + fn domain(&self) -> Interval { + self.preimage.domain() + } + + #[inline] + fn sample_unchecked(&self, t: f32) -> T { + (self.f)(self.preimage.sample_unchecked(t)) + } +} + +/// A curve whose sample space is mapped onto that of some base curve's before sampling. +/// Curves of this type are produced by [`Curve::reparametrize`]. +#[derive(Clone, Debug)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "bevy_reflect", derive(Reflect))] +pub struct ReparamCurve { + domain: Interval, + base: C, + f: F, + _phantom: PhantomData, +} + +impl Curve for ReparamCurve +where + C: Curve, + F: Fn(f32) -> f32, +{ + #[inline] + fn domain(&self) -> Interval { + self.domain + } + + #[inline] + fn sample_unchecked(&self, t: f32) -> T { + self.base.sample_unchecked((self.f)(t)) + } +} + +/// A curve that has had its domain changed by a linear reparametrization (stretching and scaling). +/// Curves of this type are produced by [`Curve::reparametrize_linear`]. +#[derive(Clone, Debug)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "bevy_reflect", derive(Reflect))] +pub struct LinearReparamCurve { + /// Invariants: The domain of this curve must always be bounded. + base: C, + /// Invariants: This interval must always be bounded. + new_domain: Interval, + _phantom: PhantomData, +} + +impl Curve for LinearReparamCurve +where + C: Curve, +{ + #[inline] + fn domain(&self) -> Interval { + self.new_domain + } + + #[inline] + fn sample_unchecked(&self, t: f32) -> T { + // The invariants imply this unwrap always succeeds. + let f = self.new_domain.linear_map_to(self.base.domain()).unwrap(); + self.base.sample_unchecked(f(t)) + } +} + +/// A curve that has been reparametrized by another curve, using that curve to transform the +/// sample times before sampling. Curves of this type are produced by [`Curve::reparametrize_by_curve`]. +#[derive(Clone, Debug)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "bevy_reflect", derive(Reflect))] +pub struct CurveReparamCurve { + base: C, + reparam_curve: D, + _phantom: PhantomData, +} + +impl Curve for CurveReparamCurve +where + C: Curve, + D: Curve, +{ + #[inline] + fn domain(&self) -> Interval { + self.reparam_curve.domain() + } + + #[inline] + fn sample_unchecked(&self, t: f32) -> T { + let sample_time = self.reparam_curve.sample_unchecked(t); + self.base.sample_unchecked(sample_time) + } +} + +/// A curve that is the graph of another curve over its parameter space. Curves of this type are +/// produced by [`Curve::graph`]. +#[derive(Clone, Debug)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "bevy_reflect", derive(Reflect))] +pub struct GraphCurve { + base: C, + _phantom: PhantomData, +} + +impl Curve<(f32, T)> for GraphCurve +where + C: Curve, +{ + #[inline] + fn domain(&self) -> Interval { + self.base.domain() + } + + #[inline] + fn sample_unchecked(&self, t: f32) -> (f32, T) { + (t, self.base.sample_unchecked(t)) + } +} + +/// A curve that combines the output data from two constituent curves into a tuple output. Curves +/// of this type are produced by [`Curve::zip`]. +#[derive(Clone, Debug)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "bevy_reflect", derive(Reflect))] +pub struct ProductCurve { + domain: Interval, + first: C, + second: D, + _phantom: PhantomData<(S, T)>, +} + +impl Curve<(S, T)> for ProductCurve +where + C: Curve, + D: Curve, +{ + #[inline] + fn domain(&self) -> Interval { + self.domain + } + + #[inline] + fn sample_unchecked(&self, t: f32) -> (S, T) { + ( + self.first.sample_unchecked(t), + self.second.sample_unchecked(t), + ) + } +} + +/// The curve that results from chaining one curve with another. The second curve is +/// effectively reparametrized so that its start is at the end of the first. +/// +/// For this to be well-formed, the first curve's domain must be right-finite and the second's +/// must be left-finite. +/// +/// Curves of this type are produced by [`Curve::chain`]. +pub struct ChainCurve { + first: C, + second: D, + _phantom: PhantomData, +} + +impl Curve for ChainCurve +where + C: Curve, + D: Curve, +{ + #[inline] + fn domain(&self) -> Interval { + // This unwrap always succeeds because `first` has a valid Interval as its domain and the + // length of `second` cannot be NAN. It's still fine if it's infinity. + Interval::new( + self.first.domain().start(), + self.first.domain().end() + self.second.domain().length(), + ) + .unwrap() + } + + #[inline] + fn sample_unchecked(&self, t: f32) -> T { + if t > self.first.domain().end() { + self.second.sample_unchecked( + // `t - first.domain.end` computes the offset into the domain of the second. + t - self.first.domain().end() + self.second.domain().start(), + ) + } else { + self.first.sample_unchecked(t) + } + } +} + +/// Create a [`Curve`] that constantly takes the given `value` over the given `domain`. +pub fn constant_curve(domain: Interval, value: T) -> ConstantCurve { + ConstantCurve { domain, value } +} + +/// Convert the given function `f` into a [`Curve`] with the given `domain`, sampled by +/// evaluating the function. +pub fn function_curve(domain: Interval, f: F) -> FunctionCurve +where + F: Fn(f32) -> T, +{ + FunctionCurve { + domain, + f, + _phantom: PhantomData, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::Quat; + use approx::{assert_abs_diff_eq, AbsDiffEq}; + use std::f32::consts::TAU; + + #[test] + fn constant_curves() { + let curve = constant_curve(Interval::EVERYWHERE, 5.0); + assert!(curve.sample_unchecked(-35.0) == 5.0); + + let curve = constant_curve(interval(0.0, 1.0).unwrap(), true); + assert!(curve.sample_unchecked(2.0)); + assert!(curve.sample(2.0).is_none()); + } + + #[test] + fn function_curves() { + let curve = function_curve(Interval::EVERYWHERE, |t| t * t); + assert!(curve.sample_unchecked(2.0).abs_diff_eq(&4.0, f32::EPSILON)); + assert!(curve.sample_unchecked(-3.0).abs_diff_eq(&9.0, f32::EPSILON)); + + let curve = function_curve(interval(0.0, f32::INFINITY).unwrap(), f32::log2); + assert_eq!(curve.sample_unchecked(3.5), f32::log2(3.5)); + assert!(curve.sample_unchecked(-1.0).is_nan()); + assert!(curve.sample(-1.0).is_none()); + } + + #[test] + fn mapping() { + let curve = function_curve(Interval::EVERYWHERE, |t| t * 3.0 + 1.0); + let mapped_curve = curve.map(|x| x / 7.0); + assert_eq!(mapped_curve.sample_unchecked(3.5), (3.5 * 3.0 + 1.0) / 7.0); + assert_eq!( + mapped_curve.sample_unchecked(-1.0), + (-1.0 * 3.0 + 1.0) / 7.0 + ); + assert_eq!(mapped_curve.domain(), Interval::EVERYWHERE); + + let curve = function_curve(interval(0.0, 1.0).unwrap(), |t| t * TAU); + let mapped_curve = curve.map(Quat::from_rotation_z); + assert_eq!(mapped_curve.sample_unchecked(0.0), Quat::IDENTITY); + assert!(mapped_curve.sample_unchecked(1.0).is_near_identity()); + assert_eq!(mapped_curve.domain(), interval(0.0, 1.0).unwrap()); + } + + #[test] + fn reparametrization() { + let curve = function_curve(interval(1.0, f32::INFINITY).unwrap(), f32::log2); + let reparametrized_curve = curve + .by_ref() + .reparametrize(interval(0.0, f32::INFINITY).unwrap(), f32::exp2); + assert_abs_diff_eq!(reparametrized_curve.sample_unchecked(3.5), 3.5); + assert_abs_diff_eq!(reparametrized_curve.sample_unchecked(100.0), 100.0); + assert_eq!( + reparametrized_curve.domain(), + interval(0.0, f32::INFINITY).unwrap() + ); + + let reparametrized_curve = curve + .by_ref() + .reparametrize(interval(0.0, 1.0).unwrap(), |t| t + 1.0); + assert_abs_diff_eq!(reparametrized_curve.sample_unchecked(0.0), 0.0); + assert_abs_diff_eq!(reparametrized_curve.sample_unchecked(1.0), 1.0); + assert_eq!(reparametrized_curve.domain(), interval(0.0, 1.0).unwrap()); + } + + #[test] + fn multiple_maps() { + // Make sure these actually happen in the right order. + let curve = function_curve(interval(0.0, 1.0).unwrap(), f32::exp2); + let first_mapped = curve.map(f32::log2); + let second_mapped = first_mapped.map(|x| x * -2.0); + assert_abs_diff_eq!(second_mapped.sample_unchecked(0.0), 0.0); + assert_abs_diff_eq!(second_mapped.sample_unchecked(0.5), -1.0); + assert_abs_diff_eq!(second_mapped.sample_unchecked(1.0), -2.0); + } + + #[test] + fn multiple_reparams() { + // Make sure these happen in the right order too. + let curve = function_curve(interval(0.0, 1.0).unwrap(), f32::exp2); + let first_reparam = curve.reparametrize(interval(1.0, 2.0).unwrap(), f32::log2); + let second_reparam = first_reparam.reparametrize(interval(0.0, 1.0).unwrap(), |t| t + 1.0); + assert_abs_diff_eq!(second_reparam.sample_unchecked(0.0), 1.0); + assert_abs_diff_eq!(second_reparam.sample_unchecked(0.5), 1.5); + assert_abs_diff_eq!(second_reparam.sample_unchecked(1.0), 2.0); + } +} diff --git a/crates/bevy_math/src/lib.rs b/crates/bevy_math/src/lib.rs index 03726f5693..f34e452fce 100644 --- a/crates/bevy_math/src/lib.rs +++ b/crates/bevy_math/src/lib.rs @@ -17,6 +17,7 @@ pub mod bounding; pub mod common_traits; mod compass; pub mod cubic_splines; +pub mod curve; mod direction; mod float_ord; mod isometry;