33"""
44
55import numpy as np
6- import tensorflow as tf
6+ import torch
77
8- import tensorcircuit as tc
8+ import tyxonq as tq
99
10- ctype , rtype = tc .set_dtype ("complex64" )
11- K = tc .set_backend ("tensorflow " )
10+ ctype , rtype = tq .set_dtype ("complex64" )
11+ K = tq .set_backend ("pytorch " )
1212
1313n = 6
1414nlayers = 6
1515
1616
1717def ansatz (structureo , structuret , preprocess = "direct" ):
18- c = tc .Circuit (n )
18+ c = tq .Circuit (n )
1919 if preprocess == "softmax" :
2020 structureo = K .softmax (structureo , axis = - 1 )
2121 structuret = K .softmax (structuret , axis = - 1 )
@@ -28,30 +28,30 @@ def ansatz(structureo, structuret, preprocess="direct"):
2828 structureo = K .cast (structureo , ctype )
2929 structuret = K .cast (structuret , ctype )
3030
31- structureo = tf .reshape (structureo , shape = [nlayers , n , 7 ])
32- structuret = tf .reshape (structuret , shape = [nlayers , n , 3 ])
31+ structureo = torch .reshape (structureo , shape = [nlayers , n , 7 ])
32+ structuret = torch .reshape (structuret , shape = [nlayers , n , 3 ])
3333
3434 for i in range (n ):
35- c .H (i )
35+ c .h (i )
3636 for j in range (nlayers ):
3737 for i in range (n ):
3838 c .unitary (
3939 i ,
40- unitary = structureo [j , i , 0 ] * tc .gates .i ().tensor
41- + structureo [j , i , 1 ] * tc .gates .x ().tensor
42- + structureo [j , i , 2 ] * tc .gates .y ().tensor
43- + structureo [j , i , 3 ] * tc .gates .z ().tensor
44- + structureo [j , i , 4 ] * tc .gates .h ().tensor
45- + structureo [j , i , 5 ] * tc .gates .s ().tensor
46- + structureo [j , i , 6 ] * tc .gates .sd ().tensor ,
40+ unitary = structureo [j , i , 0 ] * tq .gates .i ().tensor
41+ + structureo [j , i , 1 ] * tq .gates .x ().tensor
42+ + structureo [j , i , 2 ] * tq .gates .y ().tensor
43+ + structureo [j , i , 3 ] * tq .gates .z ().tensor
44+ + structureo [j , i , 4 ] * tq .gates .h ().tensor
45+ + structureo [j , i , 5 ] * tq .gates .s ().tensor
46+ + structureo [j , i , 6 ] * tq .gates .sd ().tensor ,
4747 )
4848 for i in range (n - 1 ):
4949 c .unitary (
5050 i ,
5151 i + 1 ,
52- unitary = structuret [j , i , 0 ] * tc .gates .ii ().tensor
53- + structuret [j , i , 1 ] * tc .gates .cnot ().tensor
54- + structuret [j , i , 2 ] * tc .gates .cz ().tensor ,
52+ unitary = structuret [j , i , 0 ] * tq .gates .ii ().tensor
53+ + structuret [j , i , 1 ] * tq .gates .cnot ().tensor
54+ + structuret [j , i , 2 ] * tq .gates .cz ().tensor ,
5555 )
5656 # loss = K.real(
5757 # sum(
@@ -60,7 +60,7 @@ def ansatz(structureo, structuret, preprocess="direct"):
6060 # )
6161 # )
6262 s = c .state ()
63- loss = - K .real (tc .quantum .entropy (tc .quantum .reduced_density_matrix (s , cut = n // 2 )))
63+ loss = - K .real (tq .quantum .entropy (tq .quantum .reduced_density_matrix (s , cut = n // 2 )))
6464 return loss
6565
6666
@@ -77,6 +77,7 @@ def sampling_from_structure(structures, batch=1):
7777 return r .transpose ()
7878
7979
80+ # warning pytorch might be unable to do this exactly
8081@K .jit
8182def best_from_structure (structures ):
8283 return K .argmax (structures , axis = - 1 )
@@ -89,31 +90,30 @@ def nmf_gradient(structures, oh):
8990 choice = K .argmax (oh , axis = - 1 )
9091 prob = K .softmax (K .real (structures ), axis = - 1 )
9192 indices = K .transpose (
92- K .stack ([K .cast (tf . range (structures .shape [0 ]), "int64" ), choice ])
93+ K .stack ([K .cast (torch . arange (structures .shape [0 ]), "int64" ), choice ])
9394 )
94- prob = tf . gather_nd (prob , indices )
95+ prob = torch . gather (prob , 0 , indices . unsqueeze ( 0 )). squeeze ( 0 )
9596 prob = K .reshape (prob , [- 1 , 1 ])
9697 prob = K .tile (prob , [1 , structures .shape [- 1 ]])
9798
98- return K .real (
99- tf .tensor_scatter_nd_add (
100- tf .cast (- prob , dtype = ctype ),
101- indices ,
102- tf .ones ([structures .shape [0 ]], dtype = ctype ),
103- )
104- )
99+ # warning pytorch might be unable to do this exactly
100+ result = torch .zeros_like (structures , dtype = ctype )
101+ result .scatter_add_ (0 , indices , torch .ones ([structures .shape [0 ]], dtype = ctype ))
102+ return K .real (result - prob )
105103
106104
105+ # warning pytorch might be unable to do this exactly
107106nmf_gradient_vmap = K .jit (K .vmap (nmf_gradient , vectorized_argnums = 1 ))
107+ # warning pytorch might be unable to do this exactly
108108vf = K .jit (K .vmap (ansatz , vectorized_argnums = (0 , 1 )), static_argnums = 2 )
109109
110110
111111def main (stddev = 0.05 , lr = None , epochs = 2000 , debug_step = 50 , batch = 256 , verbose = False ):
112112 so = K .implicit_randn ([nlayers * n , 7 ], stddev = stddev )
113113 st = K .implicit_randn ([nlayers * n , 3 ], stddev = stddev )
114114 if lr is None :
115- lr = tf . keras . optimizers . schedules . ExponentialDecay ( 0.06 , 1000 , 0.5 )
116- structure_opt = tc . backend . optimizer ( tf . keras . optimizers . Adam (lr ) )
115+ lr = 0.06 # Simplified learning rate
116+ structure_opt = torch . optim . Adam ([ so , st ], lr = lr )
117117
118118 avcost = 0
119119 avcost2 = 0
@@ -135,23 +135,28 @@ def main(stddev=0.05, lr=None, epochs=2000, debug_step=50, batch=256, verbose=Fa
135135
136136 # go = [(vs[i] - avcost2) * go[i] for i in range(batch)]
137137 # gt = [(vs[i] - avcost2) * gt[i] for i in range(batch)]
138- # go = tf .math.reduce_mean(go, axis=0)
139- # gt = tf .math.reduce_mean(gt, axis=0)
138+ # go = torch .math.reduce_mean(go, axis=0)
139+ # gt = torch .math.reduce_mean(gt, axis=0)
140140 avcost2 = avcost
141141
142- [so , st ] = structure_opt .update ([go , gt ], [so , st ])
142+ # Update parameters using PyTorch optimizer
143+ structure_opt .zero_grad ()
144+ so .grad = go
145+ st .grad = gt
146+ structure_opt .step ()
147+
143148 # so -= K.reshape(K.mean(so, axis=-1), [-1, 1])
144149 # st -= K.reshape(K.mean(st, axis=-1), [-1, 1])
145150 if epoch % debug_step == 0 or epoch == epochs - 1 :
146151 print ("----------epoch %s-----------" % epoch )
147152 print (
148153 "batched average loss: " ,
149- np .mean (vs ),
154+ np .mean (vs . detach (). cpu (). numpy () ),
150155 "minimum candidate loss: " ,
151- np .min (vs ),
156+ np .min (vs . detach (). cpu (). numpy () ),
152157 )
153- minp1 = tf . math . reduce_min ( tf . math . reduce_max ( tf . math . softmax (st ), axis = - 1 ))
154- minp2 = tf . math . reduce_min ( tf . math . reduce_max ( tf . math . softmax (so ), axis = - 1 ))
158+ minp1 = torch . min ( torch . max ( torch . softmax (st , dim = - 1 ), dim = - 1 )[ 0 ] )
159+ minp2 = torch . min ( torch . max ( torch . softmax (so , dim = - 1 ), dim = - 1 )[ 0 ] )
155160 if minp1 > 0.3 and minp2 > 0.6 :
156161 print ("probability converged" )
157162 break
@@ -161,9 +166,9 @@ def main(stddev=0.05, lr=None, epochs=2000, debug_step=50, batch=256, verbose=Fa
161166 print (st )
162167 print (
163168 "strcuture parameter: \n " ,
164- so .numpy (),
169+ so .detach (). cpu (). numpy (),
165170 "\n " ,
166- st .numpy (),
171+ st .detach (). cpu (). numpy (),
167172 )
168173
169174 cand_preseto = best_from_structure (so )
0 commit comments