fix _convert_simple_rnn#15723
Conversation
|
Hello @haoyang9804! I've looked through your code, but I'm not so familiar with it. It is good that you add test and it works successfully, but I should warn that numpy does not work with bfloat16, but TVM and other frameworks do. |
Thanks for the reply. But I think my patch is not related to bfloat16 |
| weightList0 = weightList[0].transpose([1, 0]) | ||
| assert len(in_data.type_annotation.shape) == 3 | ||
| for i in range(in_data.type_annotation.shape[1].value - 1): | ||
| weightList0 = np.hstack((weightList0, weightList[0].transpose([1, 0]))) |
There was a problem hiding this comment.
Could you plz elaborate? I still cannot see any data type issue here.
There was a problem hiding this comment.
Could you check your code for bfloat16 weights? numpy.hstack has dtype arg and I guess it possibly checks it if so numpy fails when dtype is bfloat16
There was a problem hiding this comment.
Oh I see. I will check it later. Thx
There was a problem hiding this comment.
It's true that numpy.hstack does not support bfloat16. But weightList[0] and weightList0 can never be bfloat16 to my best understanding. These two vars are from weightList = keras_layer.get_weights(), and they are NumPy arrays. If numpy does not support bfloat16, I think the dtype of weightList[0] should never be bfloat16. So this worry seems unnecessary here.
vvchernov
left a comment
There was a problem hiding this comment.
LGTM. Please fix lint and other CI issues
|
cc @echuraev |
|
Seems that everything is good except for gpu/docs ci tests. May I ask how to fix it? I never met this issue before @vvchernov. |
|
Hello @echuraev! Looks like it is problem from jenkins: Do you know who can help us? |
|
Hello @haoyang9804! I think it is not your issue. I've rechecked PRs: #15714 and #15709 have the same issue, but the next PRs do not have. Possibly it is the best way to restart this CI. |
|
@tvm-bot rerun |
|
@haoyang9804, I have restarted CI. In case of further errors, please try to rebase your branch to the latest mainline. |
Fix this issue
Just as @echuraev guessed,
_convert_simple_rnnhas some logical errors. I'm not very sure if I fix it correctly. All in all, after this fix, running the following bug-triggered script will feedback a good compilation result, andInferType()can successfully infer all types/shapes in the model and the inference is correct.The compilation result is