[Python 3] 부모 클래스를 상속받으면서 자식 클래스에 추가 변수 할당하기

노하람·2023년 1월 26일
0

프로젝트 코드를 작성하다 조금 헷갈리는 부분이 있어 확인하고 기록해둡니다.


개념

*args를 조심하자

*args를 사용할 때는 항상 부모 클래스의 __init__ 메서드에서 정의된 파라미터 개수와 같은 개수를 넘겨주어야 합니다.

예를 들어 부모 클래스의 __init__ 메서드에서는 a, b, c 3개의 파라미터를 받고 있다면, 자식 클래스에서도 a, b, c를 넘겨주어야 합니다.

class Parent:
    def __init__(self, a, b, c):
        self.a = a
        self.b = b
        self.c = c

class Child(Parent):
    def __init__(self, a, b, c, d, e, f): # 추가 변수 할당
        super().__init__(a, b, c)
        self.d = d
        self.e = e
        self.f = f

여기서 *args를 사용하면 이런 오류가 발생할 수 있습니다.

class Parent:
    def __init__(self, a, b, c):
        self.a = a
        self.b = b
        self.c = c

class Child(Parent):
    def __init__(self, *args, d, e, f):
        super().__init__(*args)
        self.d = d
        self.e = e
        self.f = f

TypeError: __init__() takes 3 positional arguments but 4 were given

메소드 상속을 위해선 super().init() 호출이 필요하다.

super()를 사용하여 부모 클래스의 init() 메소드를 호출하지 않은 경우, 자식 클래스에서는 부모 클래스의 init() 메소드를 호출하지 않아서 부모 클래스에서 정의한 초기화 코드를 실행하지 않습니다. 이 경우, 부모 클래스의 메소드는 상속되지 않습니다.

자식 클래스에서 *args를 사용하여 추가 변수를 할당하면서 부모 클래스의 init() 메소드를 호출하려면 부모 클래스에서 *args를 사용하여 정의된 init() 메소드를 사용해야 합니다. 그렇지 않으면 TypeError가 발생할 수 있습니다.

부모 클래스를 초기화 할 때 *args를 사용하지 않으려면 위나 아래처럼, 명시적으로 부모클래스의 인자를 상속받도록 합시다.

부모 클래스에 args를 활용해 초기화한 후, 자식 클래스에서도 args를 활용한 후 변수를 추가하려면 아래처럼 인덱스를 통해 변수를 할당하면 됩니다.

class ContainerOp:
    def __init__(self, *args):
        self.op_name = args[0]
        self.namespace = args[1]
        self.pipeline_start_date_str = args[2]
        self.data_pvc_name = "data-claim"
        self.forecasting_pvc_name = "forecasting-claim"

class LinearRegressionOp(ContainerOp):
    def __init__(self, *args, ml_objective_type=None, ml_objective_goal=None, 
                ml_objective_metric=None, ml_max_trial=None, ml_parallel_trial=None):
        super().__init__(*args)
        self.ml_objective_type = ml_objective_type if ml_objective_type is not None else "minimize"
        self.ml_objective_goal = ml_objective_goal if ml_objective_goal is not None else 0.01
        self.ml_objective_metric = ml_objective_metric if ml_objective_metric is not None else "mean_absolute_error"
        self.ml_max_trial = ml_max_trial if ml_max_trial is not None else 3
        self.ml_parallel_trial = ml_parallel_trial if ml_parallel_trial is not None else 3

이렇게 사용할 경우 자식 객체를 생성할 때, 인수가 위치 인수와 키워드 인수가 동시에 사용되고 있기 때문에 아래와 같이 3개는 위치 인수, 나머지는 키워드 인수로 정확히 할당해야 정상적으로 사용할 수 있습니다.

ml_operation_class_dict["xgbregressor"] = XgbRegressorOp(
        op_name, namespace, pipeline_start_date_str, 
        ml_objective_type=ml_objective_type, 
        ml_objective_goal=ml_objective_goal,
        ml_objective_metric=ml_objective_metric,
        ml_max_trial=ml_max_trial, 
        ml_parallel_trial=ml_parallel_trial
    )

실제 예시

오류가 났던 코드

# 부모클래스
class ContainerOp:
    def __init__(self, op_name, namespace, pipeline_start_date_str):
        self.op_name = op_name
        self.namespace = namespace
        self.pipeline_start_date_str = pipeline_start_date_str
        self.data_pvc_name = "data-claim"
        self.forecasting_pvc_name = "forecasting-claim"
        
# 자식클래스
class LinearRegressionOp(ContainerOp):
    def __init__(self, *args, ml_objective_type=None, ml_objective_goal=None, 
                ml_objective_metric=None, ml_max_trial=None, ml_parallel_trial=None):
        super(LinearRegressionOp, self).__init__(*args)
        self.ml_objective_type = ml_objective_type if ml_objective_type is not None else "minimize"
        self.ml_objective_goal = ml_objective_goal if ml_objective_goal is not None else 0.01
        self.ml_objective_metric = ml_objective_metric if ml_objective_metric is not None else "mean_absolute_error"
        self.ml_max_trial = ml_max_trial if ml_max_trial is not None else 3
        self.ml_parallel_trial = ml_parallel_trial if ml_parallel_trial is not None else 3

수정 후 코드

  • *args를 사용하지 않고 부모클래스에서 초기화한 변수를 그대로 상속받아 사용
# 부모클래스
class ContainerOp:
    def __init__(self, op_name, namespace, pipeline_start_date_str):
        self.op_name = op_name
        self.namespace = namespace
        self.pipeline_start_date_str = pipeline_start_date_str
        self.data_pvc_name = "data-claim"
        self.forecasting_pvc_name = "forecasting-claim"
        
# 자식클래스
class LinearRegressionOp(ContainerOp):
    def __init__(self, op_name, namespace, pipeline_start_date_str, ml_objective_type=None, ml_objective_goal=None, 
                ml_objective_metric=None, ml_max_trial=None, ml_parallel_trial=None):
        super().__init__(op_name, namespace, pipeline_start_date_str)
        self.ml_objective_type = ml_objective_type if ml_objective_type is not None else "minimize"
        self.ml_objective_goal = ml_objective_goal if ml_objective_goal is not None else 0.01
        self.ml_objective_metric = ml_objective_metric if ml_objective_metric is not None else "mean_absolute_error"
        self.ml_max_trial = ml_max_trial if ml_max_trial is not None else 3
        self.ml_parallel_trial = ml_parallel_trial if ml_parallel_trial is not None else 3
profile
MLOps, MLE 직무로 일하고 있습니다😍

0개의 댓글