pointer.rs 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283
  1. use super::{error::HamtError, hash::Hasher, Node, HAMT_VALUES_BUCKET_SIZE};
  2. use crate::serializable::PointerSerializable;
  3. use anyhow::Result;
  4. use libipld::Cid;
  5. use serde::{de::DeserializeOwned, Serialize};
  6. use std::fmt::Debug;
  7. use wnfs_common::{
  8. utils::{error, Arc, CondSync},
  9. BlockStore, Link, Storable,
  10. };
  11. //--------------------------------------------------------------------------------------------------
  12. // Type Definitions
  13. //--------------------------------------------------------------------------------------------------
  14. /// A key-value pair type.
  15. ///
  16. /// # Examples
  17. ///
  18. /// ```
  19. /// use wnfs_hamt::Pair;
  20. ///
  21. /// let pair = Pair::new("key", "value");
  22. ///
  23. /// assert_eq!(pair.key, "key");
  24. /// assert_eq!(pair.value, "value");
  25. /// ```
  26. #[derive(Debug, Clone, PartialEq, Eq)]
  27. pub struct Pair<K, V> {
  28. pub key: K,
  29. pub value: V,
  30. }
  31. /// Each bit in the bitmask of a node maps a `Pointer` in the HAMT structure.
  32. /// A `Pointer` can be either a link to a child node or a collection of key-value pairs.
  33. pub(crate) enum Pointer<K: CondSync, V: CondSync, H: Hasher + CondSync> {
  34. Values(Vec<Pair<K, V>>),
  35. Link(Link<Arc<Node<K, V, H>>>),
  36. }
  37. //--------------------------------------------------------------------------------------------------
  38. // Implementations
  39. //--------------------------------------------------------------------------------------------------
  40. impl<K, V> Pair<K, V> {
  41. /// Create a new `Pair` from a key and value.
  42. ///
  43. /// # Examples
  44. ///
  45. /// ```
  46. /// use wnfs_hamt::Pair;
  47. ///
  48. /// let pair = Pair::new("key", "value");
  49. ///
  50. /// assert_eq!(pair.key, "key");
  51. /// assert_eq!(pair.value, "value");
  52. /// ```
  53. pub fn new(key: K, value: V) -> Self {
  54. Self { key, value }
  55. }
  56. }
  57. impl<K: CondSync, V: CondSync, H: Hasher + CondSync> Pointer<K, V, H> {
  58. /// Converts a Link pointer to a canonical form to ensure consistent tree representation after deletes.
  59. pub async fn canonicalize(self, store: &impl BlockStore) -> Result<Option<Self>>
  60. where
  61. K: Storable + Clone + AsRef<[u8]>,
  62. V: Storable + Clone,
  63. K::Serializable: Serialize + DeserializeOwned,
  64. V::Serializable: Serialize + DeserializeOwned,
  65. H: CondSync,
  66. {
  67. match self {
  68. Pointer::Link(link) => {
  69. let node = link.resolve_owned_value(store).await?;
  70. match node.pointers.len() {
  71. 0 => Ok(None),
  72. 1 if matches!(node.pointers[0], Pointer::Values(_)) => {
  73. Ok(Some(node.pointers[0].clone()))
  74. }
  75. 2..=HAMT_VALUES_BUCKET_SIZE if node.count_values().is_ok() => {
  76. // Collect all the values of the node.
  77. let mut values = node
  78. .pointers
  79. .iter()
  80. .filter_map(|p| match p {
  81. Pointer::Values(values) => Some(values.clone()),
  82. _ => None,
  83. })
  84. .flatten()
  85. .collect::<Vec<_>>();
  86. // Bail if it's more values that we can fit into a bucket
  87. if values.len() > HAMT_VALUES_BUCKET_SIZE {
  88. return Ok(Some(Pointer::Link(Link::from(node))));
  89. }
  90. values.sort_unstable_by(|a, b| {
  91. H::hash(&a.key).partial_cmp(&H::hash(&b.key)).unwrap()
  92. });
  93. Ok(Some(Pointer::Values(values)))
  94. }
  95. _ => Ok(Some(Pointer::Link(Link::from(node)))),
  96. }
  97. }
  98. _ => error(HamtError::NonCanonicalizablePointer),
  99. }
  100. }
  101. }
  102. impl<K, V, H> Storable for Pointer<K, V, H>
  103. where
  104. K: Storable + CondSync,
  105. V: Storable + CondSync,
  106. K::Serializable: Serialize + DeserializeOwned,
  107. V::Serializable: Serialize + DeserializeOwned,
  108. H: Hasher + CondSync,
  109. {
  110. type Serializable = PointerSerializable<K::Serializable, V::Serializable>;
  111. async fn to_serializable(&self, store: &impl BlockStore) -> Result<Self::Serializable> {
  112. Ok(match self {
  113. Pointer::Values(values) => {
  114. let mut serializables = Vec::with_capacity(values.len());
  115. for pair in values.iter() {
  116. serializables.push(pair.to_serializable(store).await?);
  117. }
  118. PointerSerializable::Values(serializables)
  119. }
  120. Pointer::Link(link) => {
  121. let cid = link.resolve_cid(store).await?;
  122. PointerSerializable::Link(cid)
  123. }
  124. })
  125. }
  126. async fn from_serializable(
  127. _cid: Option<&Cid>,
  128. serializable: Self::Serializable,
  129. ) -> Result<Self> {
  130. Ok(match serializable {
  131. PointerSerializable::Values(serializables) => {
  132. let mut values = Vec::with_capacity(serializables.len());
  133. for serializable in serializables {
  134. values.push(Pair::from_serializable(None, serializable).await?);
  135. }
  136. Self::Values(values)
  137. }
  138. PointerSerializable::Link(cid) => Self::Link(Link::from_cid(cid)),
  139. })
  140. }
  141. }
  142. impl<K, V> Storable for Pair<K, V>
  143. where
  144. K: Storable + CondSync,
  145. V: Storable + CondSync,
  146. K::Serializable: Serialize + DeserializeOwned,
  147. V::Serializable: Serialize + DeserializeOwned,
  148. {
  149. type Serializable = (K::Serializable, V::Serializable);
  150. async fn to_serializable(&self, store: &impl BlockStore) -> Result<Self::Serializable> {
  151. let key = self.key.to_serializable(store).await?;
  152. let value = self.value.to_serializable(store).await?;
  153. Ok((key, value))
  154. }
  155. async fn from_serializable(
  156. _cid: Option<&Cid>,
  157. (key, value): Self::Serializable,
  158. ) -> Result<Self> {
  159. let key = K::from_serializable(None, key).await?;
  160. let value = V::from_serializable(None, value).await?;
  161. Ok(Pair { key, value })
  162. }
  163. }
  164. impl<K: Clone + CondSync, V: Clone + CondSync, H: Hasher + CondSync> Clone for Pointer<K, V, H> {
  165. fn clone(&self) -> Self {
  166. match self {
  167. Self::Values(arg0) => Self::Values(arg0.clone()),
  168. Self::Link(arg0) => Self::Link(arg0.clone()),
  169. }
  170. }
  171. }
  172. impl<K: Debug + CondSync, V: Debug + CondSync, H: Hasher + CondSync> std::fmt::Debug
  173. for Pointer<K, V, H>
  174. {
  175. fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  176. match self {
  177. Self::Values(arg0) => f.debug_tuple("Values").field(arg0).finish(),
  178. Self::Link(arg0) => f.debug_tuple("Link").field(arg0).finish(),
  179. }
  180. }
  181. }
  182. impl<K: CondSync, V: CondSync, H: Hasher + CondSync> Default for Pointer<K, V, H> {
  183. fn default() -> Self {
  184. Pointer::Values(Vec::new())
  185. }
  186. }
  187. impl<K, V, H: Hasher + CondSync> PartialEq for Pointer<K, V, H>
  188. where
  189. K: Storable + PartialEq + CondSync,
  190. V: Storable + PartialEq + CondSync,
  191. K::Serializable: Serialize + DeserializeOwned,
  192. V::Serializable: Serialize + DeserializeOwned,
  193. {
  194. fn eq(&self, other: &Self) -> bool {
  195. match (self, other) {
  196. (Pointer::Values(vals), Pointer::Values(other_vals)) => vals == other_vals,
  197. (Pointer::Link(link), Pointer::Link(other_link)) => link == other_link,
  198. _ => false,
  199. }
  200. }
  201. }
  202. //--------------------------------------------------------------------------------------------------
  203. // Tests
  204. //--------------------------------------------------------------------------------------------------
  205. #[cfg(test)]
  206. mod tests {
  207. use super::*;
  208. use testresult::TestResult;
  209. use wnfs_common::MemoryBlockStore;
  210. #[async_std::test]
  211. async fn pointer_can_encode_decode_as_cbor() -> TestResult {
  212. let store = &MemoryBlockStore::default();
  213. let pointer: Pointer<String, i32, blake3::Hasher> = Pointer::Values(vec![
  214. Pair {
  215. key: "James".into(),
  216. value: 4500,
  217. },
  218. Pair {
  219. key: "Peter".into(),
  220. value: 2000,
  221. },
  222. ]);
  223. let pointer_cid = pointer.store(store).await?;
  224. let decoded_pointer =
  225. Pointer::<String, i32, blake3::Hasher>::load(&pointer_cid, store).await?;
  226. assert_eq!(pointer, decoded_pointer);
  227. Ok(())
  228. }
  229. }
  230. #[cfg(test)]
  231. mod snapshot_tests {
  232. use super::*;
  233. use testresult::TestResult;
  234. use wnfs_common::utils::SnapshotBlockStore;
  235. #[async_std::test]
  236. async fn test_pointer() -> TestResult {
  237. let store = &SnapshotBlockStore::default();
  238. let pointer: Pointer<String, i32, blake3::Hasher> = Pointer::Values(vec![
  239. Pair {
  240. key: "James".into(),
  241. value: 4500,
  242. },
  243. Pair {
  244. key: "Peter".into(),
  245. value: 2000,
  246. },
  247. ]);
  248. let cid = pointer.store(store).await?;
  249. let ptr = store.get_block_snapshot(&cid).await?;
  250. insta::assert_json_snapshot!(ptr);
  251. Ok(())
  252. }
  253. }