@@ -13,13 +13,15 @@ use crate::clarity::types::{
1313use crate :: clarity:: { ClarityName , SymbolicExpressionType } ;
1414use crate :: repl:: settings:: InitialLink ;
1515use std:: collections:: { BTreeMap , HashMap , HashSet } ;
16+ use std:: hash:: { Hash , Hasher } ;
1617use std:: iter:: FromIterator ;
18+ use std:: ops:: { Deref , DerefMut } ;
1719use std:: process;
1820
1921use super :: ast_visitor:: TypedVar ;
2022
2123pub struct ASTDependencyDetector < ' a > {
22- dependencies : HashMap < QualifiedContractIdentifier , HashSet < QualifiedContractIdentifier > > ,
24+ dependencies : HashMap < QualifiedContractIdentifier , DependencySet > ,
2325 current_contract : Option < & ' a QualifiedContractIdentifier > ,
2426 defined_functions :
2527 HashMap < ( & ' a QualifiedContractIdentifier , & ' a ClarityName ) , Vec < TypeSignature > > ,
@@ -46,18 +48,101 @@ pub struct ASTDependencyDetector<'a> {
4648 ) > ,
4749 > ,
4850 params : Option < Vec < TypedVar < ' a > > > ,
51+ top_level : bool ,
4952 preloaded : & ' a BTreeMap < QualifiedContractIdentifier , ContractAST > ,
5053}
5154
55+ #[ derive( Clone , Debug , Eq ) ]
56+ pub struct Dependency {
57+ pub contract_id : QualifiedContractIdentifier ,
58+ pub required_before_publish : bool ,
59+ }
60+
61+ impl PartialEq for Dependency {
62+ fn eq ( & self , other : & Self ) -> bool {
63+ self . contract_id == other. contract_id
64+ }
65+ }
66+
67+ impl Hash for Dependency {
68+ fn hash < H : Hasher > ( & self , state : & mut H ) {
69+ self . contract_id . hash ( state)
70+ }
71+ }
72+
73+ impl PartialOrd for Dependency {
74+ fn partial_cmp ( & self , other : & Self ) -> Option < std:: cmp:: Ordering > {
75+ self . contract_id . partial_cmp ( & other. contract_id )
76+ }
77+ }
78+
79+ #[ derive( Debug ) ]
80+ pub struct DependencySet {
81+ set : HashSet < Dependency > ,
82+ }
83+
84+ impl DependencySet {
85+ pub fn new ( ) -> DependencySet {
86+ DependencySet {
87+ set : HashSet :: new ( ) ,
88+ }
89+ }
90+
91+ pub fn add_dependency (
92+ & mut self ,
93+ contract_id : QualifiedContractIdentifier ,
94+ required_before_publish : bool ,
95+ ) {
96+ let dep = Dependency {
97+ contract_id,
98+ required_before_publish,
99+ } ;
100+
101+ // If this dependency is required before publish, then make sure to
102+ // delete any existing dependency so that this overrides it.
103+ if required_before_publish {
104+ self . set . remove ( & dep) ;
105+ }
106+
107+ self . set . insert ( dep) ;
108+ }
109+
110+ pub fn has_dependency ( & self , contract_id : & QualifiedContractIdentifier ) -> Option < bool > {
111+ if let Some ( dep) = self . set . get ( & Dependency {
112+ contract_id : contract_id. clone ( ) ,
113+ required_before_publish : false ,
114+ } ) {
115+ println ! ( "FOUND DEP: {}" , dep. required_before_publish) ;
116+ Some ( dep. required_before_publish )
117+ } else {
118+ None
119+ }
120+ }
121+ }
122+
123+ impl Deref for DependencySet {
124+ type Target = HashSet < Dependency > ;
125+
126+ fn deref ( & self ) -> & Self :: Target {
127+ & self . set
128+ }
129+ }
130+
131+ impl DerefMut for DependencySet {
132+ fn deref_mut ( & mut self ) -> & mut Self :: Target {
133+ & mut self . set
134+ }
135+ }
136+
52137impl < ' a > ASTDependencyDetector < ' a > {
53138 pub fn detect_dependencies (
54139 contract_asts : & ' a HashMap < QualifiedContractIdentifier , ContractAST > ,
55140 preloaded : & ' a BTreeMap < QualifiedContractIdentifier , ContractAST > ,
56141 ) -> Result <
57- HashMap < QualifiedContractIdentifier , HashSet < QualifiedContractIdentifier > > ,
142+ HashMap < QualifiedContractIdentifier , DependencySet > ,
58143 (
59144 // Dependencies detected
60- HashMap < QualifiedContractIdentifier , HashSet < QualifiedContractIdentifier > > ,
145+ HashMap < QualifiedContractIdentifier , DependencySet > ,
61146 // Unresolved dependencies detected
62147 Vec < QualifiedContractIdentifier > ,
63148 ) ,
@@ -70,6 +155,7 @@ impl<'a> ASTDependencyDetector<'a> {
70155 pending_function_checks : HashMap :: new ( ) ,
71156 pending_trait_checks : HashMap :: new ( ) ,
72157 params : None ,
158+ top_level : true ,
73159 preloaded,
74160 } ;
75161
@@ -86,7 +172,7 @@ impl<'a> ASTDependencyDetector<'a> {
86172 for ( contract_identifier, ast) in contract_asts {
87173 detector
88174 . dependencies
89- . insert ( contract_identifier. clone ( ) , HashSet :: new ( ) ) ;
175+ . insert ( contract_identifier. clone ( ) , DependencySet :: new ( ) ) ;
90176 detector. current_contract = Some ( contract_identifier) ;
91177 traverse ( & mut detector, & ast. expressions ) ;
92178 }
@@ -112,7 +198,7 @@ impl<'a> ASTDependencyDetector<'a> {
112198 }
113199
114200 pub fn order_contracts (
115- dependencies : & HashMap < QualifiedContractIdentifier , HashSet < QualifiedContractIdentifier > > ,
201+ dependencies : & HashMap < QualifiedContractIdentifier , DependencySet > ,
116202 ) -> CheckResult < Vec < & QualifiedContractIdentifier > > {
117203 let mut lookup = BTreeMap :: new ( ) ;
118204 let mut reverse_lookup = Vec :: new ( ) ;
@@ -134,10 +220,10 @@ impl<'a> ASTDependencyDetector<'a> {
134220 let contract_id = lookup. get ( contract) . unwrap ( ) ;
135221 graph. add_node ( * contract_id) ;
136222 for dep in contract_dependencies. iter ( ) {
137- let dep_id = match lookup. get ( dep) {
223+ let dep_id = match lookup. get ( & dep. contract_id ) {
138224 Some ( id) => id,
139225 None => {
140- return Err ( CheckErrors :: NoSuchContract ( dep. to_string ( ) ) . into ( ) ) ;
226+ return Err ( CheckErrors :: NoSuchContract ( dep. contract_id . to_string ( ) ) . into ( ) ) ;
141227 }
142228 } ;
143229 graph. add_directed_edge ( * contract_id, * dep_id) ;
@@ -172,10 +258,10 @@ impl<'a> ASTDependencyDetector<'a> {
172258 return ;
173259 }
174260 if let Some ( set) = self . dependencies . get_mut ( from) {
175- set. insert ( to. clone ( ) ) ;
261+ set. add_dependency ( to. clone ( ) , self . top_level ) ;
176262 } else {
177- let mut set = HashSet :: new ( ) ;
178- set. insert ( to. clone ( ) ) ;
263+ let mut set = DependencySet :: new ( ) ;
264+ set. add_dependency ( to. clone ( ) , self . top_level ) ;
179265 self . dependencies . insert ( from. clone ( ) , set) ;
180266 }
181267 }
@@ -315,9 +401,11 @@ impl<'a> ASTVisitor<'a> for ASTDependencyDetector<'a> {
315401 body : & ' a SymbolicExpression ,
316402 ) -> bool {
317403 self . params = parameters. clone ( ) ;
404+ self . top_level = false ;
318405 let res =
319406 self . traverse_expr ( body) && self . visit_define_private ( expr, name, parameters, body) ;
320407 self . params = None ;
408+ self . top_level = true ;
321409 res
322410 }
323411
@@ -350,9 +438,11 @@ impl<'a> ASTVisitor<'a> for ASTDependencyDetector<'a> {
350438 body : & ' a SymbolicExpression ,
351439 ) -> bool {
352440 self . params = parameters. clone ( ) ;
441+ self . top_level = false ;
353442 let res =
354443 self . traverse_expr ( body) && self . visit_define_read_only ( expr, name, parameters, body) ;
355444 self . params = None ;
445+ self . top_level = true ;
356446 res
357447 }
358448
@@ -385,9 +475,11 @@ impl<'a> ASTVisitor<'a> for ASTDependencyDetector<'a> {
385475 body : & ' a SymbolicExpression ,
386476 ) -> bool {
387477 self . params = parameters. clone ( ) ;
478+ self . top_level = false ;
388479 let res =
389480 self . traverse_expr ( body) && self . visit_define_public ( expr, name, parameters, body) ;
390481 self . params = None ;
482+ self . top_level = true ;
391483 res
392484 }
393485
@@ -769,7 +861,7 @@ mod tests {
769861 let dependencies =
770862 ASTDependencyDetector :: detect_dependencies ( & contracts, & BTreeMap :: new ( ) ) . unwrap ( ) ;
771863 assert_eq ! ( dependencies[ & test_identifier] . len( ) , 1 ) ;
772- assert ! ( dependencies[ & test_identifier] . contains ( & foo) ) ;
864+ assert ! ( ! dependencies[ & test_identifier] . has_dependency ( & foo) . unwrap ( ) ) ;
773865 }
774866
775867 // This test is disabled because it is currently not possible to refer to a
@@ -815,7 +907,7 @@ mod tests {
815907 let dependencies =
816908 ASTDependencyDetector :: detect_dependencies ( & contracts, & BTreeMap :: new ( ) ) . unwrap ( ) ;
817909 assert_eq ! ( dependencies[ & test_identifier] . len( ) , 1 ) ;
818- assert ! ( dependencies[ & test_identifier] . contains ( & bar) ) ;
910+ assert ! ( ! dependencies[ & test_identifier] . has_dependency ( & bar) . unwrap ( ) ) ;
819911 }
820912
821913 #[ test]
@@ -859,7 +951,8 @@ mod tests {
859951 let dependencies =
860952 ASTDependencyDetector :: detect_dependencies ( & contracts, & BTreeMap :: new ( ) ) . unwrap ( ) ;
861953 assert_eq ! ( dependencies[ & test_identifier] . len( ) , 1 ) ;
862- assert ! ( dependencies[ & test_identifier] . contains( & bar) ) ;
954+ println ! ( "{:?}" , dependencies[ & test_identifier] ) ;
955+ assert ! ( dependencies[ & test_identifier] . has_dependency( & bar) . unwrap( ) ) ;
863956 }
864957
865958 #[ test]
@@ -913,8 +1006,10 @@ mod tests {
9131006 let dependencies =
9141007 ASTDependencyDetector :: detect_dependencies ( & contracts, & BTreeMap :: new ( ) ) . unwrap ( ) ;
9151008 assert_eq ! ( dependencies[ & test_identifier] . len( ) , 2 ) ;
916- assert ! ( dependencies[ & test_identifier] . contains( & bar) ) ;
917- assert ! ( dependencies[ & test_identifier] . contains( & my_trait) ) ;
1009+ assert ! ( !dependencies[ & test_identifier] . has_dependency( & bar) . unwrap( ) ) ;
1010+ assert ! ( dependencies[ & test_identifier]
1011+ . has_dependency( & my_trait)
1012+ . unwrap( ) ) ;
9181013 }
9191014
9201015 #[ test]
@@ -952,7 +1047,9 @@ mod tests {
9521047 let dependencies =
9531048 ASTDependencyDetector :: detect_dependencies ( & contracts, & BTreeMap :: new ( ) ) . unwrap ( ) ;
9541049 assert_eq ! ( dependencies[ & test_identifier] . len( ) , 1 ) ;
955- assert ! ( dependencies[ & test_identifier] . contains( & other) ) ;
1050+ assert ! ( dependencies[ & test_identifier]
1051+ . has_dependency( & other)
1052+ . unwrap( ) ) ;
9561053 }
9571054
9581055 #[ test]
@@ -990,7 +1087,9 @@ mod tests {
9901087 let dependencies =
9911088 ASTDependencyDetector :: detect_dependencies ( & contracts, & BTreeMap :: new ( ) ) . unwrap ( ) ;
9921089 assert_eq ! ( dependencies[ & test_identifier] . len( ) , 1 ) ;
993- assert ! ( dependencies[ & test_identifier] . contains( & other) ) ;
1090+ assert ! ( dependencies[ & test_identifier]
1091+ . has_dependency( & other)
1092+ . unwrap( ) ) ;
9941093 }
9951094
9961095 #[ test]
@@ -1042,4 +1141,36 @@ mod tests {
10421141 Err ( ( _, unresolved) ) => assert_eq ! ( unresolved[ 0 ] . name. as_str( ) , "bar" ) ,
10431142 }
10441143 }
1144+
1145+ #[ test]
1146+ fn contract_call_top_level ( ) {
1147+ let mut session = Session :: new ( SessionSettings :: default ( ) ) ;
1148+ let mut contracts = HashMap :: new ( ) ;
1149+ let snippet1 = "
1150+ (define-public (hello (a int))
1151+ (ok u0)
1152+ )"
1153+ . to_string ( ) ;
1154+ let foo = match session. build_ast ( & snippet1, Some ( "foo" ) ) {
1155+ Ok ( ( contract_identifier, ast, _) ) => {
1156+ contracts. insert ( contract_identifier. clone ( ) , ast) ;
1157+ contract_identifier
1158+ }
1159+ Err ( _) => panic ! ( "expected success" ) ,
1160+ } ;
1161+
1162+ let snippet = "(contract-call? .foo hello 4)" . to_string ( ) ;
1163+ let test_identifier = match session. build_ast ( & snippet, Some ( "test" ) ) {
1164+ Ok ( ( contract_identifier, ast, _) ) => {
1165+ contracts. insert ( contract_identifier. clone ( ) , ast) ;
1166+ contract_identifier
1167+ }
1168+ Err ( _) => panic ! ( "expected success" ) ,
1169+ } ;
1170+
1171+ let dependencies =
1172+ ASTDependencyDetector :: detect_dependencies ( & contracts, & BTreeMap :: new ( ) ) . unwrap ( ) ;
1173+ assert_eq ! ( dependencies[ & test_identifier] . len( ) , 1 ) ;
1174+ assert ! ( dependencies[ & test_identifier] . has_dependency( & foo) . unwrap( ) ) ;
1175+ }
10451176}
0 commit comments