| // RUN: transform-opt-ch4 %s --transform-interpreter --verify-diagnostics |
| // |
| // RUN: transform-opt-ch4 %s \ |
| // RUN: --transform-interpreter='entry-point=__transform_main_v2' \ |
| // RUN: --verify-diagnostics |
| |
| // ****************************** IMPORTANT NOTE ****************************** |
| // |
| // If you are changing this file, you may also need to change |
| // mlir/docs/Tutorials/Transform accordingly. |
| // |
| // **************************************************************************** |
| |
| // Original function to optimize. |
| func.func @fc_relu(%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>, |
| %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>) |
| -> tensor<512x512xf32> { |
| // Matrix-matrix multiplication. |
| // expected-remark @below {{matmul}} |
| %matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>) |
| outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32> |
| |
| // Elementwise addition. |
| // expected-remark @below {{elementwise binary}} |
| %biased = linalg.elemwise_binary { fun = #linalg.binary_fn<add> } |
| ins(%matmul, %bias : tensor<512x512xf32>, tensor<512x512xf32>) |
| outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> |
| |
| // Elementwise max with 0 (ReLU). |
| %c0f = arith.constant 0.0 : f32 |
| // expected-remark @below {{elementwise binary}} |
| %relued = linalg.elemwise_binary { fun = #linalg.binary_fn<max_signed> } |
| ins(%biased, %c0f : tensor<512x512xf32>, f32) |
| outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> |
| func.return %relued : tensor<512x512xf32> |
| } |
| |
| // The module containing named sequences must have an attribute allowing them |
| // to enable verification. |
| module @transforms attributes { transform.with_named_sequence } { |
| // Entry point. This takes as the only argument the root operation (typically |
| // pass root) given to the transform interpreter. |
| transform.named_sequence @__transform_main( |
| %root: !transform.any_op {transform.readonly}) { |
| // Collect operations that match the criteria specified in the named |
| // sequence. If the named sequence fails with a silenceable failure, |
| // silences it (the message is forwarded to the debug stream). If the named |
| // sequence succeeds, appends its results to the results of this operation. |
| %elemwise = transform.collect_matching @match_elemwise in %root |
| : (!transform.any_op) -> !transform.any_op |
| %matmul = transform.collect_matching @match_matmul in %root |
| : (!transform.any_op) -> !transform.any_op |
| |
| transform.include @print_elemwise failures(propagate) (%elemwise) |
| : (!transform.any_op) -> () |
| transform.include @print_matmul failures(propagate) (%matmul) |
| : (!transform.any_op) -> () |
| |
| transform.yield |
| } |
| |
| // Alternative entry point. |
| transform.named_sequence @__transform_main_v2( |
| %root: !transform.any_op {transform.readonly}) { |
| // Collect groups of operations that match the criteria specified in the |
| // named sequence. |
| %matmul, %el1, %el2 = transform.collect_matching @match_matmul_elemwise in %root |
| : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) |
| %elemwise = transform.merge_handles %el1, %el2 : !transform.any_op |
| |
| transform.include @print_elemwise failures(propagate) (%elemwise) |
| : (!transform.any_op) -> () |
| transform.include @print_matmul failures(propagate) (%matmul) |
| : (!transform.any_op) -> () |
| |
| transform.yield |
| } |
| |
| // This is a matcher sequence. It is given an operation to match and the |
| // match is considered successful unless any nested operation produces a |
| // failure. The values yielded by this operation will be forwarded to the |
| // rewriter sequence on success. |
| transform.named_sequence @match_elemwise( |
| %entry: !transform.any_op {transform.readonly}) -> !transform.any_op { |
| transform.match.operation_name %entry ["linalg.elemwise_binary"] |
| : !transform.any_op |
| transform.yield %entry : !transform.any_op |
| } |
| transform.named_sequence @match_matmul( |
| %entry: !transform.any_op {transform.readonly}) -> !transform.any_op { |
| transform.match.operation_name %entry ["linalg.matmul"] : !transform.any_op |
| transform.yield %entry : !transform.any_op |
| } |
| |
| // This is an action sequence. |
| transform.named_sequence @print_elemwise( |
| %elemwise_binary: !transform.any_op {transform.readonly}) { |
| transform.debug.emit_remark_at |
| %elemwise_binary, "elementwise binary" : !transform.any_op |
| transform.yield |
| } |
| transform.named_sequence @print_matmul( |
| %matmul: !transform.any_op {transform.readonly}) { |
| transform.debug.emit_remark_at %matmul, "matmul" : !transform.any_op |
| transform.yield |
| } |
| |
| // This is also a matcher sequence. It is similarly given an operation to |
| // match and nested operations must succeed in order for a match to be deemed |
| // successful. It starts matching from the last operation in the use-def chain |
| // and goes back because each operand (use) has exactly one definition. |
| transform.named_sequence @match_matmul_elemwise( |
| %last: !transform.any_op {transform.readonly}) |
| -> (!transform.any_op, !transform.any_op, !transform.any_op) { |
| // The last operation must be an elementwise binary. |
| transform.match.operation_name %last ["linalg.elemwise_binary"] |
| : !transform.any_op |
| // Its first operand must be defined by another operation, to which we |
| // will get a handle here. We are guaranteed that the first operand exists |
| // because we know the operation is binary, but even in absence of such a |
| // guarantee, this operation would have produced a silenceable failure when |
| // `%last` does not have enough operands. |
| %middle = transform.get_producer_of_operand %last[0] |
| : (!transform.any_op) -> !transform.any_op |
| // The defining operation must itself be an elementwise binary. |
| transform.match.operation_name %middle ["linalg.elemwise_binary"] |
| : !transform.any_op |
| // And the first operand of that operation must be defined by yet another |
| // operation. |
| %matmul = transform.get_producer_of_operand %middle[0] |
| : (!transform.any_op) -> !transform.any_op |
| // And that operation is a matmul. |
| transform.match.operation_name %matmul ["linalg.matmul"] : !transform.any_op |
| // We will yield the handles to the matmul and the two elementwise |
| // operations separately. |
| transform.yield %matmul, %middle, %last |
| : !transform.any_op, !transform.any_op, !transform.any_op |
| } |
| } |